use std::path::{Path, PathBuf};
use std::sync::{Mutex, MutexGuard, OnceLock};
use rusqlite::{params, Connection};
use serde_json::{json, Value};
use crate::api::{resolve_ollama_url, resolve_openai_compat};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
User,
Assistant,
}
impl Role {
fn as_str(self) -> &'static str {
match self {
Self::User => "user",
Self::Assistant => "assistant",
}
}
fn parse(s: &str) -> Option<Self> {
match s {
"user" => Some(Self::User),
"assistant" => Some(Self::Assistant),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RecallHit {
pub ts: String,
pub role: Role,
pub snippet: String,
pub score: f32,
}
#[derive(Debug, Clone)]
struct ScoredHit {
score: f32,
hit: RecallHit,
}
impl PartialEq for ScoredHit {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}
impl Eq for ScoredHit {}
impl PartialOrd for ScoredHit {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoredHit {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match self.score.partial_cmp(&other.score) {
Some(o) => o,
None => {
if self.score.is_nan() && other.score.is_nan() {
std::cmp::Ordering::Equal
} else if self.score.is_nan() {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Greater
}
}
}
}
}
pub trait Embedder: Send {
fn embed(&mut self, text: &str) -> Result<Vec<f32>, String>;
}
pub const RECALL_ROW_CAP: usize = 50_000;
pub const DEFAULT_EMBED_MODEL: &str = "nomic-embed-text";
#[must_use]
pub fn default_recall_db_path() -> PathBuf {
if let Ok(p) = std::env::var("CLAUDETTE_RECALL_DB") {
if !p.is_empty() {
return PathBuf::from(p);
}
}
let home = std::env::var("USERPROFILE")
.or_else(|_| std::env::var("HOME"))
.unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(".claudette").join("recall.sqlite")
}
#[must_use]
pub fn encode_vec(v: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(v.len() * 4);
for x in v {
out.extend_from_slice(&x.to_le_bytes());
}
out
}
pub fn decode_vec(bytes: &[u8]) -> Result<Vec<f32>, String> {
let mut out = Vec::with_capacity(bytes.len() / 4);
decode_vec_into(bytes, &mut out)?;
Ok(out)
}
pub fn decode_vec_into(bytes: &[u8], dst: &mut Vec<f32>) -> Result<(), String> {
if !bytes.len().is_multiple_of(4) {
return Err(format!(
"recall: BLOB length {} is not a multiple of 4 — corrupt vector",
bytes.len()
));
}
dst.clear();
dst.reserve(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
let arr: [u8; 4] = chunk.try_into().expect("chunks_exact yields 4-byte slices");
dst.push(f32::from_le_bytes(arr));
}
Ok(())
}
#[must_use]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0_f32;
let mut na = 0.0_f32;
let mut nb = 0.0_f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
let denom = na.sqrt() * nb.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
pub struct RecallStore {
conn: Connection,
embedder: Box<dyn Embedder>,
cap: usize,
}
impl RecallStore {
pub fn open(path: impl AsRef<Path>, embedder: Box<dyn Embedder>) -> Result<Self, String> {
let path = path.as_ref();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| format!("recall: create_dir_all {}: {e}", parent.display()))?;
}
let conn =
Connection::open(path).map_err(|e| format!("recall: open {}: {e}", path.display()))?;
Self::init_schema(&conn)?;
Ok(Self {
conn,
embedder,
cap: RECALL_ROW_CAP,
})
}
pub fn open_in_memory(embedder: Box<dyn Embedder>) -> Result<Self, String> {
let conn =
Connection::open_in_memory().map_err(|e| format!("recall: open in-memory: {e}"))?;
Self::init_schema(&conn)?;
Ok(Self {
conn,
embedder,
cap: RECALL_ROW_CAP,
})
}
pub fn with_cap(mut self, cap: usize) -> Self {
self.cap = cap;
self
}
fn init_schema(conn: &Connection) -> Result<(), String> {
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS recall (
id INTEGER PRIMARY KEY,
ts TEXT NOT NULL,
role TEXT NOT NULL,
snippet TEXT NOT NULL,
vec BLOB NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_recall_ts ON recall(ts);",
)
.map_err(|e| format!("recall: init schema: {e}"))
}
pub fn index(&mut self, role: Role, snippet: &str) -> Result<(), String> {
let trimmed = snippet.trim();
if trimmed.is_empty() {
return Ok(());
}
let stored: &str = if trimmed.len() > 8 * 1024 {
let mut end = 8 * 1024;
while end > 0 && !trimmed.is_char_boundary(end) {
end -= 1;
}
&trimmed[..end]
} else {
trimmed
};
let vec = self.embedder.embed(trimmed)?;
let ts = chrono::Utc::now().to_rfc3339();
let blob = encode_vec(&vec);
self.conn
.execute(
"INSERT INTO recall (ts, role, snippet, vec) VALUES (?1, ?2, ?3, ?4)",
params![ts, role.as_str(), stored, blob],
)
.map_err(|e| format!("recall: insert: {e}"))?;
self.evict_to_cap()?;
Ok(())
}
pub fn query(&mut self, query: &str, k: usize) -> Result<Vec<RecallHit>, String> {
let trimmed = query.trim();
if trimmed.is_empty() || k == 0 {
return Ok(Vec::new());
}
let qvec = self.embedder.embed(trimmed)?;
let mut qnorm_sq = 0.0_f32;
for &x in &qvec {
qnorm_sq = x.mul_add(x, qnorm_sq);
}
if qnorm_sq <= 0.0 {
return Ok(Vec::new());
}
let qnorm = qnorm_sq.sqrt();
let mut stmt = self
.conn
.prepare("SELECT ts, role, snippet, vec FROM recall")
.map_err(|e| format!("recall: prepare select: {e}"))?;
let rows = stmt
.query_map([], |row| {
let ts: String = row.get(0)?;
let role: String = row.get(1)?;
let snippet: String = row.get(2)?;
let vec_blob: Vec<u8> = row.get(3)?;
Ok((ts, role, snippet, vec_blob))
})
.map_err(|e| format!("recall: query_map: {e}"))?;
let mut vbuf: Vec<f32> = Vec::with_capacity(qvec.len());
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let mut heap: BinaryHeap<Reverse<ScoredHit>> = BinaryHeap::with_capacity(k + 1);
for row in rows {
let (ts, role_str, snippet, blob) =
row.map_err(|e| format!("recall: row error: {e}"))?;
let Some(role) = Role::parse(&role_str) else {
continue; };
if decode_vec_into(&blob, &mut vbuf).is_err() {
continue; }
if vbuf.len() != qvec.len() {
continue; }
let mut dot = 0.0_f32;
let mut bnorm_sq = 0.0_f32;
for (qx, &bx) in qvec.iter().zip(&vbuf) {
dot = qx.mul_add(bx, dot);
bnorm_sq = bx.mul_add(bx, bnorm_sq);
}
let denom = qnorm * bnorm_sq.sqrt();
let score = if denom == 0.0 { 0.0 } else { dot / denom };
heap.push(Reverse(ScoredHit {
score,
hit: RecallHit {
ts,
role,
snippet,
score,
},
}));
if heap.len() > k {
heap.pop();
}
}
let mut hits: Vec<RecallHit> = heap.into_iter().map(|Reverse(s)| s.hit).collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(hits)
}
pub fn count(&self) -> Result<usize, String> {
self.conn
.query_row("SELECT COUNT(*) FROM recall", [], |r| r.get::<_, i64>(0))
.map(|n| n.max(0) as usize)
.map_err(|e| format!("recall: count: {e}"))
}
fn evict_to_cap(&mut self) -> Result<(), String> {
let n = self.count()?;
if n <= self.cap {
return Ok(());
}
let to_remove = n - self.cap;
let to_remove_i64 = i64::try_from(to_remove).unwrap_or(i64::MAX);
self.conn
.execute(
"DELETE FROM recall WHERE id IN (SELECT id FROM recall ORDER BY id ASC LIMIT ?1)",
params![to_remove_i64],
)
.map_err(|e| format!("recall: evict: {e}"))?;
Ok(())
}
}
pub struct OllamaEmbedder {
client: reqwest::blocking::Client,
base_url: String,
model: String,
ready: bool,
}
impl OllamaEmbedder {
pub fn new() -> Result<Self, String> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.map_err(|e| format!("recall: build http client: {e}"))?;
let model = std::env::var("CLAUDETTE_RECALL_MODEL")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
Ok(Self {
client,
base_url: resolve_ollama_url(),
model,
ready: false,
})
}
fn ensure_model(&mut self) -> Result<(), String> {
if self.ready {
return Ok(());
}
let show_url = format!("{}/api/show", self.base_url);
let resp = self
.client
.post(&show_url)
.json(&json!({ "name": self.model }))
.send()
.map_err(|e| {
format!(
"recall: cannot reach Ollama at {} ({e}). Start it with `ollama serve`.",
self.base_url
)
})?;
if resp.status().is_success() {
self.ready = true;
return Ok(());
}
if resp.status() != reqwest::StatusCode::NOT_FOUND {
return Err(format!(
"recall: /api/show returned {} for {}",
resp.status(),
self.model
));
}
eprintln!(
"{} pulling embed model {} (~270MB, one-time) ...",
crate::theme::SAVE,
crate::theme::accent(&self.model)
);
let pull_url = format!("{}/api/pull", self.base_url);
let pull_resp = self
.client
.post(&pull_url)
.timeout(std::time::Duration::from_secs(600))
.json(&json!({ "name": self.model, "stream": false }))
.send()
.map_err(|e| format!("recall: /api/pull request failed: {e}"))?;
if !pull_resp.status().is_success() {
return Err(format!(
"recall: /api/pull returned {} for {} — try `ollama pull {}` manually",
pull_resp.status(),
self.model,
self.model
));
}
eprintln!(
"{} {} ready",
crate::theme::ok(crate::theme::OK_GLYPH),
crate::theme::ok(&self.model)
);
self.ready = true;
Ok(())
}
}
impl Embedder for OllamaEmbedder {
fn embed(&mut self, text: &str) -> Result<Vec<f32>, String> {
self.ensure_model()?;
let url = format!("{}/api/embeddings", self.base_url);
let resp = self
.client
.post(&url)
.json(&json!({ "model": self.model, "prompt": text }))
.send()
.map_err(|e| format!("recall: /api/embeddings request: {e}"))?;
if !resp.status().is_success() {
return Err(format!("recall: /api/embeddings HTTP {}", resp.status()));
}
let body: Value = resp
.json()
.map_err(|e| format!("recall: /api/embeddings parse: {e}"))?;
parse_ollama_embedding(&body)
}
}
fn parse_ollama_embedding(body: &Value) -> Result<Vec<f32>, String> {
let arr = body
.get("embedding")
.and_then(Value::as_array)
.ok_or_else(|| format!("recall: response missing 'embedding': {body}"))?;
json_array_to_f32s(arr)
}
fn parse_openai_compat_embedding(body: &Value) -> Result<Vec<f32>, String> {
let arr = body
.get("data")
.and_then(Value::as_array)
.and_then(|d| d.first())
.and_then(|d| d.get("embedding"))
.and_then(Value::as_array)
.ok_or_else(|| format!("recall: response missing 'data[0].embedding': {body}"))?;
json_array_to_f32s(arr)
}
fn json_array_to_f32s(arr: &[Value]) -> Result<Vec<f32>, String> {
let mut out = Vec::with_capacity(arr.len());
for v in arr {
let f = v
.as_f64()
.ok_or_else(|| "recall: non-numeric value in 'embedding'".to_string())?;
out.push(f as f32);
}
if out.is_empty() {
return Err("recall: empty embedding returned".to_string());
}
Ok(out)
}
pub struct OpenAICompatEmbedder {
client: reqwest::blocking::Client,
base_url: String,
model: String,
}
impl OpenAICompatEmbedder {
pub fn new() -> Result<Self, String> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()
.map_err(|e| format!("recall: build http client: {e}"))?;
let model = std::env::var("CLAUDETTE_RECALL_MODEL")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
Ok(Self {
client,
base_url: resolve_ollama_url(),
model,
})
}
}
impl Embedder for OpenAICompatEmbedder {
fn embed(&mut self, text: &str) -> Result<Vec<f32>, String> {
let url = format!("{}/v1/embeddings", self.base_url);
let resp = self
.client
.post(&url)
.json(&json!({ "model": self.model, "input": text }))
.send()
.map_err(|e| {
format!(
"recall: cannot reach OpenAI-compat server at {} ({e}). \
Is LM Studio running on this port?",
self.base_url
)
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().unwrap_or_default();
return Err(format!(
"recall: /v1/embeddings HTTP {status} — load `{}` in LM Studio's \
Local Server tab (or set CLAUDETTE_RECALL_MODEL to a model id you have loaded). \
Body: {body}",
self.model
));
}
let body: Value = resp
.json()
.map_err(|e| format!("recall: /v1/embeddings parse: {e}"))?;
parse_openai_compat_embedding(&body)
}
}
fn store_cell() -> &'static Mutex<Option<RecallStore>> {
static CELL: OnceLock<Mutex<Option<RecallStore>>> = OnceLock::new();
CELL.get_or_init(|| Mutex::new(None))
}
fn lock_store() -> Result<MutexGuard<'static, Option<RecallStore>>, String> {
store_cell()
.lock()
.map_err(|e| format!("recall: store lock poisoned: {e}"))
}
fn ensure_store(guard: &mut MutexGuard<'static, Option<RecallStore>>) -> Result<(), String> {
if guard.is_some() {
return Ok(());
}
let embedder: Box<dyn Embedder> = if resolve_openai_compat() {
Box::new(OpenAICompatEmbedder::new()?)
} else {
Box::new(OllamaEmbedder::new()?)
};
let store = RecallStore::open(default_recall_db_path(), embedder)?;
**guard = Some(store);
Ok(())
}
#[cfg(test)]
pub fn reset_global() {
if let Ok(mut guard) = lock_store() {
*guard = None;
}
}
#[cfg(test)]
pub fn install_global_for_test(store: RecallStore) {
if let Ok(mut guard) = lock_store() {
*guard = Some(store);
}
}
pub fn global_index(role: Role, snippet: &str) -> Result<(), String> {
let mut guard = lock_store()?;
ensure_store(&mut guard)?;
guard
.as_mut()
.expect("ensure_store left store None")
.index(role, snippet)
}
pub fn global_query(query: &str, k: usize) -> Result<Vec<RecallHit>, String> {
let mut guard = lock_store()?;
ensure_store(&mut guard)?;
guard
.as_mut()
.expect("ensure_store left store None")
.query(query, k)
}
pub fn probe() -> Result<(), String> {
let mut guard = lock_store()?;
ensure_store(&mut guard)?;
guard
.as_mut()
.expect("ensure_store left store None")
.embedder
.embed("probe")
.map(|_| ())
}
#[cfg(test)]
mod tests {
use super::*;
struct HashEmbedder {
dim: usize,
}
impl HashEmbedder {
fn new() -> Self {
Self { dim: 8 }
}
}
impl Embedder for HashEmbedder {
fn embed(&mut self, text: &str) -> Result<Vec<f32>, String> {
let mut v = vec![0.0_f32; self.dim];
for ch in text.chars() {
let bucket = (ch as usize) % self.dim;
v[bucket] += 1.0;
}
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
Ok(v)
}
}
struct ConstEmbedder;
impl Embedder for ConstEmbedder {
fn embed(&mut self, _text: &str) -> Result<Vec<f32>, String> {
Ok(vec![1.0, 0.0, 0.0, 0.0])
}
}
struct FailingEmbedder;
impl Embedder for FailingEmbedder {
fn embed(&mut self, _text: &str) -> Result<Vec<f32>, String> {
Err(
"recall: /v1/embeddings HTTP 400 — load `nomic-embed-text` in LM Studio's \
Local Server tab"
.to_string(),
)
}
}
#[test]
fn embedder_failure_propagates_as_error_string() {
let mut e = FailingEmbedder;
let err = e.embed("probe").expect_err("FailingEmbedder must fail");
assert!(err.contains("Local Server tab"), "got: {err}");
}
#[test]
fn probe_through_store_returns_err_on_embedder_failure() {
let mut store = RecallStore::open_in_memory(Box::new(FailingEmbedder)).expect("open");
let err = store
.embedder
.embed("probe")
.expect_err("FailingEmbedder must fail");
assert!(err.contains("HTTP 400"), "got: {err}");
}
#[test]
fn encode_decode_roundtrip() {
let v = vec![0.0, 1.0, -1.5, std::f32::consts::PI, f32::EPSILON];
let bytes = encode_vec(&v);
assert_eq!(bytes.len(), v.len() * 4);
let back = decode_vec(&bytes).expect("decode");
assert_eq!(back, v);
}
#[test]
fn decode_rejects_misaligned_bytes() {
let err = decode_vec(&[1, 2, 3]).expect_err("should reject 3-byte input");
assert!(err.contains("multiple of 4"), "got: {err}");
}
#[test]
fn cosine_handles_zero_vectors() {
assert!(cosine_similarity(&[0.0, 0.0], &[0.0, 0.0]).abs() < 1e-9);
assert!(cosine_similarity(&[1.0, 0.0], &[0.0, 0.0]).abs() < 1e-9);
}
#[test]
fn cosine_is_one_for_identical() {
let a = vec![0.5, 0.5, 0.5];
assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_is_zero_for_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn cosine_returns_zero_for_mismatched_length() {
assert!(cosine_similarity(&[1.0, 0.0], &[1.0, 0.0, 0.0]).abs() < 1e-9);
}
#[test]
fn store_roundtrip_indexes_and_queries() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
store
.index(Role::User, "the meeting with brian is on tuesday")
.unwrap();
store
.index(Role::Assistant, "got it, brian + tuesday noted")
.unwrap();
store
.index(Role::User, "completely unrelated content about weather")
.unwrap();
let hits = store.query("when is brian's meeting", 2).unwrap();
assert_eq!(hits.len(), 2);
for hit in &hits {
assert!(
!hit.snippet.contains("weather"),
"weather should not be in top-2: {hits:?}"
);
}
}
#[test]
fn store_skips_empty_snippets() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
store.index(Role::User, "").unwrap();
store.index(Role::User, " \t\n ").unwrap();
assert_eq!(store.count().unwrap(), 0);
}
#[test]
fn empty_query_returns_empty_results() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
store.index(Role::User, "hello").unwrap();
assert!(store.query("", 5).unwrap().is_empty());
assert!(store.query(" ", 5).unwrap().is_empty());
assert!(store.query("hello", 0).unwrap().is_empty());
}
#[test]
fn fifo_eviction_at_cap() {
let mut store = RecallStore::open_in_memory(Box::new(ConstEmbedder))
.expect("open")
.with_cap(3);
store.index(Role::User, "first").unwrap();
store.index(Role::User, "second").unwrap();
store.index(Role::User, "third").unwrap();
assert_eq!(store.count().unwrap(), 3);
store.index(Role::User, "fourth").unwrap();
assert_eq!(store.count().unwrap(), 3, "cap should hold");
let hits = store.query("any", 10).unwrap();
let snippets: Vec<&str> = hits.iter().map(|h| h.snippet.as_str()).collect();
assert!(!snippets.contains(&"first"), "oldest evicted: {snippets:?}");
assert!(snippets.contains(&"second"));
assert!(snippets.contains(&"third"));
assert!(snippets.contains(&"fourth"));
}
#[test]
fn long_snippet_is_truncated() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
let huge = "x".repeat(20_000);
store.index(Role::User, &huge).unwrap();
let hits = store.query("xxxx", 1).unwrap();
assert_eq!(hits.len(), 1);
assert!(
hits[0].snippet.len() <= 8 * 1024,
"snippet should be capped at 8KB, got {}",
hits[0].snippet.len()
);
}
#[test]
fn long_multibyte_snippet_does_not_panic_on_cap() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
let huge = "é".repeat(20_000); store.index(Role::User, &huge).unwrap();
let hits = store.query("é", 1).unwrap();
assert_eq!(hits.len(), 1);
assert!(hits[0].snippet.len() <= 8 * 1024);
assert!(
hits[0].snippet.is_char_boundary(hits[0].snippet.len()),
"stored snippet must end on a char boundary"
);
}
#[test]
fn results_are_sorted_descending_by_score() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
for snippet in [
"the cat sat on the mat",
"weather forecast for next tuesday",
"the cat stretched across the rug",
"currency exchange rates today",
] {
store.index(Role::User, snippet).unwrap();
}
let hits = store.query("cat on mat", 4).unwrap();
assert!(
hits[0].snippet.contains("cat"),
"top hit should be cat-related: {:?}",
hits[0]
);
for w in hits.windows(2) {
assert!(
w[0].score >= w[1].score,
"results must be descending by score"
);
}
}
#[test]
fn query_returns_only_top_k_via_heap() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
for i in 0..20 {
store
.index(
Role::User,
&format!("snippet number {i} talking about cats"),
)
.unwrap();
}
let hits = store.query("cats", 3).unwrap();
assert_eq!(hits.len(), 3, "exactly k results");
for w in hits.windows(2) {
assert!(
w[0].score >= w[1].score,
"results must be descending: {:?}",
hits
);
}
}
#[test]
fn query_skips_rows_with_mismatched_dim() {
let mut store = RecallStore::open_in_memory(Box::new(HashEmbedder::new())).expect("open");
store
.conn
.execute(
"INSERT INTO recall (ts, role, snippet, vec) VALUES ('2026-01-01T00:00:00Z', 'user', 'old-model row', ?1)",
params![encode_vec(&[1.0_f32, 0.0, 0.0])],
)
.expect("seed insert");
store
.index(Role::User, "modern cat content matching the embedder dim")
.unwrap();
let hits = store.query("cat", 5).unwrap();
assert_eq!(hits.len(), 1);
assert!(hits[0].snippet.contains("modern cat"));
}
#[test]
fn role_parse_roundtrip() {
for r in [Role::User, Role::Assistant] {
assert_eq!(Role::parse(r.as_str()), Some(r));
}
assert_eq!(Role::parse("system"), None);
}
#[test]
fn parse_ollama_embedding_happy_path() {
let body = json!({ "embedding": [0.1, 0.2, -0.3, 0.0, 1.5] });
let v = parse_ollama_embedding(&body).expect("parse");
assert_eq!(v.len(), 5);
assert!((v[0] - 0.1).abs() < 1e-6);
assert!((v[2] - -0.3).abs() < 1e-6);
}
#[test]
fn parse_ollama_embedding_rejects_missing_field() {
let body = json!({ "data": [] });
let err = parse_ollama_embedding(&body).expect_err("should fail");
assert!(err.contains("missing 'embedding'"), "got: {err}");
}
#[test]
fn parse_openai_compat_embedding_happy_path() {
let body = json!({
"object": "list",
"data": [
{
"object": "embedding",
"index": 0,
"embedding": [0.42, -0.17, 0.99]
}
],
"model": "nomic-embed-text-v1.5",
"usage": { "prompt_tokens": 4, "total_tokens": 4 }
});
let v = parse_openai_compat_embedding(&body).expect("parse");
assert_eq!(v.len(), 3);
assert!((v[0] - 0.42).abs() < 1e-6);
}
#[test]
fn parse_openai_compat_embedding_rejects_missing_data() {
let body = json!({ "embedding": [0.1, 0.2] });
let err = parse_openai_compat_embedding(&body).expect_err("should fail");
assert!(err.contains("'data[0].embedding'"), "got: {err}");
}
#[test]
fn parse_openai_compat_embedding_rejects_empty_data_array() {
let body = json!({ "object": "list", "data": [] });
let err = parse_openai_compat_embedding(&body).expect_err("should fail");
assert!(err.contains("'data[0].embedding'"), "got: {err}");
}
#[test]
fn json_array_to_f32s_rejects_empty() {
let err = json_array_to_f32s(&[]).expect_err("should fail");
assert!(err.contains("empty embedding"), "got: {err}");
}
#[test]
fn json_array_to_f32s_rejects_non_numeric() {
let arr = vec![json!(0.5), json!("not a number")];
let err = json_array_to_f32s(&arr).expect_err("should fail");
assert!(err.contains("non-numeric"), "got: {err}");
}
#[test]
#[ignore = "requires live LM Studio with embed model loaded"]
fn recall_live_openai_compat_embed_is_deterministic() {
let mut e = OpenAICompatEmbedder::new().expect("construct embedder");
let v1 = e
.embed("hello from claudette recall smoke")
.expect("embed 1");
let v2 = e
.embed("hello from claudette recall smoke")
.expect("embed 2");
assert!(!v1.is_empty(), "got empty vector");
assert!(
v1.len() >= 256,
"expected an embedding ≥256 dims, got {}",
v1.len()
);
assert_eq!(v1.len(), v2.len(), "dim should be stable across calls");
let cos = cosine_similarity(&v1, &v2);
assert!(
(cos - 1.0).abs() < 1e-3,
"same input should produce ~identical vectors, cos={cos}"
);
}
#[test]
#[ignore = "requires live LM Studio with embed model loaded"]
fn recall_live_full_index_query_roundtrip() {
let embedder: Box<dyn Embedder> =
Box::new(OpenAICompatEmbedder::new().expect("construct embedder"));
let mut store = RecallStore::open_in_memory(embedder).expect("open store");
store
.index(Role::User, "the meeting with brian is on tuesday at 3pm")
.unwrap();
store
.index(Role::Assistant, "got it — brian, tuesday 3pm noted")
.unwrap();
store
.index(
Role::User,
"completely unrelated content about the weather forecast for next week",
)
.unwrap();
store
.index(
Role::User,
"another tangent about currency exchange rates today",
)
.unwrap();
let hits = store.query("when is brian's meeting", 2).expect("query");
assert_eq!(hits.len(), 2, "asked for top-2: {hits:?}");
for h in &hits {
assert!(
!h.snippet.contains("weather") && !h.snippet.contains("currency"),
"off-topic snippet leaked into top-2: {h:?}"
);
}
assert!(
hits[0].snippet.contains("brian") || hits[0].snippet.contains("tuesday"),
"top hit should be brian-related, got: {:?}",
hits[0]
);
assert!(
hits[0].score > 0.5,
"top hit score too low ({}); embedder may be returning noise",
hits[0].score
);
}
}