1use super::config::{VectorServiceArg, VectorServiceResult};
7use crate::hybrid_search::multimodal_fusion::{
8 FusedResult, FusionConfig, FusionStrategy, Modality, MultimodalFusion, NormalizationMethod,
9};
10use crate::hybrid_search::types::DocumentScore;
11use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MultimodalSearchConfig {
17 pub default_weights: Vec<f64>,
19 pub default_strategy: String,
21 pub normalization: String,
23 pub cascade_thresholds: Vec<f64>,
25}
26
27impl Default for MultimodalSearchConfig {
28 fn default() -> Self {
29 Self {
30 default_weights: vec![0.33, 0.33, 0.34],
31 default_strategy: "rankfusion".to_string(),
32 normalization: "minmax".to_string(),
33 cascade_thresholds: vec![0.5, 0.7, 0.8],
34 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SparqlMultimodalResult {
41 pub uri: String,
43 pub score: f64,
45 pub text_score: Option<f64>,
47 pub vector_score: Option<f64>,
48 pub spatial_score: Option<f64>,
49}
50
51impl From<FusedResult> for SparqlMultimodalResult {
52 fn from(result: FusedResult) -> Self {
53 let text_score = result.get_score(Modality::Text);
54 let vector_score = result.get_score(Modality::Vector);
55 let spatial_score = result.get_score(Modality::Spatial);
56
57 Self {
58 uri: result.uri,
59 score: result.total_score,
60 text_score,
61 vector_score,
62 spatial_score,
63 }
64 }
65}
66
67pub fn sparql_multimodal_search(
81 text_query: Option<String>,
82 vector_query: Option<String>,
83 spatial_query: Option<String>,
84 weights: Option<String>,
85 strategy: Option<String>,
86 limit: usize,
87 config: &MultimodalSearchConfig,
88) -> Result<Vec<SparqlMultimodalResult>> {
89 let fusion_strategy = parse_fusion_strategy(
91 strategy.as_deref(),
92 weights.as_deref(),
93 &config.default_weights,
94 &config.cascade_thresholds,
95 )?;
96
97 let normalization = parse_normalization(&config.normalization)?;
99
100 let fusion_config = FusionConfig {
102 default_strategy: fusion_strategy.clone(),
103 score_normalization: normalization,
104 };
105 let fusion = MultimodalFusion::new(fusion_config);
106
107 let text_results = if let Some(query) = text_query {
109 execute_text_search(&query, limit * 2)?
110 } else {
111 Vec::new()
112 };
113
114 let vector_results = if let Some(query) = vector_query {
115 let embedding = parse_vector(&query)?;
116 execute_vector_search(&embedding, limit * 2)?
117 } else {
118 Vec::new()
119 };
120
121 let spatial_results = if let Some(query) = spatial_query {
122 execute_spatial_search(&query, limit * 2)?
123 } else {
124 Vec::new()
125 };
126
127 let fused = fusion.fuse(
129 &text_results,
130 &vector_results,
131 &spatial_results,
132 Some(fusion_strategy),
133 )?;
134
135 let results: Vec<SparqlMultimodalResult> = fused
137 .into_iter()
138 .take(limit)
139 .map(SparqlMultimodalResult::from)
140 .collect();
141
142 Ok(results)
143}
144
145fn parse_fusion_strategy(
147 strategy: Option<&str>,
148 weights: Option<&str>,
149 default_weights: &[f64],
150 cascade_thresholds: &[f64],
151) -> Result<FusionStrategy> {
152 match strategy {
153 Some("weighted") => {
154 let w = if let Some(weights_str) = weights {
155 parse_weights(weights_str)?
156 } else {
157 default_weights.to_vec()
158 };
159 Ok(FusionStrategy::Weighted { weights: w })
160 }
161 Some("sequential") => {
162 Ok(FusionStrategy::Sequential {
164 order: vec![Modality::Text, Modality::Vector],
165 })
166 }
167 Some("cascade") => Ok(FusionStrategy::Cascade {
168 thresholds: cascade_thresholds.to_vec(),
169 }),
170 Some("rankfusion") | None => Ok(FusionStrategy::RankFusion),
171 Some(s) => anyhow::bail!("Unknown fusion strategy: {}", s),
172 }
173}
174
175fn parse_normalization(normalization: &str) -> Result<NormalizationMethod> {
177 match normalization.to_lowercase().as_str() {
178 "minmax" => Ok(NormalizationMethod::MinMax),
179 "zscore" => Ok(NormalizationMethod::ZScore),
180 "sigmoid" => Ok(NormalizationMethod::Sigmoid),
181 _ => anyhow::bail!("Unknown normalization method: {}", normalization),
182 }
183}
184
185fn parse_weights(weights_str: &str) -> Result<Vec<f64>> {
187 weights_str
188 .split(',')
189 .map(|s| {
190 s.trim()
191 .parse::<f64>()
192 .context("Failed to parse weight value")
193 })
194 .collect()
195}
196
197fn parse_vector(vector_str: &str) -> Result<Vec<f32>> {
199 vector_str
200 .split(',')
201 .map(|s| {
202 s.trim()
203 .parse::<f32>()
204 .context("Failed to parse vector value")
205 })
206 .collect()
207}
208
209fn execute_text_search(query: &str, limit: usize) -> Result<Vec<DocumentScore>> {
211 Ok(vec![
214 DocumentScore {
215 doc_id: format!("text_result_1_{}", query),
216 score: 10.0,
217 rank: 0,
218 },
219 DocumentScore {
220 doc_id: format!("text_result_2_{}", query),
221 score: 8.0,
222 rank: 1,
223 },
224 ]
225 .into_iter()
226 .take(limit)
227 .collect())
228}
229
230fn execute_vector_search(embedding: &[f32], limit: usize) -> Result<Vec<DocumentScore>> {
232 Ok(vec![
235 DocumentScore {
236 doc_id: format!("vector_result_1_dim{}", embedding.len()),
237 score: 0.95,
238 rank: 0,
239 },
240 DocumentScore {
241 doc_id: format!("vector_result_2_dim{}", embedding.len()),
242 score: 0.90,
243 rank: 1,
244 },
245 ]
246 .into_iter()
247 .take(limit)
248 .collect())
249}
250
251fn execute_spatial_search(wkt: &str, limit: usize) -> Result<Vec<DocumentScore>> {
253 Ok(vec![
256 DocumentScore {
257 doc_id: format!("spatial_result_1_{}", wkt),
258 score: 0.99,
259 rank: 0,
260 },
261 DocumentScore {
262 doc_id: format!("spatial_result_2_{}", wkt),
263 score: 0.92,
264 rank: 1,
265 },
266 ]
267 .into_iter()
268 .take(limit)
269 .collect())
270}
271
272pub fn sparql_multimodal_search_from_args(
274 args: &[VectorServiceArg],
275 config: &MultimodalSearchConfig,
276) -> Result<VectorServiceResult> {
277 let mut text_query: Option<String> = None;
279 let vector_query: Option<String> = None;
280 let spatial_query: Option<String> = None;
281 let weights: Option<String> = None;
282 let strategy: Option<String> = None;
283 let mut limit: usize = 10;
284
285 for arg in args {
287 match arg {
288 VectorServiceArg::String(s) if text_query.is_none() => {
289 text_query = Some(s.clone());
290 }
291 VectorServiceArg::Number(n) => {
292 limit = *n as usize;
293 }
294 _ => {}
295 }
296 }
297
298 let results = sparql_multimodal_search(
300 text_query,
301 vector_query,
302 spatial_query,
303 weights,
304 strategy,
305 limit,
306 config,
307 )?;
308
309 let similarity_list: Vec<(String, f32)> = results
311 .into_iter()
312 .map(|r| (r.uri, r.score as f32))
313 .collect();
314
315 Ok(VectorServiceResult::SimilarityList(similarity_list))
316}
317
318pub fn generate_multimodal_sparql_function() -> String {
320 r#"
321PREFIX vec: <http://oxirs.org/vec#>
322PREFIX geo: <http://www.opengis.net/ont/geosparql#>
323
324# Multimodal Search Function
325# Combines text, vector, and spatial search with intelligent fusion
326#
327# Usage:
328# SELECT ?entity ?score WHERE {
329# ?entity vec:multimodal_search(
330# text: "machine learning conference",
331# vector: "0.1,0.2,0.3,...",
332# spatial: "POINT(10.0 20.0)",
333# weights: "0.4,0.4,0.2",
334# strategy: "rankfusion",
335# limit: 10
336# ) .
337# BIND(vec:score(?entity) AS ?score)
338# }
339# ORDER BY DESC(?score)
340#
341# Parameters:
342# - text: Text/keyword query (optional)
343# - vector: Comma-separated embedding values (optional)
344# - spatial: WKT geometry string (optional)
345# - weights: Comma-separated weights [text, vector, spatial] (optional)
346# - strategy: Fusion strategy - "weighted", "sequential", "cascade", "rankfusion" (optional)
347# - limit: Maximum results (default: 10)
348#
349# Fusion Strategies:
350# - weighted: Linear combination of normalized scores
351# - sequential: Filter with one modality, rank with another
352# - cascade: Progressive filtering (fast → expensive)
353# - rankfusion: Reciprocal Rank Fusion (position-based)
354"#
355 .to_string()
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use anyhow::Result;
362
363 #[test]
364 fn test_parse_fusion_strategy_weighted() -> Result<()> {
365 let strategy = parse_fusion_strategy(
366 Some("weighted"),
367 Some("0.5,0.3,0.2"),
368 &[0.33, 0.33, 0.34],
369 &[0.5, 0.7, 0.8],
370 )?;
371
372 match strategy {
373 FusionStrategy::Weighted { weights } => {
374 assert_eq!(weights.len(), 3);
375 assert!((weights[0] - 0.5).abs() < 1e-6);
376 assert!((weights[1] - 0.3).abs() < 1e-6);
377 assert!((weights[2] - 0.2).abs() < 1e-6);
378 }
379 _ => panic!("Expected Weighted strategy"),
380 }
381 Ok(())
382 }
383
384 #[test]
385 fn test_parse_fusion_strategy_default() -> Result<()> {
386 let strategy = parse_fusion_strategy(None, None, &[0.33, 0.33, 0.34], &[0.5, 0.7, 0.8])?;
387
388 match strategy {
389 FusionStrategy::RankFusion => {}
390 _ => panic!("Expected RankFusion as default"),
391 }
392 Ok(())
393 }
394
395 #[test]
396 fn test_parse_normalization() -> Result<()> {
397 assert!(matches!(
398 parse_normalization("minmax")?,
399 NormalizationMethod::MinMax
400 ));
401 assert!(matches!(
402 parse_normalization("zscore")?,
403 NormalizationMethod::ZScore
404 ));
405 assert!(matches!(
406 parse_normalization("sigmoid")?,
407 NormalizationMethod::Sigmoid
408 ));
409 Ok(())
410 }
411
412 #[test]
413 fn test_parse_weights() -> Result<()> {
414 let weights = parse_weights("0.4, 0.35, 0.25")?;
415 assert_eq!(weights.len(), 3);
416 assert!((weights[0] - 0.4).abs() < 1e-6);
417 assert!((weights[1] - 0.35).abs() < 1e-6);
418 assert!((weights[2] - 0.25).abs() < 1e-6);
419 Ok(())
420 }
421
422 #[test]
423 fn test_parse_vector() -> Result<()> {
424 let vector = parse_vector("0.1, 0.2, 0.3")?;
425 assert_eq!(vector.len(), 3);
426 assert!((vector[0] - 0.1).abs() < 1e-6);
427 assert!((vector[1] - 0.2).abs() < 1e-6);
428 assert!((vector[2] - 0.3).abs() < 1e-6);
429 Ok(())
430 }
431
432 #[test]
433 fn test_multimodal_search_config_default() {
434 let config = MultimodalSearchConfig::default();
435 assert_eq!(config.default_weights.len(), 3);
436 assert_eq!(config.default_strategy, "rankfusion");
437 assert_eq!(config.normalization, "minmax");
438 assert_eq!(config.cascade_thresholds.len(), 3);
439 }
440
441 #[test]
442 fn test_sparql_multimodal_result_conversion() {
443 let mut fused = FusedResult::new("test_doc".to_string());
444 fused.add_score(Modality::Text, 0.5);
445 fused.add_score(Modality::Vector, 0.3);
446 fused.calculate_total();
447
448 let sparql_result: SparqlMultimodalResult = fused.into();
449
450 assert_eq!(sparql_result.uri, "test_doc");
451 assert!((sparql_result.score - 0.8).abs() < 1e-6);
452 assert_eq!(sparql_result.text_score, Some(0.5));
453 assert_eq!(sparql_result.vector_score, Some(0.3));
454 assert_eq!(sparql_result.spatial_score, None);
455 }
456
457 #[test]
458 fn test_execute_text_search() -> Result<()> {
459 let results = execute_text_search("test query", 10)?;
460 assert!(!results.is_empty());
461 assert!(results[0].doc_id.contains("test query"));
462 Ok(())
463 }
464
465 #[test]
466 fn test_execute_vector_search() -> Result<()> {
467 let embedding = vec![0.1, 0.2, 0.3];
468 let results = execute_vector_search(&embedding, 10)?;
469 assert!(!results.is_empty());
470 assert!(results[0].doc_id.contains("dim3"));
471 Ok(())
472 }
473
474 #[test]
475 fn test_execute_spatial_search() -> Result<()> {
476 let results = execute_spatial_search("POINT(10.0 20.0)", 10)?;
477 assert!(!results.is_empty());
478 assert!(results[0].doc_id.contains("POINT"));
479 Ok(())
480 }
481
482 #[test]
483 fn test_sparql_multimodal_search_integration() -> Result<()> {
484 let config = MultimodalSearchConfig::default();
485
486 let results = sparql_multimodal_search(
487 Some("machine learning".to_string()),
488 Some("0.1,0.2,0.3".to_string()),
489 Some("POINT(10.0 20.0)".to_string()),
490 Some("0.4,0.4,0.2".to_string()),
491 Some("rankfusion".to_string()),
492 10,
493 &config,
494 )?;
495
496 assert!(!results.is_empty());
497 assert!(results[0].score > 0.0);
498 Ok(())
499 }
500
501 #[test]
502 fn test_generate_multimodal_sparql_function() {
503 let sparql = generate_multimodal_sparql_function();
504 assert!(sparql.contains("vec:multimodal_search"));
505 assert!(sparql.contains("text:"));
506 assert!(sparql.contains("vector:"));
507 assert!(sparql.contains("spatial:"));
508 }
509}