use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use crate::{config::ServerSubConfig, engine::ClawDB, error::ClawDBResult};
#[cfg(proto_compiled)]
pub mod proto {
tonic::include_proto!("clawdb.v1");
}
#[cfg(proto_compiled)]
mod handlers {
use std::pin::Pin;
use std::sync::Arc;
use tokio_stream::{wrappers::BroadcastStream, StreamExt};
use tonic::{Request, Response, Status};
use uuid::Uuid;
use super::proto::{
claw_db_service_server::ClawDbService, BranchRequest, BranchResponse, DiffRequest,
DiffResponse, EventMessage, HealthRequest, HealthResponse, MergeRequest, MergeResponse,
RecallRequest, RecallResponse, ReflectRequest, ReflectResponse, RememberRequest,
RememberResponse, SearchRequest, SearchResponse, SessionRequest, SessionResponse,
StatusRequest, StatusResponse, SyncRequest, SyncResponse,
};
use crate::{engine::ClawDB, error::ClawDBError};
pub struct ClawDBGrpcServer {
pub engine: Arc<ClawDB>,
}
fn extract_token<T>(req: &Request<T>, field_token: &str) -> Result<String, Status> {
if let Some(val) = req.metadata().get("authorization") {
let s = val
.to_str()
.map_err(|_| Status::unauthenticated("invalid authorization header"))?;
let token = s
.strip_prefix("Bearer ")
.ok_or_else(|| Status::unauthenticated("expected Bearer token"))?;
return Ok(token.to_string());
}
if !field_token.is_empty() {
return Ok(field_token.to_string());
}
Err(Status::unauthenticated("no session token provided"))
}
fn map_err(e: ClawDBError) -> Status {
e.into()
}
#[tonic::async_trait]
impl ClawDbService for ClawDBGrpcServer {
async fn remember(
&self,
req: Request<RememberRequest>,
) -> Result<Response<RememberResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let r = req.into_inner();
let metadata = if r.metadata_json.is_empty() {
serde_json::Value::Null
} else {
serde_json::from_slice(&r.metadata_json).unwrap_or(serde_json::Value::Null)
};
let result = self
.engine
.remember_typed(&sess, &r.content, &r.memory_type, &r.tags, metadata)
.await
.map_err(map_err)?;
Ok(Response::new(RememberResponse {
memory_id: result.memory_id,
importance_score: result.importance_score,
guard_applied: true,
}))
}
async fn search(
&self,
req: Request<SearchRequest>,
) -> Result<Response<SearchResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let r = req.into_inner();
let filter = if r.filter_json.is_empty() {
None
} else {
serde_json::from_slice(&r.filter_json).ok()
};
let start = std::time::Instant::now();
let results = self
.engine
.search_with_options(&sess, &r.query, r.top_k as usize, r.semantic, filter)
.await
.map_err(map_err)?;
let latency_ms = start.elapsed().as_secs_f32() * 1000.0;
let entries: Vec<_> = results
.into_iter()
.map(|v| super::proto::MemoryEntry {
id: v["id"].as_str().unwrap_or("").to_string(),
content: v["content"].as_str().unwrap_or("").to_string(),
..Default::default()
})
.collect();
Ok(Response::new(SearchResponse {
results: entries,
latency_ms,
search_type: if r.semantic { "semantic".into() } else { "keyword".into() },
}))
}
async fn recall(
&self,
req: Request<RecallRequest>,
) -> Result<Response<RecallResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let r = req.into_inner();
let mems = self.engine.recall(&sess, &r.memory_ids).await.map_err(map_err)?;
let entries: Vec<_> = mems
.into_iter()
.map(|v| super::proto::MemoryEntry {
id: v["id"].as_str().unwrap_or("").to_string(),
content: v["content"].as_str().unwrap_or("").to_string(),
..Default::default()
})
.collect();
Ok(Response::new(RecallResponse {
memories: entries,
denied_ids: vec![],
}))
}
async fn branch(
&self,
req: Request<BranchRequest>,
) -> Result<Response<BranchResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let r = req.into_inner();
let id = self.engine.branch(&sess, &r.new_branch_name).await.map_err(map_err)?;
Ok(Response::new(BranchResponse {
branch_id: id.to_string(),
branch_name: r.new_branch_name,
created_at: chrono::Utc::now().timestamp(),
}))
}
async fn merge(
&self,
req: Request<MergeRequest>,
) -> Result<Response<MergeResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let r = req.into_inner();
let source = Uuid::parse_str(&r.source_branch)
.map_err(|_| Status::invalid_argument("invalid source_branch UUID"))?;
let target = Uuid::parse_str(&r.target_branch)
.map_err(|_| Status::invalid_argument("invalid target_branch UUID"))?;
self.engine.merge(&sess, source, target).await.map_err(map_err)?;
Ok(Response::new(MergeResponse {
success: true,
applied: 1,
conflicts: 0,
conflict_ids: vec![],
}))
}
async fn diff(
&self,
req: Request<DiffRequest>,
) -> Result<Response<DiffResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let r = req.into_inner();
let a = Uuid::parse_str(&r.branch_a)
.map_err(|_| Status::invalid_argument("invalid branch_a UUID"))?;
let b = Uuid::parse_str(&r.branch_b)
.map_err(|_| Status::invalid_argument("invalid branch_b UUID"))?;
let diff = self.engine.diff(&sess, a, b).await.map_err(map_err)?;
let diff_json = serde_json::to_vec(&diff).unwrap_or_default();
Ok(Response::new(DiffResponse {
added: diff["added"].as_i64().unwrap_or(0) as i32,
removed: diff["removed"].as_i64().unwrap_or(0) as i32,
modified: diff["modified"].as_i64().unwrap_or(0) as i32,
divergence_score: 0.0,
diff_json,
}))
}
async fn sync(
&self,
req: Request<SyncRequest>,
) -> Result<Response<SyncResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let result = self.engine.sync(&sess).await.map_err(map_err)?;
Ok(Response::new(SyncResponse {
success: true,
pushed: result["pushed"].as_i64().unwrap_or(0) as i32,
pulled: result["pulled"].as_i64().unwrap_or(0) as i32,
conflicts: 0,
synced_at: chrono::Utc::now().timestamp(),
}))
}
async fn reflect(
&self,
req: Request<ReflectRequest>,
) -> Result<Response<ReflectResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
let sess = crate::session::manager::ClawDBSession::from_context(ctx);
let job_id = self.engine.reflect(&sess).await.map_err(map_err)?;
Ok(Response::new(ReflectResponse {
job_id,
processed: 0,
archived: 0,
promoted: 0,
}))
}
async fn create_session(
&self,
req: Request<SessionRequest>,
) -> Result<Response<SessionResponse>, Status> {
let r = req.into_inner();
let agent_id = Uuid::parse_str(&r.agent_id)
.map_err(|_| Status::invalid_argument("invalid agent_id UUID"))?;
let sess = self
.engine
.session_manager
.create(agent_id, &r.role, r.scopes.clone(), Some(r.task_type))
.await
.map_err(map_err)?;
Ok(Response::new(SessionResponse {
session_token: sess.guard_token.clone(),
expires_at: sess.expires_at.timestamp(),
granted_scopes: sess.scopes,
}))
}
async fn health(
&self,
_req: Request<HealthRequest>,
) -> Result<Response<HealthResponse>, Status> {
let report = self.engine.health().await.map_err(map_err)?;
let ok = matches!(
report.overall,
crate::lifecycle::health::HealthStatus::Healthy
);
let component_status: std::collections::HashMap<String, String> = report
.components
.into_iter()
.map(|(k, v)| (k, format!("{:?}", v.status)))
.collect();
Ok(Response::new(HealthResponse {
ok,
component_status,
version: env!("CARGO_PKG_VERSION").to_string(),
uptime_secs: self.engine.uptime_secs() as i64,
}))
}
async fn status(
&self,
req: Request<StatusRequest>,
) -> Result<Response<StatusResponse>, Status> {
let token = extract_token(&req, &req.get_ref().session_token)?;
let _ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
Ok(Response::new(StatusResponse::default()))
}
type StreamEventsStream =
Pin<Box<dyn tokio_stream::Stream<Item = Result<EventMessage, Status>> + Send>>;
async fn stream_events(
&self,
req: Request<SessionRequest>,
) -> Result<Response<Self::StreamEventsStream>, Status> {
let r = req.into_inner();
let agent_id_str = r.agent_id.clone();
let rx = self.engine.event_bus.subscribe();
let stream = BroadcastStream::new(rx).filter_map(move |res| {
let agent_id = agent_id_str.clone();
match res {
Ok(ev) => {
let msg = EventMessage {
event_type: ev.event_type().to_string(),
agent_id: agent_id.clone(),
payload_json: serde_json::to_vec(&*ev).unwrap_or_default(),
timestamp: chrono::Utc::now().timestamp(),
};
Some(Ok(msg))
}
Err(_) => None,
}
});
Ok(Response::new(Box::pin(stream)))
}
}
}
pub async fn serve(
engine: Arc<ClawDB>,
config: &ServerSubConfig,
shutdown: CancellationToken,
) -> ClawDBResult<()> {
#[cfg(proto_compiled)]
{
use proto::claw_db_service_server::ClawDbServiceServer;
use tonic::transport::Server;
let addr = format!("0.0.0.0:{}", config.grpc_port).parse().expect("grpc addr");
let svc = ClawDbServiceServer::new(handlers::ClawDBGrpcServer { engine });
tracing::info!(port = config.grpc_port, "gRPC server listening");
Server::builder()
.add_service(svc)
.serve_with_shutdown(addr, async move { shutdown.cancelled().await })
.await
.map_err(|e| crate::error::ClawDBError::ComponentFailed {
component: "grpc".to_string(),
reason: e.to_string(),
})?;
}
#[cfg(not(proto_compiled))]
{
tracing::warn!("gRPC server disabled: proto compilation was not available at build time");
shutdown.cancelled().await;
}
Ok(())
}