brainwires_stores/memory/
summary_store.rs1use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use brainwires_storage::CachedEmbeddingProvider;
9use brainwires_storage::databases::{
10 FieldDef, FieldType, FieldValue, Filter, Record, ScoredRecord, StorageBackend, record_get,
11};
12
13use super::tier_types::MessageSummary;
14
15const TABLE_NAME: &str = "summaries";
16
17pub fn summaries_field_defs(embedding_dim: usize) -> Vec<FieldDef> {
21 vec![
22 FieldDef::required("summary_id", FieldType::Utf8),
23 FieldDef::required("original_message_id", FieldType::Utf8),
24 FieldDef::required("conversation_id", FieldType::Utf8),
25 FieldDef::required("role", FieldType::Utf8),
26 FieldDef::required("summary", FieldType::Utf8),
27 FieldDef::required("key_entities", FieldType::Utf8), FieldDef::required("vector", FieldType::Vector(embedding_dim)),
29 FieldDef::required("created_at", FieldType::Int64),
30 ]
31}
32
33pub fn summaries_schema(embedding_dim: usize) -> std::sync::Arc<arrow_schema::Schema> {
35 use arrow_schema::{DataType, Field};
36
37 std::sync::Arc::new(arrow_schema::Schema::new(vec![
38 Field::new(
39 "vector",
40 DataType::FixedSizeList(
41 std::sync::Arc::new(Field::new("item", DataType::Float32, true)),
42 embedding_dim as i32,
43 ),
44 false,
45 ),
46 Field::new("summary_id", DataType::Utf8, false),
47 Field::new("original_message_id", DataType::Utf8, false),
48 Field::new("conversation_id", DataType::Utf8, false),
49 Field::new("role", DataType::Utf8, false),
50 Field::new("summary", DataType::Utf8, false),
51 Field::new("key_entities", DataType::Utf8, false),
52 Field::new("created_at", DataType::Int64, false),
53 ]))
54}
55
56fn to_record(summary: &MessageSummary, embedding: Vec<f32>) -> Record {
59 let key_entities_json =
60 serde_json::to_string(&summary.key_entities).unwrap_or_else(|_| "[]".to_string());
61
62 vec![
63 (
64 "summary_id".into(),
65 FieldValue::Utf8(Some(summary.summary_id.clone())),
66 ),
67 (
68 "original_message_id".into(),
69 FieldValue::Utf8(Some(summary.original_message_id.clone())),
70 ),
71 (
72 "conversation_id".into(),
73 FieldValue::Utf8(Some(summary.conversation_id.clone())),
74 ),
75 ("role".into(), FieldValue::Utf8(Some(summary.role.clone()))),
76 (
77 "summary".into(),
78 FieldValue::Utf8(Some(summary.summary.clone())),
79 ),
80 (
81 "key_entities".into(),
82 FieldValue::Utf8(Some(key_entities_json)),
83 ),
84 ("vector".into(), FieldValue::Vector(embedding)),
85 (
86 "created_at".into(),
87 FieldValue::Int64(Some(summary.created_at)),
88 ),
89 ]
90}
91
92fn from_record(r: &Record) -> Result<MessageSummary> {
93 let key_entities: Vec<String> = record_get(r, "key_entities")
94 .and_then(|v| v.as_str())
95 .and_then(|json| serde_json::from_str(json).ok())
96 .unwrap_or_default();
97
98 Ok(MessageSummary {
99 summary_id: record_get(r, "summary_id")
100 .and_then(|v| v.as_str())
101 .context("missing summary_id")?
102 .to_string(),
103 original_message_id: record_get(r, "original_message_id")
104 .and_then(|v| v.as_str())
105 .context("missing original_message_id")?
106 .to_string(),
107 conversation_id: record_get(r, "conversation_id")
108 .and_then(|v| v.as_str())
109 .context("missing conversation_id")?
110 .to_string(),
111 role: record_get(r, "role")
112 .and_then(|v| v.as_str())
113 .context("missing role")?
114 .to_string(),
115 summary: record_get(r, "summary")
116 .and_then(|v| v.as_str())
117 .context("missing summary")?
118 .to_string(),
119 key_entities,
120 created_at: record_get(r, "created_at")
121 .and_then(|v| v.as_i64())
122 .context("missing created_at")?,
123 })
124}
125
126pub struct SummaryStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
130 backend: Arc<B>,
131 embeddings: Arc<CachedEmbeddingProvider>,
132}
133
134impl<B: StorageBackend> SummaryStore<B> {
135 pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
137 Self {
138 backend,
139 embeddings,
140 }
141 }
142
143 pub async fn ensure_table(&self) -> Result<()> {
145 let dim = self.embeddings.dimension();
146 self.backend
147 .ensure_table(TABLE_NAME, &summaries_field_defs(dim))
148 .await
149 }
150
151 pub async fn add(&self, summary: MessageSummary) -> Result<()> {
153 let embedding = self.embeddings.embed(&summary.summary)?;
154 let record = to_record(&summary, embedding);
155
156 self.backend
157 .insert(TABLE_NAME, vec![record])
158 .await
159 .context("Failed to add summary")
160 }
161
162 pub async fn add_batch(&self, summaries: Vec<MessageSummary>) -> Result<()> {
164 if summaries.is_empty() {
165 return Ok(());
166 }
167
168 let contents: Vec<String> = summaries.iter().map(|s| s.summary.clone()).collect();
169 let embeddings = self.embeddings.embed_batch(&contents)?;
170
171 let records: Vec<Record> = summaries
172 .iter()
173 .zip(embeddings.into_iter())
174 .map(|(s, emb)| to_record(s, emb))
175 .collect();
176
177 self.backend
178 .insert(TABLE_NAME, records)
179 .await
180 .context("Failed to add summaries")
181 }
182
183 pub async fn get(&self, summary_id: &str) -> Result<Option<MessageSummary>> {
185 let filter = Filter::Eq(
186 "summary_id".into(),
187 FieldValue::Utf8(Some(summary_id.to_string())),
188 );
189 let records = self
190 .backend
191 .query(TABLE_NAME, Some(&filter), Some(1))
192 .await?;
193
194 match records.first() {
195 Some(r) => Ok(Some(from_record(r)?)),
196 None => Ok(None),
197 }
198 }
199
200 pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<MessageSummary>> {
202 let filter = Filter::Eq(
203 "conversation_id".into(),
204 FieldValue::Utf8(Some(conversation_id.to_string())),
205 );
206 let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
207
208 records.iter().map(from_record).collect()
209 }
210
211 pub async fn search(
213 &self,
214 query: &str,
215 limit: usize,
216 min_score: f32,
217 ) -> Result<Vec<(MessageSummary, f32)>> {
218 self.search_with_filter(query, limit, min_score, None).await
219 }
220
221 pub async fn search_conversation(
223 &self,
224 conversation_id: &str,
225 query: &str,
226 limit: usize,
227 min_score: f32,
228 ) -> Result<Vec<(MessageSummary, f32)>> {
229 let filter = Filter::Eq(
230 "conversation_id".into(),
231 FieldValue::Utf8(Some(conversation_id.to_string())),
232 );
233 self.search_with_filter(query, limit, min_score, Some(filter))
234 .await
235 }
236
237 async fn search_with_filter(
239 &self,
240 query: &str,
241 limit: usize,
242 min_score: f32,
243 filter: Option<Filter>,
244 ) -> Result<Vec<(MessageSummary, f32)>> {
245 let query_embedding = self.embeddings.embed_cached(query)?;
246
247 let scored = self
248 .backend
249 .vector_search(
250 TABLE_NAME,
251 "vector",
252 query_embedding,
253 limit,
254 filter.as_ref(),
255 )
256 .await?;
257
258 scored_records_to_summaries(&scored, min_score)
259 }
260
261 pub async fn delete(&self, summary_id: &str) -> Result<()> {
263 let filter = Filter::Eq(
264 "summary_id".into(),
265 FieldValue::Utf8(Some(summary_id.to_string())),
266 );
267 self.backend
268 .delete(TABLE_NAME, &filter)
269 .await
270 .context("Failed to delete summary")
271 }
272
273 pub async fn count(&self) -> Result<usize> {
275 self.backend.count(TABLE_NAME, None).await
276 }
277
278 pub async fn get_oldest(&self, limit: usize) -> Result<Vec<MessageSummary>> {
280 let records = self.backend.query(TABLE_NAME, None, None).await?;
281
282 let mut summaries: Vec<MessageSummary> =
283 records.iter().filter_map(|r| from_record(r).ok()).collect();
284
285 summaries.sort_by_key(|s| s.created_at);
287 summaries.truncate(limit);
288
289 Ok(summaries)
290 }
291}
292
293fn scored_records_to_summaries(
296 scored: &[ScoredRecord],
297 min_score: f32,
298) -> Result<Vec<(MessageSummary, f32)>> {
299 let mut results = Vec::new();
300 for sr in scored {
301 if sr.score >= min_score {
302 let summary = from_record(&sr.record)?;
303 results.push((summary, sr.score));
304 }
305 }
306 Ok(results)
307}