1use anyhow::{Context, Result};
7use sha2::{Digest, Sha256};
8use std::sync::Arc;
9
10use brainwires_storage::databases::{
11 FieldDef, FieldType, FieldValue, Filter, Record, StorageBackend, record_get,
12};
13use brainwires_storage::embeddings::CachedEmbeddingProvider;
14use brainwires_storage::image_types::{
15 ImageFormat, ImageMetadata, ImageSearchRequest, ImageSearchResult, ImageStorage,
16};
17
18const TABLE_NAME: &str = "images";
19
20fn table_schema(dimension: usize) -> Vec<FieldDef> {
23 vec![
24 FieldDef::required("vector", FieldType::Vector(dimension)),
25 FieldDef::required("image_id", FieldType::Utf8),
26 FieldDef::optional("message_id", FieldType::Utf8),
27 FieldDef::required("conversation_id", FieldType::Utf8),
28 FieldDef::optional("file_name", FieldType::Utf8),
29 FieldDef::required("format", FieldType::Utf8),
30 FieldDef::required("mime_type", FieldType::Utf8),
31 FieldDef::optional("width", FieldType::UInt32),
32 FieldDef::optional("height", FieldType::UInt32),
33 FieldDef::required("file_size_bytes", FieldType::UInt64),
34 FieldDef::required("file_hash", FieldType::Utf8),
35 FieldDef::required("analysis", FieldType::Utf8),
36 FieldDef::optional("extracted_text", FieldType::Utf8),
37 FieldDef::required("tags", FieldType::Utf8), FieldDef::required("storage_type", FieldType::Utf8),
39 FieldDef::required("storage_value", FieldType::Utf8),
40 FieldDef::required("created_at", FieldType::Int64),
41 ]
42}
43
44fn to_record(m: &ImageMetadata, storage: &ImageStorage, embedding: Vec<f32>) -> Record {
47 let tags_json = serde_json::to_string(&m.tags).unwrap_or_else(|_| "[]".to_string());
48
49 vec![
50 ("vector".into(), FieldValue::Vector(embedding)),
51 (
52 "image_id".into(),
53 FieldValue::Utf8(Some(m.image_id.clone())),
54 ),
55 ("message_id".into(), FieldValue::Utf8(m.message_id.clone())),
56 (
57 "conversation_id".into(),
58 FieldValue::Utf8(Some(m.conversation_id.clone())),
59 ),
60 ("file_name".into(), FieldValue::Utf8(m.file_name.clone())),
61 (
62 "format".into(),
63 FieldValue::Utf8(Some(m.format.as_str().to_string())),
64 ),
65 (
66 "mime_type".into(),
67 FieldValue::Utf8(Some(m.mime_type.clone())),
68 ),
69 ("width".into(), FieldValue::UInt32(m.width)),
70 ("height".into(), FieldValue::UInt32(m.height)),
71 (
72 "file_size_bytes".into(),
73 FieldValue::UInt64(Some(m.file_size_bytes)),
74 ),
75 (
76 "file_hash".into(),
77 FieldValue::Utf8(Some(m.file_hash.clone())),
78 ),
79 (
80 "analysis".into(),
81 FieldValue::Utf8(Some(m.analysis.clone())),
82 ),
83 (
84 "extracted_text".into(),
85 FieldValue::Utf8(m.extracted_text.clone()),
86 ),
87 ("tags".into(), FieldValue::Utf8(Some(tags_json))),
88 (
89 "storage_type".into(),
90 FieldValue::Utf8(Some(storage.storage_type().to_string())),
91 ),
92 (
93 "storage_value".into(),
94 FieldValue::Utf8(Some(storage.value().to_string())),
95 ),
96 ("created_at".into(), FieldValue::Int64(Some(m.created_at))),
97 ]
98}
99
100fn from_record(r: &Record) -> Result<ImageMetadata> {
101 let image_id = record_get(r, "image_id")
102 .and_then(|v| v.as_str())
103 .context("missing image_id")?
104 .to_string();
105
106 let message_id = record_get(r, "message_id")
107 .and_then(|v| v.as_str())
108 .map(String::from);
109
110 let conversation_id = record_get(r, "conversation_id")
111 .and_then(|v| v.as_str())
112 .context("missing conversation_id")?
113 .to_string();
114
115 let file_name = record_get(r, "file_name")
116 .and_then(|v| v.as_str())
117 .filter(|s| !s.is_empty())
118 .map(String::from);
119
120 let format_str = record_get(r, "format")
121 .and_then(|v| v.as_str())
122 .unwrap_or("unknown");
123 let format: ImageFormat = format_str.parse().unwrap_or(ImageFormat::Unknown);
124
125 let mime_type = record_get(r, "mime_type")
126 .and_then(|v| v.as_str())
127 .unwrap_or("application/octet-stream")
128 .to_string();
129
130 let width = record_get(r, "width").and_then(|v| match v {
131 FieldValue::UInt32(Some(n)) => Some(*n).filter(|&n| n > 0),
132 _ => None,
133 });
134
135 let height = record_get(r, "height").and_then(|v| match v {
136 FieldValue::UInt32(Some(n)) => Some(*n).filter(|&n| n > 0),
137 _ => None,
138 });
139
140 let file_size_bytes = record_get(r, "file_size_bytes")
141 .and_then(|v| match v {
142 FieldValue::UInt64(Some(n)) => Some(*n),
143 _ => None,
144 })
145 .unwrap_or(0);
146
147 let file_hash = record_get(r, "file_hash")
148 .and_then(|v| v.as_str())
149 .unwrap_or("")
150 .to_string();
151
152 let analysis = record_get(r, "analysis")
153 .and_then(|v| v.as_str())
154 .unwrap_or("")
155 .to_string();
156
157 let extracted_text = record_get(r, "extracted_text")
158 .and_then(|v| v.as_str())
159 .filter(|s| !s.is_empty())
160 .map(String::from);
161
162 let tags_json = record_get(r, "tags")
163 .and_then(|v| v.as_str())
164 .unwrap_or("[]");
165 let tags: Vec<String> = serde_json::from_str(tags_json).unwrap_or_default();
166
167 let created_at = record_get(r, "created_at")
168 .and_then(|v| v.as_i64())
169 .unwrap_or(0);
170
171 Ok(ImageMetadata {
172 image_id,
173 message_id,
174 conversation_id,
175 file_name,
176 format,
177 mime_type,
178 width,
179 height,
180 file_size_bytes,
181 file_hash,
182 analysis,
183 extracted_text,
184 tags,
185 created_at,
186 })
187}
188
189fn storage_from_record(r: &Record) -> Option<ImageStorage> {
190 let storage_type = record_get(r, "storage_type").and_then(|v| v.as_str())?;
191 let storage_value = record_get(r, "storage_value")
192 .and_then(|v| v.as_str())
193 .unwrap_or("")
194 .to_string();
195
196 Some(match storage_type {
197 "base64" => ImageStorage::Base64(storage_value),
198 "file" => ImageStorage::FilePath(storage_value),
199 "url" => ImageStorage::Url(storage_value),
200 _ => ImageStorage::Base64(storage_value),
201 })
202}
203
204pub struct ImageStore<B: StorageBackend = brainwires_storage::databases::lance::LanceDatabase> {
208 backend: Arc<B>,
209 embeddings: Arc<CachedEmbeddingProvider>,
210}
211
212impl<B: StorageBackend> ImageStore<B> {
213 pub fn new(backend: Arc<B>, embeddings: Arc<CachedEmbeddingProvider>) -> Self {
215 Self {
216 backend,
217 embeddings,
218 }
219 }
220
221 pub async fn ensure_table(&self) -> Result<()> {
223 let dimension = self.embeddings.dimension();
224 self.backend
225 .ensure_table(TABLE_NAME, &table_schema(dimension))
226 .await
227 }
228
229 pub fn compute_hash(bytes: &[u8]) -> String {
231 let mut hasher = Sha256::new();
232 hasher.update(bytes);
233 format!("{:x}", hasher.finalize())
234 }
235
236 pub async fn store(
242 &self,
243 metadata: ImageMetadata,
244 storage: ImageStorage,
245 ) -> Result<ImageMetadata> {
246 let searchable_text = metadata.searchable_text();
248 let embedding = self.embeddings.embed(&searchable_text)?;
249
250 let record = to_record(&metadata, &storage, embedding);
251
252 self.backend
253 .insert(TABLE_NAME, vec![record])
254 .await
255 .context("Failed to store image")?;
256
257 Ok(metadata)
258 }
259
260 pub async fn store_from_bytes(
268 &self,
269 bytes: &[u8],
270 analysis: String,
271 conversation_id: String,
272 format: ImageFormat,
273 ) -> Result<ImageMetadata> {
274 let file_hash = Self::compute_hash(bytes);
275
276 if let Some(existing) = self.get_by_hash(&file_hash).await? {
278 return Ok(existing);
279 }
280
281 let image_id = format!("img_{}", uuid::Uuid::new_v4());
282 let metadata = ImageMetadata::new(
283 image_id,
284 conversation_id,
285 format,
286 bytes.len() as u64,
287 file_hash,
288 analysis,
289 );
290
291 let storage = ImageStorage::from_bytes(bytes);
292 self.store(metadata, storage).await
293 }
294
295 pub async fn get_by_hash(&self, file_hash: &str) -> Result<Option<ImageMetadata>> {
297 let filter = Filter::Eq(
298 "file_hash".into(),
299 FieldValue::Utf8(Some(file_hash.to_string())),
300 );
301 let records = self
302 .backend
303 .query(TABLE_NAME, Some(&filter), Some(1))
304 .await
305 .context("Failed to query images by hash")?;
306
307 match records.first() {
308 Some(r) => Ok(Some(from_record(r)?)),
309 None => Ok(None),
310 }
311 }
312
313 pub async fn get(&self, image_id: &str) -> Result<Option<ImageMetadata>> {
315 let filter = Filter::Eq(
316 "image_id".into(),
317 FieldValue::Utf8(Some(image_id.to_string())),
318 );
319 let records = self
320 .backend
321 .query(TABLE_NAME, Some(&filter), Some(1))
322 .await
323 .context("Failed to query image by ID")?;
324
325 match records.first() {
326 Some(r) => Ok(Some(from_record(r)?)),
327 None => Ok(None),
328 }
329 }
330
331 pub async fn search(&self, request: ImageSearchRequest) -> Result<Vec<ImageSearchResult>> {
333 let query_embedding = self.embeddings.embed(&request.query)?;
335
336 let mut filters = Vec::new();
338
339 if let Some(ref conv_id) = request.conversation_id {
340 filters.push(Filter::Eq(
341 "conversation_id".into(),
342 FieldValue::Utf8(Some(conv_id.clone())),
343 ));
344 }
345
346 if let Some(format) = request.format {
347 filters.push(Filter::Eq(
348 "format".into(),
349 FieldValue::Utf8(Some(format.as_str().to_string())),
350 ));
351 }
352
353 let filter = if filters.is_empty() {
354 None
355 } else if filters.len() == 1 {
356 Some(filters.remove(0))
357 } else {
358 Some(Filter::And(filters))
359 };
360
361 let scored_records = self
363 .backend
364 .vector_search(
365 TABLE_NAME,
366 "vector",
367 query_embedding,
368 request.limit,
369 filter.as_ref(),
370 )
371 .await
372 .context("Failed to execute image search")?;
373
374 let mut search_results = Vec::new();
375
376 for scored in &scored_records {
377 if scored.score < request.min_score {
378 continue;
379 }
380
381 let metadata = from_record(&scored.record)?;
382 search_results.push(ImageSearchResult::from_metadata(metadata, scored.score));
383 }
384
385 search_results.sort_by(|a, b| {
387 b.score
388 .partial_cmp(&a.score)
389 .unwrap_or(std::cmp::Ordering::Equal)
390 });
391
392 Ok(search_results)
393 }
394
395 pub async fn list_by_conversation(&self, conversation_id: &str) -> Result<Vec<ImageMetadata>> {
397 let filter = Filter::Eq(
398 "conversation_id".into(),
399 FieldValue::Utf8(Some(conversation_id.to_string())),
400 );
401 let records = self
402 .backend
403 .query(TABLE_NAME, Some(&filter), None)
404 .await
405 .context("Failed to list images by conversation")?;
406
407 let mut images: Vec<ImageMetadata> =
408 records.iter().filter_map(|r| from_record(r).ok()).collect();
409
410 images.sort_by(|a, b| b.created_at.cmp(&a.created_at));
412
413 Ok(images)
414 }
415
416 pub async fn list_by_message(&self, message_id: &str) -> Result<Vec<ImageMetadata>> {
418 let filter = Filter::Eq(
419 "message_id".into(),
420 FieldValue::Utf8(Some(message_id.to_string())),
421 );
422 let records = self
423 .backend
424 .query(TABLE_NAME, Some(&filter), None)
425 .await
426 .context("Failed to list images by message")?;
427
428 let images: Vec<ImageMetadata> =
429 records.iter().filter_map(|r| from_record(r).ok()).collect();
430
431 Ok(images)
432 }
433
434 pub async fn delete(&self, image_id: &str) -> Result<bool> {
436 let filter = Filter::Eq(
437 "image_id".into(),
438 FieldValue::Utf8(Some(image_id.to_string())),
439 );
440 self.backend
441 .delete(TABLE_NAME, &filter)
442 .await
443 .context("Failed to delete image")?;
444
445 Ok(true)
446 }
447
448 pub async fn delete_by_conversation(&self, conversation_id: &str) -> Result<usize> {
450 let images = self.list_by_conversation(conversation_id).await?;
451 let count = images.len();
452
453 let filter = Filter::Eq(
454 "conversation_id".into(),
455 FieldValue::Utf8(Some(conversation_id.to_string())),
456 );
457 self.backend
458 .delete(TABLE_NAME, &filter)
459 .await
460 .context("Failed to delete images by conversation")?;
461
462 Ok(count)
463 }
464
465 pub async fn get_image_data(&self, image_id: &str) -> Result<Option<ImageStorage>> {
467 let filter = Filter::Eq(
468 "image_id".into(),
469 FieldValue::Utf8(Some(image_id.to_string())),
470 );
471 let records = self
472 .backend
473 .query(TABLE_NAME, Some(&filter), Some(1))
474 .await
475 .context("Failed to query image data")?;
476
477 match records.first() {
478 Some(r) => Ok(storage_from_record(r)),
479 None => Ok(None),
480 }
481 }
482
483 pub async fn count_by_conversation(&self, conversation_id: &str) -> Result<usize> {
485 let images = self.list_by_conversation(conversation_id).await?;
486 Ok(images.len())
487 }
488}