arms/engine/
arms.rs

1//! # Arms Engine
2//!
3//! The main ARMS orchestrator.
4//!
5//! This struct wires together:
6//! - Storage (Place port)
7//! - Index (Near port)
8//! - Configuration
9//!
10//! And exposes a unified API for storing and retrieving points.
11
12use crate::core::{Blob, Id, PlacedPoint, Point};
13use crate::core::config::ArmsConfig;
14use crate::ports::{Near, NearResult, Place, PlaceResult, SearchResult};
15use crate::adapters::storage::MemoryStorage;
16use crate::adapters::index::FlatIndex;
17
18/// The main ARMS engine
19///
20/// Orchestrates storage and indexing with a unified API.
21pub struct Arms {
22    /// Configuration
23    config: ArmsConfig,
24
25    /// Storage backend (Place port)
26    storage: Box<dyn Place>,
27
28    /// Index backend (Near port)
29    index: Box<dyn Near>,
30}
31
32impl Arms {
33    /// Create a new ARMS instance with default adapters
34    ///
35    /// Uses MemoryStorage and FlatIndex.
36    /// For production, use `Arms::with_adapters` with appropriate backends.
37    pub fn new(config: ArmsConfig) -> Self {
38        let storage = Box::new(MemoryStorage::new(config.dimensionality));
39        let index = Box::new(FlatIndex::new(
40            config.dimensionality,
41            config.proximity.clone(),
42            true, // Assuming cosine-like similarity by default
43        ));
44
45        Self {
46            config,
47            storage,
48            index,
49        }
50    }
51
52    /// Create with custom adapters
53    pub fn with_adapters(
54        config: ArmsConfig,
55        storage: Box<dyn Place>,
56        index: Box<dyn Near>,
57    ) -> Self {
58        Self {
59            config,
60            storage,
61            index,
62        }
63    }
64
65    /// Get the configuration
66    pub fn config(&self) -> &ArmsConfig {
67        &self.config
68    }
69
70    /// Get the dimensionality of this space
71    pub fn dimensionality(&self) -> usize {
72        self.config.dimensionality
73    }
74
75    // ========================================================================
76    // PLACE OPERATIONS
77    // ========================================================================
78
79    /// Place a point in the space
80    ///
81    /// The point will be normalized if configured to do so.
82    /// Returns the assigned ID.
83    pub fn place(&mut self, point: Point, blob: Blob) -> PlaceResult<Id> {
84        // Normalize if configured
85        let point = if self.config.normalize_on_insert {
86            point.normalize()
87        } else {
88            point
89        };
90
91        // Store in storage
92        let id = self.storage.place(point.clone(), blob)?;
93
94        // Add to index
95        if let Err(e) = self.index.add(id, &point) {
96            // Rollback storage if index fails
97            self.storage.remove(id);
98            return Err(crate::ports::PlaceError::StorageError(format!(
99                "Index error: {:?}",
100                e
101            )));
102        }
103
104        Ok(id)
105    }
106
107    /// Place multiple points at once
108    pub fn place_batch(&mut self, items: Vec<(Point, Blob)>) -> Vec<PlaceResult<Id>> {
109        items
110            .into_iter()
111            .map(|(point, blob)| self.place(point, blob))
112            .collect()
113    }
114
115    /// Remove a point from the space
116    pub fn remove(&mut self, id: Id) -> Option<PlacedPoint> {
117        // Remove from index first
118        let _ = self.index.remove(id);
119
120        // Then from storage
121        self.storage.remove(id)
122    }
123
124    /// Get a point by ID
125    pub fn get(&self, id: Id) -> Option<&PlacedPoint> {
126        self.storage.get(id)
127    }
128
129    /// Check if a point exists
130    pub fn contains(&self, id: Id) -> bool {
131        self.storage.contains(id)
132    }
133
134    /// Get the number of stored points
135    pub fn len(&self) -> usize {
136        self.storage.len()
137    }
138
139    /// Check if the space is empty
140    pub fn is_empty(&self) -> bool {
141        self.storage.is_empty()
142    }
143
144    /// Clear all points
145    pub fn clear(&mut self) {
146        self.storage.clear();
147        let _ = self.index.rebuild(); // Reset index
148    }
149
150    // ========================================================================
151    // NEAR OPERATIONS
152    // ========================================================================
153
154    /// Find k nearest points to query
155    pub fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
156        // Normalize query if configured
157        let query = if self.config.normalize_on_insert {
158            query.normalize()
159        } else {
160            query.clone()
161        };
162
163        self.index.near(&query, k)
164    }
165
166    /// Find all points within threshold
167    pub fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
168        let query = if self.config.normalize_on_insert {
169            query.normalize()
170        } else {
171            query.clone()
172        };
173
174        self.index.within(&query, threshold)
175    }
176
177    /// Find and retrieve k nearest points (with full data)
178    pub fn near_with_data(&self, query: &Point, k: usize) -> NearResult<Vec<(&PlacedPoint, f32)>> {
179        let results = self.near(query, k)?;
180
181        Ok(results
182            .into_iter()
183            .filter_map(|r| self.storage.get(r.id).map(|p| (p, r.score)))
184            .collect())
185    }
186
187    // ========================================================================
188    // MERGE OPERATIONS
189    // ========================================================================
190
191    /// Merge multiple points into one using the configured merge function
192    pub fn merge(&self, points: &[Point]) -> Point {
193        self.config.merge.merge(points)
194    }
195
196    /// Compute proximity between two points
197    pub fn proximity(&self, a: &Point, b: &Point) -> f32 {
198        self.config.proximity.proximity(a, b)
199    }
200
201    // ========================================================================
202    // STATS
203    // ========================================================================
204
205    /// Get storage size in bytes
206    pub fn size_bytes(&self) -> usize {
207        self.storage.size_bytes()
208    }
209
210    /// Get index stats
211    pub fn index_len(&self) -> usize {
212        self.index.len()
213    }
214
215    /// Check if index is ready
216    pub fn is_ready(&self) -> bool {
217        self.index.is_ready()
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    fn create_test_arms() -> Arms {
226        Arms::new(ArmsConfig::new(3))
227    }
228
229    #[test]
230    fn test_arms_place_and_get() {
231        let mut arms = create_test_arms();
232
233        let point = Point::new(vec![1.0, 0.0, 0.0]);
234        let blob = Blob::from_str("test data");
235
236        let id = arms.place(point, blob).unwrap();
237
238        let retrieved = arms.get(id).unwrap();
239        assert_eq!(retrieved.blob.as_str(), Some("test data"));
240    }
241
242    #[test]
243    fn test_arms_near() {
244        let mut arms = create_test_arms();
245
246        // Add some points
247        arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
248        arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
249        arms.place(Point::new(vec![0.0, 0.0, 1.0]), Blob::from_str("z")).unwrap();
250
251        // Query
252        let query = Point::new(vec![1.0, 0.0, 0.0]);
253        let results = arms.near(&query, 2).unwrap();
254
255        assert_eq!(results.len(), 2);
256        // First result should have highest similarity
257        assert!(results[0].score > results[1].score);
258    }
259
260    #[test]
261    fn test_arms_near_with_data() {
262        let mut arms = create_test_arms();
263
264        arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
265        arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
266
267        let query = Point::new(vec![1.0, 0.0, 0.0]);
268        let results = arms.near_with_data(&query, 1).unwrap();
269
270        assert_eq!(results.len(), 1);
271        assert_eq!(results[0].0.blob.as_str(), Some("x"));
272    }
273
274    #[test]
275    fn test_arms_remove() {
276        let mut arms = create_test_arms();
277
278        let id = arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::empty()).unwrap();
279
280        assert!(arms.contains(id));
281        assert_eq!(arms.len(), 1);
282
283        arms.remove(id);
284
285        assert!(!arms.contains(id));
286        assert_eq!(arms.len(), 0);
287    }
288
289    #[test]
290    fn test_arms_merge() {
291        let arms = create_test_arms();
292
293        let points = vec![
294            Point::new(vec![1.0, 0.0, 0.0]),
295            Point::new(vec![0.0, 1.0, 0.0]),
296        ];
297
298        let merged = arms.merge(&points);
299
300        // Mean of [1,0,0] and [0,1,0] = [0.5, 0.5, 0]
301        assert!((merged.dims()[0] - 0.5).abs() < 0.0001);
302        assert!((merged.dims()[1] - 0.5).abs() < 0.0001);
303        assert!((merged.dims()[2] - 0.0).abs() < 0.0001);
304    }
305
306    #[test]
307    fn test_arms_clear() {
308        let mut arms = create_test_arms();
309
310        for i in 0..10 {
311            arms.place(Point::new(vec![i as f32, 0.0, 0.0]), Blob::empty()).unwrap();
312        }
313
314        assert_eq!(arms.len(), 10);
315
316        arms.clear();
317
318        assert_eq!(arms.len(), 0);
319        assert!(arms.is_empty());
320    }
321
322    #[test]
323    fn test_arms_normalizes_on_insert() {
324        let mut arms = create_test_arms();
325
326        // Insert a non-normalized point
327        let point = Point::new(vec![3.0, 4.0, 0.0]); // magnitude = 5
328        let id = arms.place(point, Blob::empty()).unwrap();
329
330        let retrieved = arms.get(id).unwrap();
331
332        // Should be normalized
333        assert!(retrieved.point.is_normalized());
334    }
335}