cerebro 0.1.7

Blazing-fast, storage-agnostic semantic memory engine for AI Agents — written in pure Rust
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use cerebro::prelude::*;

#[derive(Serialize, Deserialize, Debug)]
struct RpcRequest {
    jsonrpc: String,
    id: Option<Value>,
    method: String,
    params: Option<Value>,
}

#[derive(Serialize, Deserialize, Debug)]
struct RpcResponse {
    jsonrpc: String,
    id: Value,
    #[serde(skip_serializing_if = "Option::is_none")]
    result: Option<Value>,
    #[serde(skip_serializing_if = "Option::is_none")]
    error: Option<Value>,
}

async fn handle_request(engine: &MemoryEngine, req: RpcRequest) -> Option<RpcResponse> {
    let id = req.id.clone()?;

    match req.method.as_str() {
        "initialize" => {
            Some(RpcResponse {
                jsonrpc: "2.0".into(), id,
                result: Some(json!({
                    "protocolVersion": "2024-11-05",
                    "capabilities": { "tools": { "listChanged": false } },
                    "serverInfo": { "name": "cerebro-mcp", "version": "0.1.0" }
                })),
                error: None,
            })
        }
        "tools/list" => {
            Some(RpcResponse {
                jsonrpc: "2.0".into(), id,
                result: Some(json!({
                    "tools": [{
                        "name": "ingest_memory",
                        "description": "Ingests a document into the long term semantic memory.",
                        "inputSchema": {
                            "type": "object",
                            "properties": { "content": { "type": "string" } },
                            "required": ["content"]
                        }
                    }, {
                        "name": "search_memory",
                        "description": "Searches the semantic memory for relevant chunks.",
                        "inputSchema": {
                            "type": "object",
                            "properties": { "query": { "type": "string" }, "top_k": { "type": "number" } },
                            "required": ["query"]
                        }
                    }]
                })),
                error: None,
            })
        }
        "tools/call" => {
            let params = req.params.unwrap_or(json!({}));
            let name = params["name"].as_str().unwrap_or("");
            let args = params["arguments"].clone();

            match name {
                "ingest_memory" => {
                    let text = args["content"].as_str().unwrap_or("");
                    let doc = Document::new(text);
                    if let Err(e) = engine.ingest_document(doc).await {
                        return Some(RpcResponse { jsonrpc: "2.0".into(), id, result: None, error: Some(json!({"code": -32603, "message": e.to_string()})) });
                    }
                    Some(RpcResponse { jsonrpc: "2.0".into(), id, result: Some(json!({ "content": [{ "type": "text", "text": "Success" }] })), error: None })
                }
                "search_memory" => {
                    let query = args["query"].as_str().unwrap_or("");
                    let top_k = args["top_k"].as_u64().unwrap_or(5) as usize;
                    match engine.query(query, top_k).await {
                        Ok(results) => {
                            let mut txt = String::new();
                            for (node, score) in results { txt.push_str(&format!("Score {}: {}\n", score, node.chunk.text)); }
                            Some(RpcResponse { jsonrpc: "2.0".into(), id, result: Some(json!({ "content": [{ "type": "text", "text": txt }] })), error: None })
                        }
                        Err(e) => Some(RpcResponse { jsonrpc: "2.0".into(), id, result: None, error: Some(json!({"code": -32603, "message": e.to_string()})) })
                    }
                }
                _ => Some(RpcResponse { jsonrpc: "2.0".into(), id, result: None, error: Some(json!({"code": -32601, "message": "Method not found"})) })
            }
        }
        "notifications/initialized" => None,
        _ => Some(RpcResponse { jsonrpc: "2.0".into(), id, result: None, error: Some(json!({"code": -32601, "message": "Method not found"})) })
    }
}

#[tokio::main]
async fn main() {
    let chunker = Arc::new(RecursiveCharacterChunker::new(512, 50));
    
    let embedder: Arc<dyn Embedder> = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
        Arc::new(OpenAIEmbedder::new(api_key, "text-embedding-3-small"))
    } else {
        Arc::new(MockEmbedder::new(1536))
    };

    let store = Arc::new(MemoryVectorStore::new());
    let engine = MemoryEngine::new(chunker, embedder, store);

    let stdin = tokio::io::stdin();
    let mut stdout = tokio::io::stdout();
    let mut reader = BufReader::new(stdin);
    let mut line = String::new();

    loop {
        line.clear();
        match reader.read_line(&mut line).await {
            Ok(0) => break,
            Ok(_) => {
                if let Ok(req) = serde_json::from_str::<RpcRequest>(&line) {
                    if let Some(resp) = handle_request(&engine, req).await {
                        let mut out = serde_json::to_string(&resp).unwrap();
                        out.push('\n');
                        if stdout.write_all(out.as_bytes()).await.is_err() { break; }
                        let _ = stdout.flush().await;
                    }
                }
            }
            Err(_) => break,
        }
    }
}