Skip to main content

brainwires_stores/memory/
summary_store.rs

1//! Persistent storage for warm tier message summaries
2//!
3//! Uses a [`StorageBackend`](brainwires_storage::StorageBackend) for persistence with semantic search capability.
4
5use anyhow::{Context, Result};
6use std::sync::Arc;
7
8use brainwires_storage::CachedEmbeddingProvider;
9use brainwires_storage::databases::{
10    FieldDef, FieldType, FieldValue, Filter, Record, ScoredRecord, StorageBackend, record_get,
11};
12
13use super::tier_types::MessageSummary;
14
15const TABLE_NAME: &str = "summaries";
16
17// ── Schema ──────────────────────────────────────────────────────────────
18
19/// Return the backend-agnostic field definitions for the summaries table.
20pub fn summaries_field_defs(embedding_dim: usize) -> Vec<FieldDef> {
21    vec![
22        FieldDef::required("summary_id", FieldType::Utf8),
23        FieldDef::required("original_message_id", FieldType::Utf8),
24        FieldDef::required("conversation_id", FieldType::Utf8),
25        FieldDef::required("role", FieldType::Utf8),
26        FieldDef::required("summary", FieldType::Utf8),
27        FieldDef::required("key_entities", FieldType::Utf8), // JSON array
28        FieldDef::required("vector", FieldType::Vector(embedding_dim)),
29        FieldDef::required("created_at", FieldType::Int64),
30    ]
31}
32
33/// Arrow schema for the summaries table, used by `LanceDatabase` table creation.
34pub fn summaries_schema(embedding_dim: usize) -> std::sync::Arc<arrow_schema::Schema> {
35    use arrow_schema::{DataType, Field};
36
37    std::sync::Arc::new(arrow_schema::Schema::new(vec![
38        Field::new(
39            "vector",
40            DataType::FixedSizeList(
41                std::sync::Arc::new(Field::new("item", DataType::Float32, true)),
42                embedding_dim as i32,
43            ),
44            false,
45        ),
46        Field::new("summary_id", DataType::Utf8, false),
47        Field::new("original_message_id", DataType::Utf8, false),
48        Field::new("conversation_id", DataType::Utf8, false),
49        Field::new("role", DataType::Utf8, false),
50        Field::new("summary", DataType::Utf8, false),
51        Field::new("key_entities", DataType::Utf8, false),
52        Field::new("created_at", DataType::Int64, false),
53    ]))
54}
55
56// ── Record conversion helpers ───────────────────────────────────────────
57
58fn to_record(summary: &MessageSummary, embedding: Vec<f32>) -> Record {
59    let key_entities_json =
60        serde_json::to_string(&summary.key_entities).unwrap_or_else(|_| "[]".to_string());
61
62    vec![
63        (
64            "summary_id".into(),
65            FieldValue::Utf8(Some(summary.summary_id.clone())),
66        ),
67        (
68            "original_message_id".into(),
69            FieldValue::Utf8(Some(summary.original_message_id.clone())),
70        ),
71        (
72            "conversation_id".into(),
73            FieldValue::Utf8(Some(summary.conversation_id.clone())),
74        ),
75        ("role".into(), FieldValue::Utf8(Some(summary.role.clone()))),
76        (
77            "summary".into(),
78            FieldValue::Utf8(Some(summary.summary.clone())),
79        ),
80        (
81            "key_entities".into(),
82            FieldValue::Utf8(Some(key_entities_json)),
83        ),
84        ("vector".into(), FieldValue::Vector(embedding)),
85        (
86            "created_at".into(),
87            FieldValue::Int64(Some(summary.created_at)),
88        ),
89    ]
90}
91
92fn from_record(r: &Record) -> Result<MessageSummary> {
93    let key_entities: Vec<String> = record_get(r, "key_entities")
94        .and_then(|v| v.as_str())
95        .and_then(|json| serde_json::from_str(json).ok())
96        .unwrap_or_default();
97
98    Ok(MessageSummary {
99        summary_id: record_get(r, "summary_id")
100            .and_then(|v| v.as_str())
101            .context("missing summary_id")?
102            .to_string(),
103        original_message_id: record_get(r, "original_message_id")
104            .and_then(|v| v.as_str())
105            .context("missing original_message_id")?
106            .to_string(),
107        conversation_id: record_get(r, "conversation_id")
108            .and_then(|v| v.as_str())
109            .context("missing conversation_id")?
110            .to_string(),
111        role: record_get(r, "role")
112            .and_then(|v| v.as_str())
113            .context("missing role")?
114            .to_string(),
115        summary: record_get(r, "summary")
116            .and_then(|v| v.as_str())
117            .context("missing summary")?
118            .to_string(),
119        key_entities,
120        created_at: record_get(r, "created_at")
121            .and_then(|v| v.as_i64())
122            .context("missing created_at")?,
123    })
124}
125
126// ── SummaryStore ────────────────────────────────────────────────────────
127
128/// Store for warm tier message summaries with semantic search
129pub struct SummaryStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
130    backend: Arc<B>,
131    embeddings: Arc<CachedEmbeddingProvider>,
132}
133
134impl<B: StorageBackend> SummaryStore<B> {
135    /// Create a new summary store
136    pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
137        Self {
138            backend,
139            embeddings,
140        }
141    }
142
143    /// Ensure the underlying table exists.
144    pub async fn ensure_table(&self) -> Result<()> {
145        let dim = self.embeddings.dimension();
146        self.backend
147            .ensure_table(TABLE_NAME, &summaries_field_defs(dim))
148            .await
149    }
150
151    /// Add a summary to the store
152    pub async fn add(&self, summary: MessageSummary) -> Result<()> {
153        let embedding = self.embeddings.embed(&summary.summary)?;
154        let record = to_record(&summary, embedding);
155
156        self.backend
157            .insert(TABLE_NAME, vec![record])
158            .await
159            .context("Failed to add summary")
160    }
161
162    /// Add multiple summaries in batch
163    pub async fn add_batch(&self, summaries: Vec<MessageSummary>) -> Result<()> {
164        if summaries.is_empty() {
165            return Ok(());
166        }
167
168        let contents: Vec<String> = summaries.iter().map(|s| s.summary.clone()).collect();
169        let embeddings = self.embeddings.embed_batch(&contents)?;
170
171        let records: Vec<Record> = summaries
172            .iter()
173            .zip(embeddings.into_iter())
174            .map(|(s, emb)| to_record(s, emb))
175            .collect();
176
177        self.backend
178            .insert(TABLE_NAME, records)
179            .await
180            .context("Failed to add summaries")
181    }
182
183    /// Get a summary by ID
184    pub async fn get(&self, summary_id: &str) -> Result<Option<MessageSummary>> {
185        let filter = Filter::Eq(
186            "summary_id".into(),
187            FieldValue::Utf8(Some(summary_id.to_string())),
188        );
189        let records = self
190            .backend
191            .query(TABLE_NAME, Some(&filter), Some(1))
192            .await?;
193
194        match records.first() {
195            Some(r) => Ok(Some(from_record(r)?)),
196            None => Ok(None),
197        }
198    }
199
200    /// Get all summaries for a conversation
201    pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<MessageSummary>> {
202        let filter = Filter::Eq(
203            "conversation_id".into(),
204            FieldValue::Utf8(Some(conversation_id.to_string())),
205        );
206        let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
207
208        records.iter().map(from_record).collect()
209    }
210
211    /// Search summaries by semantic similarity
212    pub async fn search(
213        &self,
214        query: &str,
215        limit: usize,
216        min_score: f32,
217    ) -> Result<Vec<(MessageSummary, f32)>> {
218        self.search_with_filter(query, limit, min_score, None).await
219    }
220
221    /// Search summaries within a specific conversation
222    pub async fn search_conversation(
223        &self,
224        conversation_id: &str,
225        query: &str,
226        limit: usize,
227        min_score: f32,
228    ) -> Result<Vec<(MessageSummary, f32)>> {
229        let filter = Filter::Eq(
230            "conversation_id".into(),
231            FieldValue::Utf8(Some(conversation_id.to_string())),
232        );
233        self.search_with_filter(query, limit, min_score, Some(filter))
234            .await
235    }
236
237    /// Search summaries with optional filter
238    async fn search_with_filter(
239        &self,
240        query: &str,
241        limit: usize,
242        min_score: f32,
243        filter: Option<Filter>,
244    ) -> Result<Vec<(MessageSummary, f32)>> {
245        let query_embedding = self.embeddings.embed_cached(query)?;
246
247        let scored = self
248            .backend
249            .vector_search(
250                TABLE_NAME,
251                "vector",
252                query_embedding,
253                limit,
254                filter.as_ref(),
255            )
256            .await?;
257
258        scored_records_to_summaries(&scored, min_score)
259    }
260
261    /// Delete a summary by ID
262    pub async fn delete(&self, summary_id: &str) -> Result<()> {
263        let filter = Filter::Eq(
264            "summary_id".into(),
265            FieldValue::Utf8(Some(summary_id.to_string())),
266        );
267        self.backend
268            .delete(TABLE_NAME, &filter)
269            .await
270            .context("Failed to delete summary")
271    }
272
273    /// Get count of summaries
274    pub async fn count(&self) -> Result<usize> {
275        self.backend.count(TABLE_NAME, None).await
276    }
277
278    /// Get oldest summaries (for demotion to cold tier)
279    pub async fn get_oldest(&self, limit: usize) -> Result<Vec<MessageSummary>> {
280        let records = self.backend.query(TABLE_NAME, None, None).await?;
281
282        let mut summaries: Vec<MessageSummary> =
283            records.iter().filter_map(|r| from_record(r).ok()).collect();
284
285        // Sort by created_at ascending (oldest first)
286        summaries.sort_by_key(|s| s.created_at);
287        summaries.truncate(limit);
288
289        Ok(summaries)
290    }
291}
292
293// ── Helpers ─────────────────────────────────────────────────────────────
294
295fn scored_records_to_summaries(
296    scored: &[ScoredRecord],
297    min_score: f32,
298) -> Result<Vec<(MessageSummary, f32)>> {
299    let mut results = Vec::new();
300    for sr in scored {
301        if sr.score >= min_score {
302            let summary = from_record(&sr.record)?;
303            results.push((summary, sr.score));
304        }
305    }
306    Ok(results)
307}