1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6
7use tokio::sync::RwLock;
8
9use crate::auth::TenantScope;
10use crate::error::Error;
11
12use super::{Chunk, KnowledgeBase, KnowledgeQuery, SearchResult};
13
14pub struct InMemoryKnowledgeBase {
21 chunks: RwLock<HashMap<String, Chunk>>,
22}
23
24impl InMemoryKnowledgeBase {
25 pub fn new() -> Self {
27 Self {
28 chunks: RwLock::new(HashMap::new()),
29 }
30 }
31}
32
33impl Default for InMemoryKnowledgeBase {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39fn tokenize(text: &str) -> Vec<String> {
41 let mut seen = std::collections::HashSet::new();
42 text.split_whitespace()
43 .map(|w| {
44 w.to_lowercase()
45 .trim_matches(|c: char| !c.is_alphanumeric())
46 .to_string()
47 })
48 .filter(|w| !w.is_empty() && seen.insert(w.clone()))
49 .collect()
50}
51
52fn count_matches(query_tokens: &[String], content: &str) -> usize {
54 let lower = content.to_lowercase();
55 query_tokens
56 .iter()
57 .filter(|t| lower.contains(t.as_str()))
58 .count()
59}
60
61impl KnowledgeBase for InMemoryKnowledgeBase {
62 fn index(
63 &self,
64 scope: &TenantScope,
65 mut chunk: Chunk,
66 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
67 let tid = scope.tenant_id.clone();
70 Box::pin(async move {
71 chunk.tenant_id = if tid.is_empty() { None } else { Some(tid) };
72 let mut data = self.chunks.write().await;
73 data.insert(chunk.id.clone(), chunk);
74 Ok(())
75 })
76 }
77
78 fn search(
79 &self,
80 scope: &TenantScope,
81 query: KnowledgeQuery,
82 ) -> Pin<Box<dyn Future<Output = Result<Vec<SearchResult>, Error>> + Send + '_>> {
83 let tid = scope.tenant_id.clone();
84 Box::pin(async move {
85 let data = self.chunks.read().await;
86 let tokens = tokenize(&query.text);
87
88 if tokens.is_empty() {
89 return Ok(vec![]);
90 }
91
92 let tenant_match = move |chunk: &Chunk| -> bool {
95 let chunk_tid = chunk.tenant_id.as_deref().unwrap_or("");
96 chunk_tid == tid.as_str()
97 };
98
99 let mut results: Vec<SearchResult> = data
100 .values()
101 .filter(|chunk| tenant_match(chunk))
102 .filter(|chunk| {
103 if let Some(ref filter) = query.source_filter {
104 chunk.source.uri.starts_with(filter)
105 } else {
106 true
107 }
108 })
109 .filter_map(|chunk| {
110 let matches = count_matches(&tokens, &chunk.content);
111 if matches > 0 {
112 Some(SearchResult {
113 chunk: chunk.clone(),
114 match_count: matches,
115 })
116 } else {
117 None
118 }
119 })
120 .collect();
121
122 results.sort_by(|a, b| {
124 b.match_count
125 .cmp(&a.match_count)
126 .then_with(|| a.chunk.chunk_index.cmp(&b.chunk.chunk_index))
127 .then_with(|| a.chunk.source.uri.cmp(&b.chunk.source.uri))
128 });
129
130 if query.limit > 0 {
131 results.truncate(query.limit);
132 }
133
134 Ok(results)
135 })
136 }
137
138 fn chunk_count(
139 &self,
140 scope: &TenantScope,
141 ) -> Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + '_>> {
142 let tid = scope.tenant_id.clone();
143 Box::pin(async move {
144 let data = self.chunks.read().await;
145 let count = data
146 .values()
147 .filter(|c| c.tenant_id.as_deref().unwrap_or("") == tid.as_str())
148 .count();
149 Ok(count)
150 })
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::knowledge::DocumentSource;
158 use std::sync::Arc;
159
160 fn make_chunk(id: &str, content: &str, uri: &str, index: usize) -> Chunk {
161 Chunk {
162 id: id.into(),
163 content: content.into(),
164 source: DocumentSource {
165 uri: uri.into(),
166 title: uri.into(),
167 },
168 chunk_index: index,
169 tenant_id: None,
170 }
171 }
172
173 fn s() -> TenantScope {
174 TenantScope::default()
175 }
176
177 #[tokio::test]
178 async fn index_and_search_roundtrip() {
179 let kb = InMemoryKnowledgeBase::new();
180 kb.index(
181 &s(),
182 make_chunk(
183 "c1",
184 "Rust is a systems programming language",
185 "docs/rust.md",
186 0,
187 ),
188 )
189 .await
190 .unwrap();
191
192 let results = kb
193 .search(
194 &s(),
195 KnowledgeQuery {
196 text: "rust programming".into(),
197 source_filter: None,
198 limit: 5,
199 },
200 )
201 .await
202 .unwrap();
203
204 assert_eq!(results.len(), 1);
205 assert_eq!(results[0].chunk.id, "c1");
206 assert_eq!(results[0].match_count, 2); }
208
209 #[tokio::test]
210 async fn search_is_case_insensitive() {
211 let kb = InMemoryKnowledgeBase::new();
212 kb.index(&s(), make_chunk("c1", "RUST is GREAT", "f.md", 0))
213 .await
214 .unwrap();
215
216 let results = kb
217 .search(
218 &s(),
219 KnowledgeQuery {
220 text: "rust great".into(),
221 source_filter: None,
222 limit: 5,
223 },
224 )
225 .await
226 .unwrap();
227
228 assert_eq!(results.len(), 1);
229 assert_eq!(results[0].match_count, 2);
230 }
231
232 #[tokio::test]
233 async fn source_filter_restricts_results() {
234 let kb = InMemoryKnowledgeBase::new();
235 kb.index(&s(), make_chunk("c1", "Rust language", "docs/rust.md", 0))
236 .await
237 .unwrap();
238 kb.index(&s(), make_chunk("c2", "Rust compiler", "api/rust.md", 0))
239 .await
240 .unwrap();
241
242 let results = kb
243 .search(
244 &s(),
245 KnowledgeQuery {
246 text: "rust".into(),
247 source_filter: Some("docs/".into()),
248 limit: 10,
249 },
250 )
251 .await
252 .unwrap();
253
254 assert_eq!(results.len(), 1);
255 assert_eq!(results[0].chunk.source.uri, "docs/rust.md");
256 }
257
258 #[tokio::test]
259 async fn limit_truncates_results() {
260 let kb = InMemoryKnowledgeBase::new();
261 for i in 0..10 {
262 kb.index(
263 &s(),
264 make_chunk(
265 &format!("c{i}"),
266 "rust programming language",
267 "docs/rust.md",
268 i,
269 ),
270 )
271 .await
272 .unwrap();
273 }
274
275 let results = kb
276 .search(
277 &s(),
278 KnowledgeQuery {
279 text: "rust".into(),
280 source_filter: None,
281 limit: 3,
282 },
283 )
284 .await
285 .unwrap();
286
287 assert_eq!(results.len(), 3);
288 }
289
290 #[tokio::test]
291 async fn sorted_by_match_count_descending() {
292 let kb = InMemoryKnowledgeBase::new();
293 kb.index(&s(), make_chunk("c1", "rust", "f.md", 0))
294 .await
295 .unwrap();
296 kb.index(
297 &s(),
298 make_chunk("c2", "rust programming rust systems", "f.md", 1),
299 )
300 .await
301 .unwrap();
302 kb.index(&s(), make_chunk("c3", "rust programming", "f.md", 2))
303 .await
304 .unwrap();
305
306 let results = kb
307 .search(
308 &s(),
309 KnowledgeQuery {
310 text: "rust programming systems".into(),
311 source_filter: None,
312 limit: 10,
313 },
314 )
315 .await
316 .unwrap();
317
318 assert_eq!(results.len(), 3);
319 assert_eq!(results[0].chunk.id, "c2"); assert_eq!(results[1].chunk.id, "c3"); assert_eq!(results[2].chunk.id, "c1"); }
323
324 #[tokio::test]
325 async fn reindex_replaces_chunk() {
326 let kb = InMemoryKnowledgeBase::new();
327 kb.index(&s(), make_chunk("c1", "old content", "f.md", 0))
328 .await
329 .unwrap();
330 kb.index(&s(), make_chunk("c1", "new content about rust", "f.md", 0))
331 .await
332 .unwrap();
333
334 assert_eq!(kb.chunk_count(&s()).await.unwrap(), 1);
335
336 let results = kb
337 .search(
338 &s(),
339 KnowledgeQuery {
340 text: "rust".into(),
341 source_filter: None,
342 limit: 5,
343 },
344 )
345 .await
346 .unwrap();
347
348 assert_eq!(results.len(), 1);
349 assert!(results[0].chunk.content.contains("new content"));
350 }
351
352 #[tokio::test]
353 async fn empty_query_returns_no_results() {
354 let kb = InMemoryKnowledgeBase::new();
355 kb.index(&s(), make_chunk("c1", "some content", "f.md", 0))
356 .await
357 .unwrap();
358
359 let results = kb
360 .search(
361 &s(),
362 KnowledgeQuery {
363 text: "".into(),
364 source_filter: None,
365 limit: 5,
366 },
367 )
368 .await
369 .unwrap();
370
371 assert!(results.is_empty());
372 }
373
374 #[tokio::test]
375 async fn no_match_returns_empty() {
376 let kb = InMemoryKnowledgeBase::new();
377 kb.index(&s(), make_chunk("c1", "hello world", "f.md", 0))
378 .await
379 .unwrap();
380
381 let results = kb
382 .search(
383 &s(),
384 KnowledgeQuery {
385 text: "zzzznotfound".into(),
386 source_filter: None,
387 limit: 5,
388 },
389 )
390 .await
391 .unwrap();
392
393 assert!(results.is_empty());
394 }
395
396 #[tokio::test]
397 async fn chunk_count_tracks_size() {
398 let kb = InMemoryKnowledgeBase::new();
399 assert_eq!(kb.chunk_count(&s()).await.unwrap(), 0);
400
401 kb.index(&s(), make_chunk("c1", "a", "f.md", 0))
402 .await
403 .unwrap();
404 kb.index(&s(), make_chunk("c2", "b", "f.md", 1))
405 .await
406 .unwrap();
407 assert_eq!(kb.chunk_count(&s()).await.unwrap(), 2);
408 }
409
410 #[test]
411 fn is_send_sync() {
412 fn assert_send_sync<T: Send + Sync>() {}
413 assert_send_sync::<InMemoryKnowledgeBase>();
414 fn _accepts_dyn(_kb: &dyn KnowledgeBase) {}
415 }
416
417 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
418 async fn concurrent_index_and_search() {
419 let kb = Arc::new(InMemoryKnowledgeBase::new());
420 let mut handles = Vec::new();
421
422 for i in 0..20 {
424 let kb = kb.clone();
425 handles.push(tokio::spawn(async move {
426 kb.index(
427 &s(),
428 make_chunk(
429 &format!("c{i}"),
430 &format!("rust content item {i}"),
431 "f.md",
432 i,
433 ),
434 )
435 .await
436 .unwrap();
437 }));
438 }
439
440 for _ in 0..10 {
442 let kb = kb.clone();
443 handles.push(tokio::spawn(async move {
444 let _ = kb
445 .search(
446 &s(),
447 KnowledgeQuery {
448 text: "rust".into(),
449 source_filter: None,
450 limit: 5,
451 },
452 )
453 .await
454 .unwrap();
455 }));
456 }
457
458 for h in handles {
459 h.await.unwrap();
460 }
461
462 assert_eq!(kb.chunk_count(&s()).await.unwrap(), 20);
463 }
464
465 #[tokio::test]
466 async fn duplicate_query_terms_not_inflated() {
467 let kb = InMemoryKnowledgeBase::new();
468 kb.index(&s(), make_chunk("c1", "rust is great", "f.md", 0))
469 .await
470 .unwrap();
471
472 let results = kb
473 .search(
474 &s(),
475 KnowledgeQuery {
476 text: "rust rust rust".into(),
477 source_filter: None,
478 limit: 5,
479 },
480 )
481 .await
482 .unwrap();
483
484 assert_eq!(results.len(), 1);
485 assert_eq!(results[0].match_count, 1); }
487
488 #[tokio::test]
493 async fn search_isolates_by_tenant() {
494 let kb = InMemoryKnowledgeBase::new();
495 let scope_a = TenantScope::new("tenant-a");
496 let scope_b = TenantScope::new("tenant-b");
497
498 kb.index(
499 &scope_a,
500 make_chunk("a1", "alice secret rust note", "a/notes.md", 0),
501 )
502 .await
503 .unwrap();
504 kb.index(
505 &scope_b,
506 make_chunk("b1", "bob secret rust note", "b/notes.md", 0),
507 )
508 .await
509 .unwrap();
510
511 let results_a = kb
513 .search(
514 &scope_a,
515 KnowledgeQuery {
516 text: "rust".into(),
517 source_filter: None,
518 limit: 10,
519 },
520 )
521 .await
522 .unwrap();
523 assert_eq!(results_a.len(), 1);
524 assert_eq!(results_a[0].chunk.id, "a1");
525
526 let results_b = kb
528 .search(
529 &scope_b,
530 KnowledgeQuery {
531 text: "rust".into(),
532 source_filter: None,
533 limit: 10,
534 },
535 )
536 .await
537 .unwrap();
538 assert_eq!(results_b.len(), 1);
539 assert_eq!(results_b[0].chunk.id, "b1");
540
541 assert_eq!(kb.chunk_count(&scope_a).await.unwrap(), 1);
543 assert_eq!(kb.chunk_count(&scope_b).await.unwrap(), 1);
544 }
545
546 #[tokio::test]
547 async fn sort_stable_across_sources() {
548 let kb = InMemoryKnowledgeBase::new();
549 kb.index(&s(), make_chunk("c1", "rust programming", "z_file.md", 0))
550 .await
551 .unwrap();
552 kb.index(&s(), make_chunk("c2", "rust programming", "a_file.md", 0))
553 .await
554 .unwrap();
555
556 let results = kb
557 .search(
558 &s(),
559 KnowledgeQuery {
560 text: "rust".into(),
561 source_filter: None,
562 limit: 10,
563 },
564 )
565 .await
566 .unwrap();
567
568 assert_eq!(results.len(), 2);
569 assert_eq!(results[0].chunk.source.uri, "a_file.md");
571 assert_eq!(results[1].chunk.source.uri, "z_file.md");
572 }
573}