Skip to main content

oxirs_vec/sparql_integration/
multimodal_functions.rs

1//! SPARQL function bindings for multimodal search
2//!
3//! This module provides SPARQL integration for multimodal search fusion,
4//! allowing queries to combine text, vector, and spatial search modalities.
5
6use super::config::{VectorServiceArg, VectorServiceResult};
7use crate::hybrid_search::multimodal_fusion::{
8    FusedResult, FusionConfig, FusionStrategy, Modality, MultimodalFusion, NormalizationMethod,
9};
10use crate::hybrid_search::types::DocumentScore;
11use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13
14/// Multimodal search configuration
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MultimodalSearchConfig {
17    /// Default weights for weighted fusion [text, vector, spatial]
18    pub default_weights: Vec<f64>,
19    /// Default fusion strategy
20    pub default_strategy: String,
21    /// Score normalization method
22    pub normalization: String,
23    /// Cascade thresholds [text, vector, spatial]
24    pub cascade_thresholds: Vec<f64>,
25}
26
27impl Default for MultimodalSearchConfig {
28    fn default() -> Self {
29        Self {
30            default_weights: vec![0.33, 0.33, 0.34],
31            default_strategy: "rankfusion".to_string(),
32            normalization: "minmax".to_string(),
33            cascade_thresholds: vec![0.5, 0.7, 0.8],
34        }
35    }
36}
37
38/// Multimodal search result for SPARQL
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SparqlMultimodalResult {
41    /// Resource URI
42    pub uri: String,
43    /// Combined score
44    pub score: f64,
45    /// Individual modality scores
46    pub text_score: Option<f64>,
47    pub vector_score: Option<f64>,
48    pub spatial_score: Option<f64>,
49}
50
51impl From<FusedResult> for SparqlMultimodalResult {
52    fn from(result: FusedResult) -> Self {
53        let text_score = result.get_score(Modality::Text);
54        let vector_score = result.get_score(Modality::Vector);
55        let spatial_score = result.get_score(Modality::Spatial);
56
57        Self {
58            uri: result.uri,
59            score: result.total_score,
60            text_score,
61            vector_score,
62            spatial_score,
63        }
64    }
65}
66
67/// Execute multimodal search with multiple modalities
68///
69/// # Arguments
70/// * `text_query` - Optional text/keyword query string
71/// * `vector_query` - Optional vector embedding (comma-separated)
72/// * `spatial_query` - Optional WKT geometry (e.g., "POINT(10.0 20.0)")
73/// * `weights` - Optional weights [text, vector, spatial] (comma-separated)
74/// * `strategy` - Optional fusion strategy: "weighted", "sequential", "cascade", "rankfusion"
75/// * `limit` - Maximum number of results
76/// * `config` - Multimodal search configuration
77///
78/// # Returns
79/// Vector of multimodal search results with combined scores
80pub fn sparql_multimodal_search(
81    text_query: Option<String>,
82    vector_query: Option<String>,
83    spatial_query: Option<String>,
84    weights: Option<String>,
85    strategy: Option<String>,
86    limit: usize,
87    config: &MultimodalSearchConfig,
88) -> Result<Vec<SparqlMultimodalResult>> {
89    // Parse fusion strategy
90    let fusion_strategy = parse_fusion_strategy(
91        strategy.as_deref(),
92        weights.as_deref(),
93        &config.default_weights,
94        &config.cascade_thresholds,
95    )?;
96
97    // Parse normalization method
98    let normalization = parse_normalization(&config.normalization)?;
99
100    // Create fusion engine
101    let fusion_config = FusionConfig {
102        default_strategy: fusion_strategy.clone(),
103        score_normalization: normalization,
104    };
105    let fusion = MultimodalFusion::new(fusion_config);
106
107    // Execute individual searches
108    let text_results = if let Some(query) = text_query {
109        execute_text_search(&query, limit * 2)?
110    } else {
111        Vec::new()
112    };
113
114    let vector_results = if let Some(query) = vector_query {
115        let embedding = parse_vector(&query)?;
116        execute_vector_search(&embedding, limit * 2)?
117    } else {
118        Vec::new()
119    };
120
121    let spatial_results = if let Some(query) = spatial_query {
122        execute_spatial_search(&query, limit * 2)?
123    } else {
124        Vec::new()
125    };
126
127    // Fuse results
128    let fused = fusion.fuse(
129        &text_results,
130        &vector_results,
131        &spatial_results,
132        Some(fusion_strategy),
133    )?;
134
135    // Convert to SPARQL results and limit
136    let results: Vec<SparqlMultimodalResult> = fused
137        .into_iter()
138        .take(limit)
139        .map(SparqlMultimodalResult::from)
140        .collect();
141
142    Ok(results)
143}
144
145/// Parse fusion strategy from string
146fn parse_fusion_strategy(
147    strategy: Option<&str>,
148    weights: Option<&str>,
149    default_weights: &[f64],
150    cascade_thresholds: &[f64],
151) -> Result<FusionStrategy> {
152    match strategy {
153        Some("weighted") => {
154            let w = if let Some(weights_str) = weights {
155                parse_weights(weights_str)?
156            } else {
157                default_weights.to_vec()
158            };
159            Ok(FusionStrategy::Weighted { weights: w })
160        }
161        Some("sequential") => {
162            // Default order: Text → Vector
163            Ok(FusionStrategy::Sequential {
164                order: vec![Modality::Text, Modality::Vector],
165            })
166        }
167        Some("cascade") => Ok(FusionStrategy::Cascade {
168            thresholds: cascade_thresholds.to_vec(),
169        }),
170        Some("rankfusion") | None => Ok(FusionStrategy::RankFusion),
171        Some(s) => anyhow::bail!("Unknown fusion strategy: {}", s),
172    }
173}
174
175/// Parse normalization method from string
176fn parse_normalization(normalization: &str) -> Result<NormalizationMethod> {
177    match normalization.to_lowercase().as_str() {
178        "minmax" => Ok(NormalizationMethod::MinMax),
179        "zscore" => Ok(NormalizationMethod::ZScore),
180        "sigmoid" => Ok(NormalizationMethod::Sigmoid),
181        _ => anyhow::bail!("Unknown normalization method: {}", normalization),
182    }
183}
184
185/// Parse weights from comma-separated string
186fn parse_weights(weights_str: &str) -> Result<Vec<f64>> {
187    weights_str
188        .split(',')
189        .map(|s| {
190            s.trim()
191                .parse::<f64>()
192                .context("Failed to parse weight value")
193        })
194        .collect()
195}
196
197/// Parse vector embedding from comma-separated string
198fn parse_vector(vector_str: &str) -> Result<Vec<f32>> {
199    vector_str
200        .split(',')
201        .map(|s| {
202            s.trim()
203                .parse::<f32>()
204                .context("Failed to parse vector value")
205        })
206        .collect()
207}
208
209/// Execute text/keyword search (placeholder - integrate with actual text search)
210fn execute_text_search(query: &str, limit: usize) -> Result<Vec<DocumentScore>> {
211    // This is a placeholder implementation
212    // In production, integrate with Tantivy or BM25 search
213    Ok(vec![
214        DocumentScore {
215            doc_id: format!("text_result_1_{}", query),
216            score: 10.0,
217            rank: 0,
218        },
219        DocumentScore {
220            doc_id: format!("text_result_2_{}", query),
221            score: 8.0,
222            rank: 1,
223        },
224    ]
225    .into_iter()
226    .take(limit)
227    .collect())
228}
229
230/// Execute vector/semantic search (placeholder - integrate with actual vector search)
231fn execute_vector_search(embedding: &[f32], limit: usize) -> Result<Vec<DocumentScore>> {
232    // This is a placeholder implementation
233    // In production, integrate with HNSW or vector index
234    Ok(vec![
235        DocumentScore {
236            doc_id: format!("vector_result_1_dim{}", embedding.len()),
237            score: 0.95,
238            rank: 0,
239        },
240        DocumentScore {
241            doc_id: format!("vector_result_2_dim{}", embedding.len()),
242            score: 0.90,
243            rank: 1,
244        },
245    ]
246    .into_iter()
247    .take(limit)
248    .collect())
249}
250
251/// Execute spatial/geographic search (placeholder - integrate with actual spatial search)
252fn execute_spatial_search(wkt: &str, limit: usize) -> Result<Vec<DocumentScore>> {
253    // This is a placeholder implementation
254    // In production, integrate with GeoSPARQL or spatial index
255    Ok(vec![
256        DocumentScore {
257            doc_id: format!("spatial_result_1_{}", wkt),
258            score: 0.99,
259            rank: 0,
260        },
261        DocumentScore {
262            doc_id: format!("spatial_result_2_{}", wkt),
263            score: 0.92,
264            rank: 1,
265        },
266    ]
267    .into_iter()
268    .take(limit)
269    .collect())
270}
271
272/// Convert SPARQL arguments to multimodal search
273pub fn sparql_multimodal_search_from_args(
274    args: &[VectorServiceArg],
275    config: &MultimodalSearchConfig,
276) -> Result<VectorServiceResult> {
277    // Parse arguments
278    let mut text_query: Option<String> = None;
279    let vector_query: Option<String> = None;
280    let spatial_query: Option<String> = None;
281    let weights: Option<String> = None;
282    let strategy: Option<String> = None;
283    let mut limit: usize = 10;
284
285    // Extract named arguments (simplified parsing)
286    for arg in args {
287        match arg {
288            VectorServiceArg::String(s) => {
289                if text_query.is_none() {
290                    text_query = Some(s.clone());
291                }
292            }
293            VectorServiceArg::Number(n) => {
294                limit = *n as usize;
295            }
296            _ => {}
297        }
298    }
299
300    // Execute search
301    let results = sparql_multimodal_search(
302        text_query,
303        vector_query,
304        spatial_query,
305        weights,
306        strategy,
307        limit,
308        config,
309    )?;
310
311    // Convert to SPARQL result format
312    let similarity_list: Vec<(String, f32)> = results
313        .into_iter()
314        .map(|r| (r.uri, r.score as f32))
315        .collect();
316
317    Ok(VectorServiceResult::SimilarityList(similarity_list))
318}
319
320/// Generate SPARQL function definition for multimodal search
321pub fn generate_multimodal_sparql_function() -> String {
322    r#"
323PREFIX vec: <http://oxirs.org/vec#>
324PREFIX geo: <http://www.opengis.net/ont/geosparql#>
325
326# Multimodal Search Function
327# Combines text, vector, and spatial search with intelligent fusion
328#
329# Usage:
330# SELECT ?entity ?score WHERE {
331#   ?entity vec:multimodal_search(
332#     text: "machine learning conference",
333#     vector: "0.1,0.2,0.3,...",
334#     spatial: "POINT(10.0 20.0)",
335#     weights: "0.4,0.4,0.2",
336#     strategy: "rankfusion",
337#     limit: 10
338#   ) .
339#   BIND(vec:score(?entity) AS ?score)
340# }
341# ORDER BY DESC(?score)
342#
343# Parameters:
344#   - text: Text/keyword query (optional)
345#   - vector: Comma-separated embedding values (optional)
346#   - spatial: WKT geometry string (optional)
347#   - weights: Comma-separated weights [text, vector, spatial] (optional)
348#   - strategy: Fusion strategy - "weighted", "sequential", "cascade", "rankfusion" (optional)
349#   - limit: Maximum results (default: 10)
350#
351# Fusion Strategies:
352#   - weighted: Linear combination of normalized scores
353#   - sequential: Filter with one modality, rank with another
354#   - cascade: Progressive filtering (fast → expensive)
355#   - rankfusion: Reciprocal Rank Fusion (position-based)
356"#
357    .to_string()
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use anyhow::Result;
364
365    #[test]
366    fn test_parse_fusion_strategy_weighted() -> Result<()> {
367        let strategy = parse_fusion_strategy(
368            Some("weighted"),
369            Some("0.5,0.3,0.2"),
370            &[0.33, 0.33, 0.34],
371            &[0.5, 0.7, 0.8],
372        )?;
373
374        match strategy {
375            FusionStrategy::Weighted { weights } => {
376                assert_eq!(weights.len(), 3);
377                assert!((weights[0] - 0.5).abs() < 1e-6);
378                assert!((weights[1] - 0.3).abs() < 1e-6);
379                assert!((weights[2] - 0.2).abs() < 1e-6);
380            }
381            _ => panic!("Expected Weighted strategy"),
382        }
383        Ok(())
384    }
385
386    #[test]
387    fn test_parse_fusion_strategy_default() -> Result<()> {
388        let strategy = parse_fusion_strategy(None, None, &[0.33, 0.33, 0.34], &[0.5, 0.7, 0.8])?;
389
390        match strategy {
391            FusionStrategy::RankFusion => {}
392            _ => panic!("Expected RankFusion as default"),
393        }
394        Ok(())
395    }
396
397    #[test]
398    fn test_parse_normalization() -> Result<()> {
399        assert!(matches!(
400            parse_normalization("minmax")?,
401            NormalizationMethod::MinMax
402        ));
403        assert!(matches!(
404            parse_normalization("zscore")?,
405            NormalizationMethod::ZScore
406        ));
407        assert!(matches!(
408            parse_normalization("sigmoid")?,
409            NormalizationMethod::Sigmoid
410        ));
411        Ok(())
412    }
413
414    #[test]
415    fn test_parse_weights() -> Result<()> {
416        let weights = parse_weights("0.4, 0.35, 0.25")?;
417        assert_eq!(weights.len(), 3);
418        assert!((weights[0] - 0.4).abs() < 1e-6);
419        assert!((weights[1] - 0.35).abs() < 1e-6);
420        assert!((weights[2] - 0.25).abs() < 1e-6);
421        Ok(())
422    }
423
424    #[test]
425    fn test_parse_vector() -> Result<()> {
426        let vector = parse_vector("0.1, 0.2, 0.3")?;
427        assert_eq!(vector.len(), 3);
428        assert!((vector[0] - 0.1).abs() < 1e-6);
429        assert!((vector[1] - 0.2).abs() < 1e-6);
430        assert!((vector[2] - 0.3).abs() < 1e-6);
431        Ok(())
432    }
433
434    #[test]
435    fn test_multimodal_search_config_default() {
436        let config = MultimodalSearchConfig::default();
437        assert_eq!(config.default_weights.len(), 3);
438        assert_eq!(config.default_strategy, "rankfusion");
439        assert_eq!(config.normalization, "minmax");
440        assert_eq!(config.cascade_thresholds.len(), 3);
441    }
442
443    #[test]
444    fn test_sparql_multimodal_result_conversion() {
445        let mut fused = FusedResult::new("test_doc".to_string());
446        fused.add_score(Modality::Text, 0.5);
447        fused.add_score(Modality::Vector, 0.3);
448        fused.calculate_total();
449
450        let sparql_result: SparqlMultimodalResult = fused.into();
451
452        assert_eq!(sparql_result.uri, "test_doc");
453        assert!((sparql_result.score - 0.8).abs() < 1e-6);
454        assert_eq!(sparql_result.text_score, Some(0.5));
455        assert_eq!(sparql_result.vector_score, Some(0.3));
456        assert_eq!(sparql_result.spatial_score, None);
457    }
458
459    #[test]
460    fn test_execute_text_search() -> Result<()> {
461        let results = execute_text_search("test query", 10)?;
462        assert!(!results.is_empty());
463        assert!(results[0].doc_id.contains("test query"));
464        Ok(())
465    }
466
467    #[test]
468    fn test_execute_vector_search() -> Result<()> {
469        let embedding = vec![0.1, 0.2, 0.3];
470        let results = execute_vector_search(&embedding, 10)?;
471        assert!(!results.is_empty());
472        assert!(results[0].doc_id.contains("dim3"));
473        Ok(())
474    }
475
476    #[test]
477    fn test_execute_spatial_search() -> Result<()> {
478        let results = execute_spatial_search("POINT(10.0 20.0)", 10)?;
479        assert!(!results.is_empty());
480        assert!(results[0].doc_id.contains("POINT"));
481        Ok(())
482    }
483
484    #[test]
485    fn test_sparql_multimodal_search_integration() -> Result<()> {
486        let config = MultimodalSearchConfig::default();
487
488        let results = sparql_multimodal_search(
489            Some("machine learning".to_string()),
490            Some("0.1,0.2,0.3".to_string()),
491            Some("POINT(10.0 20.0)".to_string()),
492            Some("0.4,0.4,0.2".to_string()),
493            Some("rankfusion".to_string()),
494            10,
495            &config,
496        )?;
497
498        assert!(!results.is_empty());
499        assert!(results[0].score > 0.0);
500        Ok(())
501    }
502
503    #[test]
504    fn test_generate_multimodal_sparql_function() {
505        let sparql = generate_multimodal_sparql_function();
506        assert!(sparql.contains("vec:multimodal_search"));
507        assert!(sparql.contains("text:"));
508        assert!(sparql.contains("vector:"));
509        assert!(sparql.contains("spatial:"));
510    }
511}