1use crate::error::{QueryError, Result};
8use crate::types::QueryResult;
9use manifold::column_family::ColumnFamilyDatabase;
10use manifold_graph::{GraphTable, GraphTableRead};
11use manifold_properties::{PropertyTable, PropertyTableRead, PropertyValue};
12use manifold_vectors::{VectorTable, VectorTableRead};
13use std::path::{Path, PathBuf};
14use uuid::Uuid;
15
16pub struct Database {
32 path: PathBuf,
34
35 cf_db: ColumnFamilyDatabase,
37}
38
39impl Database {
40 pub async fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
54 let path = path.as_ref().to_path_buf();
55
56 if let Some(parent) = path.parent() {
58 if !parent.exists() {
59 std::fs::create_dir_all(parent).map_err(|e| QueryError::ConnectionError {
60 message: format!("Failed to create parent directory: {}", e),
61 })?;
62 }
63 }
64
65 let cf_db = ColumnFamilyDatabase::builder().open(&path).map_err(|e| {
67 QueryError::ConnectionError {
68 message: format!("Failed to open Manifold database: {}", e),
69 }
70 })?;
71
72 Ok(Self { path, cf_db })
73 }
74
75 pub async fn close(self) -> Result<()> {
80 drop(self.cf_db);
82 Ok(())
83 }
84
85 pub fn path(&self) -> &Path {
87 &self.path
88 }
89
90 pub fn collection(&self, name: &str) -> Result<Collection> {
104 let cf =
105 self.cf_db
106 .column_family_or_create(name)
107 .map_err(|e| QueryError::ConnectionError {
108 message: format!("Failed to get column family '{}': {}", name, e),
109 })?;
110
111 Ok(Collection { cf })
112 }
113
114 pub async fn execute_hyperql(&self, _query: &str) -> Result<QueryResult> {
118 Ok(QueryResult {
120 rows: Vec::new(),
121 affected_rows: 0,
122 })
123 }
124
125 pub async fn execute_sql(&self, _query: &str) -> Result<QueryResult> {
127 Ok(QueryResult {
129 rows: Vec::new(),
130 affected_rows: 0,
131 })
132 }
133
134 pub async fn execute_cypher(&self, _query: &str) -> Result<QueryResult> {
136 Ok(QueryResult {
138 rows: Vec::new(),
139 affected_rows: 0,
140 })
141 }
142
143 pub async fn execute_custom(&self, language: &str, query: &str) -> Result<QueryResult> {
145 match language {
146 "hyperql" => self.execute_hyperql(query).await,
147 "sql" => self.execute_sql(query).await,
148 "cypher" => self.execute_cypher(query).await,
149 _ => Err(QueryError::UnsupportedLanguage {
150 language: language.to_string(),
151 }),
152 }
153 }
154}
155
156pub struct Collection {
161 cf: manifold::column_family::ColumnFamily,
162}
163
164impl Collection {
165 pub fn create_entity(&self, id: Uuid, data: serde_json::Value) -> Result<()> {
188 let write_txn = self
189 .cf
190 .begin_write()
191 .map_err(|e| QueryError::ExecutionError {
192 message: format!("Failed to begin transaction: {}", e),
193 })?;
194
195 let mut props = PropertyTable::open(&write_txn, "properties").map_err(|e| {
196 QueryError::ExecutionError {
197 message: format!("Failed to open PropertyTable: {}", e),
198 }
199 })?;
200
201 if let serde_json::Value::Object(map) = data {
203 for (key, value) in map {
204 let prop_value = match value {
205 serde_json::Value::Number(n) if n.is_i64() => {
206 PropertyValue::new_integer(n.as_i64().unwrap())
207 }
208 serde_json::Value::Number(n) if n.is_f64() => {
209 PropertyValue::new_float(n.as_f64().unwrap())
210 }
211 serde_json::Value::Bool(b) => PropertyValue::new_boolean(b),
212 serde_json::Value::String(s) => PropertyValue::new_string(s),
213 serde_json::Value::Null => PropertyValue::new_null(),
214 _ => PropertyValue::new_string(value.to_string()),
216 };
217
218 props.set(&id, key.as_str(), prop_value).map_err(|e| {
219 QueryError::ExecutionError {
220 message: format!("Failed to set property: {}", e),
221 }
222 })?;
223 }
224 }
225
226 drop(props);
227 write_txn.commit().map_err(|e| QueryError::ExecutionError {
228 message: format!("Failed to commit transaction: {}", e),
229 })?;
230
231 Ok(())
232 }
233
234 pub fn get_entity(&self, id: Uuid) -> Result<Option<serde_json::Value>> {
238 let read_txn = self
239 .cf
240 .begin_read()
241 .map_err(|e| QueryError::ExecutionError {
242 message: format!("Failed to begin read transaction: {}", e),
243 })?;
244
245 let props = PropertyTableRead::open(&read_txn, "properties").map_err(|e| {
246 QueryError::ExecutionError {
247 message: format!("Failed to open PropertyTable: {}", e),
248 }
249 })?;
250
251 let mut map = serde_json::Map::new();
253 let properties = props.get_all(&id).map_err(|e| QueryError::ExecutionError {
254 message: format!("Failed to read properties: {}", e),
255 })?;
256
257 for (key, value_guard) in properties {
258 let value = value_guard.value();
259
260 let json_value = if let Some(i) = value.as_integer() {
261 serde_json::Value::Number(i.into())
262 } else if let Some(f) = value.as_float() {
263 serde_json::Number::from_f64(f)
264 .map(serde_json::Value::Number)
265 .unwrap_or(serde_json::Value::Null)
266 } else if let Some(b) = value.as_boolean() {
267 serde_json::Value::Bool(b)
268 } else if let Some(s) = value.as_string() {
269 serde_json::from_str(s).unwrap_or_else(|_| serde_json::Value::String(s.to_string()))
271 } else {
272 serde_json::Value::Null
273 };
274
275 map.insert(key.to_string(), json_value);
276 }
277
278 if map.is_empty() {
279 Ok(None)
280 } else {
281 Ok(Some(serde_json::Value::Object(map)))
282 }
283 }
284
285 pub fn update_entity(&self, id: Uuid, data: serde_json::Value) -> Result<()> {
289 self.create_entity(id, data)
291 }
292
293 pub fn delete_entity(&self, id: Uuid) -> Result<()> {
297 let write_txn = self
298 .cf
299 .begin_write()
300 .map_err(|e| QueryError::ExecutionError {
301 message: format!("Failed to begin transaction: {}", e),
302 })?;
303
304 let mut props = PropertyTable::open(&write_txn, "properties").map_err(|e| {
305 QueryError::ExecutionError {
306 message: format!("Failed to open PropertyTable: {}", e),
307 }
308 })?;
309
310 let read_txn = self
312 .cf
313 .begin_read()
314 .map_err(|e| QueryError::ExecutionError {
315 message: format!("Failed to begin read transaction: {}", e),
316 })?;
317
318 let read_props = PropertyTableRead::open(&read_txn, "properties").map_err(|e| {
319 QueryError::ExecutionError {
320 message: format!("Failed to open PropertyTable: {}", e),
321 }
322 })?;
323
324 let keys_to_delete: Vec<String> = read_props
326 .get_all(&id)
327 .map_err(|e| QueryError::ExecutionError {
328 message: format!("Failed to read properties: {}", e),
329 })?
330 .iter()
331 .map(|(key, _)| key.to_string())
332 .collect();
333
334 drop(read_props);
335 drop(read_txn);
336
337 let keys_refs: Vec<(Uuid, &str)> = keys_to_delete
339 .iter()
340 .map(|key| (id, key.as_str()))
341 .collect();
342
343 props
345 .remove_bulk(&keys_refs)
346 .map_err(|e| QueryError::ExecutionError {
347 message: format!("Failed to delete properties: {}", e),
348 })?;
349
350 drop(props);
351 write_txn.commit().map_err(|e| QueryError::ExecutionError {
352 message: format!("Failed to commit transaction: {}", e),
353 })?;
354
355 Ok(())
356 }
357
358 pub fn add_edge(&self, source: Uuid, edge_type: &str, target: Uuid) -> Result<()> {
362 let write_txn = self
363 .cf
364 .begin_write()
365 .map_err(|e| QueryError::ExecutionError {
366 message: format!("Failed to begin transaction: {}", e),
367 })?;
368
369 let mut graph =
370 GraphTable::open(&write_txn, "edges").map_err(|e| QueryError::ExecutionError {
371 message: format!("Failed to open GraphTable: {}", e),
372 })?;
373
374 graph
375 .add_edge(&source, edge_type, &target, true, 1.0, None)
376 .map_err(|e| QueryError::ExecutionError {
377 message: format!("Failed to add edge: {}", e),
378 })?;
379
380 drop(graph);
381 write_txn.commit().map_err(|e| QueryError::ExecutionError {
382 message: format!("Failed to commit transaction: {}", e),
383 })?;
384
385 Ok(())
386 }
387
388 pub fn get_outgoing_edges(&self, source: Uuid, edge_type: &str) -> Result<Vec<Uuid>> {
392 let read_txn = self
393 .cf
394 .begin_read()
395 .map_err(|e| QueryError::ExecutionError {
396 message: format!("Failed to begin read transaction: {}", e),
397 })?;
398
399 let graph =
400 GraphTableRead::open(&read_txn, "edges").map_err(|e| QueryError::ExecutionError {
401 message: format!("Failed to open GraphTable: {}", e),
402 })?;
403
404 let mut targets = Vec::new();
405 let edges = graph
406 .outgoing_edges(&source)
407 .map_err(|e| QueryError::ExecutionError {
408 message: format!("Failed to read outgoing edges: {}", e),
409 })?;
410
411 for edge_result in edges {
412 let edge = edge_result.map_err(|e| QueryError::ExecutionError {
413 message: format!("Failed to read edge: {}", e),
414 })?;
415
416 if edge.edge_type == edge_type && edge.is_active {
417 targets.push(edge.target);
418 }
419 }
420
421 Ok(targets)
422 }
423
424 pub fn get_incoming_edges(&self, target: Uuid, edge_type: &str) -> Result<Vec<Uuid>> {
428 let read_txn = self
429 .cf
430 .begin_read()
431 .map_err(|e| QueryError::ExecutionError {
432 message: format!("Failed to begin read transaction: {}", e),
433 })?;
434
435 let graph =
436 GraphTableRead::open(&read_txn, "edges").map_err(|e| QueryError::ExecutionError {
437 message: format!("Failed to open GraphTable: {}", e),
438 })?;
439
440 let mut sources = Vec::new();
441 let edges = graph
442 .incoming_edges(&target)
443 .map_err(|e| QueryError::ExecutionError {
444 message: format!("Failed to read incoming edges: {}", e),
445 })?;
446
447 for edge_result in edges {
448 let edge = edge_result.map_err(|e| QueryError::ExecutionError {
449 message: format!("Failed to read edge: {}", e),
450 })?;
451
452 if edge.edge_type == edge_type && edge.is_active {
453 sources.push(edge.source);
454 }
455 }
456
457 Ok(sources)
458 }
459
460 pub fn list_all_ids(&self) -> Result<Vec<Uuid>> {
476 let read_txn = self
477 .cf
478 .begin_read()
479 .map_err(|e| QueryError::ExecutionError {
480 message: format!("Failed to begin read transaction: {}", e),
481 })?;
482
483 let props = PropertyTableRead::open(&read_txn, "properties").map_err(|e| {
484 QueryError::ExecutionError {
485 message: format!("Failed to open PropertyTable: {}", e),
486 }
487 })?;
488
489 let mut entity_ids = std::collections::HashSet::new();
490
491 for result in props.iter().map_err(|e| QueryError::ExecutionError {
493 message: format!("Failed to iterate properties: {}", e),
494 })? {
495 let ((entity_id, _key), _value) = result.map_err(|e| QueryError::ExecutionError {
496 message: format!("Failed to read property: {}", e),
497 })?;
498 entity_ids.insert(entity_id);
499 }
500
501 Ok(entity_ids.into_iter().collect())
502 }
503
504 pub fn list_all_entities(&self) -> Result<Vec<(Uuid, serde_json::Value)>> {
522 let ids = self.list_all_ids()?;
523 let mut entities = Vec::new();
524
525 for id in ids {
526 if let Some(data) = self.get_entity(id)? {
527 entities.push((id, data));
528 }
529 }
530
531 Ok(entities)
532 }
533
534 pub fn vectors<const DIM: usize>(&self, table_name: &str) -> Result<VectorTableWrapper<DIM>> {
555 Ok(VectorTableWrapper {
556 cf: self.cf.clone(),
557 table_name: table_name.to_string(),
558 })
559 }
560}
561
562pub struct VectorTableWrapper<const DIM: usize> {
566 cf: manifold::column_family::ColumnFamily,
567 table_name: String,
568}
569
570impl<const DIM: usize> VectorTableWrapper<DIM> {
571 pub fn insert(&self, id: &Uuid, vector: &[f32]) -> Result<()> {
573 if vector.len() != DIM {
574 return Err(QueryError::ExecutionError {
575 message: format!(
576 "Vector dimension mismatch: expected {}, got {}",
577 DIM,
578 vector.len()
579 ),
580 });
581 }
582
583 let write_txn = self
584 .cf
585 .begin_write()
586 .map_err(|e| QueryError::ExecutionError {
587 message: format!("Failed to begin transaction: {}", e),
588 })?;
589
590 let mut vectors = VectorTable::<DIM>::open(&write_txn, &self.table_name).map_err(|e| {
591 QueryError::ExecutionError {
592 message: format!("Failed to open VectorTable: {}", e),
593 }
594 })?;
595
596 let mut arr = [0.0f32; DIM];
598 arr.copy_from_slice(vector);
599
600 vectors
601 .insert(id, &arr)
602 .map_err(|e| QueryError::ExecutionError {
603 message: format!("Failed to insert vector: {}", e),
604 })?;
605
606 drop(vectors);
607 write_txn.commit().map_err(|e| QueryError::ExecutionError {
608 message: format!("Failed to commit transaction: {}", e),
609 })?;
610
611 Ok(())
612 }
613
614 pub fn get(&self, id: &Uuid) -> Result<Option<Vec<f32>>> {
616 let read_txn = self
617 .cf
618 .begin_read()
619 .map_err(|e| QueryError::ExecutionError {
620 message: format!("Failed to begin read transaction: {}", e),
621 })?;
622
623 let vectors = VectorTableRead::<DIM>::open(&read_txn, &self.table_name).map_err(|e| {
624 QueryError::ExecutionError {
625 message: format!("Failed to open VectorTable: {}", e),
626 }
627 })?;
628
629 let result = vectors.get(id).map_err(|e| QueryError::ExecutionError {
630 message: format!("Failed to read vector: {}", e),
631 })?;
632
633 Ok(result.map(|guard| guard.value().to_vec()))
634 }
635
636 pub fn search_similar(&self, query: &[f32], limit: usize) -> Result<Vec<Uuid>> {
650 if query.len() != DIM {
651 return Err(QueryError::ExecutionError {
652 message: format!(
653 "Query vector dimension mismatch: expected {}, got {}",
654 DIM,
655 query.len()
656 ),
657 });
658 }
659
660 let read_txn = self
661 .cf
662 .begin_read()
663 .map_err(|e| QueryError::ExecutionError {
664 message: format!("Failed to begin read transaction: {}", e),
665 })?;
666
667 let vectors = VectorTableRead::<DIM>::open(&read_txn, &self.table_name).map_err(|e| {
668 QueryError::ExecutionError {
669 message: format!("Failed to open VectorTable: {}", e),
670 }
671 })?;
672
673 let mut similarities: Vec<(Uuid, f32)> = Vec::new();
675
676 let iter = vectors
679 .all_vectors()
680 .map_err(|e| QueryError::ExecutionError {
681 message: format!("Failed to iterate vectors: {}", e),
682 })?;
683
684 for result in iter {
685 let (id, vector_guard) = result.map_err(|e| QueryError::ExecutionError {
686 message: format!("Failed to read vector entry: {}", e),
687 })?;
688
689 let vector = vector_guard.value();
690 let similarity = cosine_similarity(query, vector);
691 similarities.push((id, similarity));
692 }
693
694 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
696
697 let results: Vec<Uuid> = similarities
699 .into_iter()
700 .take(limit)
701 .map(|(id, _)| id)
702 .collect();
703
704 Ok(results)
705 }
706}
707
708fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
712 debug_assert_eq!(a.len(), b.len(), "Vectors must have same length");
713
714 let mut dot_product = 0.0;
715 let mut norm_a = 0.0;
716 let mut norm_b = 0.0;
717
718 for i in 0..a.len() {
719 dot_product += a[i] * b[i];
720 norm_a += a[i] * a[i];
721 norm_b += b[i] * b[i];
722 }
723
724 let magnitude = (norm_a * norm_b).sqrt();
725 if magnitude == 0.0 {
726 0.0
727 } else {
728 dot_product / magnitude
729 }
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735 use serde_json::json;
736
737 #[tokio::test]
738 async fn test_database_open() {
739 let temp_dir = std::env::temp_dir().join("audb_test_manifold_open.manifold");
740 let wal_path = temp_dir.with_extension("wal");
741 let _ = std::fs::remove_file(&temp_dir);
742 let _ = std::fs::remove_file(&wal_path);
743
744 let db = Database::open(&temp_dir).await;
745 assert!(db.is_ok());
746
747 let db = db.unwrap();
748 assert_eq!(db.path(), temp_dir.as_path());
749
750 let _ = std::fs::remove_file(&temp_dir);
752 let _ = std::fs::remove_file(&wal_path);
753 }
754
755 #[tokio::test]
756 async fn test_create_and_get_entity() {
757 let temp_dir = std::env::temp_dir().join("audb_test_manifold_crud.manifold");
758 let wal_path = temp_dir.with_extension("wal");
759 let _ = std::fs::remove_file(&temp_dir);
760 let _ = std::fs::remove_file(&wal_path);
761
762 let db = Database::open(&temp_dir).await.unwrap();
763 let collection = db.collection("users").unwrap();
764
765 let id = Uuid::new_v4();
766 let data = json!({
767 "name": "Alice",
768 "age": 42,
769 "active": true
770 });
771
772 collection.create_entity(id, data.clone()).unwrap();
774
775 let retrieved = collection.get_entity(id).unwrap();
777 assert!(retrieved.is_some());
778
779 let retrieved = retrieved.unwrap();
780 assert_eq!(retrieved["name"], "Alice");
781 assert_eq!(retrieved["age"], 42);
782 assert_eq!(retrieved["active"], true);
783
784 let _ = std::fs::remove_file(&temp_dir);
786 let _ = std::fs::remove_file(&wal_path);
787 }
788
789 #[tokio::test]
790 async fn test_update_entity() {
791 let temp_dir = std::env::temp_dir().join("audb_test_manifold_update.manifold");
792 let wal_path = temp_dir.with_extension("wal");
793 let _ = std::fs::remove_file(&temp_dir);
794 let _ = std::fs::remove_file(&wal_path);
795
796 let db = Database::open(&temp_dir).await.unwrap();
797 let collection = db.collection("users").unwrap();
798
799 let id = Uuid::new_v4();
800 let data = json!({ "name": "Alice", "age": 42 });
801
802 collection.create_entity(id, data).unwrap();
803
804 let updated = json!({ "name": "Alice", "age": 43 });
806 collection.update_entity(id, updated).unwrap();
807
808 let retrieved = collection.get_entity(id).unwrap().unwrap();
810 assert_eq!(retrieved["age"], 43);
811
812 let _ = std::fs::remove_file(&temp_dir);
814 let _ = std::fs::remove_file(&wal_path);
815 }
816
817 #[tokio::test]
818 async fn test_delete_entity() {
819 let temp_dir = std::env::temp_dir().join("audb_test_manifold_delete.manifold");
820 let wal_path = temp_dir.with_extension("wal");
821 let _ = std::fs::remove_file(&temp_dir);
822 let _ = std::fs::remove_file(&wal_path);
823
824 let db = Database::open(&temp_dir).await.unwrap();
825 let collection = db.collection("users").unwrap();
826
827 let id = Uuid::new_v4();
828 let data = json!({ "name": "Bob" });
829
830 collection.create_entity(id, data).unwrap();
831 assert!(collection.get_entity(id).unwrap().is_some());
832
833 collection.delete_entity(id).unwrap();
834 assert!(collection.get_entity(id).unwrap().is_none());
835
836 let _ = std::fs::remove_file(&temp_dir);
838 let _ = std::fs::remove_file(&wal_path);
839 }
840
841 #[tokio::test]
842 async fn test_edges() {
843 let temp_dir = std::env::temp_dir().join("audb_test_manifold_edges.manifold");
844 let wal_path = temp_dir.with_extension("wal");
845 let _ = std::fs::remove_file(&temp_dir);
846 let _ = std::fs::remove_file(&wal_path);
847
848 let db = Database::open(&temp_dir).await.unwrap();
849 let collection = db.collection("test").unwrap();
850
851 let user = Uuid::new_v4();
852 let post1 = Uuid::new_v4();
853 let post2 = Uuid::new_v4();
854
855 collection.add_edge(post1, "authored_by", user).unwrap();
857 collection.add_edge(post2, "authored_by", user).unwrap();
858
859 let authors1 = collection.get_outgoing_edges(post1, "authored_by").unwrap();
861 assert_eq!(authors1, vec![user]);
862
863 let authored_posts = collection.get_incoming_edges(user, "authored_by").unwrap();
865 assert_eq!(authored_posts.len(), 2);
866 assert!(authored_posts.contains(&post1));
867 assert!(authored_posts.contains(&post2));
868
869 let _ = std::fs::remove_file(&temp_dir);
871 let _ = std::fs::remove_file(&wal_path);
872 }
873
874 #[tokio::test]
875 async fn test_vectors() {
876 let temp_dir = std::env::temp_dir().join("audb_test_manifold_vectors.manifold");
877 let wal_path = temp_dir.with_extension("wal");
878 let _ = std::fs::remove_file(&temp_dir);
879 let _ = std::fs::remove_file(&wal_path);
880
881 let db = Database::open(&temp_dir).await.unwrap();
882 let collection = db.collection("documents").unwrap();
883
884 let id = Uuid::new_v4();
885 let embedding = vec![0.1f32, 0.2, 0.3, 0.4]; let vectors = collection.vectors::<4>("embeddings").unwrap();
889
890 vectors.insert(&id, &embedding).unwrap();
892
893 let retrieved = vectors.get(&id).unwrap();
895 assert!(retrieved.is_some());
896 let retrieved_vec = retrieved.unwrap();
897 assert_eq!(retrieved_vec.len(), 4);
898 assert_eq!(retrieved_vec, embedding);
899
900 let wrong_dim = vec![0.1f32; 8];
902 let result = vectors.insert(&id, &wrong_dim);
903 assert!(result.is_err());
904
905 let _ = std::fs::remove_file(&temp_dir);
907 let _ = std::fs::remove_file(&wal_path);
908 }
909
910 #[tokio::test]
911 async fn test_vector_similarity_search() {
912 let temp_dir = std::env::temp_dir().join("audb_test_manifold_similarity.manifold");
913 let wal_path = temp_dir.with_extension("wal");
914 let _ = std::fs::remove_file(&temp_dir);
915 let _ = std::fs::remove_file(&wal_path);
916
917 let db = Database::open(&temp_dir).await.unwrap();
918 let collection = db.collection("documents").unwrap();
919
920 let vectors = collection.vectors::<4>("embeddings").unwrap();
921
922 let id1 = Uuid::new_v4();
924 let id2 = Uuid::new_v4();
925 let id3 = Uuid::new_v4();
926
927 let vec1 = vec![1.0f32, 0.0, 0.0, 0.0];
929 let vec2 = vec![0.7f32, 0.7, 0.0, 0.0];
931 let vec3 = vec![0.0f32, 0.0, 1.0, 0.0];
933
934 vectors.insert(&id1, &vec1).unwrap();
935 vectors.insert(&id2, &vec2).unwrap();
936 vectors.insert(&id3, &vec3).unwrap();
937
938 let query = vec![0.9f32, 0.1, 0.0, 0.0];
940 let results = vectors.search_similar(&query, 2).unwrap();
941
942 assert_eq!(results.len(), 2);
944 assert_eq!(results[0], id1);
945
946 let _ = std::fs::remove_file(&temp_dir);
948 let _ = std::fs::remove_file(&wal_path);
949 }
950
951 #[test]
952 fn test_cosine_similarity() {
953 let a = vec![1.0f32, 0.0, 0.0];
955 let b = vec![1.0f32, 0.0, 0.0];
956 let sim = cosine_similarity(&a, &b);
957 assert!((sim - 1.0).abs() < 0.001);
958
959 let a = vec![1.0f32, 0.0, 0.0];
961 let b = vec![0.0f32, 1.0, 0.0];
962 let sim = cosine_similarity(&a, &b);
963 assert!(sim.abs() < 0.001);
964
965 let a = vec![1.0f32, 0.0, 0.0];
967 let b = vec![-1.0f32, 0.0, 0.0];
968 let sim = cosine_similarity(&a, &b);
969 assert!((sim + 1.0).abs() < 0.001);
970
971 let a = vec![1.0f32, 0.5, 0.0];
973 let b = vec![0.9f32, 0.4, 0.0];
974 let sim = cosine_similarity(&a, &b);
975 assert!(sim > 0.99); }
977}