1use crate::core::{Blob, Id, PlacedPoint, Point};
13use crate::core::config::ArmsConfig;
14use crate::ports::{Near, NearResult, Place, PlaceResult, SearchResult};
15use crate::adapters::storage::MemoryStorage;
16use crate::adapters::index::FlatIndex;
17
18pub struct Arms {
22 config: ArmsConfig,
24
25 storage: Box<dyn Place>,
27
28 index: Box<dyn Near>,
30}
31
32impl Arms {
33 pub fn new(config: ArmsConfig) -> Self {
38 let storage = Box::new(MemoryStorage::new(config.dimensionality));
39 let index = Box::new(FlatIndex::new(
40 config.dimensionality,
41 config.proximity.clone(),
42 true, ));
44
45 Self {
46 config,
47 storage,
48 index,
49 }
50 }
51
52 pub fn with_adapters(
54 config: ArmsConfig,
55 storage: Box<dyn Place>,
56 index: Box<dyn Near>,
57 ) -> Self {
58 Self {
59 config,
60 storage,
61 index,
62 }
63 }
64
65 pub fn config(&self) -> &ArmsConfig {
67 &self.config
68 }
69
70 pub fn dimensionality(&self) -> usize {
72 self.config.dimensionality
73 }
74
75 pub fn place(&mut self, point: Point, blob: Blob) -> PlaceResult<Id> {
84 let point = if self.config.normalize_on_insert {
86 point.normalize()
87 } else {
88 point
89 };
90
91 let id = self.storage.place(point.clone(), blob)?;
93
94 if let Err(e) = self.index.add(id, &point) {
96 self.storage.remove(id);
98 return Err(crate::ports::PlaceError::StorageError(format!(
99 "Index error: {:?}",
100 e
101 )));
102 }
103
104 Ok(id)
105 }
106
107 pub fn place_batch(&mut self, items: Vec<(Point, Blob)>) -> Vec<PlaceResult<Id>> {
109 items
110 .into_iter()
111 .map(|(point, blob)| self.place(point, blob))
112 .collect()
113 }
114
115 pub fn remove(&mut self, id: Id) -> Option<PlacedPoint> {
117 let _ = self.index.remove(id);
119
120 self.storage.remove(id)
122 }
123
124 pub fn get(&self, id: Id) -> Option<&PlacedPoint> {
126 self.storage.get(id)
127 }
128
129 pub fn contains(&self, id: Id) -> bool {
131 self.storage.contains(id)
132 }
133
134 pub fn len(&self) -> usize {
136 self.storage.len()
137 }
138
139 pub fn is_empty(&self) -> bool {
141 self.storage.is_empty()
142 }
143
144 pub fn clear(&mut self) {
146 self.storage.clear();
147 let _ = self.index.rebuild(); }
149
150 pub fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> {
156 let query = if self.config.normalize_on_insert {
158 query.normalize()
159 } else {
160 query.clone()
161 };
162
163 self.index.near(&query, k)
164 }
165
166 pub fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> {
168 let query = if self.config.normalize_on_insert {
169 query.normalize()
170 } else {
171 query.clone()
172 };
173
174 self.index.within(&query, threshold)
175 }
176
177 pub fn near_with_data(&self, query: &Point, k: usize) -> NearResult<Vec<(&PlacedPoint, f32)>> {
179 let results = self.near(query, k)?;
180
181 Ok(results
182 .into_iter()
183 .filter_map(|r| self.storage.get(r.id).map(|p| (p, r.score)))
184 .collect())
185 }
186
187 pub fn merge(&self, points: &[Point]) -> Point {
193 self.config.merge.merge(points)
194 }
195
196 pub fn proximity(&self, a: &Point, b: &Point) -> f32 {
198 self.config.proximity.proximity(a, b)
199 }
200
201 pub fn size_bytes(&self) -> usize {
207 self.storage.size_bytes()
208 }
209
210 pub fn index_len(&self) -> usize {
212 self.index.len()
213 }
214
215 pub fn is_ready(&self) -> bool {
217 self.index.is_ready()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 fn create_test_arms() -> Arms {
226 Arms::new(ArmsConfig::new(3))
227 }
228
229 #[test]
230 fn test_arms_place_and_get() {
231 let mut arms = create_test_arms();
232
233 let point = Point::new(vec![1.0, 0.0, 0.0]);
234 let blob = Blob::from_str("test data");
235
236 let id = arms.place(point, blob).unwrap();
237
238 let retrieved = arms.get(id).unwrap();
239 assert_eq!(retrieved.blob.as_str(), Some("test data"));
240 }
241
242 #[test]
243 fn test_arms_near() {
244 let mut arms = create_test_arms();
245
246 arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
248 arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
249 arms.place(Point::new(vec![0.0, 0.0, 1.0]), Blob::from_str("z")).unwrap();
250
251 let query = Point::new(vec![1.0, 0.0, 0.0]);
253 let results = arms.near(&query, 2).unwrap();
254
255 assert_eq!(results.len(), 2);
256 assert!(results[0].score > results[1].score);
258 }
259
260 #[test]
261 fn test_arms_near_with_data() {
262 let mut arms = create_test_arms();
263
264 arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::from_str("x")).unwrap();
265 arms.place(Point::new(vec![0.0, 1.0, 0.0]), Blob::from_str("y")).unwrap();
266
267 let query = Point::new(vec![1.0, 0.0, 0.0]);
268 let results = arms.near_with_data(&query, 1).unwrap();
269
270 assert_eq!(results.len(), 1);
271 assert_eq!(results[0].0.blob.as_str(), Some("x"));
272 }
273
274 #[test]
275 fn test_arms_remove() {
276 let mut arms = create_test_arms();
277
278 let id = arms.place(Point::new(vec![1.0, 0.0, 0.0]), Blob::empty()).unwrap();
279
280 assert!(arms.contains(id));
281 assert_eq!(arms.len(), 1);
282
283 arms.remove(id);
284
285 assert!(!arms.contains(id));
286 assert_eq!(arms.len(), 0);
287 }
288
289 #[test]
290 fn test_arms_merge() {
291 let arms = create_test_arms();
292
293 let points = vec![
294 Point::new(vec![1.0, 0.0, 0.0]),
295 Point::new(vec![0.0, 1.0, 0.0]),
296 ];
297
298 let merged = arms.merge(&points);
299
300 assert!((merged.dims()[0] - 0.5).abs() < 0.0001);
302 assert!((merged.dims()[1] - 0.5).abs() < 0.0001);
303 assert!((merged.dims()[2] - 0.0).abs() < 0.0001);
304 }
305
306 #[test]
307 fn test_arms_clear() {
308 let mut arms = create_test_arms();
309
310 for i in 0..10 {
311 arms.place(Point::new(vec![i as f32, 0.0, 0.0]), Blob::empty()).unwrap();
312 }
313
314 assert_eq!(arms.len(), 10);
315
316 arms.clear();
317
318 assert_eq!(arms.len(), 0);
319 assert!(arms.is_empty());
320 }
321
322 #[test]
323 fn test_arms_normalizes_on_insert() {
324 let mut arms = create_test_arms();
325
326 let point = Point::new(vec![3.0, 4.0, 0.0]); let id = arms.place(point, Blob::empty()).unwrap();
329
330 let retrieved = arms.get(id).unwrap();
331
332 assert!(retrieved.point.is_normalized());
334 }
335}