use std::collections::HashMap;
use std::sync::Arc;
pub struct EmbeddingStore {
data: Vec<(i64, String, Vec<f32>)>,
path_index: HashMap<String, usize>,
id_index: HashMap<i64, usize>,
}
impl EmbeddingStore {
pub fn from_data(data: Vec<(i64, String, Vec<f32>)>) -> Arc<Self> {
let path_index: HashMap<String, usize> = data
.iter()
.enumerate()
.map(|(i, (_, path, _))| (path.clone(), i))
.collect();
let id_index: HashMap<i64, usize> = data
.iter()
.enumerate()
.map(|(i, (id, _, _))| (*id, i))
.collect();
Arc::new(Self { data, path_index, id_index })
}
pub fn all(&self) -> &[(i64, String, Vec<f32>)] {
&self.data
}
#[deprecated(since = "0.1.4", note = "use get_arc_by_path to avoid the Vec clone")]
pub fn get_by_path(&self, path: &str) -> Option<Vec<f32>> {
let &i = self.path_index.get(path)?;
Some(self.data[i].2.clone())
}
#[deprecated(since = "0.1.4", note = "use get_arc_by_id to avoid the Vec clone")]
pub fn get_by_id(&self, id: i64) -> Option<Vec<f32>> {
let &i = self.id_index.get(&id)?;
Some(self.data[i].2.clone())
}
pub fn get_arc_by_path(&self, path: &str) -> Option<Arc<[f32]>> {
let &i = self.path_index.get(path)?;
Some(Arc::from(self.data[i].2.as_slice()))
}
pub fn get_arc_by_id(&self, id: i64) -> Option<Arc<[f32]>> {
let &i = self.id_index.get(&id)?;
Some(Arc::from(self.data[i].2.as_slice()))
}
pub fn iter_arc(&self) -> impl Iterator<Item = (i64, &str, Arc<[f32]>)> + '_ {
self.data.iter().map(|(id, path, emb)| {
(*id, path.as_str(), Arc::from(emb.as_slice()))
})
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn dim(&self) -> usize {
self.data.first().map(|(_, _, emb)| emb.len()).unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_store() -> Arc<EmbeddingStore> {
let data = vec![
(1, "a.md".to_string(), vec![1.0, 0.0, 0.0]),
(2, "b.md".to_string(), vec![0.0, 1.0, 0.0]),
];
EmbeddingStore::from_data(data)
}
#[test]
fn test_store_from_data() {
let store = sample_store();
assert_eq!(store.len(), 2);
assert!(!store.is_empty());
assert_eq!(store.dim(), 3);
#[allow(deprecated)]
{
assert!(store.get_by_path("a.md").is_some());
assert!(store.get_by_path("nonexistent.md").is_none());
assert!(store.get_by_id(1).is_some());
assert!(store.get_by_id(999).is_none());
}
}
#[test]
fn test_get_arc_by_path() {
let store = sample_store();
let arc = store.get_arc_by_path("a.md").expect("a.md should exist");
assert_eq!(arc.len(), 3);
assert!((arc[0] - 1.0).abs() < 1e-6);
assert!(store.get_arc_by_path("missing.md").is_none());
}
#[test]
fn test_get_arc_by_id() {
let store = sample_store();
let arc = store.get_arc_by_id(2).expect("id 2 should exist");
assert_eq!(arc.len(), 3);
assert!((arc[1] - 1.0).abs() < 1e-6);
assert!(store.get_arc_by_id(999).is_none());
}
#[test]
fn test_iter_arc() {
let store = sample_store();
let entries: Vec<_> = store.iter_arc().collect();
assert_eq!(entries.len(), 2);
let (id, path, arc) = &entries[0];
assert_eq!(*id, 1);
assert_eq!(*path, "a.md");
assert_eq!(arc.len(), 3);
}
}