1use chrono::{DateTime, Utc};
35use cortex_core::MemoryId;
36use rusqlite::{params, OptionalExtension, Row};
37
38use crate::{Pool, StoreError, StoreResult};
39
40pub const EMBEDDING_ENCRYPTION_KIND_NONE: &str = "none";
47
48#[derive(Debug, Clone, PartialEq)]
54pub struct EmbedRecord {
55 pub memory_id: MemoryId,
57 pub backend_id: String,
59 pub dim: u32,
63 pub vector: Vec<f32>,
65 pub computed_at: DateTime<Utc>,
67}
68
69impl EmbedRecord {
70 pub fn new(
75 memory_id: MemoryId,
76 backend_id: impl Into<String>,
77 vector: Vec<f32>,
78 computed_at: DateTime<Utc>,
79 ) -> StoreResult<Self> {
80 let backend_id = backend_id.into();
81 let dim_usize = vector.len();
82 let dim = u32::try_from(dim_usize).map_err(|_| {
83 StoreError::Validation(format!(
84 "embedding vector length {dim_usize} does not fit in u32 dim column"
85 ))
86 })?;
87 Ok(Self {
88 memory_id,
89 backend_id,
90 dim,
91 vector,
92 computed_at,
93 })
94 }
95}
96
97#[derive(Debug)]
99pub struct EmbeddingRepo<'a> {
100 pool: &'a Pool,
101}
102
103impl<'a> EmbeddingRepo<'a> {
104 #[must_use]
106 pub const fn new(pool: &'a Pool) -> Self {
107 Self { pool }
108 }
109
110 pub fn write(&self, record: &EmbedRecord) -> StoreResult<()> {
117 validate_record_dim_matches_vector(record)?;
118 validate_backend_id(&record.backend_id)?;
119
120 let blob = encode_vector_blob(&record.vector);
121
122 self.pool.execute(
123 "INSERT OR REPLACE INTO memory_embeddings (
124 memory_id,
125 backend_id,
126 dim,
127 vector_blob,
128 encryption_kind,
129 encryption_key_id,
130 computed_at
131 ) VALUES (?1, ?2, ?3, ?4, ?5, NULL, ?6);",
132 params![
133 record.memory_id.to_string(),
134 record.backend_id,
135 record.dim,
136 blob,
137 EMBEDDING_ENCRYPTION_KIND_NONE,
138 record.computed_at.to_rfc3339(),
139 ],
140 )?;
141 Ok(())
142 }
143
144 pub fn read(&self, memory_id: &MemoryId, backend_id: &str) -> StoreResult<Option<EmbedRecord>> {
146 let row = self
147 .pool
148 .query_row(
149 "SELECT memory_id, backend_id, dim, vector_blob, encryption_kind, computed_at
150 FROM memory_embeddings
151 WHERE memory_id = ?1 AND backend_id = ?2;",
152 params![memory_id.to_string(), backend_id],
153 embedding_row,
154 )
155 .optional()?;
156 row.map(TryInto::try_into).transpose()
157 }
158
159 pub fn list_by_backend(&self, backend_id: &str) -> StoreResult<Vec<EmbedRecord>> {
164 let mut stmt = self.pool.prepare(
165 "SELECT memory_id, backend_id, dim, vector_blob, encryption_kind, computed_at
166 FROM memory_embeddings
167 WHERE backend_id = ?1
168 ORDER BY memory_id;",
169 )?;
170 let rows = stmt
171 .query_map(params![backend_id], embedding_row)?
172 .collect::<Result<Vec<_>, _>>()?;
173 rows.into_iter().map(EmbedRecord::try_from).collect()
174 }
175
176 pub fn delete(&self, memory_id: &MemoryId, backend_id: &str) -> StoreResult<()> {
182 self.pool.execute(
183 "DELETE FROM memory_embeddings WHERE memory_id = ?1 AND backend_id = ?2;",
184 params![memory_id.to_string(), backend_id],
185 )?;
186 Ok(())
187 }
188}
189
190#[derive(Debug)]
191struct EmbeddingRow {
192 memory_id: String,
193 backend_id: String,
194 dim: i64,
195 vector_blob: Vec<u8>,
196 encryption_kind: String,
197 computed_at: String,
198}
199
200fn embedding_row(row: &Row<'_>) -> rusqlite::Result<EmbeddingRow> {
201 Ok(EmbeddingRow {
202 memory_id: row.get(0)?,
203 backend_id: row.get(1)?,
204 dim: row.get(2)?,
205 vector_blob: row.get(3)?,
206 encryption_kind: row.get(4)?,
207 computed_at: row.get(5)?,
208 })
209}
210
211impl TryFrom<EmbeddingRow> for EmbedRecord {
212 type Error = StoreError;
213
214 fn try_from(row: EmbeddingRow) -> StoreResult<Self> {
215 if row.encryption_kind != EMBEDDING_ENCRYPTION_KIND_NONE {
216 return Err(StoreError::Validation(format!(
217 "memory_embeddings row carries encryption_kind {kind:?}; the Phase 4.C foundation \
218 only reads {expected:?} rows. A future at-rest encryption slice introduces \
219 additional decoders.",
220 kind = row.encryption_kind,
221 expected = EMBEDDING_ENCRYPTION_KIND_NONE,
222 )));
223 }
224
225 let dim = u32::try_from(row.dim).map_err(|_| {
226 StoreError::Validation(format!(
227 "memory_embeddings.dim {} is not a valid u32 (CHECK dim > 0 enforced at write)",
228 row.dim,
229 ))
230 })?;
231 let expected_bytes = (dim as usize).checked_mul(4).ok_or_else(|| {
232 StoreError::Validation(format!("memory_embeddings.dim {dim} * 4 overflows usize"))
233 })?;
234 if row.vector_blob.len() != expected_bytes {
235 return Err(StoreError::Validation(format!(
236 "memory_embeddings.vector_blob length {} does not match dim {dim} * 4 = {expected_bytes}",
237 row.vector_blob.len(),
238 )));
239 }
240 let vector = decode_vector_blob(&row.vector_blob);
241
242 Ok(Self {
243 memory_id: row.memory_id.parse()?,
244 backend_id: row.backend_id,
245 dim,
246 vector,
247 computed_at: DateTime::parse_from_rfc3339(&row.computed_at)?.with_timezone(&Utc),
248 })
249 }
250}
251
252fn encode_vector_blob(vector: &[f32]) -> Vec<u8> {
253 let mut bytes = Vec::with_capacity(vector.len() * 4);
254 for v in vector {
255 bytes.extend_from_slice(&v.to_le_bytes());
256 }
257 bytes
258}
259
260fn decode_vector_blob(bytes: &[u8]) -> Vec<f32> {
261 bytes
262 .chunks_exact(4)
263 .map(|chunk| {
264 let arr = <[u8; 4]>::try_from(chunk).expect("chunks_exact yields four bytes");
265 f32::from_le_bytes(arr)
266 })
267 .collect()
268}
269
270fn validate_record_dim_matches_vector(record: &EmbedRecord) -> StoreResult<()> {
271 if record.dim as usize != record.vector.len() {
272 return Err(StoreError::Validation(format!(
273 "embedding record dim {} does not match vector length {} (backend `{}`)",
274 record.dim,
275 record.vector.len(),
276 record.backend_id,
277 )));
278 }
279 if record.dim == 0 {
280 return Err(StoreError::Validation(
281 "embedding record dim must be > 0 (CHECK constraint on memory_embeddings.dim)"
282 .to_string(),
283 ));
284 }
285 Ok(())
286}
287
288fn validate_backend_id(backend_id: &str) -> StoreResult<()> {
289 if backend_id.trim().is_empty() {
290 return Err(StoreError::Validation(
291 "embedding record requires non-empty backend_id".to_string(),
292 ));
293 }
294 Ok(())
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[test]
302 fn encode_then_decode_roundtrips_known_values() {
303 let vector = vec![0.0f32, 1.0, -1.5, 4.2_f32, f32::MIN_POSITIVE];
304 let bytes = encode_vector_blob(&vector);
305 assert_eq!(bytes.len(), vector.len() * 4);
306 let decoded = decode_vector_blob(&bytes);
307 assert_eq!(decoded, vector);
308 }
309
310 #[test]
311 fn validate_record_rejects_dim_mismatch() {
312 let record = EmbedRecord {
313 memory_id: "mem_01ARZ3NDEKTSV4RRFFQ69G5FAV".parse().unwrap(),
314 backend_id: "stub:v1".into(),
315 dim: 4,
316 vector: vec![0.0; 3],
317 computed_at: Utc::now(),
318 };
319 assert!(validate_record_dim_matches_vector(&record).is_err());
320 }
321
322 #[test]
323 fn validate_backend_id_rejects_blank() {
324 assert!(validate_backend_id("").is_err());
325 assert!(validate_backend_id(" ").is_err());
326 assert!(validate_backend_id("stub:v1").is_ok());
327 }
328}