1use super::config::{
4 VectorParameterType, VectorQuery, VectorServiceArg, VectorServiceFunction,
5 VectorServiceParameter, VectorServiceResult,
6};
7use super::query_executor::QueryExecutor;
8use anyhow::{anyhow, Result};
9use std::collections::HashMap;
10
11pub trait CustomVectorFunction: Send + Sync {
13 fn execute(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult>;
14 fn arity(&self) -> usize;
15 fn description(&self) -> String;
16}
17
18pub struct SparqlVectorFunctions {
20 function_registry: HashMap<String, VectorServiceFunction>,
21 custom_functions: HashMap<String, Box<dyn CustomVectorFunction>>,
22}
23
24impl SparqlVectorFunctions {
25 pub fn new() -> Self {
26 let mut functions = Self {
27 function_registry: HashMap::new(),
28 custom_functions: HashMap::new(),
29 };
30
31 functions.register_default_functions();
32 functions
33 }
34
35 fn register_default_functions(&mut self) {
37 self.function_registry.insert(
39 "similarity".to_string(),
40 VectorServiceFunction {
41 name: "similarity".to_string(),
42 arity: 2,
43 description: "Calculate similarity between two resources".to_string(),
44 parameters: vec![
45 VectorServiceParameter {
46 name: "resource1".to_string(),
47 param_type: VectorParameterType::IRI,
48 required: true,
49 description: "First resource for similarity comparison".to_string(),
50 },
51 VectorServiceParameter {
52 name: "resource2".to_string(),
53 param_type: VectorParameterType::IRI,
54 required: true,
55 description: "Second resource for similarity comparison".to_string(),
56 },
57 ],
58 },
59 );
60
61 self.function_registry.insert(
63 "similar".to_string(),
64 VectorServiceFunction {
65 name: "similar".to_string(),
66 arity: 3,
67 description: "Find similar resources to a given resource".to_string(),
68 parameters: vec![
69 VectorServiceParameter {
70 name: "resource".to_string(),
71 param_type: VectorParameterType::IRI,
72 required: true,
73 description: "Resource to find similar items for".to_string(),
74 },
75 VectorServiceParameter {
76 name: "limit".to_string(),
77 param_type: VectorParameterType::Number,
78 required: false,
79 description: "Maximum number of results to return".to_string(),
80 },
81 VectorServiceParameter {
82 name: "threshold".to_string(),
83 param_type: VectorParameterType::Number,
84 required: false,
85 description: "Minimum similarity threshold".to_string(),
86 },
87 ],
88 },
89 );
90
91 self.function_registry.insert(
93 "search".to_string(),
94 VectorServiceFunction {
95 name: "search".to_string(),
96 arity: 6,
97 description: "Search for resources using text query with cross-language support"
98 .to_string(),
99 parameters: vec![
100 VectorServiceParameter {
101 name: "query_text".to_string(),
102 param_type: VectorParameterType::String,
103 required: true,
104 description: "Text query for search".to_string(),
105 },
106 VectorServiceParameter {
107 name: "limit".to_string(),
108 param_type: VectorParameterType::Number,
109 required: false,
110 description: "Maximum number of results to return".to_string(),
111 },
112 VectorServiceParameter {
113 name: "threshold".to_string(),
114 param_type: VectorParameterType::Number,
115 required: false,
116 description: "Minimum similarity threshold".to_string(),
117 },
118 VectorServiceParameter {
119 name: "metric".to_string(),
120 param_type: VectorParameterType::String,
121 required: false,
122 description: "Similarity metric to use".to_string(),
123 },
124 VectorServiceParameter {
125 name: "cross_language".to_string(),
126 param_type: VectorParameterType::String,
127 required: false,
128 description: "Enable cross-language search (true/false)".to_string(),
129 },
130 VectorServiceParameter {
131 name: "languages".to_string(),
132 param_type: VectorParameterType::String,
133 required: false,
134 description: "Comma-separated list of target languages".to_string(),
135 },
136 ],
137 },
138 );
139
140 self.function_registry.insert(
142 "searchIn".to_string(),
143 VectorServiceFunction {
144 name: "searchIn".to_string(),
145 arity: 5,
146 description: "Search within a specific graph with scoping options".to_string(),
147 parameters: vec![
148 VectorServiceParameter {
149 name: "query".to_string(),
150 param_type: VectorParameterType::String,
151 required: true,
152 description: "Text query for search".to_string(),
153 },
154 VectorServiceParameter {
155 name: "graph".to_string(),
156 param_type: VectorParameterType::IRI,
157 required: true,
158 description: "Target graph IRI for scoped search".to_string(),
159 },
160 VectorServiceParameter {
161 name: "limit".to_string(),
162 param_type: VectorParameterType::Number,
163 required: false,
164 description: "Maximum number of results to return".to_string(),
165 },
166 VectorServiceParameter {
167 name: "scope".to_string(),
168 param_type: VectorParameterType::String,
169 required: false,
170 description:
171 "Search scope: 'exact', 'children', 'parents', 'hierarchy', 'related'"
172 .to_string(),
173 },
174 VectorServiceParameter {
175 name: "threshold".to_string(),
176 param_type: VectorParameterType::Number,
177 required: false,
178 description: "Minimum similarity threshold for results".to_string(),
179 },
180 ],
181 },
182 );
183
184 self.function_registry.insert(
186 "embed".to_string(),
187 VectorServiceFunction {
188 name: "embed".to_string(),
189 arity: 1,
190 description: "Generate embedding for text content".to_string(),
191 parameters: vec![VectorServiceParameter {
192 name: "text".to_string(),
193 param_type: VectorParameterType::String,
194 required: true,
195 description: "Text content to generate embedding for".to_string(),
196 }],
197 },
198 );
199
200 self.function_registry.insert(
202 "cluster".to_string(),
203 VectorServiceFunction {
204 name: "cluster".to_string(),
205 arity: 2,
206 description: "Cluster similar resources".to_string(),
207 parameters: vec![
208 VectorServiceParameter {
209 name: "resources".to_string(),
210 param_type: VectorParameterType::String,
211 required: true,
212 description: "List of resources to cluster".to_string(),
213 },
214 VectorServiceParameter {
215 name: "num_clusters".to_string(),
216 param_type: VectorParameterType::Number,
217 required: false,
218 description: "Number of clusters to create".to_string(),
219 },
220 ],
221 },
222 );
223
224 self.function_registry.insert(
226 "vector_similarity".to_string(),
227 VectorServiceFunction {
228 name: "vector_similarity".to_string(),
229 arity: 2,
230 description: "Calculate similarity between two vectors directly".to_string(),
231 parameters: vec![
232 VectorServiceParameter {
233 name: "vector1".to_string(),
234 param_type: VectorParameterType::Vector,
235 required: true,
236 description: "First vector for similarity comparison".to_string(),
237 },
238 VectorServiceParameter {
239 name: "vector2".to_string(),
240 param_type: VectorParameterType::Vector,
241 required: true,
242 description: "Second vector for similarity comparison".to_string(),
243 },
244 ],
245 },
246 );
247
248 self.function_registry.insert(
250 "embed_text".to_string(),
251 VectorServiceFunction {
252 name: "embed_text".to_string(),
253 arity: 1,
254 description: "Generate embedding for text content".to_string(),
255 parameters: vec![VectorServiceParameter {
256 name: "text".to_string(),
257 param_type: VectorParameterType::String,
258 required: true,
259 description: "Text content to generate embedding for".to_string(),
260 }],
261 },
262 );
263
264 self.function_registry.insert(
266 "search_text".to_string(),
267 VectorServiceFunction {
268 name: "search_text".to_string(),
269 arity: 2,
270 description: "Search for resources using text query".to_string(),
271 parameters: vec![
272 VectorServiceParameter {
273 name: "query_text".to_string(),
274 param_type: VectorParameterType::String,
275 required: true,
276 description: "Text query for search".to_string(),
277 },
278 VectorServiceParameter {
279 name: "limit".to_string(),
280 param_type: VectorParameterType::Number,
281 required: false,
282 description: "Maximum number of results to return".to_string(),
283 },
284 ],
285 },
286 );
287 }
288
289 pub fn register_function(&mut self, function: VectorServiceFunction) {
291 self.function_registry
292 .insert(function.name.clone(), function);
293 }
294
295 pub fn register_custom_function(
297 &mut self,
298 name: String,
299 function: Box<dyn CustomVectorFunction>,
300 ) {
301 self.custom_functions.insert(name, function);
302 }
303
304 pub fn execute_function(
306 &self,
307 function_name: &str,
308 args: &[VectorServiceArg],
309 executor: &mut QueryExecutor,
310 ) -> Result<VectorServiceResult> {
311 if let Some(custom_func) = self.custom_functions.get(function_name) {
313 return custom_func.execute(args);
314 }
315
316 if let Some(func_def) = self.function_registry.get(function_name) {
318 if func_def.arity > 0 && args.len() > func_def.arity {
320 return Err(anyhow!(
321 "Function {} expects at most {} arguments, got {}",
322 function_name,
323 func_def.arity,
324 args.len()
325 ));
326 }
327
328 match function_name {
330 "vector_similarity" => self.execute_vector_similarity(args),
331 "embed_text" | "embed" => self.execute_embed_text(args, executor),
332 _ => {
333 let query = VectorQuery::new(function_name.to_string(), args.to_vec());
335 let result = executor.execute_optimized_query(&query)?;
336
337 match function_name {
339 "similarity" => {
340 if let Some((_, score)) = result.results.first() {
342 Ok(VectorServiceResult::Number(*score))
343 } else {
344 Ok(VectorServiceResult::Number(0.0))
345 }
346 }
347 _ => Ok(VectorServiceResult::SimilarityList(result.results)),
348 }
349 }
350 }
351 } else {
352 Err(anyhow!("Unknown function: {}", function_name))
353 }
354 }
355
356 fn execute_vector_similarity(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult> {
358 if args.len() != 2 {
359 return Err(anyhow!(
360 "vector_similarity requires exactly 2 vector arguments"
361 ));
362 }
363
364 let vector1 = match &args[0] {
365 VectorServiceArg::Vector(v) => v,
366 _ => return Err(anyhow!("First argument must be a vector")),
367 };
368
369 let vector2 = match &args[1] {
370 VectorServiceArg::Vector(v) => v,
371 _ => return Err(anyhow!("Second argument must be a vector")),
372 };
373
374 let similarity = vector1.cosine_similarity(vector2)?;
375 Ok(VectorServiceResult::Number(similarity))
376 }
377
378 fn execute_embed_text(
380 &self,
381 args: &[VectorServiceArg],
382 executor: &mut QueryExecutor,
383 ) -> Result<VectorServiceResult> {
384 if args.is_empty() {
385 return Err(anyhow!("embed_text requires at least 1 argument"));
386 }
387
388 let _text = match &args[0] {
389 VectorServiceArg::String(s) | VectorServiceArg::Literal(s) => s,
390 _ => return Err(anyhow!("First argument must be text")),
391 };
392
393 let query = VectorQuery::new("embed".to_string(), args.to_vec());
395 let _result = executor.execute_optimized_query(&query)?;
396
397 let vector = crate::Vector::new(vec![0.0; 384]); Ok(VectorServiceResult::Vector(vector))
401 }
402
403 pub fn get_function(&self, name: &str) -> Option<&VectorServiceFunction> {
405 self.function_registry.get(name)
406 }
407
408 pub fn get_all_functions(&self) -> &HashMap<String, VectorServiceFunction> {
410 &self.function_registry
411 }
412
413 pub fn is_function_registered(&self, name: &str) -> bool {
415 self.function_registry.contains_key(name) || self.custom_functions.contains_key(name)
416 }
417
418 pub fn get_function_documentation(&self, name: &str) -> Option<String> {
420 if let Some(func) = self.function_registry.get(name) {
421 let mut doc = format!("Function: {}\n", func.name);
422 doc.push_str(&format!("Description: {}\n", func.description));
423 doc.push_str(&format!("Arity: {}\n", func.arity));
424 doc.push_str("Parameters:\n");
425
426 for param in &func.parameters {
427 doc.push_str(&format!(
428 " - {} ({:?}{}): {}\n",
429 param.name,
430 param.param_type,
431 if param.required {
432 ", required"
433 } else {
434 ", optional"
435 },
436 param.description
437 ));
438 }
439
440 Some(doc)
441 } else {
442 self.custom_functions.get(name).map(|custom_func| {
443 format!(
444 "Custom Function: {}\nDescription: {}\nArity: {}",
445 name,
446 custom_func.description(),
447 custom_func.arity()
448 )
449 })
450 }
451 }
452
453 pub fn generate_sparql_definitions(&self) -> String {
455 let mut definitions = String::new();
456 definitions.push_str("# OxiRS Vector SPARQL Functions\n\n");
457
458 for (name, func) in &self.function_registry {
459 definitions.push_str(&format!("## vec:{name}\n\n"));
460 definitions.push_str(&format!("**Description:** {}\n\n", func.description));
461
462 if func.arity > 0 {
463 definitions.push_str(&format!("**Arity:** {}\n\n", func.arity));
464 }
465
466 definitions.push_str("**Parameters:**\n\n");
467 for param in &func.parameters {
468 definitions.push_str(&format!(
469 "- `{}` ({:?}{}) - {}\n",
470 param.name,
471 param.param_type,
472 if param.required {
473 ", required"
474 } else {
475 ", optional"
476 },
477 param.description
478 ));
479 }
480
481 definitions.push_str("\n**Example:**\n\n");
483 definitions.push_str("```sparql\n");
484 match name.as_str() {
485 "similarity" => {
486 definitions.push_str("SELECT ?score WHERE {\n");
487 definitions.push_str(" BIND(vec:similarity(<http://example.org/doc1>, <http://example.org/doc2>) AS ?score)\n");
488 definitions.push_str("}\n");
489 }
490 "similar" => {
491 definitions.push_str("SELECT ?similar ?score WHERE {\n");
492 definitions.push_str(
493 " (?similar ?score) vec:similar (<http://example.org/doc1>, 10, 0.7)\n",
494 );
495 definitions.push_str("}\n");
496 }
497 "search" => {
498 definitions.push_str("SELECT ?resource ?score WHERE {\n");
499 definitions.push_str(
500 " (?resource ?score) vec:search (\"machine learning\", 10, 0.7)\n",
501 );
502 definitions.push_str("}\n");
503 }
504 "searchIn" => {
505 definitions.push_str("SELECT ?resource ?score WHERE {\n");
506 definitions.push_str(" (?resource ?score) vec:searchIn (\"AI research\", <http://example.org/graph1>, 10, \"exact\", 0.7)\n");
507 definitions.push_str("}\n");
508 }
509 "embed" => {
510 definitions.push_str("SELECT ?embedding WHERE {\n");
511 definitions.push_str(" BIND(vec:embed(\"example text\") AS ?embedding)\n");
512 definitions.push_str("}\n");
513 }
514 _ => {
515 definitions.push_str(&format!("# Example usage for vec:{name}\n"));
516 }
517 }
518 definitions.push_str("```\n\n");
519 }
520
521 definitions
522 }
523}
524
525impl Default for SparqlVectorFunctions {
526 fn default() -> Self {
527 Self::new()
528 }
529}
530
531pub struct CosineSimilarityFunction;
533
534impl CustomVectorFunction for CosineSimilarityFunction {
535 fn execute(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult> {
536 if args.len() != 2 {
537 return Err(anyhow!(
538 "Cosine similarity function requires exactly 2 arguments"
539 ));
540 }
541
542 let vector1 = match &args[0] {
543 VectorServiceArg::Vector(v) => v,
544 _ => return Err(anyhow!("First argument must be a vector")),
545 };
546
547 let vector2 = match &args[1] {
548 VectorServiceArg::Vector(v) => v,
549 _ => return Err(anyhow!("Second argument must be a vector")),
550 };
551
552 let similarity =
553 crate::similarity::cosine_similarity(&vector1.as_slice(), &vector2.as_slice());
554
555 Ok(VectorServiceResult::Number(similarity))
556 }
557
558 fn arity(&self) -> usize {
559 2
560 }
561
562 fn description(&self) -> String {
563 "Calculate cosine similarity between two vectors".to_string()
564 }
565}
566
567pub struct AverageSimilarityFunction;
569
570impl CustomVectorFunction for AverageSimilarityFunction {
571 fn execute(&self, args: &[VectorServiceArg]) -> Result<VectorServiceResult> {
572 if args.is_empty() {
573 return Err(anyhow!(
574 "Average similarity function requires at least 1 argument"
575 ));
576 }
577
578 let mut similarities = Vec::new();
579
580 for arg in args {
581 match arg {
582 VectorServiceArg::Number(sim) => similarities.push(*sim),
583 _ => return Err(anyhow!("All arguments must be numbers (similarity scores)")),
584 }
585 }
586
587 let average = similarities.iter().sum::<f32>() / similarities.len() as f32;
588 Ok(VectorServiceResult::Number(average))
589 }
590
591 fn arity(&self) -> usize {
592 0 }
594
595 fn description(&self) -> String {
596 "Calculate average of multiple similarity scores".to_string()
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603 use crate::Vector;
604
605 #[test]
606 fn test_function_registration() {
607 let functions = SparqlVectorFunctions::new();
608
609 assert!(functions.is_function_registered("similarity"));
610 assert!(functions.is_function_registered("similar"));
611 assert!(functions.is_function_registered("search"));
612 assert!(functions.is_function_registered("searchIn"));
613 assert!(!functions.is_function_registered("nonexistent"));
614 }
615
616 #[test]
617 fn test_custom_function_registration() {
618 let mut functions = SparqlVectorFunctions::new();
619
620 let custom_func = Box::new(CosineSimilarityFunction);
621 functions.register_custom_function("custom_cosine".to_string(), custom_func);
622
623 assert!(functions.is_function_registered("custom_cosine"));
624 }
625
626 #[test]
627 fn test_custom_function_execution() {
628 let func = CosineSimilarityFunction;
629
630 let vector1 = Vector::new(vec![1.0, 0.0, 0.0]);
631 let vector2 = Vector::new(vec![0.0, 1.0, 0.0]);
632
633 let args = vec![
634 VectorServiceArg::Vector(vector1),
635 VectorServiceArg::Vector(vector2),
636 ];
637
638 let result = func.execute(&args).unwrap();
639
640 match result {
641 VectorServiceResult::Number(similarity) => {
642 assert!((similarity - 0.0).abs() < 1e-6); }
644 _ => panic!("Expected number result"),
645 }
646 }
647
648 #[test]
649 fn test_function_documentation() {
650 let functions = SparqlVectorFunctions::new();
651
652 let doc = functions.get_function_documentation("similarity").unwrap();
653 assert!(doc.contains("similarity"));
654 assert!(doc.contains("Calculate similarity"));
655 assert!(doc.contains("resource1"));
656 assert!(doc.contains("resource2"));
657 }
658
659 #[test]
660 fn test_sparql_definitions_generation() {
661 let functions = SparqlVectorFunctions::new();
662
663 let definitions = functions.generate_sparql_definitions();
664 assert!(definitions.contains("vec:similarity"));
665 assert!(definitions.contains("vec:search"));
666 assert!(definitions.contains("SELECT"));
667 assert!(definitions.contains("```sparql"));
668 }
669
670 #[test]
671 fn test_average_similarity_function() {
672 let func = AverageSimilarityFunction;
673
674 let args = vec![
675 VectorServiceArg::Number(0.8),
676 VectorServiceArg::Number(0.9),
677 VectorServiceArg::Number(0.7),
678 ];
679
680 let result = func.execute(&args).unwrap();
681
682 match result {
683 VectorServiceResult::Number(average) => {
684 assert!((average - 0.8).abs() < 1e-6);
685 }
686 _ => panic!("Expected number result"),
687 }
688 }
689}