use chrono::{DateTime, Utc};
use cortex_core::MemoryId;
use rusqlite::{params, OptionalExtension, Row};
use crate::{Pool, StoreError, StoreResult};
pub const EMBEDDING_ENCRYPTION_KIND_NONE: &str = "none";
#[derive(Debug, Clone, PartialEq)]
pub struct EmbedRecord {
pub memory_id: MemoryId,
pub backend_id: String,
pub dim: u32,
pub vector: Vec<f32>,
pub computed_at: DateTime<Utc>,
}
impl EmbedRecord {
pub fn new(
memory_id: MemoryId,
backend_id: impl Into<String>,
vector: Vec<f32>,
computed_at: DateTime<Utc>,
) -> StoreResult<Self> {
let backend_id = backend_id.into();
let dim_usize = vector.len();
let dim = u32::try_from(dim_usize).map_err(|_| {
StoreError::Validation(format!(
"embedding vector length {dim_usize} does not fit in u32 dim column"
))
})?;
Ok(Self {
memory_id,
backend_id,
dim,
vector,
computed_at,
})
}
}
#[derive(Debug)]
pub struct EmbeddingRepo<'a> {
pool: &'a Pool,
}
impl<'a> EmbeddingRepo<'a> {
#[must_use]
pub const fn new(pool: &'a Pool) -> Self {
Self { pool }
}
pub fn write(&self, record: &EmbedRecord) -> StoreResult<()> {
validate_record_dim_matches_vector(record)?;
validate_backend_id(&record.backend_id)?;
let blob = encode_vector_blob(&record.vector);
self.pool.execute(
"INSERT OR REPLACE INTO memory_embeddings (
memory_id,
backend_id,
dim,
vector_blob,
encryption_kind,
encryption_key_id,
computed_at
) VALUES (?1, ?2, ?3, ?4, ?5, NULL, ?6);",
params![
record.memory_id.to_string(),
record.backend_id,
record.dim,
blob,
EMBEDDING_ENCRYPTION_KIND_NONE,
record.computed_at.to_rfc3339(),
],
)?;
Ok(())
}
pub fn read(&self, memory_id: &MemoryId, backend_id: &str) -> StoreResult<Option<EmbedRecord>> {
let row = self
.pool
.query_row(
"SELECT memory_id, backend_id, dim, vector_blob, encryption_kind, computed_at
FROM memory_embeddings
WHERE memory_id = ?1 AND backend_id = ?2;",
params![memory_id.to_string(), backend_id],
embedding_row,
)
.optional()?;
row.map(TryInto::try_into).transpose()
}
pub fn list_by_backend(&self, backend_id: &str) -> StoreResult<Vec<EmbedRecord>> {
let mut stmt = self.pool.prepare(
"SELECT memory_id, backend_id, dim, vector_blob, encryption_kind, computed_at
FROM memory_embeddings
WHERE backend_id = ?1
ORDER BY memory_id;",
)?;
let rows = stmt
.query_map(params![backend_id], embedding_row)?
.collect::<Result<Vec<_>, _>>()?;
rows.into_iter().map(EmbedRecord::try_from).collect()
}
pub fn delete(&self, memory_id: &MemoryId, backend_id: &str) -> StoreResult<()> {
self.pool.execute(
"DELETE FROM memory_embeddings WHERE memory_id = ?1 AND backend_id = ?2;",
params![memory_id.to_string(), backend_id],
)?;
Ok(())
}
}
#[derive(Debug)]
struct EmbeddingRow {
memory_id: String,
backend_id: String,
dim: i64,
vector_blob: Vec<u8>,
encryption_kind: String,
computed_at: String,
}
fn embedding_row(row: &Row<'_>) -> rusqlite::Result<EmbeddingRow> {
Ok(EmbeddingRow {
memory_id: row.get(0)?,
backend_id: row.get(1)?,
dim: row.get(2)?,
vector_blob: row.get(3)?,
encryption_kind: row.get(4)?,
computed_at: row.get(5)?,
})
}
impl TryFrom<EmbeddingRow> for EmbedRecord {
type Error = StoreError;
fn try_from(row: EmbeddingRow) -> StoreResult<Self> {
if row.encryption_kind != EMBEDDING_ENCRYPTION_KIND_NONE {
return Err(StoreError::Validation(format!(
"memory_embeddings row carries encryption_kind {kind:?}; the Phase 4.C foundation \
only reads {expected:?} rows. A future at-rest encryption slice introduces \
additional decoders.",
kind = row.encryption_kind,
expected = EMBEDDING_ENCRYPTION_KIND_NONE,
)));
}
let dim = u32::try_from(row.dim).map_err(|_| {
StoreError::Validation(format!(
"memory_embeddings.dim {} is not a valid u32 (CHECK dim > 0 enforced at write)",
row.dim,
))
})?;
let expected_bytes = (dim as usize).checked_mul(4).ok_or_else(|| {
StoreError::Validation(format!("memory_embeddings.dim {dim} * 4 overflows usize"))
})?;
if row.vector_blob.len() != expected_bytes {
return Err(StoreError::Validation(format!(
"memory_embeddings.vector_blob length {} does not match dim {dim} * 4 = {expected_bytes}",
row.vector_blob.len(),
)));
}
let vector = decode_vector_blob(&row.vector_blob);
Ok(Self {
memory_id: row.memory_id.parse()?,
backend_id: row.backend_id,
dim,
vector,
computed_at: DateTime::parse_from_rfc3339(&row.computed_at)?.with_timezone(&Utc),
})
}
}
fn encode_vector_blob(vector: &[f32]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(vector.len() * 4);
for v in vector {
bytes.extend_from_slice(&v.to_le_bytes());
}
bytes
}
fn decode_vector_blob(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|chunk| {
let arr = <[u8; 4]>::try_from(chunk).expect("chunks_exact yields four bytes");
f32::from_le_bytes(arr)
})
.collect()
}
fn validate_record_dim_matches_vector(record: &EmbedRecord) -> StoreResult<()> {
if record.dim as usize != record.vector.len() {
return Err(StoreError::Validation(format!(
"embedding record dim {} does not match vector length {} (backend `{}`)",
record.dim,
record.vector.len(),
record.backend_id,
)));
}
if record.dim == 0 {
return Err(StoreError::Validation(
"embedding record dim must be > 0 (CHECK constraint on memory_embeddings.dim)"
.to_string(),
));
}
Ok(())
}
fn validate_backend_id(backend_id: &str) -> StoreResult<()> {
if backend_id.trim().is_empty() {
return Err(StoreError::Validation(
"embedding record requires non-empty backend_id".to_string(),
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_then_decode_roundtrips_known_values() {
let vector = vec![0.0f32, 1.0, -1.5, 4.2_f32, f32::MIN_POSITIVE];
let bytes = encode_vector_blob(&vector);
assert_eq!(bytes.len(), vector.len() * 4);
let decoded = decode_vector_blob(&bytes);
assert_eq!(decoded, vector);
}
#[test]
fn validate_record_rejects_dim_mismatch() {
let record = EmbedRecord {
memory_id: "mem_01ARZ3NDEKTSV4RRFFQ69G5FAV".parse().unwrap(),
backend_id: "stub:v1".into(),
dim: 4,
vector: vec![0.0; 3],
computed_at: Utc::now(),
};
assert!(validate_record_dim_matches_vector(&record).is_err());
}
#[test]
fn validate_backend_id_rejects_blank() {
assert!(validate_backend_id("").is_err());
assert!(validate_backend_id(" ").is_err());
assert!(validate_backend_id("stub:v1").is_ok());
}
}