1use anyhow::{Context, Result};
2use std::sync::Arc;
3
4use crate::databases::{
5 FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
6};
7use crate::embeddings::EmbeddingProvider;
8
9const TABLE_NAME: &str = "messages";
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct MessageMetadata {
14 pub message_id: String,
16 pub conversation_id: String,
18 pub role: String,
20 pub content: String,
22 pub token_count: Option<i32>,
24 pub model_id: Option<String>,
26 pub images: Option<String>, pub created_at: i64,
30 pub expires_at: Option<i64>,
36}
37
38fn table_schema(embedding_dim: usize) -> Vec<FieldDef> {
40 vec![
41 FieldDef::required("vector", FieldType::Vector(embedding_dim)),
42 FieldDef::required("message_id", FieldType::Utf8),
43 FieldDef::required("conversation_id", FieldType::Utf8),
44 FieldDef::required("role", FieldType::Utf8),
45 FieldDef::required("content", FieldType::Utf8),
46 FieldDef::optional("token_count", FieldType::Int32),
47 FieldDef::optional("model_id", FieldType::Utf8),
48 FieldDef::optional("images", FieldType::Utf8),
49 FieldDef::required("created_at", FieldType::Int64),
50 FieldDef::optional("expires_at", FieldType::Int64),
51 ]
52}
53
54#[cfg(feature = "native")]
56pub fn messages_schema(embedding_dim: usize) -> Arc<arrow_schema::Schema> {
57 use arrow_schema::{DataType, Field, Schema};
58
59 Arc::new(Schema::new(vec![
60 Field::new(
61 "vector",
62 DataType::FixedSizeList(
63 Arc::new(Field::new("item", DataType::Float32, true)),
64 embedding_dim as i32,
65 ),
66 false,
67 ),
68 Field::new("message_id", DataType::Utf8, false),
69 Field::new("conversation_id", DataType::Utf8, false),
70 Field::new("role", DataType::Utf8, false),
71 Field::new("content", DataType::Utf8, false),
72 Field::new("token_count", DataType::Int32, true),
73 Field::new("model_id", DataType::Utf8, true),
74 Field::new("images", DataType::Utf8, true),
75 Field::new("created_at", DataType::Int64, false),
76 Field::new("expires_at", DataType::Int64, true),
77 ]))
78}
79
80fn to_record(m: &MessageMetadata, embedding: Vec<f32>) -> Record {
81 vec![
82 ("vector".into(), FieldValue::Vector(embedding)),
83 (
84 "message_id".into(),
85 FieldValue::Utf8(Some(m.message_id.clone())),
86 ),
87 (
88 "conversation_id".into(),
89 FieldValue::Utf8(Some(m.conversation_id.clone())),
90 ),
91 ("role".into(), FieldValue::Utf8(Some(m.role.clone()))),
92 ("content".into(), FieldValue::Utf8(Some(m.content.clone()))),
93 ("token_count".into(), FieldValue::Int32(m.token_count)),
94 ("model_id".into(), FieldValue::Utf8(m.model_id.clone())),
95 ("images".into(), FieldValue::Utf8(m.images.clone())),
96 ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
97 ("expires_at".into(), FieldValue::Int64(m.expires_at)),
98 ]
99}
100
101fn from_record(r: &Record) -> Result<MessageMetadata> {
102 Ok(MessageMetadata {
103 message_id: record_get(r, "message_id")
104 .and_then(|v| v.as_str())
105 .context("missing message_id")?
106 .to_string(),
107 conversation_id: record_get(r, "conversation_id")
108 .and_then(|v| v.as_str())
109 .context("missing conversation_id")?
110 .to_string(),
111 role: record_get(r, "role")
112 .and_then(|v| v.as_str())
113 .context("missing role")?
114 .to_string(),
115 content: record_get(r, "content")
116 .and_then(|v| v.as_str())
117 .context("missing content")?
118 .to_string(),
119 token_count: record_get(r, "token_count").and_then(|v| v.as_i32()),
120 model_id: record_get(r, "model_id")
121 .and_then(|v| v.as_str())
122 .map(String::from),
123 images: record_get(r, "images")
124 .and_then(|v| v.as_str())
125 .map(String::from),
126 created_at: record_get(r, "created_at")
127 .and_then(|v| v.as_i64())
128 .context("missing created_at")?,
129 expires_at: record_get(r, "expires_at").and_then(|v| v.as_i64()),
130 })
131}
132
133pub struct MessageStore<B: StorageBackend = crate::databases::lance::LanceDatabase> {
135 backend: Arc<B>,
136 embeddings: Arc<EmbeddingProvider>,
137}
138
139impl<B: StorageBackend> MessageStore<B> {
140 pub fn new(backend: Arc<B>, embeddings: Arc<EmbeddingProvider>) -> Self {
142 Self {
143 backend,
144 embeddings,
145 }
146 }
147
148 pub async fn ensure_table(&self) -> Result<()> {
150 self.backend
151 .ensure_table(TABLE_NAME, &table_schema(self.embeddings.dimension()))
152 .await
153 }
154
155 pub async fn add(&self, message: MessageMetadata) -> Result<()> {
157 let embedding = self.embeddings.embed(&message.content)?;
159 let record = to_record(&message, embedding);
160
161 self.backend
162 .insert(TABLE_NAME, vec![record])
163 .await
164 .context("Failed to add message")?;
165
166 Ok(())
167 }
168
169 pub async fn add_batch(&self, messages: Vec<MessageMetadata>) -> Result<()> {
171 if messages.is_empty() {
172 return Ok(());
173 }
174
175 let contents: Vec<String> = messages.iter().map(|m| m.content.clone()).collect();
177 let embeddings = self.embeddings.embed_batch(&contents)?;
178
179 let records: Vec<Record> = messages
180 .iter()
181 .zip(embeddings.into_iter())
182 .map(|(m, emb)| to_record(m, emb))
183 .collect();
184
185 self.backend
186 .insert(TABLE_NAME, records)
187 .await
188 .context("Failed to add messages")?;
189
190 Ok(())
191 }
192
193 pub async fn get(&self, message_id: &str) -> Result<Option<MessageMetadata>> {
195 let filter = Filter::Eq(
196 "message_id".into(),
197 FieldValue::Utf8(Some(message_id.to_string())),
198 );
199 let records = self
200 .backend
201 .query(TABLE_NAME, Some(&filter), Some(1))
202 .await?;
203
204 match records.first() {
205 Some(r) => Ok(Some(from_record(r)?)),
206 None => Ok(None),
207 }
208 }
209
210 pub async fn get_by_conversation(&self, conversation_id: &str) -> Result<Vec<MessageMetadata>> {
212 let filter = Filter::Eq(
213 "conversation_id".into(),
214 FieldValue::Utf8(Some(conversation_id.to_string())),
215 );
216 let records = self.backend.query(TABLE_NAME, Some(&filter), None).await?;
217
218 records.iter().map(from_record).collect()
219 }
220
221 pub async fn search(
223 &self,
224 query: &str,
225 limit: usize,
226 min_score: f32,
227 ) -> Result<Vec<(MessageMetadata, f32)>> {
228 self.search_with_filter(query, limit, min_score, None).await
229 }
230
231 pub async fn search_conversation(
233 &self,
234 conversation_id: &str,
235 query: &str,
236 limit: usize,
237 min_score: f32,
238 ) -> Result<Vec<(MessageMetadata, f32)>> {
239 let filter = Filter::Eq(
240 "conversation_id".into(),
241 FieldValue::Utf8(Some(conversation_id.to_string())),
242 );
243 self.search_with_filter(query, limit, min_score, Some(filter))
244 .await
245 }
246
247 async fn search_with_filter(
249 &self,
250 query: &str,
251 limit: usize,
252 min_score: f32,
253 filter: Option<Filter>,
254 ) -> Result<Vec<(MessageMetadata, f32)>> {
255 let query_embedding = self.embeddings.embed_cached(query)?;
257
258 let scored = self
259 .backend
260 .vector_search(
261 TABLE_NAME,
262 "vector",
263 query_embedding,
264 limit,
265 filter.as_ref(),
266 )
267 .await?;
268
269 let mut messages_with_scores = Vec::new();
270
271 for sr in scored {
272 if sr.score >= min_score {
273 let message = from_record(&sr.record)?;
274 messages_with_scores.push((message, sr.score));
275 }
276 }
277
278 Ok(messages_with_scores)
279 }
280
281 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<()> {
283 let filter = Filter::Eq(
284 "conversation_id".into(),
285 FieldValue::Utf8(Some(conversation_id.to_string())),
286 );
287 self.backend.delete(TABLE_NAME, &filter).await?;
288 Ok(())
289 }
290
291 pub async fn delete(&self, message_id: &str) -> Result<()> {
293 let filter = Filter::Eq(
294 "message_id".into(),
295 FieldValue::Utf8(Some(message_id.to_string())),
296 );
297 self.backend.delete(TABLE_NAME, &filter).await?;
298 Ok(())
299 }
300
301 pub async fn delete_expired(&self) -> Result<usize> {
309 use chrono::Utc;
310 let now = Utc::now().timestamp();
311
312 let filter = Filter::And(vec![
313 Filter::NotNull("expires_at".into()),
314 Filter::Lte("expires_at".into(), FieldValue::Int64(Some(now))),
315 ]);
316
317 let count = self.backend.count(TABLE_NAME, Some(&filter)).await?;
318 if count > 0 {
319 self.backend.delete(TABLE_NAME, &filter).await?;
320 }
321 Ok(count)
322 }
323}