Skip to main content

brainwires_stores/memory/
fact_store.rs

1//! Persistent storage for cold tier key facts
2//!
3//! Uses a [`StorageBackend`](brainwires_storage::StorageBackend) for persistence with semantic search capability.
4
5use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use brainwires_storage::CachedEmbeddingProvider;
9use brainwires_storage::databases::{
10    FieldDef, FieldType, FieldValue, Filter, Record, ScoredRecord, StorageBackend, record_get,
11};
12
13use super::tier_types::{FactType, KeyFact};
14
15const TABLE_NAME: &str = "facts";
16
17// ── Schema ──────────────────────────────────────────────────────────────
18
19/// Return the backend-agnostic field definitions for the facts table.
20pub fn facts_field_defs(embedding_dim: usize) -> Vec<FieldDef> {
21    vec![
22        FieldDef::required("fact_id", FieldType::Utf8),
23        FieldDef::required("original_message_ids", FieldType::Utf8), // JSON array
24        FieldDef::required("conversation_id", FieldType::Utf8),
25        FieldDef::required("fact", FieldType::Utf8),
26        FieldDef::required("fact_type", FieldType::Utf8),
27        FieldDef::required("vector", FieldType::Vector(embedding_dim)),
28        FieldDef::required("created_at", FieldType::Int64),
29    ]
30}
31
32/// Arrow schema for the facts table, used by `LanceDatabase` table creation.
33pub fn facts_schema(embedding_dim: usize) -> std::sync::Arc<arrow_schema::Schema> {
34    use arrow_schema::{DataType, Field};
35
36    std::sync::Arc::new(arrow_schema::Schema::new(vec![
37        Field::new(
38            "vector",
39            DataType::FixedSizeList(
40                std::sync::Arc::new(Field::new("item", DataType::Float32, true)),
41                embedding_dim as i32,
42            ),
43            false,
44        ),
45        Field::new("fact_id", DataType::Utf8, false),
46        Field::new("original_message_ids", DataType::Utf8, false), // JSON array
47        Field::new("conversation_id", DataType::Utf8, false),
48        Field::new("fact", DataType::Utf8, false),
49        Field::new("fact_type", DataType::Utf8, false),
50        Field::new("created_at", DataType::Int64, false),
51    ]))
52}
53
54// ── Record conversion helpers ───────────────────────────────────────────
55
56fn to_record(fact: &KeyFact, embedding: Vec<f32>) -> Record {
57    let original_message_ids_json =
58        serde_json::to_string(&fact.original_message_ids).unwrap_or_else(|_| "[]".to_string());
59
60    vec![
61        (
62            "fact_id".into(),
63            FieldValue::Utf8(Some(fact.fact_id.clone())),
64        ),
65        (
66            "original_message_ids".into(),
67            FieldValue::Utf8(Some(original_message_ids_json)),
68        ),
69        (
70            "conversation_id".into(),
71            FieldValue::Utf8(Some(fact.conversation_id.clone())),
72        ),
73        ("fact".into(), FieldValue::Utf8(Some(fact.fact.clone()))),
74        (
75            "fact_type".into(),
76            FieldValue::Utf8(Some(fact_type_to_string(fact.fact_type).to_string())),
77        ),
78        ("vector".into(), FieldValue::Vector(embedding)),
79        (
80            "created_at".into(),
81            FieldValue::Int64(Some(fact.created_at)),
82        ),
83    ]
84}
85
86fn from_record(r: &Record) -> Result<KeyFact> {
87    let original_message_ids: Vec<String> = record_get(r, "original_message_ids")
88        .and_then(|v| v.as_str())
89        .and_then(|json| serde_json::from_str(json).ok())
90        .unwrap_or_default();
91
92    let fact_type = record_get(r, "fact_type")
93        .and_then(|v| v.as_str())
94        .map(string_to_fact_type)
95        .unwrap_or(FactType::Other);
96
97    Ok(KeyFact {
98        fact_id: record_get(r, "fact_id")
99            .and_then(|v| v.as_str())
100            .context("missing fact_id")?
101            .to_string(),
102        original_message_ids,
103        conversation_id: record_get(r, "conversation_id")
104            .and_then(|v| v.as_str())
105            .context("missing conversation_id")?
106            .to_string(),
107        fact: record_get(r, "fact")
108            .and_then(|v| v.as_str())
109            .context("missing fact")?
110            .to_string(),
111        fact_type,
112        created_at: record_get(r, "created_at")
113            .and_then(|v| v.as_i64())
114            .context("missing created_at")?,
115    })
116}
117
118/// Convert fact type to string for storage
119fn fact_type_to_string(fact_type: FactType) -> &'static str {
120    match fact_type {
121        FactType::Decision => "decision",
122        FactType::Definition => "definition",
123        FactType::Requirement => "requirement",
124        FactType::CodeChange => "code_change",
125        FactType::Configuration => "configuration",
126        FactType::Other => "other",
127    }
128}
129
130/// Convert string to fact type
131fn string_to_fact_type(s: &str) -> FactType {
132    match s {
133        "decision" => FactType::Decision,
134        "definition" => FactType::Definition,
135        "requirement" => FactType::Requirement,
136        "code_change" => FactType::CodeChange,
137        "configuration" => FactType::Configuration,
138        _ => FactType::Other,
139    }
140}
141
142// ── FactStore ───────────────────────────────────────────────────────────
143
144/// Store for cold tier key facts with semantic search
145pub struct FactStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
146    backend: Arc<B>,
147    embeddings: Arc<CachedEmbeddingProvider>,
148}
149
150impl<B: StorageBackend> FactStore<B> {
151    /// Create a new fact store
152    pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
153        Self {
154            backend,
155            embeddings,
156        }
157    }
158
159    /// Ensure the underlying table exists.
160    pub async fn ensure_table(&self) -> Result<()> {
161        let dim = self.embeddings.dimension();
162        self.backend
163            .ensure_table(TABLE_NAME, &facts_field_defs(dim))
164            .await
165    }
166
167    /// Add a fact to the store
168    pub async fn add(&self, fact: KeyFact) -> Result<()> {
169        let embedding = self.embeddings.embed(&fact.fact)?;
170        let record = to_record(&fact, embedding);
171
172        self.backend
173            .insert(TABLE_NAME, vec![record])
174            .await
175            .context("Failed to add fact")
176    }
177
178    /// Add multiple facts in batch
179    pub async fn add_batch(&self, facts: Vec<KeyFact>) -> Result<()> {
180        if facts.is_empty() {
181            return Ok(());
182        }
183
184        let contents: Vec<String> = facts.iter().map(|f| f.fact.clone()).collect();
185        let embeddings = self.embeddings.embed_batch(&contents)?;
186
187        let records: Vec<Record> = facts
188            .iter()
189            .zip(embeddings.into_iter())
190            .map(|(f, emb)| to_record(f, emb))
191            .collect();
192
193        self.backend
194            .insert(TABLE_NAME, records)
195            .await
196            .context("Failed to add facts")
197    }
198
199    /// Get a fact by ID
200    pub async fn get(&self, fact_id: &str) -> Result<Option<KeyFact>> {
201        let filter = Filter::Eq(
202            "fact_id".into(),
203            FieldValue::Utf8(Some(fact_id.to_string())),
204        );
205        let records = self
206            .backend
207            .query(TABLE_NAME, Some(&filter), Some(1))
208            .await?;
209
210        match records.first() {
211            Some(r) => Ok(Some(from_record(r)?)),
212            None => Ok(None),
213        }
214    }
215
216    /// Get all facts for a conversation
217    pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<KeyFact>> {
218        let filter = Filter::Eq(
219            "conversation_id".into(),
220            FieldValue::Utf8(Some(conversation_id.to_string())),
221        );
222        let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
223
224        records.iter().map(from_record).collect()
225    }
226
227    /// Search facts by semantic similarity
228    pub async fn search(
229        &self,
230        query: &str,
231        limit: usize,
232        min_score: f32,
233    ) -> Result<Vec<(KeyFact, f32)>> {
234        self.search_with_filter(query, limit, min_score, None).await
235    }
236
237    /// Search facts within a specific conversation
238    pub async fn search_conversation(
239        &self,
240        conversation_id: &str,
241        query: &str,
242        limit: usize,
243        min_score: f32,
244    ) -> Result<Vec<(KeyFact, f32)>> {
245        let filter = Filter::Eq(
246            "conversation_id".into(),
247            FieldValue::Utf8(Some(conversation_id.to_string())),
248        );
249        self.search_with_filter(query, limit, min_score, Some(filter))
250            .await
251    }
252
253    /// Search facts with optional filter
254    async fn search_with_filter(
255        &self,
256        query: &str,
257        limit: usize,
258        min_score: f32,
259        filter: Option<Filter>,
260    ) -> Result<Vec<(KeyFact, f32)>> {
261        let query_embedding = self.embeddings.embed_cached(query)?;
262
263        let scored = self
264            .backend
265            .vector_search(
266                TABLE_NAME,
267                "vector",
268                query_embedding,
269                limit,
270                filter.as_ref(),
271            )
272            .await?;
273
274        scored_records_to_facts(&scored, min_score)
275    }
276
277    /// Delete a fact by ID
278    pub async fn delete(&self, fact_id: &str) -> Result<()> {
279        let filter = Filter::Eq(
280            "fact_id".into(),
281            FieldValue::Utf8(Some(fact_id.to_string())),
282        );
283        self.backend
284            .delete(TABLE_NAME, &filter)
285            .await
286            .context("Failed to delete fact")
287    }
288
289    /// Get count of facts
290    pub async fn count(&self) -> Result<usize> {
291        self.backend.count(TABLE_NAME, None).await
292    }
293}
294
295// ── Helpers ─────────────────────────────────────────────────────────────
296
297fn scored_records_to_facts(scored: &[ScoredRecord], min_score: f32) -> Result<Vec<(KeyFact, f32)>> {
298    let mut results = Vec::new();
299    for sr in scored {
300        if sr.score >= min_score {
301            let fact = from_record(&sr.record)?;
302            results.push((fact, sr.score));
303        }
304    }
305    Ok(results)
306}