brainwires-storage 0.10.0

Backend-agnostic storage, tiered memory, and document management for the Brainwires Agent Framework
//! [`StorageBackend`] implementation for [`LanceDatabase`].

use anyhow::{Context, Result};
use arrow_array::{Float32Array, RecordBatch, RecordBatchIterator, RecordBatchReader};
use futures::stream::TryStreamExt;
use lancedb::query::{ExecutableQuery, QueryBase};
use std::sync::Arc;

use crate::databases::traits::StorageBackend;
use crate::databases::types::{FieldDef, Filter, Record, ScoredRecord};

use super::arrow_convert::{
    batch_to_records, extract_field_value, field_defs_to_schema, filter_to_sql, records_to_batch,
};
use super::database::LanceDatabase;

#[async_trait::async_trait]
impl StorageBackend for LanceDatabase {
    async fn ensure_table(&self, table_name: &str, schema: &[FieldDef]) -> Result<()> {
        let table_names = self.connection.table_names().execute().await?;
        if table_names.contains(&table_name.to_string()) {
            return Ok(());
        }

        let arrow_schema = Arc::new(field_defs_to_schema(schema));
        let batches: Box<dyn RecordBatchReader + Send> =
            Box::new(RecordBatchIterator::new(vec![], arrow_schema));
        self.connection
            .create_table(table_name, batches)
            .execute()
            .await
            .with_context(|| format!("Failed to create table '{table_name}'"))?;
        Ok(())
    }

    async fn insert(&self, table_name: &str, records: Vec<Record>) -> Result<()> {
        if records.is_empty() {
            return Ok(());
        }

        let table = self
            .connection
            .open_table(table_name)
            .execute()
            .await
            .with_context(|| format!("Failed to open table '{table_name}'"))?;

        let batch = records_to_batch(&records)?;
        let schema = batch.schema();
        let batches: Box<dyn RecordBatchReader + Send> =
            Box::new(RecordBatchIterator::new(vec![Ok(batch)], schema));
        table
            .add(batches)
            .execute()
            .await
            .with_context(|| format!("Failed to insert into '{table_name}'"))?;
        Ok(())
    }

    async fn query(
        &self,
        table_name: &str,
        filter: Option<&Filter>,
        limit: Option<usize>,
    ) -> Result<Vec<Record>> {
        let table = self
            .connection
            .open_table(table_name)
            .execute()
            .await
            .with_context(|| format!("Failed to open table '{table_name}'"))?;

        let mut q = table.query();
        if let Some(f) = filter {
            q = q.only_if(filter_to_sql(f));
        }
        if let Some(n) = limit {
            q = q.limit(n);
        }

        let batches: Vec<RecordBatch> = q
            .execute()
            .await
            .with_context(|| format!("Failed to query '{table_name}'"))?
            .try_collect()
            .await?;

        let mut results = Vec::new();
        for batch in &batches {
            batch_to_records(batch, &mut results)?;
        }
        Ok(results)
    }

    async fn delete(&self, table_name: &str, filter: &Filter) -> Result<()> {
        let table = self
            .connection
            .open_table(table_name)
            .execute()
            .await
            .with_context(|| format!("Failed to open table '{table_name}'"))?;

        table
            .delete(&filter_to_sql(filter))
            .await
            .with_context(|| format!("Failed to delete from '{table_name}'"))?;
        Ok(())
    }

    async fn count(&self, table_name: &str, filter: Option<&Filter>) -> Result<usize> {
        let table = self
            .connection
            .open_table(table_name)
            .execute()
            .await
            .with_context(|| format!("Failed to open table '{table_name}'"))?;

        let mut q = table.query();
        if let Some(f) = filter {
            q = q.only_if(filter_to_sql(f));
        }
        let batches: Vec<RecordBatch> = q.execute().await?.try_collect().await?;
        Ok(batches.iter().map(|b| b.num_rows()).sum())
    }

    async fn vector_search(
        &self,
        table_name: &str,
        vector_column: &str,
        vector: Vec<f32>,
        limit: usize,
        filter: Option<&Filter>,
    ) -> Result<Vec<ScoredRecord>> {
        let table = self
            .connection
            .open_table(table_name)
            .execute()
            .await
            .with_context(|| format!("Failed to open table '{table_name}'"))?;

        let mut q = table.vector_search(vector)?.column(vector_column);
        q = q.limit(limit);
        if let Some(f) = filter {
            q = q.only_if(filter_to_sql(f));
        }

        let batches: Vec<RecordBatch> = q.execute().await?.try_collect().await?;

        let mut results = Vec::new();
        for batch in &batches {
            let distance_col = batch
                .column_by_name("_distance")
                .and_then(|c| c.as_any().downcast_ref::<Float32Array>());

            for row in 0..batch.num_rows() {
                let mut record = Vec::new();
                for (col_idx, field) in batch.schema().fields().iter().enumerate() {
                    if field.name() == "_distance" {
                        continue;
                    }
                    let val = extract_field_value(batch, col_idx, row, field)?;
                    record.push((field.name().clone(), val));
                }

                let distance = distance_col.map_or(0.0, |c| c.value(row));
                let score = 1.0 / (1.0 + distance);

                results.push(ScoredRecord { record, score });
            }
        }
        Ok(results)
    }
}