use crate::{
config::{
app_config::AppConfigBuilder, auth_config::PUBLIC_KEY_DECODING_ERR_MSG,
otlp_config::TracingLevel,
},
python::server::{running_server::PyRunningGraphServer, wait_server, BridgeCommand},
GraphServer,
};
use pyo3::{
exceptions::{PyAttributeError, PyException, PyValueError},
prelude::*,
};
use raphtory::{
python::{
packages::vectors::{PyOpenAIEmbeddings, TemplateConfig},
utils::block_on,
},
vectors::template::{DocumentTemplate, DEFAULT_EDGE_TEMPLATE, DEFAULT_NODE_TEMPLATE},
};
use std::{path::PathBuf, thread};
#[pyclass(name = "GraphServer", module = "raphtory.graphql")]
pub struct PyGraphServer(GraphServer);
impl<'py> IntoPyObject<'py> for GraphServer {
type Target = PyGraphServer;
type Output = Bound<'py, Self::Target>;
type Error = <Self::Target as IntoPyObject<'py>>::Error;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
PyGraphServer(self).into_pyobject(py)
}
}
fn template_from_python(
nodes: TemplateConfig,
edges: TemplateConfig,
) -> PyResult<DocumentTemplate> {
if nodes.is_disabled() && edges.is_disabled() {
Err(PyAttributeError::new_err(
"at least one of nodes and edges has to be set to True or some string",
))
} else {
Ok(DocumentTemplate {
node_template: nodes.get_template_or(DEFAULT_NODE_TEMPLATE),
edge_template: edges.get_template_or(DEFAULT_EDGE_TEMPLATE),
})
}
}
#[pymethods]
impl PyGraphServer {
#[new]
#[pyo3(
signature = (work_dir, cache_capacity = None, cache_tti_seconds = None, log_level = None, tracing=None, tracing_level=None, otlp_agent_host=None, otlp_agent_port=None, otlp_tracing_service_name=None, auth_public_key=None, auth_enabled_for_reads=None, config_path = None, create_index = None)
)]
fn py_new(
work_dir: PathBuf,
cache_capacity: Option<u64>,
cache_tti_seconds: Option<u64>,
log_level: Option<String>,
tracing: Option<bool>,
tracing_level: Option<String>,
otlp_agent_host: Option<String>,
otlp_agent_port: Option<String>,
otlp_tracing_service_name: Option<String>,
auth_public_key: Option<String>,
auth_enabled_for_reads: Option<bool>,
config_path: Option<PathBuf>,
create_index: Option<bool>,
) -> PyResult<Self> {
let mut app_config_builder = AppConfigBuilder::new();
if let Some(log_level) = log_level {
app_config_builder = app_config_builder.with_log_level(log_level);
}
if let Some(tracing) = tracing {
app_config_builder = app_config_builder.with_tracing(tracing);
}
if let Some(tracing_level) = tracing_level {
let json = format!("\"{}\"", tracing_level).to_uppercase();
let tl: TracingLevel = serde_json::from_str(json.as_str()).map_err(|_| {
PyValueError::new_err(format!(
"Invalid tracing level. Allowed levels {} ",
TracingLevel::all_levels_string()
))
})?;
app_config_builder = app_config_builder.with_tracing_level(tl);
}
if let Some(otlp_agent_host) = otlp_agent_host {
app_config_builder = app_config_builder.with_otlp_agent_host(otlp_agent_host);
}
if let Some(otlp_agent_port) = otlp_agent_port {
app_config_builder = app_config_builder.with_otlp_agent_port(otlp_agent_port);
}
if let Some(otlp_tracing_service_name) = otlp_tracing_service_name {
app_config_builder =
app_config_builder.with_otlp_tracing_service_name(otlp_tracing_service_name);
}
if let Some(cache_capacity) = cache_capacity {
app_config_builder = app_config_builder.with_cache_capacity(cache_capacity);
}
if let Some(cache_tti_seconds) = cache_tti_seconds {
app_config_builder = app_config_builder.with_cache_tti_seconds(cache_tti_seconds);
}
app_config_builder = app_config_builder
.with_auth_public_key(auth_public_key)
.map_err(|_| PyValueError::new_err(PUBLIC_KEY_DECODING_ERR_MSG))?;
if let Some(auth_enabled_for_reads) = auth_enabled_for_reads {
app_config_builder =
app_config_builder.with_auth_enabled_for_reads(auth_enabled_for_reads);
}
#[cfg(feature = "search")]
if let Some(create_index) = create_index {
app_config_builder = app_config_builder.with_create_index(create_index);
}
let app_config = Some(app_config_builder.build());
let server = block_on(GraphServer::new(work_dir, app_config, config_path))?;
Ok(PyGraphServer(server))
}
fn turn_off_index(mut slf: PyRefMut<Self>) {
slf.0.turn_off_index()
}
#[pyo3(
signature = (name, embeddings, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true))
)]
fn vectorise_graph(
&self,
py: Python,
name: &str,
embeddings: PyOpenAIEmbeddings,
nodes: TemplateConfig,
edges: TemplateConfig,
) -> PyResult<()> {
let template = template_from_python(nodes, edges)?;
py.allow_threads(|| {
block_on(async move {
self.0
.vectorise_graph(name, &template, embeddings.into())
.await?;
Ok(())
})
})
}
#[pyo3(
signature = (embeddings, nodes = TemplateConfig::Bool(true), edges = TemplateConfig::Bool(true))
)]
fn vectorise_all_graphs(
&self,
py: Python,
embeddings: PyOpenAIEmbeddings,
nodes: TemplateConfig,
edges: TemplateConfig,
) -> PyResult<()> {
let template = template_from_python(nodes, edges)?;
py.allow_threads(|| {
block_on(async move {
self.0
.vectorise_all_graphs(&template, embeddings.into())
.await?;
Ok(())
})
})
}
#[pyo3(
signature = (port = 1736, timeout_ms = 5000)
)]
pub fn start(&self, py: Python, port: u16, timeout_ms: u64) -> PyResult<PyRunningGraphServer> {
let (sender, receiver) = crossbeam_channel::bounded::<BridgeCommand>(1);
let cloned_sender = sender.clone();
let server = self.0.clone();
let join_handle = thread::spawn(move || {
block_on(async move {
let handler = server.start_with_port(port);
let running_server = handler.await?;
let tokio_sender = running_server._get_sender().clone();
tokio::task::spawn_blocking(move || {
match receiver.recv().expect("Failed to wait for cancellation") {
BridgeCommand::StopServer => tokio_sender
.blocking_send(())
.expect("Failed to send cancellation signal"),
BridgeCommand::StopListening => (),
}
});
let result = running_server.wait().await;
_ = cloned_sender.send(BridgeCommand::StopListening);
result
})
});
let mut server = PyRunningGraphServer::new(join_handle, sender, port)?;
if let Some(_server_handler) = &server.server_handler {
let url = format!("http://localhost:{port}");
let result = server.wait_for_server_online(&url, timeout_ms);
match result {
Ok(_) => return Ok(server),
Err(e) => {
PyRunningGraphServer::stop_server(&mut server, py)?;
Err(e)
}
}
} else {
Err(PyException::new_err("Failed to start server"))
}
}
#[pyo3(
signature = (port = 1736, timeout_ms = 180000)
)]
pub fn run(&self, py: Python, port: u16, timeout_ms: u64) -> PyResult<()> {
let mut server = self.start(py, port, timeout_ms)?.server_handler;
py.allow_threads(|| wait_server(&mut server))
}
}