1use super::knowledge_service_server::KnowledgeService;
4use super::*;
5use crate::core::{KnowledgeBase, KnowledgeEntry as CoreEntry, SearchOptions};
6
7use std::sync::Arc;
8use std::time::Instant;
9use tokio::sync::RwLock;
10use tokio_stream::wrappers::ReceiverStream;
11use tonic::{Request, Response, Status};
12use uuid::Uuid;
13
14pub struct KnowledgeServiceImpl {
16 kb: Arc<RwLock<KnowledgeBase>>,
17 start_time: Instant,
18}
19
20impl KnowledgeServiceImpl {
21 pub fn new(kb: KnowledgeBase) -> Self {
23 Self {
24 kb: Arc::new(RwLock::new(kb)),
25 start_time: Instant::now(),
26 }
27 }
28
29 pub fn from_shared(kb: Arc<RwLock<KnowledgeBase>>) -> Self {
31 Self {
32 kb,
33 start_time: Instant::now(),
34 }
35 }
36
37 fn to_proto_entry(entry: &CoreEntry) -> KnowledgeEntry {
39 KnowledgeEntry {
40 id: entry.id.to_string(),
41 title: entry.title.clone(),
42 content: entry.content.clone(),
43 category: entry.category.clone(),
44 tags: entry.tags.clone(),
45 source: entry.source.clone(),
46 metadata: entry
47 .metadata
48 .iter()
49 .map(|(k, v)| (k.to_string(), v.to_string()))
50 .collect(),
51 created_at: entry.created_at.to_rfc3339(),
52 updated_at: entry.updated_at.to_rfc3339(),
53 access_count: entry.access_count,
54 learned_relevance: entry.learned_relevance,
55 related_entries: entry
56 .related_entries
57 .iter()
58 .map(|id| id.to_string())
59 .collect(),
60 }
61 }
62
63 fn from_add_request(req: &AddEntryRequest) -> CoreEntry {
65 let mut entry = CoreEntry::new(&req.title, &req.content);
66
67 if let Some(cat) = &req.category {
68 entry = entry.with_category(cat);
69 }
70
71 if !req.tags.is_empty() {
72 entry = entry.with_tags(req.tags.clone());
73 }
74
75 if let Some(src) = &req.source {
76 entry = entry.with_source(src);
77 }
78
79 for (k, v) in &req.metadata {
80 entry = entry.with_metadata(k, v);
81 }
82
83 entry
84 }
85}
86
87#[tonic::async_trait]
88impl KnowledgeService for KnowledgeServiceImpl {
89 async fn add_entry(
90 &self,
91 request: Request<AddEntryRequest>,
92 ) -> std::result::Result<Response<AddEntryResponse>, Status> {
93 let req = request.into_inner();
94 let entry = Self::from_add_request(&req);
95
96 let kb = self.kb.read().await;
97 match kb.add_entry(entry).await {
98 Ok(id) => Ok(Response::new(AddEntryResponse {
99 id: id.to_string(),
100 success: true,
101 error: None,
102 })),
103 Err(e) => Ok(Response::new(AddEntryResponse {
104 id: String::new(),
105 success: false,
106 error: Some(e.to_string()),
107 })),
108 }
109 }
110
111 async fn add_entries(
112 &self,
113 request: Request<AddEntriesRequest>,
114 ) -> std::result::Result<Response<AddEntriesResponse>, Status> {
115 let req = request.into_inner();
116 let entries: Vec<CoreEntry> = req.entries.iter().map(Self::from_add_request).collect();
117
118 let kb = self.kb.read().await;
119 match kb.add_entries(entries).await {
120 Ok(ids) => Ok(Response::new(AddEntriesResponse {
121 ids: ids.iter().map(|id| id.to_string()).collect(),
122 success: true,
123 error: None,
124 })),
125 Err(e) => Ok(Response::new(AddEntriesResponse {
126 ids: Vec::new(),
127 success: false,
128 error: Some(e.to_string()),
129 })),
130 }
131 }
132
133 async fn get_entry(
134 &self,
135 request: Request<GetEntryRequest>,
136 ) -> std::result::Result<Response<GetEntryResponse>, Status> {
137 let req = request.into_inner();
138 let id = Uuid::parse_str(&req.id)
139 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
140
141 let kb = self.kb.read().await;
142 match kb.get(id) {
143 Some(entry) => Ok(Response::new(GetEntryResponse {
144 entry: Some(Self::to_proto_entry(&entry)),
145 found: true,
146 })),
147 None => Ok(Response::new(GetEntryResponse {
148 entry: None,
149 found: false,
150 })),
151 }
152 }
153
154 async fn update_entry(
155 &self,
156 request: Request<UpdateEntryRequest>,
157 ) -> std::result::Result<Response<UpdateEntryResponse>, Status> {
158 let req = request.into_inner();
159 let id = Uuid::parse_str(&req.id)
160 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
161
162 let kb = self.kb.read().await;
163
164 let mut entry = match kb.get(id) {
166 Some(e) => e,
167 None => {
168 return Ok(Response::new(UpdateEntryResponse {
169 success: false,
170 error: Some("Entry not found".to_string()),
171 }));
172 }
173 };
174
175 if let Some(title) = req.title {
177 entry.title = title;
178 }
179 if let Some(content) = req.content {
180 entry.content = content;
181 }
182 if let Some(category) = req.category {
183 entry.category = Some(category);
184 }
185 if !req.tags.is_empty() {
186 entry.tags = req.tags;
187 }
188 if let Some(source) = req.source {
189 entry.source = Some(source);
190 }
191 for (k, v) in req.metadata {
192 entry.metadata.insert(k, v);
193 }
194
195 match kb.update_entry(entry).await {
196 Ok(_) => Ok(Response::new(UpdateEntryResponse {
197 success: true,
198 error: None,
199 })),
200 Err(e) => Ok(Response::new(UpdateEntryResponse {
201 success: false,
202 error: Some(e.to_string()),
203 })),
204 }
205 }
206
207 async fn delete_entry(
208 &self,
209 request: Request<DeleteEntryRequest>,
210 ) -> std::result::Result<Response<DeleteEntryResponse>, Status> {
211 let req = request.into_inner();
212 let id = Uuid::parse_str(&req.id)
213 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
214
215 let kb = self.kb.read().await;
216 match kb.delete_entry(id).await {
217 Ok(_) => Ok(Response::new(DeleteEntryResponse {
218 success: true,
219 error: None,
220 })),
221 Err(e) => Ok(Response::new(DeleteEntryResponse {
222 success: false,
223 error: Some(e.to_string()),
224 })),
225 }
226 }
227
228 async fn search(
229 &self,
230 request: Request<SearchRequest>,
231 ) -> std::result::Result<Response<SearchResponse>, Status> {
232 let req = request.into_inner();
233 let start = Instant::now();
234
235 let options = SearchOptions {
236 limit: req.limit as usize,
237 min_similarity: req.min_similarity,
238 category: req.category,
239 tags: req.tags,
240 use_learning: req.use_learning,
241 include_related: req.include_related,
242 diversity: req.diversity,
243 hybrid: req.hybrid,
244 keyword_weight: req.keyword_weight,
245 };
246
247 let kb = self.kb.read().await;
248 match kb.search(&req.query, options).await {
249 Ok(results) => {
250 let elapsed = start.elapsed().as_secs_f32() * 1000.0;
251 let proto_results: Vec<SearchResult> = results
252 .iter()
253 .map(|r| SearchResult {
254 entry: Some(Self::to_proto_entry(&r.entry)),
255 similarity: r.similarity,
256 relevance_boost: r.relevance_boost,
257 score: r.score,
258 distance: r.distance,
259 })
260 .collect();
261
262 Ok(Response::new(SearchResponse {
263 results: proto_results.clone(),
264 total_results: proto_results.len() as u32,
265 search_time_ms: elapsed,
266 }))
267 }
268 Err(e) => Err(Status::internal(e.to_string())),
269 }
270 }
271
272 type SearchStreamStream = ReceiverStream<std::result::Result<SearchResult, Status>>;
273
274 async fn search_stream(
275 &self,
276 request: Request<SearchRequest>,
277 ) -> std::result::Result<Response<Self::SearchStreamStream>, Status> {
278 let req = request.into_inner();
279 let kb = self.kb.clone();
280
281 let (tx, rx) = tokio::sync::mpsc::channel(100);
282
283 tokio::spawn(async move {
284 let options = SearchOptions {
285 limit: req.limit as usize,
286 min_similarity: req.min_similarity,
287 category: req.category,
288 tags: req.tags,
289 use_learning: req.use_learning,
290 include_related: req.include_related,
291 diversity: req.diversity,
292 hybrid: req.hybrid,
293 keyword_weight: req.keyword_weight,
294 };
295
296 let kb = kb.read().await;
297 if let Ok(results) = kb.search(&req.query, options).await {
298 for result in results {
299 let proto_result = SearchResult {
300 entry: Some(KnowledgeServiceImpl::to_proto_entry(&result.entry)),
301 similarity: result.similarity,
302 relevance_boost: result.relevance_boost,
303 score: result.score,
304 distance: result.distance,
305 };
306
307 if tx.send(Ok(proto_result)).await.is_err() {
308 break;
309 }
310 }
311 }
312 });
313
314 Ok(Response::new(ReceiverStream::new(rx)))
315 }
316
317 async fn record_feedback(
318 &self,
319 request: Request<FeedbackRequest>,
320 ) -> std::result::Result<Response<FeedbackResponse>, Status> {
321 let req = request.into_inner();
322 let id = Uuid::parse_str(&req.entry_id)
323 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
324
325 let kb = self.kb.read().await;
326 match kb.record_feedback(id, req.positive).await {
327 Ok(_) => Ok(Response::new(FeedbackResponse { success: true })),
328 Err(_) => Ok(Response::new(FeedbackResponse { success: false })),
329 }
330 }
331
332 async fn get_related(
333 &self,
334 request: Request<GetRelatedRequest>,
335 ) -> std::result::Result<Response<GetRelatedResponse>, Status> {
336 let req = request.into_inner();
337 let id = Uuid::parse_str(&req.id)
338 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
339
340 let kb = self.kb.read().await;
341 let related = kb.get_related(id, req.limit as usize);
342
343 Ok(Response::new(GetRelatedResponse {
344 entries: related.iter().map(Self::to_proto_entry).collect(),
345 }))
346 }
347
348 async fn link_entries(
349 &self,
350 request: Request<LinkEntriesRequest>,
351 ) -> std::result::Result<Response<LinkEntriesResponse>, Status> {
352 let req = request.into_inner();
353 let id1 = Uuid::parse_str(&req.id1)
354 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
355 let id2 = Uuid::parse_str(&req.id2)
356 .map_err(|e| Status::invalid_argument(format!("Invalid UUID: {}", e)))?;
357
358 let kb = self.kb.read().await;
359 match kb.link_entries(id1, id2).await {
360 Ok(_) => Ok(Response::new(LinkEntriesResponse {
361 success: true,
362 error: None,
363 })),
364 Err(e) => Ok(Response::new(LinkEntriesResponse {
365 success: false,
366 error: Some(e.to_string()),
367 })),
368 }
369 }
370
371 async fn get_stats(
372 &self,
373 _request: Request<GetStatsRequest>,
374 ) -> std::result::Result<Response<GetStatsResponse>, Status> {
375 let kb = self.kb.read().await;
376 let stats = kb.stats();
377
378 Ok(Response::new(GetStatsResponse {
379 total_entries: stats.total_entries as u64,
380 unique_categories: stats.unique_categories as u64,
381 unique_tags: stats.unique_tags as u64,
382 total_access_count: stats.total_access_count,
383 dimensions: stats.dimensions as u32,
384 learning_enabled: stats.learning_enabled,
385 learning_stats: None, }))
387 }
388
389 async fn health(
390 &self,
391 _request: Request<HealthRequest>,
392 ) -> std::result::Result<Response<HealthResponse>, Status> {
393 Ok(Response::new(HealthResponse {
394 healthy: true,
395 version: env!("CARGO_PKG_VERSION").to_string(),
396 uptime_seconds: self.start_time.elapsed().as_secs(),
397 }))
398 }
399}