use crate::core::{Blob, Id, PlacedPoint, Point};
use crate::core::config::ArmsConfig;
use crate::ports::{Near, NearResult, Place, PlaceResult, SearchResult};
use crate::adapters::storage::MemoryStorage;
use crate::adapters::index::FlatIndex;
pub struct Arms {
config: ArmsConfig,
storage: Box<dyn Place>,
index: Box<dyn Near>,
}
impl Arms {
pub fn new(config: ArmsConfig) -> Self {
let storage = Box::new(MemoryStorage::new(config.dimensionality));
let index = Box::new(FlatIndex::new(
config.dimensionality,
config.proximity.clone(),
true, ));
Self {
config,
storage,
index,
}
}
pub fn with_adapters(
config: ArmsConfig,
storage: Box<dyn Place>,
index: Box<dyn Near>,
) -> Self {
Self {
config,
storage,
index,
}
}
pub fn config(&self) -> &ArmsConfig {
&self.config
}
pub fn dimensionality(&self) -> usize {
self.config.dimensionality
}
pub fn place(&mut self, point: Point, blob: Blob) -> PlaceResult<Id> {
let point = if self.config.normalize_on_insert {
point.normalize()
} else {
point
};
let id = self.storage.place(point.clone(), blob)?;
if let Err(e) = self.index.add(id, &point) {
self.storage.remove(id);
return Err(crate::ports::PlaceError::StorageError(format!(
"Index error: {:?}",
e
)));
}
Ok(id)
}
pub fn place_batch(&mut self, items: Vec<(Point, Blob)>) -> Vec<PlaceResult<Id>> {
items
.into_iter()
.map(|(point, blob)| self.place(point, blob))
.collect()
}
pub fn remove(&mut self, id: Id) -> Option<PlacedPoint> {
let _ = self.index.remove(id);
self.storage.remove(id)
}
pub fn get(&self, id: Id) -> Option<&PlacedPoint> {
self.storage.get(id)
}
pub fn contains(&self, id: Id) -> bool {
self.storage.contains(id)
}
pub fn len(&self) -> usize {
self.storage.len()
}
pub fn is_empty(&self) -> bool {
self.storage.is_empty()
}
pub fn clear(&mut self) {
self.storage.clear();
let _ = self.index.rebuild(); }
pub fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
let query = if self.config.normalize_on_insert {
query.normalize()
} else {
query.clone()
};
self.index.near(&query, k)
}
pub fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
let query = if self.config.normalize_on_insert {
query.normalize()
} else {
query.clone()
};
self.index.within(&query, threshold)
}
pub fn near_with_data(&self, query: &Point, k: usize) -> NearResult<Vec<(&PlacedPoint, f32)>> {
let results = self.near(query, k)?;
Ok(results
.into_iter()
.filter_map(|r| self.storage.get(r.id).map(|p| (p, r.score)))
.collect())
}
pub fn merge(&self, points: &[Point]) -> Point {
self.config.merge.merge(points)
}
pub fn proximity(&self, a: &Point, b: &Point) -> f32 {
self.config.proximity.proximity(a, b)
}
pub fn size_bytes(&self) -> usize {
self.storage.size_bytes()
}
pub fn index_len(&self) -> usize {
self.index.len()
}
pub fn is_ready(&self) -> bool {
self.index.is_ready()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_arms() -> Arms {
Arms::new(ArmsConfig::new(3))
}
#[test]
fn test_arms_place_and_get() {
let mut arms = create_test_arms();
let point = Point::new(vec![1.0, 0.0, 0.0]);
let blob = Blob::from_str("test data");
let id = arms.place(point, blob).unwrap();
let retrieved = arms.get(id).unwrap();
assert_eq!(retrieved.blob.as_str(), Some("test data"));
}
#[test]
fn test_arms_near() {
let mut arms = create_test_arms();
arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
arms.place(Point::new(vec![0.0, 0.0, 1.0]), Blob::from_str("z")).unwrap();
let query = Point::new(vec![1.0, 0.0, 0.0]);
let results = arms.near(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].score > results[1].score);
}
#[test]
fn test_arms_near_with_data() {
let mut arms = create_test_arms();
arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
let query = Point::new(vec![1.0, 0.0, 0.0]);
let results = arms.near_with_data(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0.blob.as_str(), Some("x"));
}
#[test]
fn test_arms_remove() {
let mut arms = create_test_arms();
let id = arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::empty()).unwrap();
assert!(arms.contains(id));
assert_eq!(arms.len(), 1);
arms.remove(id);
assert!(!arms.contains(id));
assert_eq!(arms.len(), 0);
}
#[test]
fn test_arms_merge() {
let arms = create_test_arms();
let points = vec![
Point::new(vec![1.0, 0.0, 0.0]),
Point::new(vec![0.0, 1.0, 0.0]),
];
let merged = arms.merge(&points);
assert!((merged.dims()[0] - 0.5).abs() < 0.0001);
assert!((merged.dims()[1] - 0.5).abs() < 0.0001);
assert!((merged.dims()[2] - 0.0).abs() < 0.0001);
}
#[test]
fn test_arms_clear() {
let mut arms = create_test_arms();
for i in 0..10 {
arms.place(Point::new(vec![i as f32, 0.0, 0.0]), Blob::empty()).unwrap();
}
assert_eq!(arms.len(), 10);
arms.clear();
assert_eq!(arms.len(), 0);
assert!(arms.is_empty());
}
#[test]
fn test_arms_normalizes_on_insert() {
let mut arms = create_test_arms();
let point = Point::new(vec![3.0, 4.0, 0.0]); let id = arms.place(point, Blob::empty()).unwrap();
let retrieved = arms.get(id).unwrap();
assert!(retrieved.point.is_normalized());
}
}