oxirs_vec/sparql_integration/
sparql_functions.rs

1//! SPARQL vector function implementations
2
3use super::config::{
4    VectorParameterType, VectorQuery, VectorServiceArg, VectorServiceFunction,
5    VectorServiceParameter, VectorServiceResult,
6};
7use super::query_executor::QueryExecutor;
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10
11/// Custom vector function trait for user-defined functions
12pub trait CustomVectorFunction: Send + Sync {
13    fn execute(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult>;
14    fn arity(&self) -> usize;
15    fn description(&self) -> String;
16}
17
18/// SPARQL vector functions implementation
19pub struct SparqlVectorFunctions {
20    function_registry: HashMap<String, VectorServiceFunction>,
21    custom_functions: HashMap<String, Box<dyn CustomVectorFunction>>,
22}
23
24impl SparqlVectorFunctions {
25    pub fn new() -> Self {
26        let mut functions = Self {
27            function_registry: HashMap::new(),
28            custom_functions: HashMap::new(),
29        };
30
31        functions.register_default_functions();
32        functions
33    }
34
35    /// Register all default SPARQL vector functions
36    fn register_default_functions(&mut self) {
37        // vec:similarity function
38        self.function_registry.insert(
39            "similarity".to_string(),
40            VectorServiceFunction {
41                name: "similarity".to_string(),
42                arity: 2,
43                description: "Calculate similarity between two resources".to_string(),
44                parameters: vec![
45                    VectorServiceParameter {
46                        name: "resource1".to_string(),
47                        param_type: VectorParameterType::IRI,
48                        required: true,
49                        description: "First resource for similarity comparison".to_string(),
50                    },
51                    VectorServiceParameter {
52                        name: "resource2".to_string(),
53                        param_type: VectorParameterType::IRI,
54                        required: true,
55                        description: "Second resource for similarity comparison".to_string(),
56                    },
57                ],
58            },
59        );
60
61        // vec:similar function
62        self.function_registry.insert(
63            "similar".to_string(),
64            VectorServiceFunction {
65                name: "similar".to_string(),
66                arity: 3,
67                description: "Find similar resources to a given resource".to_string(),
68                parameters: vec![
69                    VectorServiceParameter {
70                        name: "resource".to_string(),
71                        param_type: VectorParameterType::IRI,
72                        required: true,
73                        description: "Resource to find similar items for".to_string(),
74                    },
75                    VectorServiceParameter {
76                        name: "limit".to_string(),
77                        param_type: VectorParameterType::Number,
78                        required: false,
79                        description: "Maximum number of results to return".to_string(),
80                    },
81                    VectorServiceParameter {
82                        name: "threshold".to_string(),
83                        param_type: VectorParameterType::Number,
84                        required: false,
85                        description: "Minimum similarity threshold".to_string(),
86                    },
87                ],
88            },
89        );
90
91        // vec:search function
92        self.function_registry.insert(
93            "search".to_string(),
94            VectorServiceFunction {
95                name: "search".to_string(),
96                arity: 6,
97                description: "Search for resources using text query with cross-language support"
98                    .to_string(),
99                parameters: vec![
100                    VectorServiceParameter {
101                        name: "query_text".to_string(),
102                        param_type: VectorParameterType::String,
103                        required: true,
104                        description: "Text query for search".to_string(),
105                    },
106                    VectorServiceParameter {
107                        name: "limit".to_string(),
108                        param_type: VectorParameterType::Number,
109                        required: false,
110                        description: "Maximum number of results to return".to_string(),
111                    },
112                    VectorServiceParameter {
113                        name: "threshold".to_string(),
114                        param_type: VectorParameterType::Number,
115                        required: false,
116                        description: "Minimum similarity threshold".to_string(),
117                    },
118                    VectorServiceParameter {
119                        name: "metric".to_string(),
120                        param_type: VectorParameterType::String,
121                        required: false,
122                        description: "Similarity metric to use".to_string(),
123                    },
124                    VectorServiceParameter {
125                        name: "cross_language".to_string(),
126                        param_type: VectorParameterType::String,
127                        required: false,
128                        description: "Enable cross-language search (true/false)".to_string(),
129                    },
130                    VectorServiceParameter {
131                        name: "languages".to_string(),
132                        param_type: VectorParameterType::String,
133                        required: false,
134                        description: "Comma-separated list of target languages".to_string(),
135                    },
136                ],
137            },
138        );
139
140        // vec:searchIn function
141        self.function_registry.insert(
142            "searchIn".to_string(),
143            VectorServiceFunction {
144                name: "searchIn".to_string(),
145                arity: 5,
146                description: "Search within a specific graph with scoping options".to_string(),
147                parameters: vec![
148                    VectorServiceParameter {
149                        name: "query".to_string(),
150                        param_type: VectorParameterType::String,
151                        required: true,
152                        description: "Text query for search".to_string(),
153                    },
154                    VectorServiceParameter {
155                        name: "graph".to_string(),
156                        param_type: VectorParameterType::IRI,
157                        required: true,
158                        description: "Target graph IRI for scoped search".to_string(),
159                    },
160                    VectorServiceParameter {
161                        name: "limit".to_string(),
162                        param_type: VectorParameterType::Number,
163                        required: false,
164                        description: "Maximum number of results to return".to_string(),
165                    },
166                    VectorServiceParameter {
167                        name: "scope".to_string(),
168                        param_type: VectorParameterType::String,
169                        required: false,
170                        description:
171                            "Search scope: 'exact', 'children', 'parents', 'hierarchy', 'related'"
172                                .to_string(),
173                    },
174                    VectorServiceParameter {
175                        name: "threshold".to_string(),
176                        param_type: VectorParameterType::Number,
177                        required: false,
178                        description: "Minimum similarity threshold for results".to_string(),
179                    },
180                ],
181            },
182        );
183
184        // vec:embed function
185        self.function_registry.insert(
186            "embed".to_string(),
187            VectorServiceFunction {
188                name: "embed".to_string(),
189                arity: 1,
190                description: "Generate embedding for text content".to_string(),
191                parameters: vec![VectorServiceParameter {
192                    name: "text".to_string(),
193                    param_type: VectorParameterType::String,
194                    required: true,
195                    description: "Text content to generate embedding for".to_string(),
196                }],
197            },
198        );
199
200        // vec:cluster function
201        self.function_registry.insert(
202            "cluster".to_string(),
203            VectorServiceFunction {
204                name: "cluster".to_string(),
205                arity: 2,
206                description: "Cluster similar resources".to_string(),
207                parameters: vec![
208                    VectorServiceParameter {
209                        name: "resources".to_string(),
210                        param_type: VectorParameterType::String,
211                        required: true,
212                        description: "List of resources to cluster".to_string(),
213                    },
214                    VectorServiceParameter {
215                        name: "num_clusters".to_string(),
216                        param_type: VectorParameterType::Number,
217                        required: false,
218                        description: "Number of clusters to create".to_string(),
219                    },
220                ],
221            },
222        );
223
224        // vec:vector_similarity function (alias for direct vector similarity)
225        self.function_registry.insert(
226            "vector_similarity".to_string(),
227            VectorServiceFunction {
228                name: "vector_similarity".to_string(),
229                arity: 2,
230                description: "Calculate similarity between two vectors directly".to_string(),
231                parameters: vec![
232                    VectorServiceParameter {
233                        name: "vector1".to_string(),
234                        param_type: VectorParameterType::Vector,
235                        required: true,
236                        description: "First vector for similarity comparison".to_string(),
237                    },
238                    VectorServiceParameter {
239                        name: "vector2".to_string(),
240                        param_type: VectorParameterType::Vector,
241                        required: true,
242                        description: "Second vector for similarity comparison".to_string(),
243                    },
244                ],
245            },
246        );
247
248        // vec:embed_text function (alias for embed)
249        self.function_registry.insert(
250            "embed_text".to_string(),
251            VectorServiceFunction {
252                name: "embed_text".to_string(),
253                arity: 1,
254                description: "Generate embedding for text content".to_string(),
255                parameters: vec![VectorServiceParameter {
256                    name: "text".to_string(),
257                    param_type: VectorParameterType::String,
258                    required: true,
259                    description: "Text content to generate embedding for".to_string(),
260                }],
261            },
262        );
263
264        // vec:search_text function (alias for search)
265        self.function_registry.insert(
266            "search_text".to_string(),
267            VectorServiceFunction {
268                name: "search_text".to_string(),
269                arity: 2,
270                description: "Search for resources using text query".to_string(),
271                parameters: vec![
272                    VectorServiceParameter {
273                        name: "query_text".to_string(),
274                        param_type: VectorParameterType::String,
275                        required: true,
276                        description: "Text query for search".to_string(),
277                    },
278                    VectorServiceParameter {
279                        name: "limit".to_string(),
280                        param_type: VectorParameterType::Number,
281                        required: false,
282                        description: "Maximum number of results to return".to_string(),
283                    },
284                ],
285            },
286        );
287    }
288
289    /// Register a custom vector service function
290    pub fn register_function(&mut self, function: VectorServiceFunction) {
291        self.function_registry
292            .insert(function.name.clone(), function);
293    }
294
295    /// Register a custom vector function implementation
296    pub fn register_custom_function(
297        &mut self,
298        name: String,
299        function: Box<dyn CustomVectorFunction>,
300    ) {
301        self.custom_functions.insert(name, function);
302    }
303
304    /// Execute a SPARQL vector function
305    pub fn execute_function(
306        &self,
307        function_name: &str,
308        args: &[VectorServiceArg],
309        executor: &mut QueryExecutor,
310    ) -> Result<VectorServiceResult> {
311        // Check if it's a custom function first
312        if let Some(custom_func) = self.custom_functions.get(function_name) {
313            return custom_func.execute(args);
314        }
315
316        // Check if it's a built-in function
317        if let Some(func_def) = self.function_registry.get(function_name) {
318            // Validate arity if specified
319            if func_def.arity > 0 && args.len() > func_def.arity {
320                return Err(anyhow!(
321                    "Function {} expects at most {} arguments, got {}",
322                    function_name,
323                    func_def.arity,
324                    args.len()
325                ));
326            }
327
328            // Handle special functions that work with vectors directly
329            match function_name {
330                "vector_similarity" => self.execute_vector_similarity(args),
331                "embed_text" | "embed" => self.execute_embed_text(args, executor),
332                _ => {
333                    // Create a query for the function
334                    let query = VectorQuery::new(function_name.to_string(), args.to_vec());
335                    let result = executor.execute_optimized_query(&query)?;
336
337                    // Convert VectorQueryResult to VectorServiceResult based on function type
338                    match function_name {
339                        "similarity" => {
340                            // For similarity between resources, return a single number
341                            if let Some((_, score)) = result.results.first() {
342                                Ok(VectorServiceResult::Number(*score))
343                            } else {
344                                Ok(VectorServiceResult::Number(0.0))
345                            }
346                        }
347                        _ => Ok(VectorServiceResult::SimilarityList(result.results)),
348                    }
349                }
350            }
351        } else {
352            Err(anyhow!("Unknown function: {}", function_name))
353        }
354    }
355
356    /// Execute vector similarity function directly on vectors
357    fn execute_vector_similarity(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult> {
358        if args.len() != 2 {
359            return Err(anyhow!(
360                "vector_similarity requires exactly 2 vector arguments"
361            ));
362        }
363
364        let vector1 = match &args[0] {
365            VectorServiceArg::Vector(v) => v,
366            _ => return Err(anyhow!("First argument must be a vector")),
367        };
368
369        let vector2 = match &args[1] {
370            VectorServiceArg::Vector(v) => v,
371            _ => return Err(anyhow!("Second argument must be a vector")),
372        };
373
374        let similarity = vector1.cosine_similarity(vector2)?;
375        Ok(VectorServiceResult::Number(similarity))
376    }
377
378    /// Execute embed text function
379    fn execute_embed_text(
380        &self,
381        args: &[VectorServiceArg],
382        executor: &mut QueryExecutor,
383    ) -> Result<VectorServiceResult> {
384        if args.is_empty() {
385            return Err(anyhow!("embed_text requires at least 1 argument"));
386        }
387
388        let _text = match &args[0] {
389            VectorServiceArg::String(s) | VectorServiceArg::Literal(s) => s,
390            _ => return Err(anyhow!("First argument must be text")),
391        };
392
393        // Use the embed query type to generate the embedding
394        let query = VectorQuery::new("embed".to_string(), args.to_vec());
395        let _result = executor.execute_optimized_query(&query)?;
396
397        // For embed functions, we want to return the vector itself
398        // This is a simplified implementation - in practice, you'd generate the actual vector
399        let vector = crate::Vector::new(vec![0.0; 384]); // Placeholder vector
400        Ok(VectorServiceResult::Vector(vector))
401    }
402
403    /// Get function definition
404    pub fn get_function(&self, name: &str) -> Option<&VectorServiceFunction> {
405        self.function_registry.get(name)
406    }
407
408    /// Get all registered functions
409    pub fn get_all_functions(&self) -> &HashMap<String, VectorServiceFunction> {
410        &self.function_registry
411    }
412
413    /// Check if a function is registered
414    pub fn is_function_registered(&self, name: &str) -> bool {
415        self.function_registry.contains_key(name) || self.custom_functions.contains_key(name)
416    }
417
418    /// Get function documentation
419    pub fn get_function_documentation(&self, name: &str) -> Option<String> {
420        if let Some(func) = self.function_registry.get(name) {
421            let mut doc = format!("Function: {}\n", func.name);
422            doc.push_str(&format!("Description: {}\n", func.description));
423            doc.push_str(&format!("Arity: {}\n", func.arity));
424            doc.push_str("Parameters:\n");
425
426            for param in &func.parameters {
427                doc.push_str(&format!(
428                    "  - {} ({:?}{}): {}\n",
429                    param.name,
430                    param.param_type,
431                    if param.required {
432                        ", required"
433                    } else {
434                        ", optional"
435                    },
436                    param.description
437                ));
438            }
439
440            Some(doc)
441        } else {
442            self.custom_functions.get(name).map(|custom_func| {
443                format!(
444                    "Custom Function: {}\nDescription: {}\nArity: {}",
445                    name,
446                    custom_func.description(),
447                    custom_func.arity()
448                )
449            })
450        }
451    }
452
453    /// Generate SPARQL function definitions for documentation
454    pub fn generate_sparql_definitions(&self) -> String {
455        let mut definitions = String::new();
456        definitions.push_str("# OxiRS Vector SPARQL Functions\n\n");
457
458        for (name, func) in &self.function_registry {
459            definitions.push_str(&format!("## vec:{name}\n\n"));
460            definitions.push_str(&format!("**Description:** {}\n\n", func.description));
461
462            if func.arity > 0 {
463                definitions.push_str(&format!("**Arity:** {}\n\n", func.arity));
464            }
465
466            definitions.push_str("**Parameters:**\n\n");
467            for param in &func.parameters {
468                definitions.push_str(&format!(
469                    "- `{}` ({:?}{}) - {}\n",
470                    param.name,
471                    param.param_type,
472                    if param.required {
473                        ", required"
474                    } else {
475                        ", optional"
476                    },
477                    param.description
478                ));
479            }
480
481            // Add usage example
482            definitions.push_str("\n**Example:**\n\n");
483            definitions.push_str("```sparql\n");
484            match name.as_str() {
485                "similarity" => {
486                    definitions.push_str("SELECT ?score WHERE {\n");
487                    definitions.push_str("  BIND(vec:similarity(<http://example.org/doc1>, <http://example.org/doc2>) AS ?score)\n");
488                    definitions.push_str("}\n");
489                }
490                "similar" => {
491                    definitions.push_str("SELECT ?similar ?score WHERE {\n");
492                    definitions.push_str(
493                        "  (?similar ?score) vec:similar (<http://example.org/doc1>, 10, 0.7)\n",
494                    );
495                    definitions.push_str("}\n");
496                }
497                "search" => {
498                    definitions.push_str("SELECT ?resource ?score WHERE {\n");
499                    definitions.push_str(
500                        "  (?resource ?score) vec:search (\"machine learning\", 10, 0.7)\n",
501                    );
502                    definitions.push_str("}\n");
503                }
504                "searchIn" => {
505                    definitions.push_str("SELECT ?resource ?score WHERE {\n");
506                    definitions.push_str("  (?resource ?score) vec:searchIn (\"AI research\", <http://example.org/graph1>, 10, \"exact\", 0.7)\n");
507                    definitions.push_str("}\n");
508                }
509                "embed" => {
510                    definitions.push_str("SELECT ?embedding WHERE {\n");
511                    definitions.push_str("  BIND(vec:embed(\"example text\") AS ?embedding)\n");
512                    definitions.push_str("}\n");
513                }
514                _ => {
515                    definitions.push_str(&format!("# Example usage for vec:{name}\n"));
516                }
517            }
518            definitions.push_str("```\n\n");
519        }
520
521        definitions
522    }
523}
524
525impl Default for SparqlVectorFunctions {
526    fn default() -> Self {
527        Self::new()
528    }
529}
530
531/// Example custom function implementation
532pub struct CosineSimilarityFunction;
533
534impl CustomVectorFunction for CosineSimilarityFunction {
535    fn execute(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult> {
536        if args.len() != 2 {
537            return Err(anyhow!(
538                "Cosine similarity function requires exactly 2 arguments"
539            ));
540        }
541
542        let vector1 = match &args[0] {
543            VectorServiceArg::Vector(v) => v,
544            _ => return Err(anyhow!("First argument must be a vector")),
545        };
546
547        let vector2 = match &args[1] {
548            VectorServiceArg::Vector(v) => v,
549            _ => return Err(anyhow!("Second argument must be a vector")),
550        };
551
552        let similarity =
553            crate::similarity::cosine_similarity(&vector1.as_slice(), &vector2.as_slice());
554
555        Ok(VectorServiceResult::Number(similarity))
556    }
557
558    fn arity(&self) -> usize {
559        2
560    }
561
562    fn description(&self) -> String {
563        "Calculate cosine similarity between two vectors".to_string()
564    }
565}
566
567/// Example aggregate function implementation
568pub struct AverageSimilarityFunction;
569
570impl CustomVectorFunction for AverageSimilarityFunction {
571    fn execute(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult> {
572        if args.is_empty() {
573            return Err(anyhow!(
574                "Average similarity function requires at least 1 argument"
575            ));
576        }
577
578        let mut similarities = Vec::new();
579
580        for arg in args {
581            match arg {
582                VectorServiceArg::Number(sim) => similarities.push(*sim),
583                _ => return Err(anyhow!("All arguments must be numbers (similarity scores)")),
584            }
585        }
586
587        let average = similarities.iter().sum::<f32>() / similarities.len() as f32;
588        Ok(VectorServiceResult::Number(average))
589    }
590
591    fn arity(&self) -> usize {
592        0 // Variable arity
593    }
594
595    fn description(&self) -> String {
596        "Calculate average of multiple similarity scores".to_string()
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603    use crate::Vector;
604
605    #[test]
606    fn test_function_registration() {
607        let functions = SparqlVectorFunctions::new();
608
609        assert!(functions.is_function_registered("similarity"));
610        assert!(functions.is_function_registered("similar"));
611        assert!(functions.is_function_registered("search"));
612        assert!(functions.is_function_registered("searchIn"));
613        assert!(!functions.is_function_registered("nonexistent"));
614    }
615
616    #[test]
617    fn test_custom_function_registration() {
618        let mut functions = SparqlVectorFunctions::new();
619
620        let custom_func = Box::new(CosineSimilarityFunction);
621        functions.register_custom_function("custom_cosine".to_string(), custom_func);
622
623        assert!(functions.is_function_registered("custom_cosine"));
624    }
625
626    #[test]
627    fn test_custom_function_execution() {
628        let func = CosineSimilarityFunction;
629
630        let vector1 = Vector::new(vec![1.0, 0.0, 0.0]);
631        let vector2 = Vector::new(vec![0.0, 1.0, 0.0]);
632
633        let args = vec![
634            VectorServiceArg::Vector(vector1),
635            VectorServiceArg::Vector(vector2),
636        ];
637
638        let result = func.execute(&args).unwrap();
639
640        match result {
641            VectorServiceResult::Number(similarity) => {
642                assert!((similarity - 0.0).abs() < 1e-6); // Orthogonal vectors
643            }
644            _ => panic!("Expected number result"),
645        }
646    }
647
648    #[test]
649    fn test_function_documentation() {
650        let functions = SparqlVectorFunctions::new();
651
652        let doc = functions.get_function_documentation("similarity").unwrap();
653        assert!(doc.contains("similarity"));
654        assert!(doc.contains("Calculate similarity"));
655        assert!(doc.contains("resource1"));
656        assert!(doc.contains("resource2"));
657    }
658
659    #[test]
660    fn test_sparql_definitions_generation() {
661        let functions = SparqlVectorFunctions::new();
662
663        let definitions = functions.generate_sparql_definitions();
664        assert!(definitions.contains("vec:similarity"));
665        assert!(definitions.contains("vec:search"));
666        assert!(definitions.contains("SELECT"));
667        assert!(definitions.contains("```sparql"));
668    }
669
670    #[test]
671    fn test_average_similarity_function() {
672        let func = AverageSimilarityFunction;
673
674        let args = vec![
675            VectorServiceArg::Number(0.8),
676            VectorServiceArg::Number(0.9),
677            VectorServiceArg::Number(0.7),
678        ];
679
680        let result = func.execute(&args).unwrap();
681
682        match result {
683            VectorServiceResult::Number(average) => {
684                assert!((average - 0.8).abs() < 1e-6);
685            }
686            _ => panic!("Expected number result"),
687        }
688    }
689}