use crate::{EmbeddingModel, Vector};
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, trace};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparqlExtensionConfig {
pub default_similarity_threshold: f32,
pub max_expansions_per_element: usize,
pub enable_query_rewriting: bool,
pub enable_semantic_caching: bool,
pub semantic_cache_size: usize,
pub enable_fuzzy_matching: bool,
pub min_expansion_confidence: f32,
pub enable_parallel_processing: bool,
}
impl Default for SparqlExtensionConfig {
fn default() -> Self {
Self {
default_similarity_threshold: 0.7,
max_expansions_per_element: 10,
enable_query_rewriting: true,
enable_semantic_caching: true,
semantic_cache_size: 1000,
enable_fuzzy_matching: true,
min_expansion_confidence: 0.6,
enable_parallel_processing: true,
}
}
}
pub struct SparqlExtension {
model: Arc<RwLock<Box<dyn EmbeddingModel>>>,
config: SparqlExtensionConfig,
semantic_cache: Arc<RwLock<SemanticCache>>,
query_statistics: Arc<RwLock<QueryStatistics>>,
}
impl SparqlExtension {
pub fn new(model: Box<dyn EmbeddingModel>) -> Self {
Self {
model: Arc::new(RwLock::new(model)),
config: SparqlExtensionConfig::default(),
semantic_cache: Arc::new(RwLock::new(SemanticCache::new(1000))),
query_statistics: Arc::new(RwLock::new(QueryStatistics::default())),
}
}
pub fn with_config(model: Box<dyn EmbeddingModel>, config: SparqlExtensionConfig) -> Self {
let cache_size = config.semantic_cache_size;
Self {
model: Arc::new(RwLock::new(model)),
config,
semantic_cache: Arc::new(RwLock::new(SemanticCache::new(cache_size))),
query_statistics: Arc::new(RwLock::new(QueryStatistics::default())),
}
}
pub async fn vec_similarity(&self, entity1: &str, entity2: &str) -> Result<f32> {
trace!("Computing similarity between {} and {}", entity1, entity2);
if self.config.enable_semantic_caching {
let cache = self.semantic_cache.read().await;
let cache_key = format!("sim:{}:{}", entity1, entity2);
if let Some(cached_result) = cache.get(&cache_key) {
debug!("Cache hit for similarity computation");
return Ok(cached_result);
}
}
let model = self.model.read().await;
let emb1 = model.get_entity_embedding(entity1)?;
let emb2 = model.get_entity_embedding(entity2)?;
let similarity = normalized_cosine_similarity(&emb1, &emb2)?;
if self.config.enable_semantic_caching {
let mut cache = self.semantic_cache.write().await;
let cache_key = format!("sim:{}:{}", entity1, entity2);
cache.put(cache_key, similarity);
}
let mut stats = self.query_statistics.write().await;
stats.similarity_computations += 1;
Ok(similarity)
}
pub async fn vec_nearest(
&self,
entity: &str,
k: usize,
min_similarity: Option<f32>,
) -> Result<Vec<(String, f32)>> {
info!("Finding {} nearest neighbors for {}", k, entity);
let model = self.model.read().await;
let target_emb = model.get_entity_embedding(entity)?;
let all_entities = model.get_entities();
let threshold = min_similarity.unwrap_or(self.config.default_similarity_threshold);
let mut similarities: Vec<(String, f32)> = if self.config.enable_parallel_processing {
self.compute_similarities_parallel(&all_entities, &target_emb, entity)
.await?
} else {
self.compute_similarities_sequential(&all_entities, &target_emb, entity, &**model)
.await?
};
similarities.retain(|(_, sim)| *sim >= threshold);
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let result: Vec<(String, f32)> = similarities.into_iter().take(k).collect();
let mut stats = self.query_statistics.write().await;
stats.nearest_neighbor_queries += 1;
Ok(result)
}
pub async fn vec_similar_entities(
&self,
entity: &str,
threshold: f32,
) -> Result<Vec<(String, f32)>> {
debug!(
"Finding entities similar to {} (threshold: {})",
entity, threshold
);
let model = self.model.read().await;
let target_emb = model.get_entity_embedding(entity)?;
let all_entities = model.get_entities();
let similarities = if self.config.enable_parallel_processing {
self.compute_similarities_parallel(&all_entities, &target_emb, entity)
.await?
} else {
self.compute_similarities_sequential(&all_entities, &target_emb, entity, &**model)
.await?
};
let result: Vec<(String, f32)> = similarities
.into_iter()
.filter(|(_, sim)| *sim >= threshold)
.collect();
Ok(result)
}
pub async fn vec_similar_relations(
&self,
relation: &str,
threshold: f32,
) -> Result<Vec<(String, f32)>> {
debug!(
"Finding relations similar to {} (threshold: {})",
relation, threshold
);
let model = self.model.read().await;
let target_emb = model.get_relation_embedding(relation)?;
let all_relations = model.get_relations();
let mut similarities = Vec::new();
for rel in &all_relations {
if rel == relation {
continue; }
let rel_emb = model.get_relation_embedding(rel)?;
let sim = cosine_similarity(&target_emb, &rel_emb)?;
if sim >= threshold {
similarities.push((rel.clone(), sim));
}
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(similarities)
}
pub async fn expand_query_semantically(&self, query: &str) -> Result<ExpandedQuery> {
info!("Performing semantic query expansion");
let mut stats = self.query_statistics.write().await;
stats.query_expansions += 1;
drop(stats);
let model = self.model.read().await;
let parsed = parse_sparql_query(query)?;
let mut entity_expansions = HashMap::new();
let mut relation_expansions = HashMap::new();
for entity in &parsed.entities {
let similar = self
.vec_similar_entities(entity, self.config.min_expansion_confidence)
.await?;
let expansions: Vec<Expansion> = similar
.into_iter()
.take(self.config.max_expansions_per_element)
.map(|(uri, confidence)| Expansion {
original: entity.clone(),
expanded: uri,
confidence,
expansion_type: ExpansionType::Entity,
})
.collect();
if !expansions.is_empty() {
entity_expansions.insert(entity.clone(), expansions);
}
}
for relation in &parsed.relations {
let similar = self
.vec_similar_relations(relation, self.config.min_expansion_confidence)
.await?;
let expansions: Vec<Expansion> = similar
.into_iter()
.take(self.config.max_expansions_per_element)
.map(|(uri, confidence)| Expansion {
original: relation.clone(),
expanded: uri,
confidence,
expansion_type: ExpansionType::Relation,
})
.collect();
if !expansions.is_empty() {
relation_expansions.insert(relation.clone(), expansions);
}
}
drop(model);
let expanded_query = if self.config.enable_query_rewriting {
self.rewrite_query_with_expansions(query, &entity_expansions, &relation_expansions)
.await?
} else {
query.to_string()
};
let expansion_count = entity_expansions.len() + relation_expansions.len();
Ok(ExpandedQuery {
original_query: query.to_string(),
expanded_query,
entity_expansions,
relation_expansions,
expansion_count,
})
}
pub async fn fuzzy_match_entity(
&self,
entity_name: &str,
k: usize,
) -> Result<Vec<(String, f32)>> {
if !self.config.enable_fuzzy_matching {
return Ok(vec![]);
}
debug!("Performing fuzzy match for entity: {}", entity_name);
let model = self.model.read().await;
let all_entities = model.get_entities();
let mut matches = Vec::new();
for entity in &all_entities {
let score = fuzzy_match_score(entity_name, entity);
if score > 0.5 {
matches.push((entity.clone(), score));
}
}
matches.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(matches.into_iter().take(k).collect())
}
pub async fn get_statistics(&self) -> QueryStatistics {
self.query_statistics.read().await.clone()
}
pub async fn clear_cache(&self) {
let mut cache = self.semantic_cache.write().await;
cache.clear();
info!("Semantic cache cleared");
}
async fn compute_similarities_parallel(
&self,
entities: &[String],
target_emb: &Vector,
exclude_entity: &str,
) -> Result<Vec<(String, f32)>> {
use rayon::prelude::*;
let model = self.model.read().await;
let embeddings: Vec<_> = entities
.iter()
.filter(|e| e.as_str() != exclude_entity)
.filter_map(|e| {
model
.get_entity_embedding(e)
.ok()
.map(|emb| (e.clone(), emb))
})
.collect();
drop(model);
let target_emb_clone = target_emb.clone();
let similarities: Vec<(String, f32)> = embeddings
.par_iter()
.filter_map(|(entity, emb)| {
cosine_similarity(&target_emb_clone, emb)
.ok()
.map(|sim| (entity.clone(), sim))
})
.collect();
Ok(similarities)
}
async fn compute_similarities_sequential(
&self,
entities: &[String],
target_emb: &Vector,
exclude_entity: &str,
model: &dyn EmbeddingModel,
) -> Result<Vec<(String, f32)>> {
let mut similarities = Vec::new();
for entity in entities {
if entity == exclude_entity {
continue;
}
if let Ok(entity_emb) = model.get_entity_embedding(entity) {
if let Ok(sim) = cosine_similarity(target_emb, &entity_emb) {
similarities.push((entity.clone(), sim));
}
}
}
Ok(similarities)
}
async fn rewrite_query_with_expansions(
&self,
original_query: &str,
entity_expansions: &HashMap<String, Vec<Expansion>>,
relation_expansions: &HashMap<String, Vec<Expansion>>,
) -> Result<String> {
let mut rewritten = original_query.to_string();
for (original, expansions) in entity_expansions {
if let Some(first_expansion) = expansions.first() {
let union_clause = format!(
"\n UNION {{ # Semantic expansion for {}\n # Similar entity: {} (confidence: {:.2})\n }}",
original, first_expansion.expanded, first_expansion.confidence
);
rewritten.push_str(&union_clause);
}
}
for (original, expansions) in relation_expansions {
if let Some(first_expansion) = expansions.first() {
let comment = format!(
"\n # Relation '{}' can be expanded to '{}' (confidence: {:.2})",
original, first_expansion.expanded, first_expansion.confidence
);
rewritten.push_str(&comment);
}
}
Ok(rewritten)
}
}
struct SemanticCache {
cache: HashMap<String, f32>,
max_size: usize,
access_count: HashMap<String, u64>,
}
impl SemanticCache {
fn new(max_size: usize) -> Self {
Self {
cache: HashMap::new(),
max_size,
access_count: HashMap::new(),
}
}
fn get(&self, key: &str) -> Option<f32> {
self.cache.get(key).copied()
}
fn put(&mut self, key: String, value: f32) {
if self.cache.len() >= self.max_size {
if let Some(lru_key) = self
.access_count
.iter()
.min_by_key(|(_, &count)| count)
.map(|(k, _)| k.clone())
{
self.cache.remove(&lru_key);
self.access_count.remove(&lru_key);
}
}
self.cache.insert(key.clone(), value);
*self.access_count.entry(key).or_insert(0) += 1;
}
fn clear(&mut self) {
self.cache.clear();
self.access_count.clear();
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct QueryStatistics {
pub similarity_computations: u64,
pub nearest_neighbor_queries: u64,
pub query_expansions: u64,
pub fuzzy_matches: u64,
pub cache_hits: u64,
pub cache_misses: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpandedQuery {
pub original_query: String,
pub expanded_query: String,
pub entity_expansions: HashMap<String, Vec<Expansion>>,
pub relation_expansions: HashMap<String, Vec<Expansion>>,
pub expansion_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Expansion {
pub original: String,
pub expanded: String,
pub confidence: f32,
pub expansion_type: ExpansionType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExpansionType {
Entity,
Relation,
Pattern,
}
#[derive(Debug, Clone)]
struct ParsedQuery {
entities: Vec<String>,
relations: Vec<String>,
variables: HashSet<String>,
}
fn parse_sparql_query(query: &str) -> Result<ParsedQuery> {
let mut entities = Vec::new();
let mut relations = Vec::new();
let mut variables = HashSet::new();
let uri_pattern =
regex::Regex::new(r"<(https?://[^>]+)>").expect("regex should compile for valid pattern");
let var_pattern =
regex::Regex::new(r"\?(\w+)").expect("regex should compile for valid pattern");
for line in query.lines() {
if line.contains("http://") || line.contains("https://") {
for cap in uri_pattern.captures_iter(line) {
let uri = cap[1].to_string();
if line.contains(&format!(" <{uri}> ")) {
relations.push(uri.clone());
} else {
entities.push(uri);
}
}
}
for cap in var_pattern.captures_iter(line) {
variables.insert(cap[1].to_string());
}
}
Ok(ParsedQuery {
entities,
relations,
variables,
})
}
fn cosine_similarity(v1: &Vector, v2: &Vector) -> Result<f32> {
if v1.dimensions != v2.dimensions {
return Err(anyhow!(
"Vector dimensions must match: {} vs {}",
v1.dimensions,
v2.dimensions
));
}
let dot_product: f32 = v1
.values
.iter()
.zip(v2.values.iter())
.map(|(a, b)| a * b)
.sum();
let norm1: f32 = v1.values.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm2: f32 = v2.values.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm1 == 0.0 || norm2 == 0.0 {
return Ok(0.0);
}
let cosine_sim = dot_product / (norm1 * norm2);
Ok(cosine_sim)
}
fn normalized_cosine_similarity(v1: &Vector, v2: &Vector) -> Result<f32> {
let cosine_sim = cosine_similarity(v1, v2)?;
Ok((cosine_sim + 1.0) / 2.0)
}
fn fuzzy_match_score(s1: &str, s2: &str) -> f32 {
let s1_lower = s1.to_lowercase();
let s2_lower = s2.to_lowercase();
if s1_lower == s2_lower {
return 1.0;
}
if s1_lower.contains(&s2_lower) || s2_lower.contains(&s1_lower) {
let max_len = s1.len().max(s2.len()) as f32;
let min_len = s1.len().min(s2.len()) as f32;
return min_len / max_len;
}
let distance = levenshtein_distance(&s1_lower, &s2_lower);
let max_len = s1.len().max(s2.len()) as f32;
if max_len == 0.0 {
return 1.0;
}
1.0 - (distance as f32 / max_len)
}
#[allow(clippy::needless_range_loop)]
fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let len1 = s1.len();
let len2 = s2.len();
if len1 == 0 {
return len2;
}
if len2 == 0 {
return len1;
}
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for i in 0..=len1 {
matrix[i][0] = i;
}
for j in 0..=len2 {
matrix[0][j] = j;
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
for i in 1..=len1 {
for j in 1..=len2 {
let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = (matrix[i - 1][j] + 1)
.min(matrix[i][j - 1] + 1)
.min(matrix[i - 1][j - 1] + cost);
}
}
matrix[len1][len2]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::TransE;
use crate::{ModelConfig, NamedNode, Triple};
fn create_test_model() -> TransE {
let config = ModelConfig::default().with_dimensions(10);
let mut model = TransE::new(config);
let triples = vec![
("alice", "knows", "bob"),
("bob", "knows", "charlie"),
("alice", "likes", "music"),
("charlie", "likes", "art"),
];
for (s, p, o) in triples {
let triple = Triple::new(
NamedNode::new(&format!("http://example.org/{s}")).expect("should succeed"),
NamedNode::new(&format!("http://example.org/{p}")).expect("should succeed"),
NamedNode::new(&format!("http://example.org/{o}")).expect("should succeed"),
);
model.add_triple(triple).expect("should succeed");
}
model
}
#[tokio::test]
async fn test_vec_similarity() -> Result<()> {
let model = create_test_model();
let ext = SparqlExtension::new(Box::new(model));
{
let mut model = ext.model.write().await;
model.train(Some(10)).await?;
}
let sim = ext
.vec_similarity("http://example.org/alice", "http://example.org/bob")
.await?;
assert!((0.0..=1.0).contains(&sim));
Ok(())
}
#[tokio::test]
async fn test_vec_nearest() -> Result<()> {
let model = create_test_model();
let ext = SparqlExtension::new(Box::new(model));
{
let mut model = ext.model.write().await;
model.train(Some(10)).await?;
}
let neighbors = ext
.vec_nearest("http://example.org/alice", 2, Some(0.0))
.await?;
assert!(neighbors.len() <= 2);
for (entity, sim) in neighbors {
assert!(!entity.is_empty());
assert!((0.0..=1.0).contains(&sim));
}
Ok(())
}
#[tokio::test]
async fn test_semantic_query_expansion() -> Result<()> {
let model = create_test_model();
let ext = SparqlExtension::new(Box::new(model));
{
let mut model = ext.model.write().await;
model.train(Some(10)).await?;
}
let query = r#"
SELECT ?s ?o WHERE {
?s <http://example.org/knows> ?o
}
"#;
let expanded = ext.expand_query_semantically(query).await?;
assert_eq!(expanded.original_query, query);
assert!(!expanded.expanded_query.is_empty());
Ok(())
}
#[tokio::test]
async fn test_fuzzy_match() -> Result<()> {
let model = create_test_model();
let ext = SparqlExtension::new(Box::new(model));
let matches = ext.fuzzy_match_entity("alice", 3).await?;
assert!(matches.len() <= 3);
for (entity, score) in matches {
assert!(!entity.is_empty());
assert!((0.0..=1.0).contains(&score));
}
Ok(())
}
#[test]
fn test_parse_sparql_query() -> Result<()> {
let query = r#"
SELECT ?s ?o WHERE {
?s <http://example.org/knows> ?o .
<http://example.org/alice> <http://example.org/likes> ?o .
}
"#;
let parsed = parse_sparql_query(query)?;
assert!(parsed.variables.contains("s"));
assert!(parsed.variables.contains("o"));
assert!(
!parsed.entities.is_empty() || !parsed.relations.is_empty(),
"Should extract at least some URIs from the query"
);
Ok(())
}
#[test]
fn test_cosine_similarity() -> Result<()> {
let v1 = Vector::new(vec![1.0, 0.0, 0.0]);
let v2 = Vector::new(vec![1.0, 0.0, 0.0]);
let sim = cosine_similarity(&v1, &v2)?;
assert!((sim - 1.0).abs() < 1e-6);
let v3 = Vector::new(vec![0.0, 1.0, 0.0]);
let sim2 = cosine_similarity(&v1, &v3)?;
assert!((sim2 - 0.0).abs() < 1e-6);
Ok(())
}
#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("alice", "alice"), 0);
assert_eq!(levenshtein_distance("alice", "alise"), 1);
assert_eq!(levenshtein_distance("alice", "bob"), 5);
assert_eq!(levenshtein_distance("", "abc"), 3);
assert_eq!(levenshtein_distance("abc", ""), 3);
}
#[test]
fn test_fuzzy_match_score() {
assert!((fuzzy_match_score("alice", "alice") - 1.0).abs() < 1e-6);
assert!(fuzzy_match_score("alice", "alise") > 0.7);
assert!(fuzzy_match_score("alice", "bob") < 0.5);
}
#[tokio::test]
async fn test_statistics_tracking() -> Result<()> {
let model = create_test_model();
let ext = SparqlExtension::new(Box::new(model));
{
let mut model = ext.model.write().await;
model.train(Some(10)).await?;
}
let _ = ext
.vec_similarity("http://example.org/alice", "http://example.org/bob")
.await;
let _ = ext.vec_nearest("http://example.org/alice", 2, None).await;
let stats = ext.get_statistics().await;
assert!(stats.similarity_computations > 0);
assert!(stats.nearest_neighbor_queries > 0);
Ok(())
}
#[tokio::test]
async fn test_semantic_cache() -> Result<()> {
let model = create_test_model();
let ext = SparqlExtension::new(Box::new(model));
{
let mut model = ext.model.write().await;
model.train(Some(10)).await?;
}
let sim1 = ext
.vec_similarity("http://example.org/alice", "http://example.org/bob")
.await?;
let sim2 = ext
.vec_similarity("http://example.org/alice", "http://example.org/bob")
.await?;
assert!((sim1 - sim2).abs() < 1e-6);
ext.clear_cache().await;
Ok(())
}
}