Skip to main content

oxirs_vec/sparql_integration/
multimodal_functions.rs

1//! SPARQL function bindings for multimodal search
2//!
3//! This module provides SPARQL integration for multimodal search fusion,
4//! allowing queries to combine text, vector, and spatial search modalities.
5
6use super::config::{VectorServiceArg, VectorServiceResult};
7use crate::hybrid_search::multimodal_fusion::{
8    FusedResult, FusionConfig, FusionStrategy, Modality, MultimodalFusion, NormalizationMethod,
9};
10use crate::hybrid_search::types::DocumentScore;
11use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13
14/// Multimodal search configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MultimodalSearchConfig {
17    /// Default weights for weighted fusion [text, vector, spatial]
18    pub default_weights: Vec<f64>,
19    /// Default fusion strategy
20    pub default_strategy: String,
21    /// Score normalization method
22    pub normalization: String,
23    /// Cascade thresholds [text, vector, spatial]
24    pub cascade_thresholds: Vec<f64>,
25}
26
27impl Default for MultimodalSearchConfig {
28    fn default() -> Self {
29        Self {
30            default_weights: vec![0.33, 0.33, 0.34],
31            default_strategy: "rankfusion".to_string(),
32            normalization: "minmax".to_string(),
33            cascade_thresholds: vec![0.5, 0.7, 0.8],
34        }
35    }
36}
37
38/// Multimodal search result for SPARQL
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SparqlMultimodalResult {
41    /// Resource URI
42    pub uri: String,
43    /// Combined score
44    pub score: f64,
45    /// Individual modality scores
46    pub text_score: Option<f64>,
47    pub vector_score: Option<f64>,
48    pub spatial_score: Option<f64>,
49}
50
51impl From<FusedResult> for SparqlMultimodalResult {
52    fn from(result: FusedResult) -> Self {
53        let text_score = result.get_score(Modality::Text);
54        let vector_score = result.get_score(Modality::Vector);
55        let spatial_score = result.get_score(Modality::Spatial);
56
57        Self {
58            uri: result.uri,
59            score: result.total_score,
60            text_score,
61            vector_score,
62            spatial_score,
63        }
64    }
65}
66
67/// Execute multimodal search with multiple modalities
68///
69/// # Arguments
70/// * `text_query` - Optional text/keyword query string
71/// * `vector_query` - Optional vector embedding (comma-separated)
72/// * `spatial_query` - Optional WKT geometry (e.g., "POINT(10.0 20.0)")
73/// * `weights` - Optional weights [text, vector, spatial] (comma-separated)
74/// * `strategy` - Optional fusion strategy: "weighted", "sequential", "cascade", "rankfusion"
75/// * `limit` - Maximum number of results
76/// * `config` - Multimodal search configuration
77///
78/// # Returns
79/// Vector of multimodal search results with combined scores
80pub fn sparql_multimodal_search(
81    text_query: Option<String>,
82    vector_query: Option<String>,
83    spatial_query: Option<String>,
84    weights: Option<String>,
85    strategy: Option<String>,
86    limit: usize,
87    config: &MultimodalSearchConfig,
88) -> Result<Vec<SparqlMultimodalResult>> {
89    // Parse fusion strategy
90    let fusion_strategy = parse_fusion_strategy(
91        strategy.as_deref(),
92        weights.as_deref(),
93        &config.default_weights,
94        &config.cascade_thresholds,
95    )?;
96
97    // Parse normalization method
98    let normalization = parse_normalization(&config.normalization)?;
99
100    // Create fusion engine
101    let fusion_config = FusionConfig {
102        default_strategy: fusion_strategy.clone(),
103        score_normalization: normalization,
104    };
105    let fusion = MultimodalFusion::new(fusion_config);
106
107    // Execute individual searches
108    let text_results = if let Some(query) = text_query {
109        execute_text_search(&query, limit * 2)?
110    } else {
111        Vec::new()
112    };
113
114    let vector_results = if let Some(query) = vector_query {
115        let embedding = parse_vector(&query)?;
116        execute_vector_search(&embedding, limit * 2)?
117    } else {
118        Vec::new()
119    };
120
121    let spatial_results = if let Some(query) = spatial_query {
122        execute_spatial_search(&query, limit * 2)?
123    } else {
124        Vec::new()
125    };
126
127    // Fuse results
128    let fused = fusion.fuse(
129        &text_results,
130        &vector_results,
131        &spatial_results,
132        Some(fusion_strategy),
133    )?;
134
135    // Convert to SPARQL results and limit
136    let results: Vec<SparqlMultimodalResult> = fused
137        .into_iter()
138        .take(limit)
139        .map(SparqlMultimodalResult::from)
140        .collect();
141
142    Ok(results)
143}
144
145/// Parse fusion strategy from string
146fn parse_fusion_strategy(
147    strategy: Option<&str>,
148    weights: Option<&str>,
149    default_weights: &[f64],
150    cascade_thresholds: &[f64],
151) -> Result<FusionStrategy> {
152    match strategy {
153        Some("weighted") => {
154            let w = if let Some(weights_str) = weights {
155                parse_weights(weights_str)?
156            } else {
157                default_weights.to_vec()
158            };
159            Ok(FusionStrategy::Weighted { weights: w })
160        }
161        Some("sequential") => {
162            // Default order: Text → Vector
163            Ok(FusionStrategy::Sequential {
164                order: vec![Modality::Text, Modality::Vector],
165            })
166        }
167        Some("cascade") => Ok(FusionStrategy::Cascade {
168            thresholds: cascade_thresholds.to_vec(),
169        }),
170        Some("rankfusion") | None => Ok(FusionStrategy::RankFusion),
171        Some(s) => anyhow::bail!("Unknown fusion strategy: {}", s),
172    }
173}
174
175/// Parse normalization method from string
176fn parse_normalization(normalization: &str) -> Result<NormalizationMethod> {
177    match normalization.to_lowercase().as_str() {
178        "minmax" => Ok(NormalizationMethod::MinMax),
179        "zscore" => Ok(NormalizationMethod::ZScore),
180        "sigmoid" => Ok(NormalizationMethod::Sigmoid),
181        _ => anyhow::bail!("Unknown normalization method: {}", normalization),
182    }
183}
184
185/// Parse weights from comma-separated string
186fn parse_weights(weights_str: &str) -> Result<Vec<f64>> {
187    weights_str
188        .split(',')
189        .map(|s| {
190            s.trim()
191                .parse::<f64>()
192                .context("Failed to parse weight value")
193        })
194        .collect()
195}
196
197/// Parse vector embedding from comma-separated string
198fn parse_vector(vector_str: &str) -> Result<Vec<f32>> {
199    vector_str
200        .split(',')
201        .map(|s| {
202            s.trim()
203                .parse::<f32>()
204                .context("Failed to parse vector value")
205        })
206        .collect()
207}
208
209/// Execute text/keyword search (placeholder - integrate with actual text search)
210fn execute_text_search(query: &str, limit: usize) -> Result<Vec<DocumentScore>> {
211    // This is a placeholder implementation
212    // In production, integrate with Tantivy or BM25 search
213    Ok(vec![
214        DocumentScore {
215            doc_id: format!("text_result_1_{}", query),
216            score: 10.0,
217            rank: 0,
218        },
219        DocumentScore {
220            doc_id: format!("text_result_2_{}", query),
221            score: 8.0,
222            rank: 1,
223        },
224    ]
225    .into_iter()
226    .take(limit)
227    .collect())
228}
229
230/// Execute vector/semantic search (placeholder - integrate with actual vector search)
231fn execute_vector_search(embedding: &[f32], limit: usize) -> Result<Vec<DocumentScore>> {
232    // This is a placeholder implementation
233    // In production, integrate with HNSW or vector index
234    Ok(vec![
235        DocumentScore {
236            doc_id: format!("vector_result_1_dim{}", embedding.len()),
237            score: 0.95,
238            rank: 0,
239        },
240        DocumentScore {
241            doc_id: format!("vector_result_2_dim{}", embedding.len()),
242            score: 0.90,
243            rank: 1,
244        },
245    ]
246    .into_iter()
247    .take(limit)
248    .collect())
249}
250
251/// Execute spatial/geographic search (placeholder - integrate with actual spatial search)
252fn execute_spatial_search(wkt: &str, limit: usize) -> Result<Vec<DocumentScore>> {
253    // This is a placeholder implementation
254    // In production, integrate with GeoSPARQL or spatial index
255    Ok(vec![
256        DocumentScore {
257            doc_id: format!("spatial_result_1_{}", wkt),
258            score: 0.99,
259            rank: 0,
260        },
261        DocumentScore {
262            doc_id: format!("spatial_result_2_{}", wkt),
263            score: 0.92,
264            rank: 1,
265        },
266    ]
267    .into_iter()
268    .take(limit)
269    .collect())
270}
271
272/// Convert SPARQL arguments to multimodal search
273pub fn sparql_multimodal_search_from_args(
274    args: &[VectorServiceArg],
275    config: &MultimodalSearchConfig,
276) -> Result<VectorServiceResult> {
277    // Parse arguments
278    let mut text_query: Option<String> = None;
279    let vector_query: Option<String> = None;
280    let spatial_query: Option<String> = None;
281    let weights: Option<String> = None;
282    let strategy: Option<String> = None;
283    let mut limit: usize = 10;
284
285    // Extract named arguments (simplified parsing)
286    for arg in args {
287        match arg {
288            VectorServiceArg::String(s) if text_query.is_none() => {
289                text_query = Some(s.clone());
290            }
291            VectorServiceArg::Number(n) => {
292                limit = *n as usize;
293            }
294            _ => {}
295        }
296    }
297
298    // Execute search
299    let results = sparql_multimodal_search(
300        text_query,
301        vector_query,
302        spatial_query,
303        weights,
304        strategy,
305        limit,
306        config,
307    )?;
308
309    // Convert to SPARQL result format
310    let similarity_list: Vec<(String, f32)> = results
311        .into_iter()
312        .map(|r| (r.uri, r.score as f32))
313        .collect();
314
315    Ok(VectorServiceResult::SimilarityList(similarity_list))
316}
317
318/// Generate SPARQL function definition for multimodal search
319pub fn generate_multimodal_sparql_function() -> String {
320    r#"
321PREFIX vec: <http://oxirs.org/vec#>
322PREFIX geo: <http://www.opengis.net/ont/geosparql#>
323
324# Multimodal Search Function
325# Combines text, vector, and spatial search with intelligent fusion
326#
327# Usage:
328# SELECT ?entity ?score WHERE {
329#   ?entity vec:multimodal_search(
330#     text: "machine learning conference",
331#     vector: "0.1,0.2,0.3,...",
332#     spatial: "POINT(10.0 20.0)",
333#     weights: "0.4,0.4,0.2",
334#     strategy: "rankfusion",
335#     limit: 10
336#   ) .
337#   BIND(vec:score(?entity) AS ?score)
338# }
339# ORDER BY DESC(?score)
340#
341# Parameters:
342#   - text: Text/keyword query (optional)
343#   - vector: Comma-separated embedding values (optional)
344#   - spatial: WKT geometry string (optional)
345#   - weights: Comma-separated weights [text, vector, spatial] (optional)
346#   - strategy: Fusion strategy - "weighted", "sequential", "cascade", "rankfusion" (optional)
347#   - limit: Maximum results (default: 10)
348#
349# Fusion Strategies:
350#   - weighted: Linear combination of normalized scores
351#   - sequential: Filter with one modality, rank with another
352#   - cascade: Progressive filtering (fast → expensive)
353#   - rankfusion: Reciprocal Rank Fusion (position-based)
354"#
355    .to_string()
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use anyhow::Result;
362
363    #[test]
364    fn test_parse_fusion_strategy_weighted() -> Result<()> {
365        let strategy = parse_fusion_strategy(
366            Some("weighted"),
367            Some("0.5,0.3,0.2"),
368            &[0.33, 0.33, 0.34],
369            &[0.5, 0.7, 0.8],
370        )?;
371
372        match strategy {
373            FusionStrategy::Weighted { weights } => {
374                assert_eq!(weights.len(), 3);
375                assert!((weights[0] - 0.5).abs() < 1e-6);
376                assert!((weights[1] - 0.3).abs() < 1e-6);
377                assert!((weights[2] - 0.2).abs() < 1e-6);
378            }
379            _ => panic!("Expected Weighted strategy"),
380        }
381        Ok(())
382    }
383
384    #[test]
385    fn test_parse_fusion_strategy_default() -> Result<()> {
386        let strategy = parse_fusion_strategy(None, None, &[0.33, 0.33, 0.34], &[0.5, 0.7, 0.8])?;
387
388        match strategy {
389            FusionStrategy::RankFusion => {}
390            _ => panic!("Expected RankFusion as default"),
391        }
392        Ok(())
393    }
394
395    #[test]
396    fn test_parse_normalization() -> Result<()> {
397        assert!(matches!(
398            parse_normalization("minmax")?,
399            NormalizationMethod::MinMax
400        ));
401        assert!(matches!(
402            parse_normalization("zscore")?,
403            NormalizationMethod::ZScore
404        ));
405        assert!(matches!(
406            parse_normalization("sigmoid")?,
407            NormalizationMethod::Sigmoid
408        ));
409        Ok(())
410    }
411
412    #[test]
413    fn test_parse_weights() -> Result<()> {
414        let weights = parse_weights("0.4, 0.35, 0.25")?;
415        assert_eq!(weights.len(), 3);
416        assert!((weights[0] - 0.4).abs() < 1e-6);
417        assert!((weights[1] - 0.35).abs() < 1e-6);
418        assert!((weights[2] - 0.25).abs() < 1e-6);
419        Ok(())
420    }
421
422    #[test]
423    fn test_parse_vector() -> Result<()> {
424        let vector = parse_vector("0.1, 0.2, 0.3")?;
425        assert_eq!(vector.len(), 3);
426        assert!((vector[0] - 0.1).abs() < 1e-6);
427        assert!((vector[1] - 0.2).abs() < 1e-6);
428        assert!((vector[2] - 0.3).abs() < 1e-6);
429        Ok(())
430    }
431
432    #[test]
433    fn test_multimodal_search_config_default() {
434        let config = MultimodalSearchConfig::default();
435        assert_eq!(config.default_weights.len(), 3);
436        assert_eq!(config.default_strategy, "rankfusion");
437        assert_eq!(config.normalization, "minmax");
438        assert_eq!(config.cascade_thresholds.len(), 3);
439    }
440
441    #[test]
442    fn test_sparql_multimodal_result_conversion() {
443        let mut fused = FusedResult::new("test_doc".to_string());
444        fused.add_score(Modality::Text, 0.5);
445        fused.add_score(Modality::Vector, 0.3);
446        fused.calculate_total();
447
448        let sparql_result: SparqlMultimodalResult = fused.into();
449
450        assert_eq!(sparql_result.uri, "test_doc");
451        assert!((sparql_result.score - 0.8).abs() < 1e-6);
452        assert_eq!(sparql_result.text_score, Some(0.5));
453        assert_eq!(sparql_result.vector_score, Some(0.3));
454        assert_eq!(sparql_result.spatial_score, None);
455    }
456
457    #[test]
458    fn test_execute_text_search() -> Result<()> {
459        let results = execute_text_search("test query", 10)?;
460        assert!(!results.is_empty());
461        assert!(results[0].doc_id.contains("test query"));
462        Ok(())
463    }
464
465    #[test]
466    fn test_execute_vector_search() -> Result<()> {
467        let embedding = vec![0.1, 0.2, 0.3];
468        let results = execute_vector_search(&embedding, 10)?;
469        assert!(!results.is_empty());
470        assert!(results[0].doc_id.contains("dim3"));
471        Ok(())
472    }
473
474    #[test]
475    fn test_execute_spatial_search() -> Result<()> {
476        let results = execute_spatial_search("POINT(10.0 20.0)", 10)?;
477        assert!(!results.is_empty());
478        assert!(results[0].doc_id.contains("POINT"));
479        Ok(())
480    }
481
482    #[test]
483    fn test_sparql_multimodal_search_integration() -> Result<()> {
484        let config = MultimodalSearchConfig::default();
485
486        let results = sparql_multimodal_search(
487            Some("machine learning".to_string()),
488            Some("0.1,0.2,0.3".to_string()),
489            Some("POINT(10.0 20.0)".to_string()),
490            Some("0.4,0.4,0.2".to_string()),
491            Some("rankfusion".to_string()),
492            10,
493            &config,
494        )?;
495
496        assert!(!results.is_empty());
497        assert!(results[0].score > 0.0);
498        Ok(())
499    }
500
501    #[test]
502    fn test_generate_multimodal_sparql_function() {
503        let sparql = generate_multimodal_sparql_function();
504        assert!(sparql.contains("vec:multimodal_search"));
505        assert!(sparql.contains("text:"));
506        assert!(sparql.contains("vector:"));
507        assert!(sparql.contains("spatial:"));
508    }
509}