1use super::config::{VectorServiceArg, VectorServiceResult};
7use crate::hybrid_search::multimodal_fusion::{
8 FusedResult, FusionConfig, FusionStrategy, Modality, MultimodalFusion, NormalizationMethod,
9};
10use crate::hybrid_search::types::DocumentScore;
11use anyhow::{Context, Result};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MultimodalSearchConfig {
17 pub default_weights: Vec<f64>,
19 pub default_strategy: String,
21 pub normalization: String,
23 pub cascade_thresholds: Vec<f64>,
25}
26
27impl Default for MultimodalSearchConfig {
28 fn default() -> Self {
29 Self {
30 default_weights: vec![0.33, 0.33, 0.34],
31 default_strategy: "rankfusion".to_string(),
32 normalization: "minmax".to_string(),
33 cascade_thresholds: vec![0.5, 0.7, 0.8],
34 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct SparqlMultimodalResult {
41 pub uri: String,
43 pub score: f64,
45 pub text_score: Option<f64>,
47 pub vector_score: Option<f64>,
48 pub spatial_score: Option<f64>,
49}
50
51impl From<FusedResult> for SparqlMultimodalResult {
52 fn from(result: FusedResult) -> Self {
53 let text_score = result.get_score(Modality::Text);
54 let vector_score = result.get_score(Modality::Vector);
55 let spatial_score = result.get_score(Modality::Spatial);
56
57 Self {
58 uri: result.uri,
59 score: result.total_score,
60 text_score,
61 vector_score,
62 spatial_score,
63 }
64 }
65}
66
67pub fn sparql_multimodal_search(
81 text_query: Option<String>,
82 vector_query: Option<String>,
83 spatial_query: Option<String>,
84 weights: Option<String>,
85 strategy: Option<String>,
86 limit: usize,
87 config: &MultimodalSearchConfig,
88) -> Result<Vec<SparqlMultimodalResult>> {
89 let fusion_strategy = parse_fusion_strategy(
91 strategy.as_deref(),
92 weights.as_deref(),
93 &config.default_weights,
94 &config.cascade_thresholds,
95 )?;
96
97 let normalization = parse_normalization(&config.normalization)?;
99
100 let fusion_config = FusionConfig {
102 default_strategy: fusion_strategy.clone(),
103 score_normalization: normalization,
104 };
105 let fusion = MultimodalFusion::new(fusion_config);
106
107 let text_results = if let Some(query) = text_query {
109 execute_text_search(&query, limit * 2)?
110 } else {
111 Vec::new()
112 };
113
114 let vector_results = if let Some(query) = vector_query {
115 let embedding = parse_vector(&query)?;
116 execute_vector_search(&embedding, limit * 2)?
117 } else {
118 Vec::new()
119 };
120
121 let spatial_results = if let Some(query) = spatial_query {
122 execute_spatial_search(&query, limit * 2)?
123 } else {
124 Vec::new()
125 };
126
127 let fused = fusion.fuse(
129 &text_results,
130 &vector_results,
131 &spatial_results,
132 Some(fusion_strategy),
133 )?;
134
135 let results: Vec<SparqlMultimodalResult> = fused
137 .into_iter()
138 .take(limit)
139 .map(SparqlMultimodalResult::from)
140 .collect();
141
142 Ok(results)
143}
144
145fn parse_fusion_strategy(
147 strategy: Option<&str>,
148 weights: Option<&str>,
149 default_weights: &[f64],
150 cascade_thresholds: &[f64],
151) -> Result<FusionStrategy> {
152 match strategy {
153 Some("weighted") => {
154 let w = if let Some(weights_str) = weights {
155 parse_weights(weights_str)?
156 } else {
157 default_weights.to_vec()
158 };
159 Ok(FusionStrategy::Weighted { weights: w })
160 }
161 Some("sequential") => {
162 Ok(FusionStrategy::Sequential {
164 order: vec![Modality::Text, Modality::Vector],
165 })
166 }
167 Some("cascade") => Ok(FusionStrategy::Cascade {
168 thresholds: cascade_thresholds.to_vec(),
169 }),
170 Some("rankfusion") | None => Ok(FusionStrategy::RankFusion),
171 Some(s) => anyhow::bail!("Unknown fusion strategy: {}", s),
172 }
173}
174
175fn parse_normalization(normalization: &str) -> Result<NormalizationMethod> {
177 match normalization.to_lowercase().as_str() {
178 "minmax" => Ok(NormalizationMethod::MinMax),
179 "zscore" => Ok(NormalizationMethod::ZScore),
180 "sigmoid" => Ok(NormalizationMethod::Sigmoid),
181 _ => anyhow::bail!("Unknown normalization method: {}", normalization),
182 }
183}
184
185fn parse_weights(weights_str: &str) -> Result<Vec<f64>> {
187 weights_str
188 .split(',')
189 .map(|s| {
190 s.trim()
191 .parse::<f64>()
192 .context("Failed to parse weight value")
193 })
194 .collect()
195}
196
197fn parse_vector(vector_str: &str) -> Result<Vec<f32>> {
199 vector_str
200 .split(',')
201 .map(|s| {
202 s.trim()
203 .parse::<f32>()
204 .context("Failed to parse vector value")
205 })
206 .collect()
207}
208
209fn execute_text_search(query: &str, limit: usize) -> Result<Vec<DocumentScore>> {
211 Ok(vec![
214 DocumentScore {
215 doc_id: format!("text_result_1_{}", query),
216 score: 10.0,
217 rank: 0,
218 },
219 DocumentScore {
220 doc_id: format!("text_result_2_{}", query),
221 score: 8.0,
222 rank: 1,
223 },
224 ]
225 .into_iter()
226 .take(limit)
227 .collect())
228}
229
230fn execute_vector_search(embedding: &[f32], limit: usize) -> Result<Vec<DocumentScore>> {
232 Ok(vec![
235 DocumentScore {
236 doc_id: format!("vector_result_1_dim{}", embedding.len()),
237 score: 0.95,
238 rank: 0,
239 },
240 DocumentScore {
241 doc_id: format!("vector_result_2_dim{}", embedding.len()),
242 score: 0.90,
243 rank: 1,
244 },
245 ]
246 .into_iter()
247 .take(limit)
248 .collect())
249}
250
251fn execute_spatial_search(wkt: &str, limit: usize) -> Result<Vec<DocumentScore>> {
253 Ok(vec![
256 DocumentScore {
257 doc_id: format!("spatial_result_1_{}", wkt),
258 score: 0.99,
259 rank: 0,
260 },
261 DocumentScore {
262 doc_id: format!("spatial_result_2_{}", wkt),
263 score: 0.92,
264 rank: 1,
265 },
266 ]
267 .into_iter()
268 .take(limit)
269 .collect())
270}
271
272pub fn sparql_multimodal_search_from_args(
274 args: &[VectorServiceArg],
275 config: &MultimodalSearchConfig,
276) -> Result<VectorServiceResult> {
277 let mut text_query: Option<String> = None;
279 let vector_query: Option<String> = None;
280 let spatial_query: Option<String> = None;
281 let weights: Option<String> = None;
282 let strategy: Option<String> = None;
283 let mut limit: usize = 10;
284
285 for arg in args {
287 match arg {
288 VectorServiceArg::String(s) => {
289 if text_query.is_none() {
290 text_query = Some(s.clone());
291 }
292 }
293 VectorServiceArg::Number(n) => {
294 limit = *n as usize;
295 }
296 _ => {}
297 }
298 }
299
300 let results = sparql_multimodal_search(
302 text_query,
303 vector_query,
304 spatial_query,
305 weights,
306 strategy,
307 limit,
308 config,
309 )?;
310
311 let similarity_list: Vec<(String, f32)> = results
313 .into_iter()
314 .map(|r| (r.uri, r.score as f32))
315 .collect();
316
317 Ok(VectorServiceResult::SimilarityList(similarity_list))
318}
319
320pub fn generate_multimodal_sparql_function() -> String {
322 r#"
323PREFIX vec: <http://oxirs.org/vec#>
324PREFIX geo: <http://www.opengis.net/ont/geosparql#>
325
326# Multimodal Search Function
327# Combines text, vector, and spatial search with intelligent fusion
328#
329# Usage:
330# SELECT ?entity ?score WHERE {
331# ?entity vec:multimodal_search(
332# text: "machine learning conference",
333# vector: "0.1,0.2,0.3,...",
334# spatial: "POINT(10.0 20.0)",
335# weights: "0.4,0.4,0.2",
336# strategy: "rankfusion",
337# limit: 10
338# ) .
339# BIND(vec:score(?entity) AS ?score)
340# }
341# ORDER BY DESC(?score)
342#
343# Parameters:
344# - text: Text/keyword query (optional)
345# - vector: Comma-separated embedding values (optional)
346# - spatial: WKT geometry string (optional)
347# - weights: Comma-separated weights [text, vector, spatial] (optional)
348# - strategy: Fusion strategy - "weighted", "sequential", "cascade", "rankfusion" (optional)
349# - limit: Maximum results (default: 10)
350#
351# Fusion Strategies:
352# - weighted: Linear combination of normalized scores
353# - sequential: Filter with one modality, rank with another
354# - cascade: Progressive filtering (fast → expensive)
355# - rankfusion: Reciprocal Rank Fusion (position-based)
356"#
357 .to_string()
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use anyhow::Result;
364
365 #[test]
366 fn test_parse_fusion_strategy_weighted() -> Result<()> {
367 let strategy = parse_fusion_strategy(
368 Some("weighted"),
369 Some("0.5,0.3,0.2"),
370 &[0.33, 0.33, 0.34],
371 &[0.5, 0.7, 0.8],
372 )?;
373
374 match strategy {
375 FusionStrategy::Weighted { weights } => {
376 assert_eq!(weights.len(), 3);
377 assert!((weights[0] - 0.5).abs() < 1e-6);
378 assert!((weights[1] - 0.3).abs() < 1e-6);
379 assert!((weights[2] - 0.2).abs() < 1e-6);
380 }
381 _ => panic!("Expected Weighted strategy"),
382 }
383 Ok(())
384 }
385
386 #[test]
387 fn test_parse_fusion_strategy_default() -> Result<()> {
388 let strategy = parse_fusion_strategy(None, None, &[0.33, 0.33, 0.34], &[0.5, 0.7, 0.8])?;
389
390 match strategy {
391 FusionStrategy::RankFusion => {}
392 _ => panic!("Expected RankFusion as default"),
393 }
394 Ok(())
395 }
396
397 #[test]
398 fn test_parse_normalization() -> Result<()> {
399 assert!(matches!(
400 parse_normalization("minmax")?,
401 NormalizationMethod::MinMax
402 ));
403 assert!(matches!(
404 parse_normalization("zscore")?,
405 NormalizationMethod::ZScore
406 ));
407 assert!(matches!(
408 parse_normalization("sigmoid")?,
409 NormalizationMethod::Sigmoid
410 ));
411 Ok(())
412 }
413
414 #[test]
415 fn test_parse_weights() -> Result<()> {
416 let weights = parse_weights("0.4, 0.35, 0.25")?;
417 assert_eq!(weights.len(), 3);
418 assert!((weights[0] - 0.4).abs() < 1e-6);
419 assert!((weights[1] - 0.35).abs() < 1e-6);
420 assert!((weights[2] - 0.25).abs() < 1e-6);
421 Ok(())
422 }
423
424 #[test]
425 fn test_parse_vector() -> Result<()> {
426 let vector = parse_vector("0.1, 0.2, 0.3")?;
427 assert_eq!(vector.len(), 3);
428 assert!((vector[0] - 0.1).abs() < 1e-6);
429 assert!((vector[1] - 0.2).abs() < 1e-6);
430 assert!((vector[2] - 0.3).abs() < 1e-6);
431 Ok(())
432 }
433
434 #[test]
435 fn test_multimodal_search_config_default() {
436 let config = MultimodalSearchConfig::default();
437 assert_eq!(config.default_weights.len(), 3);
438 assert_eq!(config.default_strategy, "rankfusion");
439 assert_eq!(config.normalization, "minmax");
440 assert_eq!(config.cascade_thresholds.len(), 3);
441 }
442
443 #[test]
444 fn test_sparql_multimodal_result_conversion() {
445 let mut fused = FusedResult::new("test_doc".to_string());
446 fused.add_score(Modality::Text, 0.5);
447 fused.add_score(Modality::Vector, 0.3);
448 fused.calculate_total();
449
450 let sparql_result: SparqlMultimodalResult = fused.into();
451
452 assert_eq!(sparql_result.uri, "test_doc");
453 assert!((sparql_result.score - 0.8).abs() < 1e-6);
454 assert_eq!(sparql_result.text_score, Some(0.5));
455 assert_eq!(sparql_result.vector_score, Some(0.3));
456 assert_eq!(sparql_result.spatial_score, None);
457 }
458
459 #[test]
460 fn test_execute_text_search() -> Result<()> {
461 let results = execute_text_search("test query", 10)?;
462 assert!(!results.is_empty());
463 assert!(results[0].doc_id.contains("test query"));
464 Ok(())
465 }
466
467 #[test]
468 fn test_execute_vector_search() -> Result<()> {
469 let embedding = vec![0.1, 0.2, 0.3];
470 let results = execute_vector_search(&embedding, 10)?;
471 assert!(!results.is_empty());
472 assert!(results[0].doc_id.contains("dim3"));
473 Ok(())
474 }
475
476 #[test]
477 fn test_execute_spatial_search() -> Result<()> {
478 let results = execute_spatial_search("POINT(10.0 20.0)", 10)?;
479 assert!(!results.is_empty());
480 assert!(results[0].doc_id.contains("POINT"));
481 Ok(())
482 }
483
484 #[test]
485 fn test_sparql_multimodal_search_integration() -> Result<()> {
486 let config = MultimodalSearchConfig::default();
487
488 let results = sparql_multimodal_search(
489 Some("machine learning".to_string()),
490 Some("0.1,0.2,0.3".to_string()),
491 Some("POINT(10.0 20.0)".to_string()),
492 Some("0.4,0.4,0.2".to_string()),
493 Some("rankfusion".to_string()),
494 10,
495 &config,
496 )?;
497
498 assert!(!results.is_empty());
499 assert!(results[0].score > 0.0);
500 Ok(())
501 }
502
503 #[test]
504 fn test_generate_multimodal_sparql_function() {
505 let sparql = generate_multimodal_sparql_function();
506 assert!(sparql.contains("vec:multimodal_search"));
507 assert!(sparql.contains("text:"));
508 assert!(sparql.contains("vector:"));
509 assert!(sparql.contains("spatial:"));
510 }
511}