use crate::error::StorageError;
use crate::storage::DbPool;
#[derive(Debug, Clone)]
pub struct ChunkEmbeddingRow {
pub id: i64,
pub chunk_id: i64,
pub account_id: String,
pub embedding: Vec<u8>,
pub model_id: String,
pub dimension: i64,
pub embedding_hash: String,
pub generation: i64,
pub created_at: String,
pub updated_at: String,
}
type EmbeddingRowTuple = (
i64,
i64,
String,
Vec<u8>,
String,
i64,
String,
i64,
String,
String,
);
impl ChunkEmbeddingRow {
fn from_row(r: EmbeddingRowTuple) -> Self {
Self {
id: r.0,
chunk_id: r.1,
account_id: r.2,
embedding: r.3,
model_id: r.4,
dimension: r.5,
embedding_hash: r.6,
generation: r.7,
created_at: r.8,
updated_at: r.9,
}
}
}
#[derive(Debug, Clone)]
pub struct DirtyChunk {
pub chunk_id: i64,
pub chunk_text: String,
pub chunk_hash: String,
pub node_id: i64,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct IndexStats {
pub total_chunks: i64,
pub embedded_chunks: i64,
pub dirty_chunks: i64,
pub freshness_pct: f64,
pub last_indexed_at: Option<String>,
pub model_id: Option<String>,
}
#[allow(clippy::too_many_arguments)]
pub async fn upsert_chunk_embedding(
pool: &DbPool,
chunk_id: i64,
account_id: &str,
embedding_bytes: &[u8],
model_id: &str,
dimension: i64,
embedding_hash: &str,
generation: i64,
) -> Result<(), StorageError> {
sqlx::query(
"INSERT INTO chunk_embeddings \
(chunk_id, account_id, embedding, model_id, dimension, embedding_hash, generation) \
VALUES (?, ?, ?, ?, ?, ?, ?) \
ON CONFLICT(chunk_id, account_id) DO UPDATE SET \
embedding = excluded.embedding, \
model_id = excluded.model_id, \
dimension = excluded.dimension, \
embedding_hash = excluded.embedding_hash, \
generation = excluded.generation, \
updated_at = datetime('now')",
)
.bind(chunk_id)
.bind(account_id)
.bind(embedding_bytes)
.bind(model_id)
.bind(dimension)
.bind(embedding_hash)
.bind(generation)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn get_all_embeddings_for(
pool: &DbPool,
account_id: &str,
) -> Result<Vec<ChunkEmbeddingRow>, StorageError> {
let rows: Vec<EmbeddingRowTuple> = sqlx::query_as(
"SELECT id, chunk_id, account_id, embedding, model_id, dimension, \
embedding_hash, generation, created_at, updated_at \
FROM chunk_embeddings \
WHERE account_id = ? \
ORDER BY id",
)
.bind(account_id)
.fetch_all(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(rows.into_iter().map(ChunkEmbeddingRow::from_row).collect())
}
pub async fn get_dirty_chunks_for(
pool: &DbPool,
account_id: &str,
limit: u32,
) -> Result<Vec<DirtyChunk>, StorageError> {
let rows: Vec<(i64, String, String, i64)> = sqlx::query_as(
"SELECT c.id, c.chunk_text, c.chunk_hash, c.node_id \
FROM content_chunks c \
LEFT JOIN chunk_embeddings e \
ON c.id = e.chunk_id AND e.account_id = ? \
WHERE c.account_id = ? \
AND c.status = 'active' \
AND (e.id IS NULL OR c.chunk_hash != e.embedding_hash) \
LIMIT ?",
)
.bind(account_id)
.bind(account_id)
.bind(limit)
.fetch_all(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(rows
.into_iter()
.map(|(chunk_id, chunk_text, chunk_hash, node_id)| DirtyChunk {
chunk_id,
chunk_text,
chunk_hash,
node_id,
})
.collect())
}
pub async fn get_index_stats_for(
pool: &DbPool,
account_id: &str,
) -> Result<IndexStats, StorageError> {
let (total_chunks,): (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM content_chunks \
WHERE account_id = ? AND status = 'active'",
)
.bind(account_id)
.fetch_one(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
let (embedded_chunks,): (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM chunk_embeddings e \
INNER JOIN content_chunks c ON c.id = e.chunk_id \
WHERE e.account_id = ? AND c.status = 'active' \
AND c.chunk_hash = e.embedding_hash",
)
.bind(account_id)
.fetch_one(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
let dirty_chunks = total_chunks - embedded_chunks;
let freshness_pct = if total_chunks > 0 {
(embedded_chunks as f64 / total_chunks as f64) * 100.0
} else {
100.0
};
let last_indexed_at: Option<(String,)> =
sqlx::query_as("SELECT MAX(updated_at) FROM chunk_embeddings WHERE account_id = ?")
.bind(account_id)
.fetch_optional(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
let model_id: Option<(String,)> =
sqlx::query_as("SELECT model_id FROM chunk_embeddings WHERE account_id = ? LIMIT 1")
.bind(account_id)
.fetch_optional(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(IndexStats {
total_chunks,
embedded_chunks,
dirty_chunks,
freshness_pct,
last_indexed_at: last_indexed_at.map(|r| r.0),
model_id: model_id.map(|r| r.0),
})
}
pub async fn delete_embeddings_by_model(
pool: &DbPool,
account_id: &str,
model_id: &str,
) -> Result<u64, StorageError> {
let result = sqlx::query("DELETE FROM chunk_embeddings WHERE account_id = ? AND model_id = ?")
.bind(account_id)
.bind(model_id)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(result.rows_affected())
}
pub async fn delete_embedding_for_chunk(
pool: &DbPool,
chunk_id: i64,
account_id: &str,
) -> Result<(), StorageError> {
sqlx::query("DELETE FROM chunk_embeddings WHERE chunk_id = ? AND account_id = ?")
.bind(chunk_id)
.bind(account_id)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub fn vec_to_bytes(v: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(v.len() * 4);
for &val in v {
bytes.extend_from_slice(&val.to_le_bytes());
}
bytes
}
pub fn bytes_to_vec(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}