Skip to main content

heartbit_core/knowledge/
in_memory.rs

1//! In-memory knowledge base backed by BM25 scoring.
2
3use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6
7use tokio::sync::RwLock;
8
9use crate::auth::TenantScope;
10use crate::error::Error;
11
12use super::{Chunk, KnowledgeBase, KnowledgeQuery, SearchResult};
13
14/// In-memory knowledge base backed by a `tokio::sync::RwLock<HashMap>`.
15///
16/// Search is keyword-based: tokenizes query into lowercase words, counts
17/// matches per chunk, and sorts by match count descending.
18///
19/// Always used behind `Arc<dyn KnowledgeBase>`, so no inner `Arc` needed.
20pub struct InMemoryKnowledgeBase {
21    chunks: RwLock<HashMap<String, Chunk>>,
22}
23
24impl InMemoryKnowledgeBase {
25    /// Create an empty in-memory knowledge base.
26    pub fn new() -> Self {
27        Self {
28            chunks: RwLock::new(HashMap::new()),
29        }
30    }
31}
32
33impl Default for InMemoryKnowledgeBase {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39/// Tokenize text into deduplicated lowercase words for keyword matching.
40fn tokenize(text: &str) -> Vec<String> {
41    let mut seen = std::collections::HashSet::new();
42    text.split_whitespace()
43        .map(|w| {
44            w.to_lowercase()
45                .trim_matches(|c: char| !c.is_alphanumeric())
46                .to_string()
47        })
48        .filter(|w| !w.is_empty() && seen.insert(w.clone()))
49        .collect()
50}
51
52/// Count how many query tokens appear in the chunk content.
53fn count_matches(query_tokens: &[String], content: &str) -> usize {
54    let lower = content.to_lowercase();
55    query_tokens
56        .iter()
57        .filter(|t| lower.contains(t.as_str()))
58        .count()
59}
60
61impl KnowledgeBase for InMemoryKnowledgeBase {
62    fn index(
63        &self,
64        scope: &TenantScope,
65        mut chunk: Chunk,
66    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
67        // SECURITY (F-KB-1): stamp the chunk with the caller's tenant_id at
68        // index time. The argument is &str → owned String for the async block.
69        let tid = scope.tenant_id.clone();
70        Box::pin(async move {
71            chunk.tenant_id = if tid.is_empty() { None } else { Some(tid) };
72            let mut data = self.chunks.write().await;
73            data.insert(chunk.id.clone(), chunk);
74            Ok(())
75        })
76    }
77
78    fn search(
79        &self,
80        scope: &TenantScope,
81        query: KnowledgeQuery,
82    ) -> Pin<Box<dyn Future<Output = Result<Vec<SearchResult>, Error>> + Send + '_>> {
83        let tid = scope.tenant_id.clone();
84        Box::pin(async move {
85            let data = self.chunks.read().await;
86            let tokens = tokenize(&query.text);
87
88            if tokens.is_empty() {
89                return Ok(vec![]);
90            }
91
92            // Tenant filter: keep chunks whose tenant_id matches scope.
93            // Single-tenant scope (`""`) matches `None` and `""`.
94            let tenant_match = move |chunk: &Chunk| -> bool {
95                let chunk_tid = chunk.tenant_id.as_deref().unwrap_or("");
96                chunk_tid == tid.as_str()
97            };
98
99            let mut results: Vec<SearchResult> = data
100                .values()
101                .filter(|chunk| tenant_match(chunk))
102                .filter(|chunk| {
103                    if let Some(ref filter) = query.source_filter {
104                        chunk.source.uri.starts_with(filter)
105                    } else {
106                        true
107                    }
108                })
109                .filter_map(|chunk| {
110                    let matches = count_matches(&tokens, &chunk.content);
111                    if matches > 0 {
112                        Some(SearchResult {
113                            chunk: chunk.clone(),
114                            match_count: matches,
115                        })
116                    } else {
117                        None
118                    }
119                })
120                .collect();
121
122            // Sort by match count descending, then chunk_index, then source URI for full stability
123            results.sort_by(|a, b| {
124                b.match_count
125                    .cmp(&a.match_count)
126                    .then_with(|| a.chunk.chunk_index.cmp(&b.chunk.chunk_index))
127                    .then_with(|| a.chunk.source.uri.cmp(&b.chunk.source.uri))
128            });
129
130            if query.limit > 0 {
131                results.truncate(query.limit);
132            }
133
134            Ok(results)
135        })
136    }
137
138    fn chunk_count(
139        &self,
140        scope: &TenantScope,
141    ) -> Pin<Box<dyn Future<Output = Result<usize, Error>> + Send + '_>> {
142        let tid = scope.tenant_id.clone();
143        Box::pin(async move {
144            let data = self.chunks.read().await;
145            let count = data
146                .values()
147                .filter(|c| c.tenant_id.as_deref().unwrap_or("") == tid.as_str())
148                .count();
149            Ok(count)
150        })
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::knowledge::DocumentSource;
158    use std::sync::Arc;
159
160    fn make_chunk(id: &str, content: &str, uri: &str, index: usize) -> Chunk {
161        Chunk {
162            id: id.into(),
163            content: content.into(),
164            source: DocumentSource {
165                uri: uri.into(),
166                title: uri.into(),
167            },
168            chunk_index: index,
169            tenant_id: None,
170        }
171    }
172
173    fn s() -> TenantScope {
174        TenantScope::default()
175    }
176
177    #[tokio::test]
178    async fn index_and_search_roundtrip() {
179        let kb = InMemoryKnowledgeBase::new();
180        kb.index(
181            &s(),
182            make_chunk(
183                "c1",
184                "Rust is a systems programming language",
185                "docs/rust.md",
186                0,
187            ),
188        )
189        .await
190        .unwrap();
191
192        let results = kb
193            .search(
194                &s(),
195                KnowledgeQuery {
196                    text: "rust programming".into(),
197                    source_filter: None,
198                    limit: 5,
199                },
200            )
201            .await
202            .unwrap();
203
204        assert_eq!(results.len(), 1);
205        assert_eq!(results[0].chunk.id, "c1");
206        assert_eq!(results[0].match_count, 2); // "rust" + "programming"
207    }
208
209    #[tokio::test]
210    async fn search_is_case_insensitive() {
211        let kb = InMemoryKnowledgeBase::new();
212        kb.index(&s(), make_chunk("c1", "RUST is GREAT", "f.md", 0))
213            .await
214            .unwrap();
215
216        let results = kb
217            .search(
218                &s(),
219                KnowledgeQuery {
220                    text: "rust great".into(),
221                    source_filter: None,
222                    limit: 5,
223                },
224            )
225            .await
226            .unwrap();
227
228        assert_eq!(results.len(), 1);
229        assert_eq!(results[0].match_count, 2);
230    }
231
232    #[tokio::test]
233    async fn source_filter_restricts_results() {
234        let kb = InMemoryKnowledgeBase::new();
235        kb.index(&s(), make_chunk("c1", "Rust language", "docs/rust.md", 0))
236            .await
237            .unwrap();
238        kb.index(&s(), make_chunk("c2", "Rust compiler", "api/rust.md", 0))
239            .await
240            .unwrap();
241
242        let results = kb
243            .search(
244                &s(),
245                KnowledgeQuery {
246                    text: "rust".into(),
247                    source_filter: Some("docs/".into()),
248                    limit: 10,
249                },
250            )
251            .await
252            .unwrap();
253
254        assert_eq!(results.len(), 1);
255        assert_eq!(results[0].chunk.source.uri, "docs/rust.md");
256    }
257
258    #[tokio::test]
259    async fn limit_truncates_results() {
260        let kb = InMemoryKnowledgeBase::new();
261        for i in 0..10 {
262            kb.index(
263                &s(),
264                make_chunk(
265                    &format!("c{i}"),
266                    "rust programming language",
267                    "docs/rust.md",
268                    i,
269                ),
270            )
271            .await
272            .unwrap();
273        }
274
275        let results = kb
276            .search(
277                &s(),
278                KnowledgeQuery {
279                    text: "rust".into(),
280                    source_filter: None,
281                    limit: 3,
282                },
283            )
284            .await
285            .unwrap();
286
287        assert_eq!(results.len(), 3);
288    }
289
290    #[tokio::test]
291    async fn sorted_by_match_count_descending() {
292        let kb = InMemoryKnowledgeBase::new();
293        kb.index(&s(), make_chunk("c1", "rust", "f.md", 0))
294            .await
295            .unwrap();
296        kb.index(
297            &s(),
298            make_chunk("c2", "rust programming rust systems", "f.md", 1),
299        )
300        .await
301        .unwrap();
302        kb.index(&s(), make_chunk("c3", "rust programming", "f.md", 2))
303            .await
304            .unwrap();
305
306        let results = kb
307            .search(
308                &s(),
309                KnowledgeQuery {
310                    text: "rust programming systems".into(),
311                    source_filter: None,
312                    limit: 10,
313                },
314            )
315            .await
316            .unwrap();
317
318        assert_eq!(results.len(), 3);
319        assert_eq!(results[0].chunk.id, "c2"); // 3 matches
320        assert_eq!(results[1].chunk.id, "c3"); // 2 matches
321        assert_eq!(results[2].chunk.id, "c1"); // 1 match
322    }
323
324    #[tokio::test]
325    async fn reindex_replaces_chunk() {
326        let kb = InMemoryKnowledgeBase::new();
327        kb.index(&s(), make_chunk("c1", "old content", "f.md", 0))
328            .await
329            .unwrap();
330        kb.index(&s(), make_chunk("c1", "new content about rust", "f.md", 0))
331            .await
332            .unwrap();
333
334        assert_eq!(kb.chunk_count(&s()).await.unwrap(), 1);
335
336        let results = kb
337            .search(
338                &s(),
339                KnowledgeQuery {
340                    text: "rust".into(),
341                    source_filter: None,
342                    limit: 5,
343                },
344            )
345            .await
346            .unwrap();
347
348        assert_eq!(results.len(), 1);
349        assert!(results[0].chunk.content.contains("new content"));
350    }
351
352    #[tokio::test]
353    async fn empty_query_returns_no_results() {
354        let kb = InMemoryKnowledgeBase::new();
355        kb.index(&s(), make_chunk("c1", "some content", "f.md", 0))
356            .await
357            .unwrap();
358
359        let results = kb
360            .search(
361                &s(),
362                KnowledgeQuery {
363                    text: "".into(),
364                    source_filter: None,
365                    limit: 5,
366                },
367            )
368            .await
369            .unwrap();
370
371        assert!(results.is_empty());
372    }
373
374    #[tokio::test]
375    async fn no_match_returns_empty() {
376        let kb = InMemoryKnowledgeBase::new();
377        kb.index(&s(), make_chunk("c1", "hello world", "f.md", 0))
378            .await
379            .unwrap();
380
381        let results = kb
382            .search(
383                &s(),
384                KnowledgeQuery {
385                    text: "zzzznotfound".into(),
386                    source_filter: None,
387                    limit: 5,
388                },
389            )
390            .await
391            .unwrap();
392
393        assert!(results.is_empty());
394    }
395
396    #[tokio::test]
397    async fn chunk_count_tracks_size() {
398        let kb = InMemoryKnowledgeBase::new();
399        assert_eq!(kb.chunk_count(&s()).await.unwrap(), 0);
400
401        kb.index(&s(), make_chunk("c1", "a", "f.md", 0))
402            .await
403            .unwrap();
404        kb.index(&s(), make_chunk("c2", "b", "f.md", 1))
405            .await
406            .unwrap();
407        assert_eq!(kb.chunk_count(&s()).await.unwrap(), 2);
408    }
409
410    #[test]
411    fn is_send_sync() {
412        fn assert_send_sync<T: Send + Sync>() {}
413        assert_send_sync::<InMemoryKnowledgeBase>();
414        fn _accepts_dyn(_kb: &dyn KnowledgeBase) {}
415    }
416
417    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
418    async fn concurrent_index_and_search() {
419        let kb = Arc::new(InMemoryKnowledgeBase::new());
420        let mut handles = Vec::new();
421
422        // Spawn writers
423        for i in 0..20 {
424            let kb = kb.clone();
425            handles.push(tokio::spawn(async move {
426                kb.index(
427                    &s(),
428                    make_chunk(
429                        &format!("c{i}"),
430                        &format!("rust content item {i}"),
431                        "f.md",
432                        i,
433                    ),
434                )
435                .await
436                .unwrap();
437            }));
438        }
439
440        // Spawn readers concurrently
441        for _ in 0..10 {
442            let kb = kb.clone();
443            handles.push(tokio::spawn(async move {
444                let _ = kb
445                    .search(
446                        &s(),
447                        KnowledgeQuery {
448                            text: "rust".into(),
449                            source_filter: None,
450                            limit: 5,
451                        },
452                    )
453                    .await
454                    .unwrap();
455            }));
456        }
457
458        for h in handles {
459            h.await.unwrap();
460        }
461
462        assert_eq!(kb.chunk_count(&s()).await.unwrap(), 20);
463    }
464
465    #[tokio::test]
466    async fn duplicate_query_terms_not_inflated() {
467        let kb = InMemoryKnowledgeBase::new();
468        kb.index(&s(), make_chunk("c1", "rust is great", "f.md", 0))
469            .await
470            .unwrap();
471
472        let results = kb
473            .search(
474                &s(),
475                KnowledgeQuery {
476                    text: "rust rust rust".into(),
477                    source_filter: None,
478                    limit: 5,
479                },
480            )
481            .await
482            .unwrap();
483
484        assert_eq!(results.len(), 1);
485        assert_eq!(results[0].match_count, 1); // deduplicated, not 3
486    }
487
488    /// SECURITY (F-KB-1): tenant A's chunks must not be visible to tenant B
489    /// when both share an `Arc<dyn KnowledgeBase>`. The trait now requires a
490    /// `&TenantScope` for both index and search; without filtering, a daemon
491    /// shared across tenants would leak documents.
492    #[tokio::test]
493    async fn search_isolates_by_tenant() {
494        let kb = InMemoryKnowledgeBase::new();
495        let scope_a = TenantScope::new("tenant-a");
496        let scope_b = TenantScope::new("tenant-b");
497
498        kb.index(
499            &scope_a,
500            make_chunk("a1", "alice secret rust note", "a/notes.md", 0),
501        )
502        .await
503        .unwrap();
504        kb.index(
505            &scope_b,
506            make_chunk("b1", "bob secret rust note", "b/notes.md", 0),
507        )
508        .await
509        .unwrap();
510
511        // Tenant A search must NOT return B's chunk.
512        let results_a = kb
513            .search(
514                &scope_a,
515                KnowledgeQuery {
516                    text: "rust".into(),
517                    source_filter: None,
518                    limit: 10,
519                },
520            )
521            .await
522            .unwrap();
523        assert_eq!(results_a.len(), 1);
524        assert_eq!(results_a[0].chunk.id, "a1");
525
526        // Tenant B sees only B's chunk.
527        let results_b = kb
528            .search(
529                &scope_b,
530                KnowledgeQuery {
531                    text: "rust".into(),
532                    source_filter: None,
533                    limit: 10,
534                },
535            )
536            .await
537            .unwrap();
538        assert_eq!(results_b.len(), 1);
539        assert_eq!(results_b[0].chunk.id, "b1");
540
541        // chunk_count is also tenant-scoped.
542        assert_eq!(kb.chunk_count(&scope_a).await.unwrap(), 1);
543        assert_eq!(kb.chunk_count(&scope_b).await.unwrap(), 1);
544    }
545
546    #[tokio::test]
547    async fn sort_stable_across_sources() {
548        let kb = InMemoryKnowledgeBase::new();
549        kb.index(&s(), make_chunk("c1", "rust programming", "z_file.md", 0))
550            .await
551            .unwrap();
552        kb.index(&s(), make_chunk("c2", "rust programming", "a_file.md", 0))
553            .await
554            .unwrap();
555
556        let results = kb
557            .search(
558                &s(),
559                KnowledgeQuery {
560                    text: "rust".into(),
561                    source_filter: None,
562                    limit: 10,
563                },
564            )
565            .await
566            .unwrap();
567
568        assert_eq!(results.len(), 2);
569        // Same match_count and chunk_index → sorted by source URI
570        assert_eq!(results[0].chunk.source.uri, "a_file.md");
571        assert_eq!(results[1].chunk.source.uri, "z_file.md");
572    }
573}