raphtory 0.17.0

raphtory, a temporal graph library
Documentation
use async_openai::types::{CreateEmbeddingResponse, Embedding, EmbeddingUsage};
use axum::{
    extract::{Json, State},
    http::StatusCode,
    routing::post,
    Router,
};
use serde::Deserialize;
use std::{
    net::{IpAddr, SocketAddr},
    panic::{catch_unwind, AssertUnwindSafe},
    sync::Arc,
};
use tokio::{signal, sync::mpsc, task::JoinHandle};

#[derive(Deserialize, Debug)]
struct EmbeddingRequest {
    input: Vec<String>,
}

async fn embeddings(
    State(function): State<Arc<dyn EmbeddingFunction + Send + Sync>>,
    Json(req): Json<EmbeddingRequest>,
) -> Result<Json<CreateEmbeddingResponse>, (StatusCode, String)> {
    let data = req
        .input
        .iter()
        .enumerate()
        .map(|(i, t)| {
            catch_unwind(AssertUnwindSafe(|| function.call(t)))
                .map(|embedding| Embedding {
                    index: i as u32,
                    object: "embedding".into(),
                    embedding,
                })
                .map_err(|_| {
                    (
                        StatusCode::INTERNAL_SERVER_ERROR,
                        "embedding function panicked".to_owned(),
                    )
                })
        })
        .collect::<Result<Vec<_>, _>>()?;
    Ok(Json(CreateEmbeddingResponse {
        object: "list".into(),
        data,
        model: "".to_owned(),
        usage: EmbeddingUsage {
            prompt_tokens: 0,
            total_tokens: 0,
        },
    }))
}

pub struct EmbeddingServer {
    execution: JoinHandle<()>,
    stop_signal: tokio::sync::mpsc::Sender<()>,
}

impl EmbeddingServer {
    pub async fn wait(self) {
        self.execution.await.unwrap();
    }

    pub async fn stop(&self) {
        if let Err(e) = self.stop_signal.send(()).await {
            eprintln!("Failed to send stop signal to embedding server: {}", e);
        }
    }
}

/// Runs the embedding server on the given host and port based on the provided function. Host is "0.0.0.0" by default.
pub async fn serve_custom_embedding(
    host: Option<&str>,
    port: u16,
    function: impl EmbeddingFunction,
) -> EmbeddingServer {
    let ip_addr: IpAddr = host
        .unwrap_or("0.0.0.0")
        .parse()
        .expect("invalid IP address");
    let state = Arc::new(function);
    let app = Router::new()
        .route("/embeddings", post(embeddings)) // TODO: this should be /v1/embeddings if we were to support multiple versions
        .with_state(state);
    // since the listener is created at this point, when this function returns the server is already available,
    // might just take some time to answer for the first time, but no requests should be rejected
    let socket_addr = SocketAddr::new(ip_addr, port);
    let listener = tokio::net::TcpListener::bind(socket_addr).await.unwrap();
    let (sender, mut receiver) = mpsc::channel(1);
    let execution = tokio::spawn(async {
        axum::serve(listener, app)
            .with_graceful_shutdown(async move {
                #[cfg(unix)]
                let terminate = async {
                    signal::unix::signal(signal::unix::SignalKind::terminate())
                        .expect("failed to install signal handler")
                        .recv()
                        .await;
                };
                #[cfg(not(unix))]
                let terminate = std::future::pending::<()>();

                tokio::select! {
                    _ = terminate => {},
                    _ = signal::ctrl_c() => {},
                    _ = receiver.recv() => {},
                }
            })
            .await
            .unwrap();
    });
    EmbeddingServer {
        execution,
        stop_signal: sender,
    }
}

pub trait EmbeddingFunction: Send + Sync + 'static {
    fn call(&self, text: &str) -> Vec<f32>;
}

impl<F: Fn(&str) -> Vec<f32> + Send + Sync + 'static> EmbeddingFunction for F {
    fn call(&self, text: &str) -> Vec<f32> {
        self(text)
    }
}