brainwires_stores/memory/
fact_store.rs1use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use brainwires_storage::CachedEmbeddingProvider;
9use brainwires_storage::databases::{
10 FieldDef, FieldType, FieldValue, Filter, Record, ScoredRecord, StorageBackend, record_get,
11};
12
13use super::tier_types::{FactType, KeyFact};
14
15const TABLE_NAME: &str = "facts";
16
17pub fn facts_field_defs(embedding_dim: usize) -> Vec<FieldDef> {
21 vec![
22 FieldDef::required("fact_id", FieldType::Utf8),
23 FieldDef::required("original_message_ids", FieldType::Utf8), FieldDef::required("conversation_id", FieldType::Utf8),
25 FieldDef::required("fact", FieldType::Utf8),
26 FieldDef::required("fact_type", FieldType::Utf8),
27 FieldDef::required("vector", FieldType::Vector(embedding_dim)),
28 FieldDef::required("created_at", FieldType::Int64),
29 ]
30}
31
32pub fn facts_schema(embedding_dim: usize) -> std::sync::Arc<arrow_schema::Schema> {
34 use arrow_schema::{DataType, Field};
35
36 std::sync::Arc::new(arrow_schema::Schema::new(vec![
37 Field::new(
38 "vector",
39 DataType::FixedSizeList(
40 std::sync::Arc::new(Field::new("item", DataType::Float32, true)),
41 embedding_dim as i32,
42 ),
43 false,
44 ),
45 Field::new("fact_id", DataType::Utf8, false),
46 Field::new("original_message_ids", DataType::Utf8, false), Field::new("conversation_id", DataType::Utf8, false),
48 Field::new("fact", DataType::Utf8, false),
49 Field::new("fact_type", DataType::Utf8, false),
50 Field::new("created_at", DataType::Int64, false),
51 ]))
52}
53
54fn to_record(fact: &KeyFact, embedding: Vec<f32>) -> Record {
57 let original_message_ids_json =
58 serde_json::to_string(&fact.original_message_ids).unwrap_or_else(|_| "[]".to_string());
59
60 vec![
61 (
62 "fact_id".into(),
63 FieldValue::Utf8(Some(fact.fact_id.clone())),
64 ),
65 (
66 "original_message_ids".into(),
67 FieldValue::Utf8(Some(original_message_ids_json)),
68 ),
69 (
70 "conversation_id".into(),
71 FieldValue::Utf8(Some(fact.conversation_id.clone())),
72 ),
73 ("fact".into(), FieldValue::Utf8(Some(fact.fact.clone()))),
74 (
75 "fact_type".into(),
76 FieldValue::Utf8(Some(fact_type_to_string(fact.fact_type).to_string())),
77 ),
78 ("vector".into(), FieldValue::Vector(embedding)),
79 (
80 "created_at".into(),
81 FieldValue::Int64(Some(fact.created_at)),
82 ),
83 ]
84}
85
86fn from_record(r: &Record) -> Result<KeyFact> {
87 let original_message_ids: Vec<String> = record_get(r, "original_message_ids")
88 .and_then(|v| v.as_str())
89 .and_then(|json| serde_json::from_str(json).ok())
90 .unwrap_or_default();
91
92 let fact_type = record_get(r, "fact_type")
93 .and_then(|v| v.as_str())
94 .map(string_to_fact_type)
95 .unwrap_or(FactType::Other);
96
97 Ok(KeyFact {
98 fact_id: record_get(r, "fact_id")
99 .and_then(|v| v.as_str())
100 .context("missing fact_id")?
101 .to_string(),
102 original_message_ids,
103 conversation_id: record_get(r, "conversation_id")
104 .and_then(|v| v.as_str())
105 .context("missing conversation_id")?
106 .to_string(),
107 fact: record_get(r, "fact")
108 .and_then(|v| v.as_str())
109 .context("missing fact")?
110 .to_string(),
111 fact_type,
112 created_at: record_get(r, "created_at")
113 .and_then(|v| v.as_i64())
114 .context("missing created_at")?,
115 })
116}
117
118fn fact_type_to_string(fact_type: FactType) -> &'static str {
120 match fact_type {
121 FactType::Decision => "decision",
122 FactType::Definition => "definition",
123 FactType::Requirement => "requirement",
124 FactType::CodeChange => "code_change",
125 FactType::Configuration => "configuration",
126 FactType::Other => "other",
127 }
128}
129
130fn string_to_fact_type(s: &str) -> FactType {
132 match s {
133 "decision" => FactType::Decision,
134 "definition" => FactType::Definition,
135 "requirement" => FactType::Requirement,
136 "code_change" => FactType::CodeChange,
137 "configuration" => FactType::Configuration,
138 _ => FactType::Other,
139 }
140}
141
142pub struct FactStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
146 backend: Arc<B>,
147 embeddings: Arc<CachedEmbeddingProvider>,
148}
149
150impl<B: StorageBackend> FactStore<B> {
151 pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
153 Self {
154 backend,
155 embeddings,
156 }
157 }
158
159 pub async fn ensure_table(&self) -> Result<()> {
161 let dim = self.embeddings.dimension();
162 self.backend
163 .ensure_table(TABLE_NAME, &facts_field_defs(dim))
164 .await
165 }
166
167 pub async fn add(&self, fact: KeyFact) -> Result<()> {
169 let embedding = self.embeddings.embed(&fact.fact)?;
170 let record = to_record(&fact, embedding);
171
172 self.backend
173 .insert(TABLE_NAME, vec![record])
174 .await
175 .context("Failed to add fact")
176 }
177
178 pub async fn add_batch(&self, facts: Vec<KeyFact>) -> Result<()> {
180 if facts.is_empty() {
181 return Ok(());
182 }
183
184 let contents: Vec<String> = facts.iter().map(|f| f.fact.clone()).collect();
185 let embeddings = self.embeddings.embed_batch(&contents)?;
186
187 let records: Vec<Record> = facts
188 .iter()
189 .zip(embeddings.into_iter())
190 .map(|(f, emb)| to_record(f, emb))
191 .collect();
192
193 self.backend
194 .insert(TABLE_NAME, records)
195 .await
196 .context("Failed to add facts")
197 }
198
199 pub async fn get(&self, fact_id: &str) -> Result<Option<KeyFact>> {
201 let filter = Filter::Eq(
202 "fact_id".into(),
203 FieldValue::Utf8(Some(fact_id.to_string())),
204 );
205 let records = self
206 .backend
207 .query(TABLE_NAME, Some(&filter), Some(1))
208 .await?;
209
210 match records.first() {
211 Some(r) => Ok(Some(from_record(r)?)),
212 None => Ok(None),
213 }
214 }
215
216 pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<KeyFact>> {
218 let filter = Filter::Eq(
219 "conversation_id".into(),
220 FieldValue::Utf8(Some(conversation_id.to_string())),
221 );
222 let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
223
224 records.iter().map(from_record).collect()
225 }
226
227 pub async fn search(
229 &self,
230 query: &str,
231 limit: usize,
232 min_score: f32,
233 ) -> Result<Vec<(KeyFact, f32)>> {
234 self.search_with_filter(query, limit, min_score, None).await
235 }
236
237 pub async fn search_conversation(
239 &self,
240 conversation_id: &str,
241 query: &str,
242 limit: usize,
243 min_score: f32,
244 ) -> Result<Vec<(KeyFact, f32)>> {
245 let filter = Filter::Eq(
246 "conversation_id".into(),
247 FieldValue::Utf8(Some(conversation_id.to_string())),
248 );
249 self.search_with_filter(query, limit, min_score, Some(filter))
250 .await
251 }
252
253 async fn search_with_filter(
255 &self,
256 query: &str,
257 limit: usize,
258 min_score: f32,
259 filter: Option<Filter>,
260 ) -> Result<Vec<(KeyFact, f32)>> {
261 let query_embedding = self.embeddings.embed_cached(query)?;
262
263 let scored = self
264 .backend
265 .vector_search(
266 TABLE_NAME,
267 "vector",
268 query_embedding,
269 limit,
270 filter.as_ref(),
271 )
272 .await?;
273
274 scored_records_to_facts(&scored, min_score)
275 }
276
277 pub async fn delete(&self, fact_id: &str) -> Result<()> {
279 let filter = Filter::Eq(
280 "fact_id".into(),
281 FieldValue::Utf8(Some(fact_id.to_string())),
282 );
283 self.backend
284 .delete(TABLE_NAME, &filter)
285 .await
286 .context("Failed to delete fact")
287 }
288
289 pub async fn count(&self) -> Result<usize> {
291 self.backend.count(TABLE_NAME, None).await
292 }
293}
294
295fn scored_records_to_facts(scored: &[ScoredRecord], min_score: f32) -> Result<Vec<(KeyFact, f32)>> {
298 let mut results = Vec::new();
299 for sr in scored {
300 if sr.score >= min_score {
301 let fact = from_record(&sr.record)?;
302 results.push((fact, sr.score));
303 }
304 }
305 Ok(results)
306}