Skip to main content

converge_knowledge/grpc/
server.rs

1//! gRPC server implementation.
2
3use super::knowledge_service_server::KnowledgeService;
4use super::*;
5use crate::core::{KnowledgeBase, KnowledgeEntry as CoreEntry, SearchOptions};
6
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::sync::RwLock;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{Request, Response, Status};
12use uuid::Uuid;
13
14/// gRPC service implementation for the knowledge base.
15pub struct KnowledgeServiceImpl {
16    kb: Arc<RwLock<KnowledgeBase>>,
17    start_time: Instant,
18}
19
20impl KnowledgeServiceImpl {
21    /// Create a new service instance.
22    pub fn new(kb: KnowledgeBase) -> Self {
23        Self {
24            kb: Arc::new(RwLock::new(kb)),
25            start_time: Instant::now(),
26        }
27    }
28
29    /// Create from shared knowledge base.
30    pub fn from_shared(kb: Arc<RwLock<KnowledgeBase>>) -> Self {
31        Self {
32            kb,
33            start_time: Instant::now(),
34        }
35    }
36
37    /// Convert core entry to proto entry.
38    fn to_proto_entry(entry: &CoreEntry) -> KnowledgeEntry {
39        KnowledgeEntry {
40            id: entry.id.to_string(),
41            title: entry.title.clone(),
42            content: entry.content.clone(),
43            category: entry.category.clone(),
44            tags: entry.tags.clone(),
45            source: entry.source.clone(),
46            metadata: entry
47                .metadata
48                .iter()
49                .map(|(k, v)| (k.to_string(), v.to_string()))
50                .collect(),
51            created_at: entry.created_at.to_rfc3339(),
52            updated_at: entry.updated_at.to_rfc3339(),
53            access_count: entry.access_count,
54            learned_relevance: entry.learned_relevance,
55            related_entries: entry
56                .related_entries
57                .iter()
58                .map(|id| id.to_string())
59                .collect(),
60        }
61    }
62
63    /// Convert add request to core entry.
64    fn from_add_request(req: &AddEntryRequest) -> CoreEntry {
65        let mut entry = CoreEntry::new(&req.title, &req.content);
66
67        if let Some(cat) = &req.category {
68            entry = entry.with_category(cat);
69        }
70
71        if !req.tags.is_empty() {
72            entry = entry.with_tags(req.tags.clone());
73        }
74
75        if let Some(src) = &req.source {
76            entry = entry.with_source(src);
77        }
78
79        for (k, v) in &req.metadata {
80            entry = entry.with_metadata(k, v);
81        }
82
83        entry
84    }
85}
86
87#[tonic::async_trait]
88impl KnowledgeService for KnowledgeServiceImpl {
89    async fn add_entry(
90        &self,
91        request: Request<AddEntryRequest>,
92    ) -> std::result::Result<Response<AddEntryResponse>, Status> {
93        let req = request.into_inner();
94        let entry = Self::from_add_request(&req);
95
96        let kb = self.kb.read().await;
97        match kb.add_entry(entry).await {
98            Ok(id) => Ok(Response::new(AddEntryResponse {
99                id: id.to_string(),
100                success: true,
101                error: None,
102            })),
103            Err(e) => Ok(Response::new(AddEntryResponse {
104                id: String::new(),
105                success: false,
106                error: Some(e.to_string()),
107            })),
108        }
109    }
110
111    async fn add_entries(
112        &self,
113        request: Request<AddEntriesRequest>,
114    ) -> std::result::Result<Response<AddEntriesResponse>, Status> {
115        let req = request.into_inner();
116        let entries: Vec<CoreEntry> = req.entries.iter().map(Self::from_add_request).collect();
117
118        let kb = self.kb.read().await;
119        match kb.add_entries(entries).await {
120            Ok(ids) => Ok(Response::new(AddEntriesResponse {
121                ids: ids.iter().map(|id| id.to_string()).collect(),
122                success: true,
123                error: None,
124            })),
125            Err(e) => Ok(Response::new(AddEntriesResponse {
126                ids: Vec::new(),
127                success: false,
128                error: Some(e.to_string()),
129            })),
130        }
131    }
132
133    async fn get_entry(
134        &self,
135        request: Request<GetEntryRequest>,
136    ) -> std::result::Result<Response<GetEntryResponse>, Status> {
137        let req = request.into_inner();
138        let id = Uuid::parse_str(&req.id)
139            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
140
141        let kb = self.kb.read().await;
142        match kb.get(id) {
143            Some(entry) => Ok(Response::new(GetEntryResponse {
144                entry: Some(Self::to_proto_entry(&entry)),
145                found: true,
146            })),
147            None => Ok(Response::new(GetEntryResponse {
148                entry: None,
149                found: false,
150            })),
151        }
152    }
153
154    async fn update_entry(
155        &self,
156        request: Request<UpdateEntryRequest>,
157    ) -> std::result::Result<Response<UpdateEntryResponse>, Status> {
158        let req = request.into_inner();
159        let id = Uuid::parse_str(&req.id)
160            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
161
162        let kb = self.kb.read().await;
163
164        // Get existing entry
165        let mut entry = match kb.get(id) {
166            Some(e) => e,
167            None => {
168                return Ok(Response::new(UpdateEntryResponse {
169                    success: false,
170                    error: Some("Entry not found".to_string()),
171                }));
172            }
173        };
174
175        // Update fields
176        if let Some(title) = req.title {
177            entry.title = title;
178        }
179        if let Some(content) = req.content {
180            entry.content = content;
181        }
182        if let Some(category) = req.category {
183            entry.category = Some(category);
184        }
185        if !req.tags.is_empty() {
186            entry.tags = req.tags;
187        }
188        if let Some(source) = req.source {
189            entry.source = Some(source);
190        }
191        for (k, v) in req.metadata {
192            entry.metadata.insert(k, v);
193        }
194
195        match kb.update_entry(entry).await {
196            Ok(_) => Ok(Response::new(UpdateEntryResponse {
197                success: true,
198                error: None,
199            })),
200            Err(e) => Ok(Response::new(UpdateEntryResponse {
201                success: false,
202                error: Some(e.to_string()),
203            })),
204        }
205    }
206
207    async fn delete_entry(
208        &self,
209        request: Request<DeleteEntryRequest>,
210    ) -> std::result::Result<Response<DeleteEntryResponse>, Status> {
211        let req = request.into_inner();
212        let id = Uuid::parse_str(&req.id)
213            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
214
215        let kb = self.kb.read().await;
216        match kb.delete_entry(id).await {
217            Ok(_) => Ok(Response::new(DeleteEntryResponse {
218                success: true,
219                error: None,
220            })),
221            Err(e) => Ok(Response::new(DeleteEntryResponse {
222                success: false,
223                error: Some(e.to_string()),
224            })),
225        }
226    }
227
228    async fn search(
229        &self,
230        request: Request<SearchRequest>,
231    ) -> std::result::Result<Response<SearchResponse>, Status> {
232        let req = request.into_inner();
233        let start = Instant::now();
234
235        let options = SearchOptions {
236            limit: req.limit as usize,
237            min_similarity: req.min_similarity,
238            category: req.category,
239            tags: req.tags,
240            use_learning: req.use_learning,
241            include_related: req.include_related,
242            diversity: req.diversity,
243            hybrid: req.hybrid,
244            keyword_weight: req.keyword_weight,
245        };
246
247        let kb = self.kb.read().await;
248        match kb.search(&req.query, options).await {
249            Ok(results) => {
250                let elapsed = start.elapsed().as_secs_f32() * 1000.0;
251                let proto_results: Vec<SearchResult> = results
252                    .iter()
253                    .map(|r| SearchResult {
254                        entry: Some(Self::to_proto_entry(&r.entry)),
255                        similarity: r.similarity,
256                        relevance_boost: r.relevance_boost,
257                        score: r.score,
258                        distance: r.distance,
259                    })
260                    .collect();
261
262                Ok(Response::new(SearchResponse {
263                    results: proto_results.clone(),
264                    total_results: proto_results.len() as u32,
265                    search_time_ms: elapsed,
266                }))
267            }
268            Err(e) => Err(Status::internal(e.to_string())),
269        }
270    }
271
272    type SearchStreamStream = ReceiverStream<std::result::Result<SearchResult, Status>>;
273
274    async fn search_stream(
275        &self,
276        request: Request<SearchRequest>,
277    ) -> std::result::Result<Response<Self::SearchStreamStream>, Status> {
278        let req = request.into_inner();
279        let kb = self.kb.clone();
280
281        let (tx, rx) = tokio::sync::mpsc::channel(100);
282
283        tokio::spawn(async move {
284            let options = SearchOptions {
285                limit: req.limit as usize,
286                min_similarity: req.min_similarity,
287                category: req.category,
288                tags: req.tags,
289                use_learning: req.use_learning,
290                include_related: req.include_related,
291                diversity: req.diversity,
292                hybrid: req.hybrid,
293                keyword_weight: req.keyword_weight,
294            };
295
296            let kb = kb.read().await;
297            if let Ok(results) = kb.search(&req.query, options).await {
298                for result in results {
299                    let proto_result = SearchResult {
300                        entry: Some(KnowledgeServiceImpl::to_proto_entry(&result.entry)),
301                        similarity: result.similarity,
302                        relevance_boost: result.relevance_boost,
303                        score: result.score,
304                        distance: result.distance,
305                    };
306
307                    if tx.send(Ok(proto_result)).await.is_err() {
308                        break;
309                    }
310                }
311            }
312        });
313
314        Ok(Response::new(ReceiverStream::new(rx)))
315    }
316
317    async fn record_feedback(
318        &self,
319        request: Request<FeedbackRequest>,
320    ) -> std::result::Result<Response<FeedbackResponse>, Status> {
321        let req = request.into_inner();
322        let id = Uuid::parse_str(&req.entry_id)
323            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
324
325        let kb = self.kb.read().await;
326        match kb.record_feedback(id, req.positive).await {
327            Ok(_) => Ok(Response::new(FeedbackResponse { success: true })),
328            Err(_) => Ok(Response::new(FeedbackResponse { success: false })),
329        }
330    }
331
332    async fn get_related(
333        &self,
334        request: Request<GetRelatedRequest>,
335    ) -> std::result::Result<Response<GetRelatedResponse>, Status> {
336        let req = request.into_inner();
337        let id = Uuid::parse_str(&req.id)
338            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
339
340        let kb = self.kb.read().await;
341        let related = kb.get_related(id, req.limit as usize);
342
343        Ok(Response::new(GetRelatedResponse {
344            entries: related.iter().map(Self::to_proto_entry).collect(),
345        }))
346    }
347
348    async fn link_entries(
349        &self,
350        request: Request<LinkEntriesRequest>,
351    ) -> std::result::Result<Response<LinkEntriesResponse>, Status> {
352        let req = request.into_inner();
353        let id1 = Uuid::parse_str(&req.id1)
354            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
355        let id2 = Uuid::parse_str(&req.id2)
356            .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
357
358        let kb = self.kb.read().await;
359        match kb.link_entries(id1, id2).await {
360            Ok(_) => Ok(Response::new(LinkEntriesResponse {
361                success: true,
362                error: None,
363            })),
364            Err(e) => Ok(Response::new(LinkEntriesResponse {
365                success: false,
366                error: Some(e.to_string()),
367            })),
368        }
369    }
370
371    async fn get_stats(
372        &self,
373        _request: Request<GetStatsRequest>,
374    ) -> std::result::Result<Response<GetStatsResponse>, Status> {
375        let kb = self.kb.read().await;
376        let stats = kb.stats();
377
378        Ok(Response::new(GetStatsResponse {
379            total_entries: stats.total_entries as u64,
380            unique_categories: stats.unique_categories as u64,
381            unique_tags: stats.unique_tags as u64,
382            total_access_count: stats.total_access_count,
383            dimensions: stats.dimensions as u32,
384            learning_enabled: stats.learning_enabled,
385            learning_stats: None, // TODO: Add learning stats
386        }))
387    }
388
389    async fn health(
390        &self,
391        _request: Request<HealthRequest>,
392    ) -> std::result::Result<Response<HealthResponse>, Status> {
393        Ok(Response::new(HealthResponse {
394            healthy: true,
395            version: env!("CARGO_PKG_VERSION").to_string(),
396            uptime_seconds: self.start_time.elapsed().as_secs(),
397        }))
398    }
399}