use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[cfg(feature = "storage")]
use crate::error::Error;
use crate::error::Result;
#[cfg(feature = "storage")]
use crate::recursive::db::Connection;
#[inline]
pub fn memory() -> Memory<HashEmbedder> {
Memory::new()
}
pub trait Embedder: Send + Sync {
fn embed(&self, text: &str) -> Vec<f32>;
fn dimension(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct HashEmbedder {
dimension: usize,
}
impl HashEmbedder {
pub fn new(dimension: usize) -> Self {
Self { dimension }
}
}
impl Default for HashEmbedder {
fn default() -> Self {
Self::new(64)
}
}
impl Embedder for HashEmbedder {
fn embed(&self, text: &str) -> Vec<f32> {
let mut embedding = vec![0.0f32; self.dimension];
for word in text.split_whitespace() {
let mut hasher = DefaultHasher::new();
word.hash(&mut hasher);
let hash = hasher.finish();
let pos = (hash as usize) % self.dimension;
let value = ((hash >> 32) as f32) / (u32::MAX as f32);
embedding[pos] += value;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
embedding
}
fn dimension(&self) -> usize {
self.dimension
}
}
#[derive(Debug, Clone)]
pub struct Document {
pub id: String,
pub content: String,
pub embedding: Vec<f32>,
pub tag: Option<String>,
}
#[derive(Debug, Clone)]
pub struct Recall {
pub id: String,
pub content: String,
pub score: f64,
pub tag: Option<String>,
}
struct InMemoryStore {
documents: Vec<Document>,
next_id: u64,
}
impl InMemoryStore {
fn new() -> Self {
Self {
documents: Vec::new(),
next_id: 0,
}
}
fn add(&mut self, content: String, embedding: Vec<f32>, tag: Option<String>) -> String {
let id = format!("doc:{}", self.next_id);
self.next_id += 1;
self.documents.push(Document {
id: id.clone(),
content,
embedding,
tag,
});
id
}
fn add_with_id(
&mut self,
id: String,
content: String,
embedding: Vec<f32>,
tag: Option<String>,
) {
self.documents.retain(|d| d.id != id);
self.documents.push(Document {
id,
content,
embedding,
tag,
});
}
fn get(&self, id: &str) -> Option<&Document> {
self.documents.iter().find(|d| d.id == id)
}
fn search(&self, query_embedding: &[f32], k: usize) -> Vec<Recall> {
let mut scored: Vec<_> = self
.documents
.iter()
.map(|doc| {
let score = cosine_similarity(query_embedding, &doc.embedding);
(doc, score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored
.into_iter()
.take(k)
.map(|(doc, score)| Recall {
id: doc.id.clone(),
content: doc.content.clone(),
score,
tag: doc.tag.clone(),
})
.collect()
}
fn update(&mut self, id: &str, content: String, embedding: Vec<f32>) -> bool {
if let Some(doc) = self.documents.iter_mut().find(|d| d.id == id) {
doc.content = content;
doc.embedding = embedding;
true
} else {
false
}
}
fn remove(&mut self, id: &str) -> bool {
let len_before = self.documents.len();
self.documents.retain(|d| d.id != id);
self.documents.len() < len_before
}
fn all(&self) -> Vec<Recall> {
self.documents
.iter()
.map(|doc| Recall {
id: doc.id.clone(),
content: doc.content.clone(),
score: 1.0,
tag: doc.tag.clone(),
})
.collect()
}
fn len(&self) -> usize {
self.documents.len()
}
#[allow(dead_code)]
fn is_empty(&self) -> bool {
self.documents.is_empty()
}
fn tags(&self) -> Vec<String> {
let mut tags: Vec<_> = self
.documents
.iter()
.filter_map(|d| d.tag.clone())
.collect();
tags.sort();
tags.dedup();
tags
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
(dot / (norm_a * norm_b)) as f64
}
}
pub trait VectorIndex: Send + Sync {
fn insert(&mut self, id: usize, embedding: &[f32]);
fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f64)>;
fn remove(&mut self, id: usize);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, Default)]
pub struct LinearIndex {
embeddings: Vec<(usize, Vec<f32>)>,
}
impl LinearIndex {
pub fn new() -> Self {
Self {
embeddings: Vec::new(),
}
}
}
impl VectorIndex for LinearIndex {
fn insert(&mut self, id: usize, embedding: &[f32]) {
self.embeddings.retain(|(i, _)| *i != id);
self.embeddings.push((id, embedding.to_vec()));
}
fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f64)> {
let mut scored: Vec<_> = self
.embeddings
.iter()
.map(|(id, emb)| (*id, cosine_similarity(query, emb)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(k);
scored
}
fn remove(&mut self, id: usize) {
self.embeddings.retain(|(i, _)| *i != id);
}
fn len(&self) -> usize {
self.embeddings.len()
}
}
pub fn mmr_select(
_query_embedding: &[f32],
doc_embeddings: &[(usize, Vec<f32>, f64)],
k: usize,
lambda: f64,
) -> Vec<(usize, f64)> {
if doc_embeddings.is_empty() || k == 0 {
return Vec::new();
}
let lambda = lambda.clamp(0.0, 1.0);
let mut selected: Vec<(usize, f64)> = Vec::with_capacity(k);
let mut remaining: Vec<_> = doc_embeddings.iter().collect();
while selected.len() < k && !remaining.is_empty() {
let best_idx = remaining
.iter()
.enumerate()
.map(|(idx, (doc_id, doc_emb, relevance))| {
let max_sim_to_selected = if selected.is_empty() {
0.0
} else {
selected
.iter()
.filter_map(|(sel_id, _)| {
doc_embeddings
.iter()
.find(|(id, _, _)| id == sel_id)
.map(|(_, sel_emb, _)| cosine_similarity(doc_emb, sel_emb))
})
.fold(0.0f64, |a, b| a.max(b))
};
let mmr_score = lambda * relevance - (1.0 - lambda) * max_sim_to_selected;
(idx, *doc_id, mmr_score)
})
.max_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
if let Some((idx, doc_id, score)) = best_idx {
selected.push((doc_id, score));
remaining.remove(idx);
} else {
break;
}
}
selected
}
enum MemoryStore {
InMemory(InMemoryStore),
#[cfg(feature = "storage")]
Persistent {
conn: Connection,
dimension: usize,
db_path: String,
},
}
fn content_hash_id(content: &str) -> String {
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
format!("upsert:{:016x}", hasher.finish())
}
impl MemoryStore {
fn insert(
&mut self,
id: &str,
content: &str,
embedding: &[f32],
tag: Option<&str>,
) -> Result<()> {
match self {
MemoryStore::InMemory(store) => {
store.add_with_id(
id.to_string(),
content.to_string(),
embedding.to_vec(),
tag.map(|t| t.to_string()),
);
Ok(())
}
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let embedding_bytes = embedding_to_bytes(embedding);
conn.execute(
"INSERT OR REPLACE INTO documents (id, content, embedding, tag) VALUES (?, ?, ?, ?)",
duckdb::params![id, content, &embedding_bytes, &tag],
)
.map_err(|e| Error::memory("insert", e.to_string(), "Ensure the database file is not locked by another process and has write permissions."))?;
Ok(())
}
}
}
fn add(&mut self, content: &str, embedding: &[f32], tag: Option<&str>) -> Result<String> {
match self {
MemoryStore::InMemory(store) => Ok(store.add(
content.to_string(),
embedding.to_vec(),
tag.map(|t| t.to_string()),
)),
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let id = format!("doc:{}", uuid_v4());
let embedding_bytes = embedding_to_bytes(embedding);
conn.execute(
"INSERT OR REPLACE INTO documents (id, content, embedding, tag) VALUES (?, ?, ?, ?)",
duckdb::params![&id, content, &embedding_bytes, &tag],
)
.map_err(|e| Error::memory("add", e.to_string(), "Ensure the database file is not locked by another process and has write permissions."))?;
Ok(id)
}
}
}
fn fetch_all_with_embeddings(&self) -> Result<Vec<(String, String, Vec<f32>, Option<String>)>> {
match self {
MemoryStore::InMemory(store) => Ok(store
.documents
.iter()
.map(|d| {
(
d.id.clone(),
d.content.clone(),
d.embedding.clone(),
d.tag.clone(),
)
})
.collect()),
#[cfg(feature = "storage")]
MemoryStore::Persistent {
conn, dimension, ..
} => {
let mut stmt = conn
.prepare("SELECT id, content, embedding, tag FROM documents")
.map_err(|e| {
Error::memory("search", e.to_string(), "Check database integrity.")
})?;
Ok(stmt
.query_map([], |row| {
let id: String = row.get(0)?;
let content: String = row.get(1)?;
let embedding_bytes: Vec<u8> = row.get(2)?;
let tag: Option<String> = row.get(3)?;
let embedding = bytes_to_embedding(&embedding_bytes, *dimension);
Ok((id, content, embedding, tag))
})
.map_err(|e| {
Error::memory("search", e.to_string(), "Check database integrity.")
})?
.filter_map(|r| r.ok())
.collect())
}
}
}
fn update_doc(&mut self, id: &str, content: &str, embedding: &[f32]) -> Result<bool> {
match self {
MemoryStore::InMemory(store) => {
Ok(store.update(id, content.to_string(), embedding.to_vec()))
}
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let embedding_bytes = embedding_to_bytes(embedding);
conn.execute(
"UPDATE documents SET content = ?, embedding = ? WHERE id = ?",
duckdb::params![content, &embedding_bytes, id],
)
.map(|n| n > 0)
.map_err(|e| {
Error::memory(
"update",
e.to_string(),
"The document may have been removed concurrently.",
)
})
}
}
}
fn remove_doc(&mut self, id: &str) -> Result<bool> {
match self {
MemoryStore::InMemory(store) => Ok(store.remove(id)),
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => conn
.execute("DELETE FROM documents WHERE id = ?", [id])
.map(|n| n > 0)
.map_err(|e| {
Error::memory("remove", e.to_string(), "Ensure the database is writable.")
}),
}
}
fn all_docs(&self) -> Result<Vec<Recall>> {
match self {
MemoryStore::InMemory(store) => Ok(store.all()),
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let mut stmt = conn
.prepare("SELECT id, content, tag FROM documents")
.map_err(|e| {
Error::memory("all", e.to_string(), "Check database integrity.")
})?;
Ok(stmt
.query_map([], |row| {
Ok(Recall {
id: row.get(0)?,
content: row.get(1)?,
score: 1.0,
tag: row.get(2)?,
})
})
.map_err(|e| Error::memory("all", e.to_string(), "Check database integrity."))?
.filter_map(|r| r.ok())
.collect())
}
}
}
fn count(&self) -> Result<usize> {
match self {
MemoryStore::InMemory(store) => Ok(store.len()),
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let mut stmt = conn
.prepare("SELECT COUNT(*) FROM documents")
.map_err(|e| {
Error::memory(
"len",
e.to_string(),
"The database may be corrupt; try recreating it.",
)
})?;
let count = stmt
.query_row([], |row| row.get::<_, i64>(0))
.map_err(|e| {
Error::memory(
"len",
e.to_string(),
"The database may be corrupt; try recreating it.",
)
})?;
Ok(count as usize)
}
}
}
fn unique_tags(&self) -> Result<Vec<String>> {
match self {
MemoryStore::InMemory(store) => Ok(store.tags()),
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let mut stmt = conn
.prepare(
"SELECT DISTINCT tag FROM documents WHERE tag IS NOT NULL ORDER BY tag",
)
.map_err(|e| {
Error::memory("tags", e.to_string(), "Check database integrity.")
})?;
Ok(stmt
.query_map([], |row| row.get(0))
.map_err(|e| Error::memory("tags", e.to_string(), "Check database integrity."))?
.filter_map(|r| r.ok())
.collect())
}
}
}
}
pub struct Memory<E: Embedder = HashEmbedder> {
embedder: E,
store: MemoryStore,
k: usize,
learn_threshold: Option<f64>,
mmr_lambda: Option<f64>,
}
impl<E: Embedder> Memory<E> {
pub fn embedder(&self) -> &E {
&self.embedder
}
pub fn k(&self) -> usize {
self.k
}
pub fn set_k(&mut self, k: usize) {
self.k = k;
}
}
impl Memory<HashEmbedder> {
pub fn new() -> Self {
Self {
embedder: HashEmbedder::default(),
store: MemoryStore::InMemory(InMemoryStore::new()),
k: 3,
learn_threshold: None,
mmr_lambda: None,
}
}
}
impl Default for Memory<HashEmbedder> {
fn default() -> Self {
Self::new()
}
}
impl<E: Embedder> Memory<E> {
pub fn with_embedder(embedder: E) -> Self {
Self {
embedder,
store: MemoryStore::InMemory(InMemoryStore::new()),
k: 3,
learn_threshold: None,
mmr_lambda: None,
}
}
pub fn embedder_with<E2: Embedder>(self, embedder: E2) -> Memory<E2> {
Memory {
embedder,
store: MemoryStore::InMemory(InMemoryStore::new()),
k: self.k,
learn_threshold: self.learn_threshold,
mmr_lambda: self.mmr_lambda,
}
}
pub fn diversity(mut self, lambda: f64) -> Self {
self.mmr_lambda = Some(lambda.clamp(0.0, 1.0));
self
}
pub fn with_k(mut self, k: usize) -> Self {
self.k = k;
self
}
pub fn learn_above(mut self, threshold: f64) -> Self {
self.learn_threshold = Some(threshold);
self
}
#[cfg(feature = "storage")]
pub fn persist(mut self, path: &str) -> Result<Self> {
let conn = Connection::open(path).map_err(|e| Error::storage(e.to_string()))?;
let dimension = self.embedder.dimension();
conn.execute_batch(
r#"
CREATE TABLE IF NOT EXISTS documents (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
embedding BLOB NOT NULL,
tag TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_tag ON documents(tag);
"#,
)
.map_err(|e| Error::storage(e.to_string()))?;
self.store = MemoryStore::Persistent {
conn,
dimension,
db_path: path.to_string(),
};
Ok(self)
}
pub fn seed_if_empty<I, S1, S2>(mut self, items: I) -> Result<Self>
where
I: IntoIterator<Item = (S1, S2)>,
S1: Into<String>,
S2: Into<String>,
{
if self.is_empty()? {
for (question, answer) in items {
let content = format!("Q: {}\nA: {}", question.into(), answer.into());
self.add(&content)?;
}
}
Ok(self)
}
fn insert_doc(&mut self, id: Option<&str>, content: &str, tag: Option<&str>) -> Result<String> {
let embedding = self.embedder.embed(content);
match id {
Some(id) => {
self.store.insert(id, content, &embedding, tag)?;
Ok(id.to_string())
}
None => self.store.add(content, &embedding, tag),
}
}
pub fn add(&mut self, content: &str) -> Result<String> {
self.insert_doc(None, content, None)
}
pub fn add_with_id(&mut self, id: impl Into<String>, content: &str) -> Result<()> {
let id = id.into();
let embedding = self.embedder.embed(content);
self.store.insert(&id, content, &embedding, None)
}
pub fn add_tagged(&mut self, tag: &str, content: &str) -> Result<String> {
self.insert_doc(None, content, Some(tag))
}
pub fn get(&self, id: &str) -> Option<String> {
match &self.store {
MemoryStore::InMemory(store) => store.get(id).map(|d| d.content.clone()),
#[cfg(feature = "storage")]
MemoryStore::Persistent { conn, .. } => {
let mut stmt = conn
.prepare("SELECT content FROM documents WHERE id = ?")
.ok()?;
stmt.query_row([id], |row| row.get(0)).ok()
}
}
}
pub fn search(&self, query: &str, k: usize) -> Result<Vec<Recall>> {
if let Some(lambda) = self.mmr_lambda {
return self.search_diverse(query, k, lambda);
}
let query_embedding = self.embedder.embed(query);
match &self.store {
MemoryStore::InMemory(store) => Ok(store.search(&query_embedding, k)),
#[cfg(feature = "storage")]
_ => {
let all = self.store.fetch_all_with_embeddings()?;
let mut results: Vec<Recall> = all
.into_iter()
.map(|(id, content, embedding, tag)| {
let score = cosine_similarity(&query_embedding, &embedding);
Recall {
id,
content,
score,
tag,
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
}
}
pub fn search_default(&self, query: &str) -> Result<Vec<Recall>> {
self.search(query, self.k)
}
pub fn search_diverse(&self, query: &str, k: usize, lambda: f64) -> Result<Vec<Recall>> {
let query_embedding = self.embedder.embed(query);
let docs = self.store.fetch_all_with_embeddings()?;
let doc_data: Vec<_> = docs
.iter()
.enumerate()
.map(|(idx, (_, _, emb, _))| {
let relevance = cosine_similarity(&query_embedding, emb);
(idx, emb.clone(), relevance)
})
.collect();
let selected = mmr_select(&query_embedding, &doc_data, k, lambda);
Ok(selected
.into_iter()
.filter_map(|(idx, score)| {
docs.get(idx).map(|(id, content, _, tag)| Recall {
id: id.clone(),
content: content.clone(),
score,
tag: tag.clone(),
})
})
.collect())
}
pub fn update(&mut self, id: &str, content: &str) -> Result<bool> {
let embedding = self.embedder.embed(content);
self.store.update_doc(id, content, &embedding)
}
pub fn upsert(&mut self, content: &str) -> Result<String> {
let id = content_hash_id(content);
let embedding = self.embedder.embed(content);
self.store.insert(&id, content, &embedding, None)?;
Ok(id)
}
pub fn upsert_tagged(&mut self, tag: &str, content: &str) -> Result<String> {
let id = content_hash_id(content);
let embedding = self.embedder.embed(content);
self.store.insert(&id, content, &embedding, Some(tag))?;
Ok(id)
}
pub fn search_above(&self, query: &str, k: usize, min_score: f64) -> Result<Vec<Recall>> {
Ok(self
.search(query, k)?
.into_iter()
.filter(|r| r.score >= min_score)
.collect())
}
pub fn remove(&mut self, id: &str) -> Result<bool> {
self.store.remove_doc(id)
}
pub fn all(&self) -> Result<Vec<Recall>> {
self.store.all_docs()
}
pub fn len(&self) -> Result<usize> {
self.store.count()
}
pub fn is_empty(&self) -> Result<bool> {
Ok(self.store.count()? == 0)
}
pub fn tags(&self) -> Result<Vec<String>> {
self.store.unique_tags()
}
pub fn learn(&mut self, question: &str, output: &str, score: f64) -> Result<()> {
if let Some(threshold) = self.learn_threshold {
if score >= threshold {
let content = format!("Q: {}\nA: {}", question, output);
self.add(&content)?;
}
}
Ok(())
}
#[cfg(feature = "storage")]
pub fn db_path(&self) -> Option<&str> {
match &self.store {
MemoryStore::Persistent { db_path, .. } => Some(db_path.as_str()),
_ => None,
}
}
#[cfg(feature = "storage")]
pub fn package(&self, name: &str) -> Result<crate::recursive::packager::PackagerBuilder<'_>> {
match &self.store {
MemoryStore::Persistent { db_path, .. } => Ok(
crate::recursive::packager::PackagerBuilder::new(std::path::Path::new(
db_path.as_str(),
))
.name_owned(name.to_string()),
),
_ => Err(Error::storage(
"Cannot package in-memory store. Call .persist(path) first.",
)),
}
}
}
#[cfg(feature = "storage")]
fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
}
#[cfg(feature = "storage")]
fn bytes_to_embedding(bytes: &[u8], dimension: usize) -> Vec<f32> {
bytes
.chunks_exact(4)
.take(dimension)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
#[cfg(feature = "storage")]
fn uuid_v4() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
format!("{:032x}", nanos)
}
#[cfg(feature = "embeddings-onnx")]
pub struct OnnxEmbedder {
session: std::sync::Mutex<ort::session::Session>,
tokenizer: tokenizers::Tokenizer,
dimension: usize,
max_length: usize,
normalize: bool,
}
#[cfg(feature = "embeddings-onnx")]
impl std::fmt::Debug for OnnxEmbedder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OnnxEmbedder")
.field("dimension", &self.dimension)
.field("max_length", &self.max_length)
.field("normalize", &self.normalize)
.finish()
}
}
#[cfg(feature = "embeddings-onnx")]
impl OnnxEmbedder {
pub fn from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, OnnxEmbedderError> {
let path = path.as_ref();
let model_path = path.join("model.onnx");
let tokenizer_path = path.join("tokenizer.json");
if !model_path.exists() {
return Err(OnnxEmbedderError::ModelNotFound(
model_path.display().to_string(),
));
}
if !tokenizer_path.exists() {
return Err(OnnxEmbedderError::TokenizerNotFound(
tokenizer_path.display().to_string(),
));
}
let session = ort::session::Session::builder()
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?
.with_optimization_level(ort::session::builder::GraphOptimizationLevel::Level3)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?
.with_intra_threads(1)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?
.commit_from_file(&model_path)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?;
let dimension = 384;
let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
.map_err(|e| OnnxEmbedderError::TokenizerError(e.to_string()))?;
Ok(Self {
session: std::sync::Mutex::new(session),
tokenizer,
dimension,
max_length: 512,
normalize: true,
})
}
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
pub fn without_normalization(mut self) -> Self {
self.normalize = false;
self
}
fn run_inference(&self, text: &str) -> Result<Vec<f32>, OnnxEmbedderError> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| OnnxEmbedderError::TokenizerError(e.to_string()))?;
let input_ids: Vec<i64> = encoding
.get_ids()
.iter()
.take(self.max_length)
.map(|&id| id as i64)
.collect();
let attention_mask: Vec<i64> = encoding
.get_attention_mask()
.iter()
.take(self.max_length)
.map(|&m| m as i64)
.collect();
let seq_len = input_ids.len();
let input_ids_array =
ndarray::Array2::from_shape_vec((1, seq_len), input_ids).map_err(|e| {
OnnxEmbedderError::OrtError(format!("Failed to create input_ids array: {}", e))
})?;
let attention_mask_array =
ndarray::Array2::from_shape_vec((1, seq_len), attention_mask.clone()).map_err(|e| {
OnnxEmbedderError::OrtError(format!("Failed to create attention_mask array: {}", e))
})?;
let token_type_ids: Vec<i64> = vec![0; seq_len];
let token_type_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), token_type_ids)
.map_err(|e| {
OnnxEmbedderError::OrtError(format!("Failed to create token_type_ids array: {}", e))
})?;
let input_ids_value = ort::value::Tensor::from_array(input_ids_array)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?;
let attention_mask_value = ort::value::Tensor::from_array(attention_mask_array)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?;
let token_type_ids_value = ort::value::Tensor::from_array(token_type_ids_array)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?;
let mut session = self
.session
.lock()
.map_err(|e| OnnxEmbedderError::OrtError(format!("Session lock poisoned: {}", e)))?;
let output_name = session
.outputs()
.first()
.map(|o| o.name().to_string())
.ok_or_else(|| OnnxEmbedderError::OrtError("No output in model".to_string()))?;
let inputs: Vec<(
std::borrow::Cow<'_, str>,
ort::session::SessionInputValue<'_>,
)> = vec![
(
std::borrow::Cow::Borrowed("input_ids"),
input_ids_value.into(),
),
(
std::borrow::Cow::Borrowed("attention_mask"),
attention_mask_value.into(),
),
(
std::borrow::Cow::Borrowed("token_type_ids"),
token_type_ids_value.into(),
),
];
let outputs = session
.run(inputs)
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?;
let output_value = outputs
.get(&output_name)
.ok_or_else(|| OnnxEmbedderError::OrtError("No output tensor".to_string()))?;
let tensor_data = output_value
.try_extract_tensor::<f32>()
.map_err(|e| OnnxEmbedderError::OrtError(e.to_string()))?;
let shape: Vec<i64> = tensor_data.0.iter().map(|&x| x as i64).collect();
let data: Vec<f32> = tensor_data.1.to_vec();
let embedding = Self::mean_pool_static(
self.dimension,
self.normalize,
&shape,
&data,
&attention_mask,
);
Ok(embedding)
}
fn mean_pool_static(
dimension: usize,
normalize: bool,
shape: &[i64],
data: &[f32],
mask: &[i64],
) -> Vec<f32> {
if shape.len() != 3 {
return data.iter().take(dimension).copied().collect();
}
let seq_len = shape[1] as usize;
let hidden_size = shape[2] as usize;
let mut pooled = vec![0.0f32; hidden_size];
let mut total_weight = 0.0f32;
for (t, &m) in mask.iter().enumerate().take(seq_len) {
if m > 0 {
let weight = m as f32;
for h in 0..hidden_size {
let idx = t * hidden_size + h;
if let Some(&val) = data.get(idx) {
pooled[h] += val * weight;
}
}
total_weight += weight;
}
}
if total_weight > 0.0 {
for p in &mut pooled {
*p /= total_weight;
}
}
if normalize {
let norm: f32 = pooled.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for p in &mut pooled {
*p /= norm;
}
}
}
pooled
}
}
#[cfg(feature = "embeddings-onnx")]
impl Embedder for OnnxEmbedder {
fn embed(&self, text: &str) -> Vec<f32> {
self.run_inference(text).unwrap_or_else(|_| {
vec![0.0; self.dimension]
})
}
fn dimension(&self) -> usize {
self.dimension
}
}
#[cfg(feature = "embeddings-onnx")]
#[derive(Debug, Clone)]
pub enum OnnxEmbedderError {
ModelNotFound(String),
TokenizerNotFound(String),
OrtError(String),
TokenizerError(String),
}
#[cfg(feature = "embeddings-onnx")]
impl std::fmt::Display for OnnxEmbedderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ModelNotFound(path) => write!(f, "Model file not found: {}", path),
Self::TokenizerNotFound(path) => write!(f, "Tokenizer file not found: {}", path),
Self::OrtError(msg) => write!(f, "ONNX Runtime error: {}", msg),
Self::TokenizerError(msg) => write!(f, "Tokenizer error: {}", msg),
}
}
}
#[cfg(feature = "embeddings-onnx")]
impl std::error::Error for OnnxEmbedderError {}
#[cfg(feature = "hnsw")]
#[derive(Debug, Clone)]
pub struct HnswIndex {
vectors: Vec<(usize, Vec<f32>)>,
layers: Vec<Vec<Vec<usize>>>,
entry_point: Option<usize>,
m: usize,
m0: usize,
ml: f64,
ef_construction: usize,
ef_search: usize,
dimension: usize,
id_to_index: std::collections::HashMap<usize, usize>,
rng_state: u64,
}
#[cfg(feature = "hnsw")]
impl HnswIndex {
pub fn new(dimension: usize) -> Self {
Self {
vectors: Vec::new(),
layers: Vec::new(),
entry_point: None,
m: 16, m0: 32, ml: 1.0 / (16.0_f64).ln(),
ef_construction: 200,
ef_search: 50,
dimension,
id_to_index: std::collections::HashMap::new(),
rng_state: 42,
}
}
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self.m0 = 2 * m;
self.ml = 1.0 / (m as f64).ln();
self
}
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
fn random_level(&mut self) -> usize {
self.rng_state ^= self.rng_state << 13;
self.rng_state ^= self.rng_state >> 7;
self.rng_state ^= self.rng_state << 17;
let random = (self.rng_state as f64) / (u64::MAX as f64);
let level = (-random.ln() * self.ml).floor() as usize;
level.min(15) }
fn get_embedding(&self, idx: usize) -> Option<&[f32]> {
self.vectors.get(idx).map(|(_, v)| v.as_slice())
}
fn search_layer(
&self,
query: &[f32],
entry_points: &[usize],
ef: usize,
layer: usize,
) -> Vec<(usize, f64)> {
use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashSet};
let mut visited: HashSet<usize> = HashSet::new();
let mut candidates: BinaryHeap<Reverse<(OrderedFloat, usize)>> = BinaryHeap::new();
let mut results: BinaryHeap<(OrderedFloat, usize)> = BinaryHeap::new();
for &ep in entry_points {
if let Some(emb) = self.get_embedding(ep) {
let dist = 1.0 - cosine_similarity(query, emb);
visited.insert(ep);
candidates.push(Reverse((OrderedFloat(dist), ep)));
results.push((OrderedFloat(dist), ep));
}
}
while let Some(Reverse((OrderedFloat(c_dist), c_idx))) = candidates.pop() {
let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f64::MAX);
if c_dist > f_dist && results.len() >= ef {
break;
}
if let Some(neighbors) = self.layers.get(layer).and_then(|l| l.get(c_idx)) {
for &neighbor in neighbors {
if visited.insert(neighbor) {
if let Some(emb) = self.get_embedding(neighbor) {
let dist = 1.0 - cosine_similarity(query, emb);
let f_dist = results.peek().map(|(d, _)| d.0).unwrap_or(f64::MAX);
if dist < f_dist || results.len() < ef {
candidates.push(Reverse((OrderedFloat(dist), neighbor)));
results.push((OrderedFloat(dist), neighbor));
while results.len() > ef {
results.pop();
}
}
}
}
}
}
}
results
.into_sorted_vec()
.into_iter()
.map(|(OrderedFloat(dist), idx)| (idx, 1.0 - dist))
.collect()
}
fn select_neighbors(&self, candidates: &[(usize, f64)], m: usize) -> Vec<usize> {
candidates.iter().take(m).map(|(idx, _)| *idx).collect()
}
fn connect_neighbors(
&mut self,
node_idx: usize,
neighbors: &[usize],
layer: usize,
max_connections: usize,
) {
while self.layers.len() <= layer {
self.layers.push(Vec::new());
}
while self.layers[layer].len() <= node_idx {
self.layers[layer].push(Vec::new());
}
self.layers[layer][node_idx] = neighbors.to_vec();
for &neighbor in neighbors {
while self.layers[layer].len() <= neighbor {
self.layers[layer].push(Vec::new());
}
let needs_add = !self.layers[layer][neighbor].contains(&node_idx);
if needs_add {
self.layers[layer][neighbor].push(node_idx);
if self.layers[layer][neighbor].len() > max_connections {
let neighbor_emb = self.vectors.get(neighbor).map(|(_, v)| v.clone());
if let Some(emb) = neighbor_emb {
let current_connections = self.layers[layer][neighbor].clone();
let mut scored: Vec<_> = current_connections
.iter()
.filter_map(|&n| {
self.vectors
.get(n)
.map(|(_, e)| (n, cosine_similarity(&emb, e)))
})
.collect();
scored.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
self.layers[layer][neighbor] = scored
.into_iter()
.take(max_connections)
.map(|(n, _)| n)
.collect();
}
}
}
}
}
}
#[cfg(feature = "hnsw")]
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedFloat(f64);
#[cfg(feature = "hnsw")]
impl Eq for OrderedFloat {}
#[cfg(feature = "hnsw")]
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
#[cfg(feature = "hnsw")]
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[cfg(feature = "hnsw")]
impl VectorIndex for HnswIndex {
fn insert(&mut self, id: usize, embedding: &[f32]) {
if embedding.len() != self.dimension {
return;
}
if self.id_to_index.contains_key(&id) {
self.remove(id);
}
let node_idx = self.vectors.len();
self.vectors.push((id, embedding.to_vec()));
self.id_to_index.insert(id, node_idx);
let node_level = self.random_level();
if self.entry_point.is_none() {
self.entry_point = Some(node_idx);
for l in 0..=node_level {
while self.layers.len() <= l {
self.layers.push(Vec::new());
}
while self.layers[l].len() <= node_idx {
self.layers[l].push(Vec::new());
}
}
return;
}
let entry = self.entry_point.unwrap();
let max_layer = self.layers.len().saturating_sub(1);
let mut curr_entry = vec![entry];
for layer in (node_level + 1..=max_layer).rev() {
let results = self.search_layer(embedding, &curr_entry, 1, layer);
if let Some((best, _)) = results.first() {
curr_entry = vec![*best];
}
}
for layer in (0..=node_level.min(max_layer)).rev() {
let max_conn = if layer == 0 { self.m0 } else { self.m };
let candidates = self.search_layer(embedding, &curr_entry, self.ef_construction, layer);
let neighbors = self.select_neighbors(&candidates, max_conn);
self.connect_neighbors(node_idx, &neighbors, layer, max_conn);
curr_entry = neighbors;
}
if node_level > max_layer {
self.entry_point = Some(node_idx);
}
}
fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f64)> {
if self.vectors.is_empty() || self.entry_point.is_none() {
return Vec::new();
}
let entry = self.entry_point.unwrap();
let max_layer = self.layers.len().saturating_sub(1);
let mut curr_entry = vec![entry];
for layer in (1..=max_layer).rev() {
let results = self.search_layer(query, &curr_entry, 1, layer);
if let Some((best, _)) = results.first() {
curr_entry = vec![*best];
}
}
let mut results = self.search_layer(query, &curr_entry, self.ef_search.max(k), 0);
results.truncate(k);
results
.into_iter()
.filter_map(|(idx, score)| self.vectors.get(idx).map(|(id, _)| (*id, score)))
.collect()
}
fn remove(&mut self, id: usize) {
if let Some(&idx) = self.id_to_index.get(&id) {
for layer in &mut self.layers {
if idx < layer.len() {
layer[idx].clear();
}
for neighbors in layer.iter_mut() {
neighbors.retain(|&n| n != idx);
}
}
self.id_to_index.remove(&id);
if self.entry_point == Some(idx) {
self.entry_point = self.id_to_index.values().copied().next();
}
}
}
fn len(&self) -> usize {
self.id_to_index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_embedder() {
let embedder = HashEmbedder::new(64);
let embedding = embedder.embed("hello world");
assert_eq!(embedding.len(), 64);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < f64::EPSILON);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < f64::EPSILON);
}
#[test]
fn test_memory_add_and_get() {
let mut mem = memory();
let id = mem.add("test content").unwrap();
assert!(mem.get(&id).is_some());
assert_eq!(mem.get(&id).unwrap(), "test content");
}
#[test]
fn test_memory_search() {
let mut mem = memory();
mem.add("How to parse JSON in Rust? Use serde_json")
.unwrap();
mem.add("How to read a file? Use std::fs").unwrap();
mem.add("How to make HTTP requests? Use reqwest").unwrap();
let results = mem.search("parse JSON Rust", 2).unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].content.contains("JSON"));
}
#[test]
fn test_memory_update() {
let mut mem = memory();
let id = mem.add("original").unwrap();
assert!(mem.update(&id, "updated").unwrap());
assert_eq!(mem.get(&id).unwrap(), "updated");
}
#[test]
fn test_memory_remove() {
let mut mem = memory();
let id = mem.add("to delete").unwrap();
assert!(mem.remove(&id).unwrap());
assert!(mem.get(&id).is_none());
}
#[test]
fn test_memory_tagged() {
let mut mem = memory();
mem.add_tagged("rust", "Rust content").unwrap();
mem.add_tagged("python", "Python content").unwrap();
let tags = mem.tags().unwrap();
assert!(tags.contains(&"rust".to_string()));
assert!(tags.contains(&"python".to_string()));
}
#[test]
fn test_memory_seed_if_empty() {
let mem = memory()
.seed_if_empty([("What is Rust?", "A systems programming language")])
.unwrap();
assert_eq!(mem.len().unwrap(), 1);
let results = mem.search("Rust", 1).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_memory_learn() {
let mut mem = memory().learn_above(0.8);
mem.learn("question", "bad answer", 0.5).unwrap();
assert_eq!(mem.len().unwrap(), 0);
mem.learn("question", "good answer", 0.9).unwrap();
assert_eq!(mem.len().unwrap(), 1);
}
#[test]
fn test_memory_len_and_empty() {
let mut mem = memory();
assert!(mem.is_empty().unwrap());
assert_eq!(mem.len().unwrap(), 0);
mem.add("doc1").unwrap();
assert!(!mem.is_empty().unwrap());
assert_eq!(mem.len().unwrap(), 1);
}
#[test]
fn test_memory_all() {
let mut mem = memory();
mem.add("doc1").unwrap();
mem.add("doc2").unwrap();
mem.add("doc3").unwrap();
let all = mem.all().unwrap();
assert_eq!(all.len(), 3);
}
#[test]
fn test_memory_with_k() {
let mem = memory().with_k(5);
assert_eq!(mem.k(), 5);
}
#[test]
fn test_linear_index() {
let mut index = LinearIndex::new();
assert!(index.is_empty());
index.insert(0, &[1.0, 0.0, 0.0]);
index.insert(1, &[0.0, 1.0, 0.0]);
index.insert(2, &[0.5, 0.5, 0.0]);
assert_eq!(index.len(), 3);
let results = index.search(&[1.0, 0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); }
#[test]
fn test_linear_index_remove() {
let mut index = LinearIndex::new();
index.insert(0, &[1.0, 0.0]);
index.insert(1, &[0.0, 1.0]);
assert_eq!(index.len(), 2);
index.remove(0);
assert_eq!(index.len(), 1);
}
#[test]
fn test_mmr_select_basic() {
let doc_embeddings = vec![
(0, vec![1.0, 0.0, 0.0], 0.9), (1, vec![0.9, 0.1, 0.0], 0.85), (2, vec![0.0, 1.0, 0.0], 0.5), ];
let selected = mmr_select(&[1.0, 0.0, 0.0], &doc_embeddings, 2, 1.0);
assert_eq!(selected.len(), 2);
assert_eq!(selected[0].0, 0);
let selected_diverse = mmr_select(&[1.0, 0.0, 0.0], &doc_embeddings, 2, 0.3);
assert_eq!(selected_diverse.len(), 2);
}
#[test]
fn test_mmr_select_empty() {
let selected = mmr_select(&[1.0, 0.0], &[], 3, 0.5);
assert!(selected.is_empty());
}
#[test]
fn test_memory_diversity() {
let mut mem = memory().diversity(0.5);
mem.add("Document about Rust programming").unwrap();
mem.add("Another document about Rust language").unwrap();
mem.add("Python is great for ML").unwrap();
let results = mem.search("Rust programming language", 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_memory_search_diverse() {
let mut mem = memory();
mem.add("Rust systems programming").unwrap();
mem.add("Rust memory safety").unwrap();
mem.add("Python data science").unwrap();
let results = mem.search_diverse("programming language", 2, 0.5).unwrap();
assert_eq!(results.len(), 2);
}
#[cfg(feature = "hnsw")]
mod hnsw_tests {
use super::*;
#[test]
fn test_hnsw_basic() {
let mut index = HnswIndex::new(3);
assert!(index.is_empty());
index.insert(0, &[1.0, 0.0, 0.0]);
index.insert(1, &[0.0, 1.0, 0.0]);
index.insert(2, &[0.0, 0.0, 1.0]);
assert_eq!(index.len(), 3);
}
#[test]
fn test_hnsw_search() {
let mut index = HnswIndex::new(3);
index.insert(0, &[1.0, 0.0, 0.0]);
index.insert(1, &[0.9, 0.1, 0.0]);
index.insert(2, &[0.0, 1.0, 0.0]);
let results = index.search(&[1.0, 0.0, 0.0], 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0);
}
#[test]
fn test_hnsw_remove() {
let mut index = HnswIndex::new(3);
index.insert(0, &[1.0, 0.0, 0.0]);
index.insert(1, &[0.0, 1.0, 0.0]);
assert_eq!(index.len(), 2);
index.remove(0);
assert_eq!(index.len(), 1);
let results = index.search(&[0.0, 1.0, 0.0], 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 1);
}
#[test]
fn test_hnsw_update() {
let mut index = HnswIndex::new(3);
index.insert(0, &[1.0, 0.0, 0.0]);
index.insert(1, &[0.0, 1.0, 0.0]);
index.insert(0, &[0.0, 0.9, 0.1]);
let results = index.search(&[0.0, 1.0, 0.0], 2);
assert!(results.iter().any(|(id, _)| *id == 0));
assert!(results.iter().any(|(id, _)| *id == 1));
}
#[test]
fn test_hnsw_builder_pattern() {
let index = HnswIndex::new(64)
.with_m(32)
.with_ef_construction(400)
.with_ef_search(100);
assert_eq!(index.dimension, 64);
assert_eq!(index.m, 32);
assert_eq!(index.ef_construction, 400);
assert_eq!(index.ef_search, 100);
}
#[test]
fn test_hnsw_many_vectors() {
let mut index = HnswIndex::new(16);
for i in 0..100 {
let mut vec = vec![0.0f32; 16];
vec[i % 16] = 1.0;
vec[(i + 1) % 16] = 0.5;
index.insert(i, &vec);
}
assert_eq!(index.len(), 100);
let query = vec![1.0f32; 16];
let results = index.search(&query, 10);
assert_eq!(results.len(), 10);
}
#[test]
fn test_hnsw_empty_search() {
let index = HnswIndex::new(3);
let results = index.search(&[1.0, 0.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn test_hnsw_wrong_dimension() {
let mut index = HnswIndex::new(3);
index.insert(0, &[1.0, 0.0]); assert_eq!(index.len(), 0);
index.insert(0, &[1.0, 0.0, 0.0]);
assert_eq!(index.len(), 1);
}
}
}