clawdb 0.1.2

The cognitive database for AI agents — unified memory, semantic retrieval, branching, sync, and governance.
Documentation
//! Production gRPC server: tonic implementation of `ClawDBService`.
//!
//! # Proto-compilation gate
//! All handler code is guarded behind `#[cfg(proto_compiled)]` so the crate
//! still compiles on machines where `protoc` is absent.  The `build.rs` sets
//! the `proto_compiled` cfg flag when proto compilation succeeds.
//!
//! # Auth
//! Every RPC extracts a bearer token from the `authorization` metadata header
//! (format: `Bearer <token>`) **or** from the `session_token` field on the
//! request message.  The token is validated via `SessionManager::validate`.

use std::sync::Arc;

use tokio_util::sync::CancellationToken;

use crate::{config::ServerSubConfig, engine::ClawDB, error::ClawDBResult};

// ── Proto types (generated by build.rs when protoc is available) ──────────────

#[cfg(proto_compiled)]
pub mod proto {
    tonic::include_proto!("clawdb.v1");
}

// ── gRPC service implementation ───────────────────────────────────────────────

#[cfg(proto_compiled)]
mod handlers {
    use std::pin::Pin;
    use std::sync::Arc;

    use tokio_stream::{wrappers::BroadcastStream, StreamExt};
    use tonic::{Request, Response, Status};
    use uuid::Uuid;

    use super::proto::{
        claw_db_service_server::ClawDbService, BranchRequest, BranchResponse, DiffRequest,
        DiffResponse, EventMessage, HealthRequest, HealthResponse, MergeRequest, MergeResponse,
        RecallRequest, RecallResponse, ReflectRequest, ReflectResponse, RememberRequest,
        RememberResponse, SearchRequest, SearchResponse, SessionRequest, SessionResponse,
        StatusRequest, StatusResponse, SyncRequest, SyncResponse,
    };
    use crate::{engine::ClawDB, error::ClawDBError};

    pub struct ClawDBGrpcServer {
        pub engine: Arc<ClawDB>,
    }

    // ── Token extraction helpers ──────────────────────────────────────────────

    fn extract_token<T>(req: &Request<T>, field_token: &str) -> Result<String, Status> {
        if let Some(val) = req.metadata().get("authorization") {
            let s = val
                .to_str()
                .map_err(|_| Status::unauthenticated("invalid authorization header"))?;
            let token = s
                .strip_prefix("Bearer ")
                .ok_or_else(|| Status::unauthenticated("expected Bearer token"))?;
            return Ok(token.to_string());
        }
        if !field_token.is_empty() {
            return Ok(field_token.to_string());
        }
        Err(Status::unauthenticated("no session token provided"))
    }

    // Use the existing From<ClawDBError> for tonic::Status impl.
    fn map_err(e: ClawDBError) -> Status {
        e.into()
    }

    // ── Service implementation ────────────────────────────────────────────────

    #[tonic::async_trait]
    impl ClawDbService for ClawDBGrpcServer {
        async fn remember(
            &self,
            req: Request<RememberRequest>,
        ) -> Result<Response<RememberResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let r = req.into_inner();
            let metadata = if r.metadata_json.is_empty() {
                serde_json::Value::Null
            } else {
                serde_json::from_slice(&r.metadata_json).unwrap_or(serde_json::Value::Null)
            };
            let result = self
                .engine
                .remember_typed(&sess, &r.content, &r.memory_type, &r.tags, metadata)
                .await
                .map_err(map_err)?;
            Ok(Response::new(RememberResponse {
                memory_id: result.memory_id,
                importance_score: result.importance_score,
                guard_applied: true,
            }))
        }

