use crate::{
model::algorithms::{
algorithm_entry_point::AlgorithmEntryPoint, document::GqlDocument,
global_plugins::GlobalPlugins, vector_algorithms::VectorAlgorithms,
},
python::{
adapt_graphql_value,
global_plugins::PyGlobalPlugins,
server::{
running_server::PyRunningGraphServer, take_server_ownership, wait_server, BridgeCommand,
},
},
server_config::AppConfigBuilder,
GraphServer,
};
use async_graphql::dynamic::{Field, FieldFuture, FieldValue, InputValue, Object, TypeRef};
use dynamic_graphql::internal::{Registry, TypeName};
use itertools::intersperse;
use pyo3::{
exceptions::{PyAttributeError, PyException},
pyclass, pymethods,
types::{IntoPyDict, PyFunction, PyList},
IntoPy, Py, PyObject, PyRefMut, PyResult, Python,
};
use raphtory::{
python::{
packages::vectors::PyDocumentTemplate, types::wrappers::document::PyDocument,
utils::execute_async_task,
},
vectors::{embeddings::openai_embedding, EmbeddingFunction},
};
use std::{collections::HashMap, path::PathBuf, thread};
#[pyclass(name = "GraphServer")]
pub struct PyGraphServer(pub(crate) Option<GraphServer>);
impl PyGraphServer {
fn new(server: GraphServer) -> Self {
Self(Some(server))
}
fn with_vectorised_generic_embedding<F: EmbeddingFunction + Clone + 'static>(
slf: PyRefMut<Self>,
graph_names: Option<Vec<String>>,
embedding: F,
cache: String,
graph_document: Option<String>,
node_document: Option<String>,
edge_document: Option<String>,
) -> PyResult<Self> {
let template = PyDocumentTemplate::new(graph_document, node_document, edge_document);
let server = take_server_ownership(slf)?;
execute_async_task(move || async move {
let new_server = server
.with_vectorised(
graph_names,
embedding,
&PathBuf::from(cache),
Some(template),
)
.await?;
Ok(Self::new(new_server))
})
}
fn with_generic_document_search_function<
'a,
E: AlgorithmEntryPoint<'a> + 'static,
F: Fn(&E, Python) -> PyObject + Send + Sync + 'static,
>(
slf: PyRefMut<Self>,
name: String,
input: HashMap<String, String>,
function: &PyFunction,
adapter: F,
) -> PyResult<Self> {
let function: Py<PyFunction> = function.into();
let input_mapper = HashMap::from([
("str", TypeRef::named_nn(TypeRef::STRING)),
("int", TypeRef::named_nn(TypeRef::INT)),
("float", TypeRef::named_nn(TypeRef::FLOAT)),
]);
let input_values = input
.into_iter()
.map(|(name, type_name)| {
let type_ref = input_mapper.get(&type_name.as_str()).cloned();
type_ref
.map(|type_ref| InputValue::new(name, type_ref))
.ok_or_else(|| {
let valid_types = input_mapper.keys().map(|key| key.to_owned());
let valid_types_string: String = intersperse(valid_types, ", ").collect();
let msg = format!("types in input have to be one of: {valid_types_string}");
PyAttributeError::new_err(msg)
})
})
.collect::<PyResult<Vec<InputValue>>>()?;
let register_function = |name: &str, registry: Registry, parent: Object| {
let registry = registry.register::<GqlDocument>();
let output_type = TypeRef::named_nn_list_nn(GqlDocument::get_type_name());
let mut field = Field::new(name, output_type, move |ctx| {
let documents = Python::with_gil(|py| {
let entry_point = adapter(ctx.parent_value.downcast_ref().unwrap(), py);
let kw_args: HashMap<&str, PyObject> = ctx
.args
.iter()
.map(|(name, value)| (name.as_str(), adapt_graphql_value(&value, py)))
.collect();
let py_kw_args = kw_args.into_py_dict(py);
let result = function.call(py, (entry_point,), Some(py_kw_args)).unwrap();
let list = result.downcast::<PyList>(py).unwrap();
let py_documents = list.iter().map(|doc| doc.extract::<PyDocument>().unwrap());
py_documents
.map(|doc| doc.extract_rust_document(py).unwrap())
.collect::<Vec<_>>()
});
let gql_documents = documents
.into_iter()
.map(|doc| FieldValue::owned_any(GqlDocument::from(doc)));
FieldFuture::Value(Some(FieldValue::list(gql_documents)))
});
for input_value in input_values {
field = field.argument(input_value);
}
let parent = parent.field(field);
(registry, parent)
};
E::lock_plugins().insert(name, Box::new(register_function));
let new_server = take_server_ownership(slf)?;
Ok(Self::new(new_server))
}
}
#[pymethods]
impl PyGraphServer {
#[new]
#[pyo3(
signature = (work_dir, cache_capacity = None, cache_tti_seconds = None, log_level = None, config_path = None)
)]
fn py_new(
work_dir: PathBuf,
cache_capacity: Option<u64>,
cache_tti_seconds: Option<u64>,
log_level: Option<String>,
config_path: Option<PathBuf>,
) -> PyResult<Self> {
let mut app_config_builder = AppConfigBuilder::new();
if let Some(log_level) = log_level {
app_config_builder = app_config_builder.with_log_level(log_level);
}
if let Some(cache_capacity) = cache_capacity {
app_config_builder = app_config_builder.with_cache_capacity(cache_capacity);
}
if let Some(cache_tti_seconds) = cache_tti_seconds {
app_config_builder = app_config_builder.with_cache_tti_seconds(cache_tti_seconds);
}
let app_config = Some(app_config_builder.build());
let server = GraphServer::new(work_dir, app_config, config_path)?;
Ok(PyGraphServer::new(server))
}
fn with_vectorised(
slf: PyRefMut<Self>,
cache: String,
graph_names: Option<Vec<String>>,
embedding: Option<&PyFunction>,
graph_document: Option<String>,
node_document: Option<String>,
edge_document: Option<String>,
) -> PyResult<Self> {
match embedding {
Some(embedding) => {
let embedding: Py<PyFunction> = embedding.into();
Self::with_vectorised_generic_embedding(
slf,
graph_names,
embedding,
cache,
graph_document,
node_document,
edge_document,
)
}
None => Self::with_vectorised_generic_embedding(
slf,
graph_names,
openai_embedding,
cache,
graph_document,
node_document,
edge_document,
),
}
}
pub fn with_document_search_function(
slf: PyRefMut<Self>,
name: String,
input: HashMap<String, String>,
function: &PyFunction,
) -> PyResult<Self> {
let adapter =
|entry_point: &VectorAlgorithms, py: Python| entry_point.graph.clone().into_py(py);
PyGraphServer::with_generic_document_search_function(slf, name, input, function, adapter)
}
pub fn with_global_search_function(
slf: PyRefMut<Self>,
name: String,
input: HashMap<String, String>,
function: &PyFunction,
) -> PyResult<Self> {
let adapter = |entry_point: &GlobalPlugins, py: Python| {
PyGlobalPlugins(entry_point.clone()).into_py(py)
};
PyGraphServer::with_generic_document_search_function(slf, name, input, function, adapter)
}
#[pyo3(
signature = (port = 1736, timeout_ms = None)
)]
pub fn start(
slf: PyRefMut<Self>,
py: Python,
port: u16,
timeout_ms: Option<u64>,
) -> PyResult<PyRunningGraphServer> {
let (sender, receiver) = crossbeam_channel::bounded::<BridgeCommand>(1);
let server = take_server_ownership(slf)?;
let cloned_sender = sender.clone();
let join_handle = thread::spawn(move || {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async move {
let handler = server.start_with_port(port);
let running_server = handler.await?;
let tokio_sender = running_server._get_sender().clone();
tokio::task::spawn_blocking(move || {
match receiver.recv().expect("Failed to wait for cancellation") {
BridgeCommand::StopServer => tokio_sender
.blocking_send(())
.expect("Failed to send cancellation signal"),
BridgeCommand::StopListening => (),
}
});
let result = running_server.wait().await;
_ = cloned_sender.send(BridgeCommand::StopListening);
result
})
});
let mut server = PyRunningGraphServer::new(join_handle, sender, port)?;
if let Some(server_handler) = &server.server_handler {
match PyRunningGraphServer::wait_for_server_online(
&server_handler.client.url,
timeout_ms,
) {
Ok(_) => return Ok(server),
Err(e) => {
PyRunningGraphServer::stop_server(&mut server, py)?;
Err(e)
}
}
} else {
Err(PyException::new_err("Failed to start server"))
}
}
#[pyo3(
signature = (port = 1736, timeout_ms = Some(180000))
)]
pub fn run(
slf: PyRefMut<Self>,
py: Python,
port: u16,
timeout_ms: Option<u64>,
) -> PyResult<()> {
let mut server = Self::start(slf, py, port, timeout_ms)?.server_handler;
py.allow_threads(|| wait_server(&mut server))
}
}