1use crate::{EmbeddingModel, Vector};
41use anyhow::{anyhow, Result};
42use serde::{Deserialize, Serialize};
43use std::collections::{HashMap, HashSet};
44use std::sync::Arc;
45use tokio::sync::RwLock;
46use tracing::{debug, info, trace};
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct SparqlExtensionConfig {
51 pub default_similarity_threshold: f32,
53 pub max_expansions_per_element: usize,
55 pub enable_query_rewriting: bool,
57 pub enable_semantic_caching: bool,
59 pub semantic_cache_size: usize,
61 pub enable_fuzzy_matching: bool,
63 pub min_expansion_confidence: f32,
65 pub enable_parallel_processing: bool,
67}
68
69impl Default for SparqlExtensionConfig {
70 fn default() -> Self {
71 Self {
72 default_similarity_threshold: 0.7,
73 max_expansions_per_element: 10,
74 enable_query_rewriting: true,
75 enable_semantic_caching: true,
76 semantic_cache_size: 1000,
77 enable_fuzzy_matching: true,
78 min_expansion_confidence: 0.6,
79 enable_parallel_processing: true,
80 }
81 }
82}
83
84pub struct SparqlExtension {
86 model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
87 config: SparqlExtensionConfig,
88 semantic_cache: Arc<RwLock<SemanticCache>>,
89 query_statistics: Arc<RwLock<QueryStatistics>>,
90}
91
92impl SparqlExtension {
93 pub fn new(model: Box<dyn EmbeddingModel>) -> Self {
95 Self {
96 model: Arc::new(RwLock::new(model)),
97 config: SparqlExtensionConfig::default(),
98 semantic_cache: Arc::new(RwLock::new(SemanticCache::new(1000))),
99 query_statistics: Arc::new(RwLock::new(QueryStatistics::default())),
100 }
101 }
102
103 pub fn with_config(model: Box<dyn EmbeddingModel>, config: SparqlExtensionConfig) -> Self {
105 let cache_size = config.semantic_cache_size;
106 Self {
107 model: Arc::new(RwLock::new(model)),
108 config,
109 semantic_cache: Arc::new(RwLock::new(SemanticCache::new(cache_size))),
110 query_statistics: Arc::new(RwLock::new(QueryStatistics::default())),
111 }
112 }
113
114 pub async fn vec_similarity(&self, entity1: &str, entity2: &str) -> Result<f32> {
123 trace!("Computing similarity between {} and {}", entity1, entity2);
124
125 if self.config.enable_semantic_caching {
127 let cache = self.semantic_cache.read().await;
128 let cache_key = format!("sim:{}:{}", entity1, entity2);
129 if let Some(cached_result) = cache.get(&cache_key) {
130 debug!("Cache hit for similarity computation");
131 return Ok(cached_result);
132 }
133 }
134
135 let model = self.model.read().await;
136 let emb1 = model.get_entity_embedding(entity1)?;
137 let emb2 = model.get_entity_embedding(entity2)?;
138
139 let similarity = normalized_cosine_similarity(&emb1, &emb2)?;
140
141 if self.config.enable_semantic_caching {
143 let mut cache = self.semantic_cache.write().await;
144 let cache_key = format!("sim:{}:{}", entity1, entity2);
145 cache.put(cache_key, similarity);
146 }
147
148 let mut stats = self.query_statistics.write().await;
150 stats.similarity_computations += 1;
151
152 Ok(similarity)
153 }
154
155 pub async fn vec_nearest(
165 &self,
166 entity: &str,
167 k: usize,
168 min_similarity: Option<f32>,
169 ) -> Result<Vec<(String, f32)>> {
170 info!("Finding {} nearest neighbors for {}", k, entity);
171
172 let model = self.model.read().await;
173 let target_emb = model.get_entity_embedding(entity)?;
174 let all_entities = model.get_entities();
175
176 let threshold = min_similarity.unwrap_or(self.config.default_similarity_threshold);
177
178 let mut similarities: Vec<(String, f32)> = if self.config.enable_parallel_processing {
180 self.compute_similarities_parallel(&all_entities, &target_emb, entity)
181 .await?
182 } else {
183 self.compute_similarities_sequential(&all_entities, &target_emb, entity, &**model)
184 .await?
185 };
186
187 similarities.retain(|(_, sim)| *sim >= threshold);
189 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
190
191 let result: Vec<(String, f32)> = similarities.into_iter().take(k).collect();
193
194 let mut stats = self.query_statistics.write().await;
196 stats.nearest_neighbor_queries += 1;
197
198 Ok(result)
199 }
200
201 pub async fn vec_similar_entities(
210 &self,
211 entity: &str,
212 threshold: f32,
213 ) -> Result<Vec<(String, f32)>> {
214 debug!(
215 "Finding entities similar to {} (threshold: {})",
216 entity, threshold
217 );
218
219 let model = self.model.read().await;
220 let target_emb = model.get_entity_embedding(entity)?;
221 let all_entities = model.get_entities();
222
223 let similarities = if self.config.enable_parallel_processing {
224 self.compute_similarities_parallel(&all_entities, &target_emb, entity)
225 .await?
226 } else {
227 self.compute_similarities_sequential(&all_entities, &target_emb, entity, &**model)
228 .await?
229 };
230
231 let result: Vec<(String, f32)> = similarities
232 .into_iter()
233 .filter(|(_, sim)| *sim >= threshold)
234 .collect();
235
236 Ok(result)
237 }
238
239 pub async fn vec_similar_relations(
248 &self,
249 relation: &str,
250 threshold: f32,
251 ) -> Result<Vec<(String, f32)>> {
252 debug!(
253 "Finding relations similar to {} (threshold: {})",
254 relation, threshold
255 );
256
257 let model = self.model.read().await;
258 let target_emb = model.get_relation_embedding(relation)?;
259 let all_relations = model.get_relations();
260
261 let mut similarities = Vec::new();
262 for rel in &all_relations {
263 if rel == relation {
264 continue; }
266
267 let rel_emb = model.get_relation_embedding(rel)?;
268 let sim = cosine_similarity(&target_emb, &rel_emb)?;
269
270 if sim >= threshold {
271 similarities.push((rel.clone(), sim));
272 }
273 }
274
275 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
276
277 Ok(similarities)
278 }
279
280 pub async fn expand_query_semantically(&self, query: &str) -> Result<ExpandedQuery> {
288 info!("Performing semantic query expansion");
289
290 let mut stats = self.query_statistics.write().await;
291 stats.query_expansions += 1;
292 drop(stats);
293
294 let model = self.model.read().await;
295
296 let parsed = parse_sparql_query(query)?;
298
299 let mut entity_expansions = HashMap::new();
300 let mut relation_expansions = HashMap::new();
301
302 for entity in &parsed.entities {
304 let similar = self
305 .vec_similar_entities(entity, self.config.min_expansion_confidence)
306 .await?;
307
308 let expansions: Vec<Expansion> = similar
309 .into_iter()
310 .take(self.config.max_expansions_per_element)
311 .map(|(uri, confidence)| Expansion {
312 original: entity.clone(),
313 expanded: uri,
314 confidence,
315 expansion_type: ExpansionType::Entity,
316 })
317 .collect();
318
319 if !expansions.is_empty() {
320 entity_expansions.insert(entity.clone(), expansions);
321 }
322 }
323
324 for relation in &parsed.relations {
326 let similar = self
327 .vec_similar_relations(relation, self.config.min_expansion_confidence)
328 .await?;
329
330 let expansions: Vec<Expansion> = similar
331 .into_iter()
332 .take(self.config.max_expansions_per_element)
333 .map(|(uri, confidence)| Expansion {
334 original: relation.clone(),
335 expanded: uri,
336 confidence,
337 expansion_type: ExpansionType::Relation,
338 })
339 .collect();
340
341 if !expansions.is_empty() {
342 relation_expansions.insert(relation.clone(), expansions);
343 }
344 }
345
346 drop(model);
347
348 let expanded_query = if self.config.enable_query_rewriting {
349 self.rewrite_query_with_expansions(query, &entity_expansions, &relation_expansions)
350 .await?
351 } else {
352 query.to_string()
353 };
354
355 let expansion_count = entity_expansions.len() + relation_expansions.len();
356
357 Ok(ExpandedQuery {
358 original_query: query.to_string(),
359 expanded_query,
360 entity_expansions,
361 relation_expansions,
362 expansion_count,
363 })
364 }
365
366 pub async fn fuzzy_match_entity(
375 &self,
376 entity_name: &str,
377 k: usize,
378 ) -> Result<Vec<(String, f32)>> {
379 if !self.config.enable_fuzzy_matching {
380 return Ok(vec![]);
381 }
382
383 debug!("Performing fuzzy match for entity: {}", entity_name);
384
385 let model = self.model.read().await;
386 let all_entities = model.get_entities();
387
388 let mut matches = Vec::new();
389
390 for entity in &all_entities {
391 let score = fuzzy_match_score(entity_name, entity);
392 if score > 0.5 {
393 matches.push((entity.clone(), score));
395 }
396 }
397
398 matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
399
400 Ok(matches.into_iter().take(k).collect())
401 }
402
403 pub async fn get_statistics(&self) -> QueryStatistics {
405 self.query_statistics.read().await.clone()
406 }
407
408 pub async fn clear_cache(&self) {
410 let mut cache = self.semantic_cache.write().await;
411 cache.clear();
412 info!("Semantic cache cleared");
413 }
414
415 async fn compute_similarities_parallel(
418 &self,
419 entities: &[String],
420 target_emb: &Vector,
421 exclude_entity: &str,
422 ) -> Result<Vec<(String, f32)>> {
423 use rayon::prelude::*;
424
425 let model = self.model.read().await;
426 let embeddings: Vec<_> = entities
427 .iter()
428 .filter(|e| e.as_str() != exclude_entity)
429 .filter_map(|e| {
430 model
431 .get_entity_embedding(e)
432 .ok()
433 .map(|emb| (e.clone(), emb))
434 })
435 .collect();
436 drop(model);
437
438 let target_emb_clone = target_emb.clone();
439 let similarities: Vec<(String, f32)> = embeddings
440 .par_iter()
441 .filter_map(|(entity, emb)| {
442 cosine_similarity(&target_emb_clone, emb)
443 .ok()
444 .map(|sim| (entity.clone(), sim))
445 })
446 .collect();
447
448 Ok(similarities)
449 }
450
451 async fn compute_similarities_sequential(
452 &self,
453 entities: &[String],
454 target_emb: &Vector,
455 exclude_entity: &str,
456 model: &dyn EmbeddingModel,
457 ) -> Result<Vec<(String, f32)>> {
458 let mut similarities = Vec::new();
459
460 for entity in entities {
461 if entity == exclude_entity {
462 continue;
463 }
464
465 if let Ok(entity_emb) = model.get_entity_embedding(entity) {
466 if let Ok(sim) = cosine_similarity(target_emb, &entity_emb) {
467 similarities.push((entity.clone(), sim));
468 }
469 }
470 }
471
472 Ok(similarities)
473 }
474
475 async fn rewrite_query_with_expansions(
476 &self,
477 original_query: &str,
478 entity_expansions: &HashMap<String, Vec<Expansion>>,
479 relation_expansions: &HashMap<String, Vec<Expansion>>,
480 ) -> Result<String> {
481 let mut rewritten = original_query.to_string();
484
485 for (original, expansions) in entity_expansions {
487 if let Some(first_expansion) = expansions.first() {
488 let union_clause = format!(
489 "\n UNION {{ # Semantic expansion for {}\n # Similar entity: {} (confidence: {:.2})\n }}",
490 original, first_expansion.expanded, first_expansion.confidence
491 );
492 rewritten.push_str(&union_clause);
493 }
494 }
495
496 for (original, expansions) in relation_expansions {
498 if let Some(first_expansion) = expansions.first() {
499 let comment = format!(
500 "\n # Relation '{}' can be expanded to '{}' (confidence: {:.2})",
501 original, first_expansion.expanded, first_expansion.confidence
502 );
503 rewritten.push_str(&comment);
504 }
505 }
506
507 Ok(rewritten)
508 }
509}
510
511struct SemanticCache {
513 cache: HashMap<String, f32>,
514 max_size: usize,
515 access_count: HashMap<String, u64>,
516}
517
518impl SemanticCache {
519 fn new(max_size: usize) -> Self {
520 Self {
521 cache: HashMap::new(),
522 max_size,
523 access_count: HashMap::new(),
524 }
525 }
526
527 fn get(&self, key: &str) -> Option<f32> {
528 self.cache.get(key).copied()
529 }
530
531 fn put(&mut self, key: String, value: f32) {
532 if self.cache.len() >= self.max_size {
534 if let Some(lru_key) = self
535 .access_count
536 .iter()
537 .min_by_key(|(_, &count)| count)
538 .map(|(k, _)| k.clone())
539 {
540 self.cache.remove(&lru_key);
541 self.access_count.remove(&lru_key);
542 }
543 }
544
545 self.cache.insert(key.clone(), value);
546 *self.access_count.entry(key).or_insert(0) += 1;
547 }
548
549 fn clear(&mut self) {
550 self.cache.clear();
551 self.access_count.clear();
552 }
553}
554
555#[derive(Debug, Clone, Default, Serialize, Deserialize)]
557pub struct QueryStatistics {
558 pub similarity_computations: u64,
559 pub nearest_neighbor_queries: u64,
560 pub query_expansions: u64,
561 pub fuzzy_matches: u64,
562 pub cache_hits: u64,
563 pub cache_misses: u64,
564}
565
566#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct ExpandedQuery {
569 pub original_query: String,
570 pub expanded_query: String,
571 pub entity_expansions: HashMap<String, Vec<Expansion>>,
572 pub relation_expansions: HashMap<String, Vec<Expansion>>,
573 pub expansion_count: usize,
574}
575
576#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct Expansion {
579 pub original: String,
580 pub expanded: String,
581 pub confidence: f32,
582 pub expansion_type: ExpansionType,
583}
584
585#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
587pub enum ExpansionType {
588 Entity,
589 Relation,
590 Pattern,
591}
592
593#[derive(Debug, Clone)]
595struct ParsedQuery {
596 entities: Vec<String>,
597 relations: Vec<String>,
598 variables: HashSet<String>,
599}
600
601fn parse_sparql_query(query: &str) -> Result<ParsedQuery> {
603 let mut entities = Vec::new();
604 let mut relations = Vec::new();
605 let mut variables = HashSet::new();
606
607 let uri_pattern =
609 regex::Regex::new(r"<(https?://[^>]+)>").expect("regex should compile for valid pattern");
610 let var_pattern =
611 regex::Regex::new(r"\?(\w+)").expect("regex should compile for valid pattern");
612
613 for line in query.lines() {
614 if line.contains("http://") || line.contains("https://") {
616 for cap in uri_pattern.captures_iter(line) {
618 let uri = cap[1].to_string();
619 if line.contains(&format!(" <{uri}> ")) {
621 relations.push(uri.clone());
622 } else {
623 entities.push(uri);
624 }
625 }
626 }
627
628 for cap in var_pattern.captures_iter(line) {
630 variables.insert(cap[1].to_string());
631 }
632 }
633
634 Ok(ParsedQuery {
635 entities,
636 relations,
637 variables,
638 })
639}
640
641fn cosine_similarity(v1: &Vector, v2: &Vector) -> Result<f32> {
644 if v1.dimensions != v2.dimensions {
645 return Err(anyhow!(
646 "Vector dimensions must match: {} vs {}",
647 v1.dimensions,
648 v2.dimensions
649 ));
650 }
651
652 let dot_product: f32 = v1
653 .values
654 .iter()
655 .zip(v2.values.iter())
656 .map(|(a, b)| a * b)
657 .sum();
658
659 let norm1: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
660 let norm2: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
661
662 if norm1 == 0.0 || norm2 == 0.0 {
663 return Ok(0.0);
664 }
665
666 let cosine_sim = dot_product / (norm1 * norm2);
668
669 Ok(cosine_sim)
670}
671
672fn normalized_cosine_similarity(v1: &Vector, v2: &Vector) -> Result<f32> {
676 let cosine_sim = cosine_similarity(v1, v2)?;
677 Ok((cosine_sim + 1.0) / 2.0)
679}
680
681fn fuzzy_match_score(s1: &str, s2: &str) -> f32 {
683 let s1_lower = s1.to_lowercase();
684 let s2_lower = s2.to_lowercase();
685
686 if s1_lower == s2_lower {
688 return 1.0;
689 }
690
691 if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
693 let max_len = s1.len().max(s2.len()) as f32;
694 let min_len = s1.len().min(s2.len()) as f32;
695 return min_len / max_len;
696 }
697
698 let distance = levenshtein_distance(&s1_lower, &s2_lower);
700 let max_len = s1.len().max(s2.len()) as f32;
701
702 if max_len == 0.0 {
703 return 1.0;
704 }
705
706 1.0 - (distance as f32 / max_len)
707}
708
709#[allow(clippy::needless_range_loop)]
711fn levenshtein_distance(s1: &str, s2: &str) -> usize {
712 let len1 = s1.len();
713 let len2 = s2.len();
714
715 if len1 == 0 {
716 return len2;
717 }
718 if len2 == 0 {
719 return len1;
720 }
721
722 let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
723
724 for i in 0..=len1 {
725 matrix[i][0] = i;
726 }
727 for j in 0..=len2 {
728 matrix[0][j] = j;
729 }
730
731 let s1_chars: Vec<char> = s1.chars().collect();
732 let s2_chars: Vec<char> = s2.chars().collect();
733
734 for i in 1..=len1 {
735 for j in 1..=len2 {
736 let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
737 0
738 } else {
739 1
740 };
741
742 matrix[i][j] = (matrix[i - 1][j] + 1)
743 .min(matrix[i][j - 1] + 1)
744 .min(matrix[i - 1][j - 1] + cost);
745 }
746 }
747
748 matrix[len1][len2]
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754 use crate::models::TransE;
755 use crate::{ModelConfig, NamedNode, Triple};
756
757 fn create_test_model() -> TransE {
758 let config = ModelConfig::default().with_dimensions(10);
759 let mut model = TransE::new(config);
760
761 let triples = vec![
763 ("alice", "knows", "bob"),
764 ("bob", "knows", "charlie"),
765 ("alice", "likes", "music"),
766 ("charlie", "likes", "art"),
767 ];
768
769 for (s, p, o) in triples {
770 let triple = Triple::new(
771 NamedNode::new(&format!("http://example.org/{s}")).unwrap(),
772 NamedNode::new(&format!("http://example.org/{p}")).unwrap(),
773 NamedNode::new(&format!("http://example.org/{o}")).unwrap(),
774 );
775 model.add_triple(triple).unwrap();
776 }
777
778 model
779 }
780
781 #[tokio::test]
782 async fn test_vec_similarity() -> Result<()> {
783 let model = create_test_model();
784 let ext = SparqlExtension::new(Box::new(model));
785
786 {
788 let mut model = ext.model.write().await;
789 model.train(Some(10)).await?;
790 }
791
792 let sim = ext
793 .vec_similarity("http://example.org/alice", "http://example.org/bob")
794 .await?;
795
796 assert!((0.0..=1.0).contains(&sim));
797 Ok(())
798 }
799
800 #[tokio::test]
801 async fn test_vec_nearest() -> Result<()> {
802 let model = create_test_model();
803 let ext = SparqlExtension::new(Box::new(model));
804
805 {
806 let mut model = ext.model.write().await;
807 model.train(Some(10)).await?;
808 }
809
810 let neighbors = ext
812 .vec_nearest("http://example.org/alice", 2, Some(0.0))
813 .await?;
814
815 assert!(neighbors.len() <= 2);
818
819 for (entity, sim) in neighbors {
820 assert!(!entity.is_empty());
821 assert!((0.0..=1.0).contains(&sim));
822 }
823
824 Ok(())
825 }
826
827 #[tokio::test]
828 async fn test_semantic_query_expansion() -> Result<()> {
829 let model = create_test_model();
830 let ext = SparqlExtension::new(Box::new(model));
831
832 {
833 let mut model = ext.model.write().await;
834 model.train(Some(10)).await?;
835 }
836
837 let query = r#"
838 SELECT ?s ?o WHERE {
839 ?s <http://example.org/knows> ?o
840 }
841 "#;
842
843 let expanded = ext.expand_query_semantically(query).await?;
844
845 assert_eq!(expanded.original_query, query);
846 assert!(!expanded.expanded_query.is_empty());
847
848 Ok(())
849 }
850
851 #[tokio::test]
852 async fn test_fuzzy_match() -> Result<()> {
853 let model = create_test_model();
854 let ext = SparqlExtension::new(Box::new(model));
855
856 let matches = ext.fuzzy_match_entity("alice", 3).await?;
857
858 assert!(matches.len() <= 3);
863 for (entity, score) in matches {
864 assert!(!entity.is_empty());
865 assert!((0.0..=1.0).contains(&score));
866 }
867
868 Ok(())
869 }
870
871 #[test]
872 fn test_parse_sparql_query() -> Result<()> {
873 let query = r#"
874 SELECT ?s ?o WHERE {
875 ?s <http://example.org/knows> ?o .
876 <http://example.org/alice> <http://example.org/likes> ?o .
877 }
878 "#;
879
880 let parsed = parse_sparql_query(query)?;
881
882 assert!(parsed.variables.contains("s"));
885 assert!(parsed.variables.contains("o"));
886
887 assert!(
890 !parsed.entities.is_empty() || !parsed.relations.is_empty(),
891 "Should extract at least some URIs from the query"
892 );
893
894 Ok(())
895 }
896
897 #[test]
898 fn test_cosine_similarity() -> Result<()> {
899 let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
900 let v2 = Vector::new(vec![1.0, 0.0, 0.0]);
901 let sim = cosine_similarity(&v1, &v2)?;
902 assert!((sim - 1.0).abs() < 1e-6);
903
904 let v3 = Vector::new(vec![0.0, 1.0, 0.0]);
905 let sim2 = cosine_similarity(&v1, &v3)?;
906 assert!((sim2 - 0.0).abs() < 1e-6);
907
908 Ok(())
909 }
910
911 #[test]
912 fn test_levenshtein_distance() {
913 assert_eq!(levenshtein_distance("alice", "alice"), 0);
914 assert_eq!(levenshtein_distance("alice", "alise"), 1);
915 assert_eq!(levenshtein_distance("alice", "bob"), 5);
916 assert_eq!(levenshtein_distance("", "abc"), 3);
917 assert_eq!(levenshtein_distance("abc", ""), 3);
918 }
919
920 #[test]
921 fn test_fuzzy_match_score() {
922 assert!((fuzzy_match_score("alice", "alice") - 1.0).abs() < 1e-6);
923 assert!(fuzzy_match_score("alice", "alise") > 0.7);
924 assert!(fuzzy_match_score("alice", "bob") < 0.5);
925 }
926
927 #[tokio::test]
928 async fn test_statistics_tracking() -> Result<()> {
929 let model = create_test_model();
930 let ext = SparqlExtension::new(Box::new(model));
931
932 {
933 let mut model = ext.model.write().await;
934 model.train(Some(10)).await?;
935 }
936
937 let _ = ext
939 .vec_similarity("http://example.org/alice", "http://example.org/bob")
940 .await;
941 let _ = ext.vec_nearest("http://example.org/alice", 2, None).await;
942
943 let stats = ext.get_statistics().await;
944
945 assert!(stats.similarity_computations > 0);
946 assert!(stats.nearest_neighbor_queries > 0);
947
948 Ok(())
949 }
950
951 #[tokio::test]
952 async fn test_semantic_cache() -> Result<()> {
953 let model = create_test_model();
954 let ext = SparqlExtension::new(Box::new(model));
955
956 {
957 let mut model = ext.model.write().await;
958 model.train(Some(10)).await?;
959 }
960
961 let sim1 = ext
963 .vec_similarity("http://example.org/alice", "http://example.org/bob")
964 .await?;
965
966 let sim2 = ext
968 .vec_similarity("http://example.org/alice", "http://example.org/bob")
969 .await?;
970
971 assert!((sim1 - sim2).abs() < 1e-6);
972
973 ext.clear_cache().await;
975
976 Ok(())
977 }
978}