        async fn search(
            &self,
            req: Request<SearchRequest>,
        ) -> Result<Response<SearchResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let r = req.into_inner();
            let filter = if r.filter_json.is_empty() {
                None
            } else {
                serde_json::from_slice(&r.filter_json).ok()
            };
            let start = std::time::Instant::now();
            let results = self
                .engine
                .search_with_options(&sess, &r.query, r.top_k as usize, r.semantic, filter)
                .await
                .map_err(map_err)?;
            let latency_ms = start.elapsed().as_secs_f32() * 1000.0;
            let entries: Vec<_> = results
                .into_iter()
                .map(|v| super::proto::MemoryEntry {
                    id: v["id"].as_str().unwrap_or("").to_string(),
                    content: v["content"].as_str().unwrap_or("").to_string(),
                    ..Default::default()
                })
                .collect();
            Ok(Response::new(SearchResponse {
                results: entries,
                latency_ms,
                search_type: if r.semantic { "semantic".into() } else { "keyword".into() },
            }))
        }

        async fn recall(
            &self,
            req: Request<RecallRequest>,
        ) -> Result<Response<RecallResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let r = req.into_inner();
            let mems = self.engine.recall(&sess, &r.memory_ids).await.map_err(map_err)?;
            let entries: Vec<_> = mems
                .into_iter()
                .map(|v| super::proto::MemoryEntry {
                    id: v["id"].as_str().unwrap_or("").to_string(),
                    content: v["content"].as_str().unwrap_or("").to_string(),
                    ..Default::default()
                })
                .collect();
            Ok(Response::new(RecallResponse {
                memories: entries,
                denied_ids: vec![],
            }))
        }

        async fn branch(
            &self,
            req: Request<BranchRequest>,
        ) -> Result<Response<BranchResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let r = req.into_inner();
            let id = self.engine.branch(&sess, &r.new_branch_name).await.map_err(map_err)?;
            Ok(Response::new(BranchResponse {
                branch_id: id.to_string(),
                branch_name: r.new_branch_name,
                created_at: chrono::Utc::now().timestamp(),
            }))
        }

        async fn merge(
            &self,
            req: Request<MergeRequest>,
        ) -> Result<Response<MergeResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let r = req.into_inner();
            let source = Uuid::parse_str(&r.source_branch)
                .map_err(|_| Status::invalid_argument("invalid source_branch UUID"))?;
            let target = Uuid::parse_str(&r.target_branch)
                .map_err(|_| Status::invalid_argument("invalid target_branch UUID"))?;
            self.engine.merge(&sess, source, target).await.map_err(map_err)?;
            Ok(Response::new(MergeResponse {
                success: true,
                applied: 1,
                conflicts: 0,
                conflict_ids: vec![],
            }))
        }

        async fn diff(
            &self,
            req: Request<DiffRequest>,
        ) -> Result<Response<DiffResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let r = req.into_inner();
            let a = Uuid::parse_str(&r.branch_a)
                .map_err(|_| Status::invalid_argument("invalid branch_a UUID"))?;
            let b = Uuid::parse_str(&r.branch_b)
                .map_err(|_| Status::invalid_argument("invalid branch_b UUID"))?;
            let diff = self.engine.diff(&sess, a, b).await.map_err(map_err)?;
            let diff_json = serde_json::to_vec(&diff).unwrap_or_default();
            Ok(Response::new(DiffResponse {
                added: diff["added"].as_i64().unwrap_or(0) as i32,
                removed: diff["removed"].as_i64().unwrap_or(0) as i32,
                modified: diff["modified"].as_i64().unwrap_or(0) as i32,
                divergence_score: 0.0,
                diff_json,
            }))
        }

        async fn sync(
            &self,
            req: Request<SyncRequest>,
        ) -> Result<Response<SyncResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let result = self.engine.sync(&sess).await.map_err(map_err)?;
            Ok(Response::new(SyncResponse {
                success: true,
                pushed: result["pushed"].as_i64().unwrap_or(0) as i32,
                pulled: result["pulled"].as_i64().unwrap_or(0) as i32,
                conflicts: 0,
                synced_at: chrono::Utc::now().timestamp(),
            }))
        }

        async fn reflect(
            &self,
            req: Request<ReflectRequest>,
        ) -> Result<Response<ReflectResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            let sess = crate::session::manager::ClawDBSession::from_context(ctx);
            let job_id = self.engine.reflect(&sess).await.map_err(map_err)?;
            Ok(Response::new(ReflectResponse {
                job_id,
                processed: 0,
                archived: 0,
                promoted: 0,
            }))
        }

        async fn create_session(
            &self,
            req: Request<SessionRequest>,
        ) -> Result<Response<SessionResponse>, Status> {
            let r = req.into_inner();
            let agent_id = Uuid::parse_str(&r.agent_id)
                .map_err(|_| Status::invalid_argument("invalid agent_id UUID"))?;
            let sess = self
                .engine
                .session_manager
                .create(agent_id, &r.role, r.scopes.clone(), Some(r.task_type))
                .await
                .map_err(map_err)?;
            Ok(Response::new(SessionResponse {
                session_token: sess.guard_token.clone(),
                expires_at: sess.expires_at.timestamp(),
                granted_scopes: sess.scopes,
            }))
        }

        async fn health(
            &self,
            _req: Request<HealthRequest>,
        ) -> Result<Response<HealthResponse>, Status> {
            let report = self.engine.health().await.map_err(map_err)?;
            let ok = matches!(
                report.overall,
                crate::lifecycle::health::HealthStatus::Healthy
            );
            let component_status: std::collections::HashMap<String, String> = report
                .components
                .into_iter()
                .map(|(k, v)| (k, format!("{:?}", v.status)))
                .collect();
            Ok(Response::new(HealthResponse {
                ok,
                component_status,
                version: env!("CARGO_PKG_VERSION").to_string(),
                uptime_secs: self.engine.uptime_secs() as i64,
            }))
        }

        async fn status(
            &self,
            req: Request<StatusRequest>,
        ) -> Result<Response<StatusResponse>, Status> {
            let token = extract_token(&req, &req.get_ref().session_token)?;
            let _ctx = self.engine.session_manager.validate(&token).await.map_err(map_err)?;
            Ok(Response::new(StatusResponse::default()))
        }

        type StreamEventsStream =
            Pin<Box<dyn tokio_stream::Stream<Item = Result<EventMessage, Status>> + Send>>;

        async fn stream_events(
            &self,
            req: Request<SessionRequest>,
        ) -> Result<Response<Self::StreamEventsStream>, Status> {
            let r = req.into_inner();
            let agent_id_str = r.agent_id.clone();
            let rx = self.engine.event_bus.subscribe();
            let stream = BroadcastStream::new(rx).filter_map(move |res| {
                let agent_id = agent_id_str.clone();
                match res {
                    Ok(ev) => {
                        let msg = EventMessage {
                            event_type: ev.event_type().to_string(),
                            agent_id: agent_id.clone(),
                            payload_json: serde_json::to_vec(&*ev).unwrap_or_default(),
                            timestamp: chrono::Utc::now().timestamp(),
                        };
                        Some(Ok(msg))
                    }
                    Err(_) => None,
                }
            });
            Ok(Response::new(Box::pin(stream)))
        }
    }
}

// ── Serve function ────────────────────────────────────────────────────────────

/// Starts the gRPC server.
///
/// Returns immediately if proto compilation was not available (`!cfg(proto_compiled)`).
pub async fn serve(
    engine: Arc<ClawDB>,
    config: &ServerSubConfig,
    shutdown: CancellationToken,
) -> ClawDBResult<()> {
    #[cfg(proto_compiled)]
    {
        use proto::claw_db_service_server::ClawDbServiceServer;
        use tonic::transport::Server;

        let addr = format!("0.0.0.0:{}", config.grpc_port).parse().expect("grpc addr");
        let svc = ClawDbServiceServer::new(handlers::ClawDBGrpcServer { engine });

        tracing::info!(port = config.grpc_port, "gRPC server listening");

        Server::builder()
            .add_service(svc)
            .serve_with_shutdown(addr, async move { shutdown.cancelled().await })
            .await
            .map_err(|e| crate::error::ClawDBError::ComponentFailed {
                component: "grpc".to_string(),
                reason: e.to_string(),
            })?;
    }

    #[cfg(not(proto_compiled))]
    {
        tracing::warn!("gRPC server disabled: proto compilation was not available at build time");
        shutdown.cancelled().await;
    }

    Ok(())
}