Skip to main content

brainos_grpcadapter/handlers/
memory.rs

1//! `MemoryService` RPC handlers — search, store, list, stream_signals.
2
3use std::pin::Pin;
4
5use tokio_stream::Stream;
6use tonic::{Request, Response, Status};
7
8use signal::{Signal, SignalSource};
9
10use crate::errors::public_status;
11use crate::helpers::{non_empty, response_to_string};
12use crate::memory_proto::{
13    memory_service_server::MemoryService, Fact, GetFactsRequest, GetFactsResponse, SearchRequest,
14    SearchResponse, SignalEvent, SignalRequest as MemorySignalRequest, StoreRequest, StoreResponse,
15};
16use crate::state::MemoryServiceImpl;
17
18/// Stream type alias for the server-streaming `StreamSignals` RPC.
19type SignalEventStream = Pin<Box<dyn Stream<Item = Result<SignalEvent, Status>> + Send + 'static>>;
20
21#[tonic::async_trait]
22impl MemoryService for MemoryServiceImpl {
23    /// Search semantic memory using a text query.
24    async fn search(
25        &self,
26        request: Request<SearchRequest>,
27    ) -> Result<Response<SearchResponse>, Status> {
28        let req = request.into_inner();
29        let top_k = if req.top_k == 0 {
30            10
31        } else {
32            req.top_k as usize
33        };
34
35        let namespace = non_empty(req.namespace);
36
37        let results = self
38            .processor
39            .search_facts(&req.query, top_k, namespace.as_deref())
40            .await;
41
42        let facts = results
43            .into_iter()
44            .map(|r| Fact {
45                id: r.fact.id,
46                category: r.fact.category,
47                subject: r.fact.subject,
48                predicate: r.fact.predicate,
49                object: r.fact.object,
50                confidence: r.fact.confidence,
51                distance: r.distance,
52            })
53            .collect();
54
55        Ok(Response::new(SearchResponse { facts }))
56    }
57
58    /// Store a structured fact in semantic memory.
59    async fn store(
60        &self,
61        request: Request<StoreRequest>,
62    ) -> Result<Response<StoreResponse>, Status> {
63        let req = request.into_inner();
64        let category = non_empty(req.category).unwrap_or_else(|| "general".to_string());
65        let namespace = non_empty(req.namespace).unwrap_or_else(|| "personal".to_string());
66
67        match self
68            .processor
69            .store_fact_direct(
70                &namespace,
71                &category,
72                &req.subject,
73                &req.predicate,
74                &req.object,
75                None,
76            )
77            .await
78        {
79            Ok(fact_id) => Ok(Response::new(StoreResponse {
80                fact_id,
81                success: true,
82                message: "Fact stored successfully".to_string(),
83            })),
84            Err(e) => {
85                tracing::error!(error = %e, "gRPC store_fact failed");
86                Err(public_status(&e))
87            }
88        }
89    }
90
91    /// List all active facts, optionally filtered by subject and/or namespace.
92    async fn get_facts(
93        &self,
94        request: Request<GetFactsRequest>,
95    ) -> Result<Response<GetFactsResponse>, Status> {
96        let req = request.into_inner();
97
98        let namespace = non_empty(req.namespace);
99
100        let raw_facts = if req.subject.is_empty() {
101            self.processor.list_facts(namespace.as_deref())
102        } else {
103            self.processor
104                .facts_about(&req.subject, namespace.as_deref())
105        };
106
107        let facts = raw_facts
108            .into_iter()
109            .map(|f| Fact {
110                id: f.id,
111                category: f.category,
112                subject: f.subject,
113                predicate: f.predicate,
114                object: f.object,
115                confidence: f.confidence,
116                distance: 0.0,
117            })
118            .collect();
119
120        Ok(Response::new(GetFactsResponse { facts }))
121    }
122
123    type StreamSignalsStream = SignalEventStream;
124
125    /// Process a signal and stream the response event(s).
126    async fn stream_signals(
127        &self,
128        request: Request<MemorySignalRequest>,
129    ) -> Result<Response<Self::StreamSignalsStream>, Status> {
130        let principal = self.resolve_principal(&request).await;
131        let req = request.into_inner();
132        let source = SignalSource::parse(Some(&req.source), SignalSource::Grpc);
133
134        let sig = Signal::from_adapter_request(signal::AdapterRequest {
135            source,
136            content: req.content,
137            channel: non_empty(req.channel),
138            sender: non_empty(req.sender),
139            metadata: Some(req.metadata),
140            namespace: non_empty(req.namespace),
141            agent: non_empty(req.agent),
142            session_id: non_empty(req.session_id),
143            default_channel: "grpc".to_string(),
144            default_sender: "grpcclient".to_string(),
145        })
146        .with_principal_opt(principal);
147
148        let processor = self.processor.clone();
149        let (tx, rx) = tokio::sync::mpsc::channel(4);
150
151        tokio::spawn(async move {
152            match processor.process(sig).await {
153                Ok(resp) => {
154                    let event = SignalEvent {
155                        signal_id: resp.signal_id.to_string(),
156                        status: format!("{:?}", resp.status),
157                        response: response_to_string(resp.response),
158                        facts_used: resp.memory_context.facts_used as u32,
159                        episodes_used: resp.memory_context.episodes_used as u32,
160                        session_id: resp.session_id.unwrap_or_default(),
161                    };
162                    let _ = tx.send(Ok(event)).await;
163                }
164                Err(e) => {
165                    tracing::error!(error = %e, "gRPC stream_signals processing failed");
166                    let _ = tx.send(Err(public_status(&e))).await;
167                }
168            }
169        });
170
171        let stream: SignalEventStream = Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx));
172        Ok(Response::new(stream))
173    }
174}