1use crate::algebra::{Expression, Term, Variable};
7use anyhow::{anyhow, bail, Result};
8use oxirs_core::model::NamedNode;
9use std::any::Any;
10use std::collections::HashMap;
11use std::fmt::Debug;
12use std::sync::{Arc, RwLock};
13
14#[derive(Debug)]
16pub struct ExtensionRegistry {
17 pub functions: Arc<RwLock<HashMap<String, Box<dyn CustomFunction>>>>,
19 pub operators: Arc<RwLock<HashMap<String, Box<dyn CustomOperator>>>>,
21 pub aggregates: Arc<RwLock<HashMap<String, Box<dyn CustomAggregate>>>>,
23 pub plugins: Arc<RwLock<Vec<Box<dyn ExtensionPlugin>>>>,
25 pub type_converters: Arc<RwLock<HashMap<String, Box<dyn TypeConverter>>>>,
27}
28
29pub trait CustomFunction: Send + Sync + Debug {
31 fn name(&self) -> &str;
33
34 fn arity(&self) -> Option<usize>;
36
37 fn parameter_types(&self) -> Vec<ValueType>;
39
40 fn return_type(&self) -> ValueType;
42
43 fn documentation(&self) -> &str;
45
46 fn execute(&self, args: &[Value], context: &ExecutionContext) -> Result<Value>;
48
49 fn clone_function(&self) -> Box<dyn CustomFunction>;
51
52 fn validate(&self, args: &[Expression]) -> Result<()> {
54 if let Some(expected_arity) = self.arity() {
55 if args.len() != expected_arity {
56 bail!(
57 "Function {} expects {} arguments, got {}",
58 self.name(),
59 expected_arity,
60 args.len()
61 );
62 }
63 }
64 Ok(())
65 }
66
67 fn cost_estimate(&self, args: &[Expression]) -> f64 {
69 100.0 + args.len() as f64 * 10.0
71 }
72
73 fn is_deterministic(&self) -> bool {
75 true
76 }
77
78 fn can_pushdown(&self) -> bool {
80 self.is_deterministic()
81 }
82}
83
84pub trait CustomOperator: Send + Sync + Debug {
86 fn symbol(&self) -> &str;
88
89 fn precedence(&self) -> i32;
91
92 fn associativity(&self) -> Associativity;
94
95 fn operator_type(&self) -> OperatorType;
97
98 fn execute(
100 &self,
101 left: Option<&Value>,
102 right: Option<&Value>,
103 context: &ExecutionContext,
104 ) -> Result<Value>;
105
106 fn type_check(
108 &self,
109 left_type: Option<ValueType>,
110 right_type: Option<ValueType>,
111 ) -> Result<ValueType>;
112}
113
114pub trait CustomAggregate: Send + Sync + Debug {
116 fn name(&self) -> &str;
118
119 fn init(&self) -> Box<dyn AggregateState>;
121
122 fn supports_distinct(&self) -> bool {
124 true
125 }
126
127 fn documentation(&self) -> &str;
129}
130
131pub trait AggregateState: Send + Sync + Debug {
133 fn add(&mut self, value: &Value) -> Result<()>;
135
136 fn result(&self) -> Result<Value>;
138
139 fn reset(&mut self);
141
142 fn clone_state(&self) -> Box<dyn AggregateState>;
144}
145
146pub trait ExtensionPlugin: Send + Sync + Debug {
148 fn name(&self) -> &str;
150
151 fn version(&self) -> &str;
153
154 fn dependencies(&self) -> Vec<String>;
156
157 fn initialize(&mut self, registry: &mut ExtensionRegistry) -> Result<()>;
159
160 fn shutdown(&mut self) -> Result<()>;
162
163 fn metadata(&self) -> PluginMetadata;
165}
166
167#[derive(Debug, Clone)]
169pub struct PluginMetadata {
170 pub name: String,
171 pub version: String,
172 pub author: String,
173 pub description: String,
174 pub license: String,
175 pub homepage: Option<String>,
176 pub repository: Option<String>,
177}
178
179pub trait TypeConverter: Send + Sync + Debug {
181 #[allow(clippy::wrong_self_convention)]
183 fn from_type(&self) -> &str;
184
185 fn to_type(&self) -> &str;
187
188 fn convert(&self, value: &Value) -> Result<Value>;
190
191 fn can_convert(&self, value: &Value) -> bool;
193}
194
195#[derive(Debug, Clone, PartialEq)]
197pub enum ValueType {
198 String,
199 Integer,
200 Float,
201 Boolean,
202 DateTime,
203 Duration,
204 Iri,
205 BlankNode,
206 Literal,
207 Custom(String),
208 List(Box<ValueType>),
209 Optional(Box<ValueType>),
210 Union(Vec<ValueType>),
211}
212
213#[derive(Debug)]
215pub enum Value {
216 String(String),
217 Integer(i64),
218 Float(f64),
219 Boolean(bool),
220 DateTime(chrono::DateTime<chrono::Utc>),
221 Duration(chrono::Duration),
222 Iri(String),
223 BlankNode(String),
224 Literal {
225 value: String,
226 language: Option<String>,
227 datatype: Option<String>,
228 },
229 List(Vec<Value>),
230 Null,
231 Custom {
232 type_name: String,
233 data: Box<dyn Any + Send + Sync>,
234 },
235}
236
237impl Clone for Value {
238 fn clone(&self) -> Self {
239 match self {
240 Value::String(s) => Value::String(s.clone()),
241 Value::Integer(i) => Value::Integer(*i),
242 Value::Float(f) => Value::Float(*f),
243 Value::Boolean(b) => Value::Boolean(*b),
244 Value::DateTime(dt) => Value::DateTime(*dt),
245 Value::Duration(d) => Value::Duration(*d),
246 Value::Iri(iri) => Value::Iri(iri.clone()),
247 Value::BlankNode(id) => Value::BlankNode(id.clone()),
248 Value::Literal {
249 value,
250 language,
251 datatype,
252 } => Value::Literal {
253 value: value.clone(),
254 language: language.clone(),
255 datatype: datatype.clone(),
256 },
257 Value::List(list) => Value::List(list.clone()),
258 Value::Null => Value::Null,
259 Value::Custom { type_name, .. } => {
260 Value::String(format!("Custom({type_name})"))
262 }
263 }
264 }
265}
266
267impl PartialEq for Value {
268 fn eq(&self, other: &Self) -> bool {
269 match (self, other) {
270 (Value::String(a), Value::String(b)) => a == b,
271 (Value::Integer(a), Value::Integer(b)) => a == b,
272 (Value::Float(a), Value::Float(b)) => a == b,
273 (Value::Boolean(a), Value::Boolean(b)) => a == b,
274 (Value::DateTime(a), Value::DateTime(b)) => a == b,
275 (Value::Duration(a), Value::Duration(b)) => a == b,
276 (Value::Iri(a), Value::Iri(b)) => a == b,
277 (Value::BlankNode(a), Value::BlankNode(b)) => a == b,
278 (
279 Value::Literal {
280 value: v1,
281 language: l1,
282 datatype: d1,
283 },
284 Value::Literal {
285 value: v2,
286 language: l2,
287 datatype: d2,
288 },
289 ) => v1 == v2 && l1 == l2 && d1 == d2,
290 (Value::List(a), Value::List(b)) => a == b,
291 (Value::Null, Value::Null) => true,
292 (Value::Custom { type_name: t1, .. }, Value::Custom { type_name: t2, .. }) => t1 == t2,
293 _ => false,
294 }
295 }
296}
297
298impl PartialOrd for Value {
299 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
300 use std::cmp::Ordering;
301 match (self, other) {
302 (Value::String(a), Value::String(b)) => a.partial_cmp(b),
303 (Value::Integer(a), Value::Integer(b)) => a.partial_cmp(b),
304 (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
305 (Value::Boolean(a), Value::Boolean(b)) => a.partial_cmp(b),
306 (Value::DateTime(a), Value::DateTime(b)) => a.partial_cmp(b),
307 (Value::Duration(a), Value::Duration(b)) => a.partial_cmp(b),
308 (Value::Iri(a), Value::Iri(b)) => a.partial_cmp(b),
309 (Value::BlankNode(a), Value::BlankNode(b)) => a.partial_cmp(b),
310 (
311 Value::Literal {
312 value: v1,
313 language: l1,
314 datatype: d1,
315 },
316 Value::Literal {
317 value: v2,
318 language: l2,
319 datatype: d2,
320 },
321 ) => match v1.partial_cmp(v2) {
322 Some(Ordering::Equal) => match l1.partial_cmp(l2) {
323 Some(Ordering::Equal) => d1.partial_cmp(d2),
324 other => other,
325 },
326 other => other,
327 },
328 (Value::Integer(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
329 (Value::Float(a), Value::Integer(b)) => a.partial_cmp(&(*b as f64)),
330 (Value::Null, Value::Null) => Some(Ordering::Equal),
331 (Value::Null, _) => Some(Ordering::Less),
332 (_, Value::Null) => Some(Ordering::Greater),
333 _ => None, }
335 }
336}
337
338#[derive(Debug, Clone, PartialEq)]
340pub enum Associativity {
341 Left,
342 Right,
343 None,
344}
345
346#[derive(Debug, Clone, PartialEq)]
348pub enum OperatorType {
349 Binary,
350 Unary,
351 Ternary,
352}
353
354#[derive(Debug, Clone)]
356pub struct ExecutionContext {
357 pub variables: HashMap<Variable, Term>,
358 pub namespaces: HashMap<String, String>,
359 pub base_iri: Option<String>,
360 pub dataset_context: Option<String>,
361 pub query_time: chrono::DateTime<chrono::Utc>,
362 pub optimization_level: OptimizationLevel,
363 pub memory_limit: Option<usize>,
364 pub time_limit: Option<std::time::Duration>,
365}
366
367#[derive(Debug, Clone, PartialEq)]
369pub enum OptimizationLevel {
370 None,
371 Basic,
372 Aggressive,
373}
374
375impl ExtensionRegistry {
376 pub fn new() -> Self {
377 Self {
378 functions: Arc::new(RwLock::new(HashMap::new())),
379 operators: Arc::new(RwLock::new(HashMap::new())),
380 aggregates: Arc::new(RwLock::new(HashMap::new())),
381 plugins: Arc::new(RwLock::new(Vec::new())),
382 type_converters: Arc::new(RwLock::new(HashMap::new())),
383 }
384 }
385
386 pub fn register_function<F>(&self, function: F) -> Result<()>
388 where
389 F: CustomFunction + 'static,
390 {
391 let name = function.name().to_string();
392 let mut functions = self
393 .functions
394 .write()
395 .map_err(|_| anyhow!("Failed to acquire write lock on functions"))?;
396 functions.insert(name, Box::new(function));
397 Ok(())
398 }
399
400 pub fn register_operator<O>(&self, operator: O) -> Result<()>
402 where
403 O: CustomOperator + 'static,
404 {
405 let symbol = operator.symbol().to_string();
406 let mut operators = self
407 .operators
408 .write()
409 .map_err(|_| anyhow!("Failed to acquire write lock on operators"))?;
410 operators.insert(symbol, Box::new(operator));
411 Ok(())
412 }
413
414 pub fn register_aggregate<A>(&self, aggregate: A) -> Result<()>
416 where
417 A: CustomAggregate + 'static,
418 {
419 let name = aggregate.name().to_string();
420 let mut aggregates = self
421 .aggregates
422 .write()
423 .map_err(|_| anyhow!("Failed to acquire write lock on aggregates"))?;
424 aggregates.insert(name, Box::new(aggregate));
425 Ok(())
426 }
427
428 pub fn register_plugin<P>(&mut self, mut plugin: P) -> Result<()>
430 where
431 P: ExtensionPlugin + 'static,
432 {
433 plugin.initialize(self)?;
435
436 let mut plugins = self
437 .plugins
438 .write()
439 .map_err(|_| anyhow!("Failed to acquire write lock on plugins"))?;
440 plugins.push(Box::new(plugin));
441 Ok(())
442 }
443
444 pub fn register_type_converter<T>(&self, converter: T) -> Result<()>
446 where
447 T: TypeConverter + 'static,
448 {
449 let key = format!("{}:{}", converter.from_type(), converter.to_type());
450 let mut converters = self
451 .type_converters
452 .write()
453 .map_err(|_| anyhow!("Failed to acquire write lock on type converters"))?;
454 converters.insert(key, Box::new(converter));
455 Ok(())
456 }
457
458 pub fn get_function(&self, name: &str) -> Result<Option<Box<dyn CustomFunction>>> {
460 let functions = self
461 .functions
462 .read()
463 .map_err(|_| anyhow!("Failed to acquire read lock on functions"))?;
464 Ok(functions.get(name).map(|f| f.clone_function()))
465 }
466
467 pub fn has_function(&self, name: &str) -> Result<bool> {
469 let functions = self
470 .functions
471 .read()
472 .map_err(|_| anyhow!("Failed to acquire read lock on functions"))?;
473 Ok(functions.contains_key(name))
474 }
475
476 pub fn has_operator(&self, symbol: &str) -> Result<bool> {
478 let operators = self
479 .operators
480 .read()
481 .map_err(|_| anyhow!("Failed to acquire read lock on operators"))?;
482 Ok(operators.contains_key(symbol))
483 }
484
485 pub fn has_aggregate(&self, name: &str) -> Result<bool> {
487 let aggregates = self
488 .aggregates
489 .read()
490 .map_err(|_| anyhow!("Failed to acquire read lock on aggregates"))?;
491 Ok(aggregates.contains_key(name))
492 }
493
494 pub fn execute_function(
496 &self,
497 name: &str,
498 args: &[Value],
499 context: &ExecutionContext,
500 ) -> Result<Value> {
501 let functions = self
502 .functions
503 .read()
504 .map_err(|_| anyhow!("Failed to acquire read lock on functions"))?;
505
506 if let Some(func) = functions.get(name) {
507 func.execute(args, context)
508 } else {
509 Err(anyhow!("Function '{}' not found", name))
510 }
511 }
512
513 pub fn execute_operator(
515 &self,
516 symbol: &str,
517 left: Option<&Value>,
518 right: Option<&Value>,
519 context: &ExecutionContext,
520 ) -> Result<Value> {
521 let operators = self
522 .operators
523 .read()
524 .map_err(|_| anyhow!("Failed to acquire read lock on operators"))?;
525
526 if let Some(op) = operators.get(symbol) {
527 op.execute(left, right, context)
528 } else {
529 Err(anyhow!("Operator '{}' not found", symbol))
530 }
531 }
532
533 pub fn create_aggregate_state(&self, name: &str) -> Result<Box<dyn AggregateState>> {
535 let aggregates = self
536 .aggregates
537 .read()
538 .map_err(|_| anyhow!("Failed to acquire read lock on aggregates"))?;
539
540 if let Some(agg) = aggregates.get(name) {
541 Ok(agg.init())
542 } else {
543 Err(anyhow!("Aggregate '{}' not found", name))
544 }
545 }
546
547 pub fn convert_value(&self, value: &Value, target_type: &str) -> Result<Value> {
549 let source_type = value.type_name();
550 let key = format!("{source_type}:{target_type}");
551
552 let converters = self
553 .type_converters
554 .read()
555 .map_err(|_| anyhow!("Failed to acquire read lock on type converters"))?;
556
557 if let Some(converter) = converters.get(&key) {
558 converter.convert(value)
559 } else {
560 self.builtin_convert(value, target_type)
562 }
563 }
564
565 fn builtin_convert(&self, value: &Value, target_type: &str) -> Result<Value> {
567 match (value, target_type) {
568 (Value::String(s), "integer") => s
569 .parse::<i64>()
570 .map(Value::Integer)
571 .map_err(|_| anyhow!("Cannot convert '{}' to integer", s)),
572 (Value::String(s), "float") => s
573 .parse::<f64>()
574 .map(Value::Float)
575 .map_err(|_| anyhow!("Cannot convert '{}' to float", s)),
576 (Value::Integer(i), "string") => Ok(Value::String(i.to_string())),
577 (Value::Float(f), "string") => Ok(Value::String(f.to_string())),
578 (Value::Boolean(b), "string") => Ok(Value::String(b.to_string())),
579 _ => bail!(
580 "No conversion available from {} to {}",
581 value.type_name(),
582 target_type
583 ),
584 }
585 }
586
587 pub fn list_functions(&self) -> Result<Vec<String>> {
589 let functions = self
590 .functions
591 .read()
592 .map_err(|_| anyhow!("Failed to acquire read lock on functions"))?;
593 Ok(functions.keys().cloned().collect())
594 }
595
596 pub fn list_operators(&self) -> Result<Vec<String>> {
598 let operators = self
599 .operators
600 .read()
601 .map_err(|_| anyhow!("Failed to acquire read lock on operators"))?;
602 Ok(operators.keys().cloned().collect())
603 }
604
605 pub fn validate_extensions(&self) -> Result<Vec<String>> {
607 let mut errors = Vec::new();
608
609 let plugins = self
611 .plugins
612 .read()
613 .map_err(|_| anyhow!("Failed to acquire read lock on plugins"))?;
614
615 for plugin in plugins.iter() {
616 for dep in plugin.dependencies() {
617 let found = plugins.iter().any(|p| p.name() == dep);
618 if !found {
619 errors.push(format!(
620 "Plugin '{}' missing dependency '{}'",
621 plugin.name(),
622 dep
623 ));
624 }
625 }
626 }
627
628 Ok(errors)
629 }
630}
631
632impl Default for ExtensionRegistry {
633 fn default() -> Self {
634 Self::new()
635 }
636}
637
638impl Value {
639 pub fn type_name(&self) -> &str {
641 match self {
642 Value::String(_) => "string",
643 Value::Integer(_) => "integer",
644 Value::Float(_) => "float",
645 Value::Boolean(_) => "boolean",
646 Value::DateTime(_) => "datetime",
647 Value::Duration(_) => "duration",
648 Value::Iri(_) => "iri",
649 Value::BlankNode(_) => "bnode",
650 Value::Literal { .. } => "literal",
651 Value::List(_) => "list",
652 Value::Null => "null",
653 Value::Custom { type_name, .. } => type_name,
654 }
655 }
656
657 pub fn to_term(&self) -> Result<Term> {
659 match self {
660 Value::String(s) => Ok(Term::Literal(crate::algebra::Literal {
661 value: s.clone(),
662 language: None,
663 datatype: None,
664 })),
665 Value::Iri(iri) => Ok(Term::Iri(NamedNode::new_unchecked(iri.clone()))),
666 Value::BlankNode(id) => Ok(Term::BlankNode(id.clone())),
667 Value::Literal {
668 value,
669 language,
670 datatype,
671 } => Ok(Term::Literal(crate::algebra::Literal {
672 value: value.clone(),
673 language: language.clone(),
674 datatype: datatype
675 .as_ref()
676 .map(|dt| NamedNode::new_unchecked(dt.clone())),
677 })),
678 _ => bail!("Cannot convert {} to Term", self.type_name()),
679 }
680 }
681
682 pub fn from_term(term: &Term) -> Self {
684 match term {
685 Term::Iri(iri) => Value::Iri(iri.as_str().to_string()),
686 Term::BlankNode(id) => Value::BlankNode(id.clone()),
687 Term::Literal(lit) => Value::Literal {
688 value: lit.value.clone(),
689 language: lit.language.clone(),
690 datatype: lit.datatype.as_ref().map(|dt| dt.as_str().to_string()),
691 },
692 Term::Variable(var) => Value::String(format!("?{var}")),
693 Term::QuotedTriple(_) => Value::String("<<quoted triple>>".to_string()),
694 Term::PropertyPath(_) => Value::String("<property path>".to_string()),
695 }
696 }
697}
698
699#[macro_export]
701macro_rules! register_function {
702 ($registry:expr_2021, $name:expr_2021, $params:expr_2021, $return_type:expr_2021, $body:expr_2021) => {{
703 #[derive(Debug, Clone)]
704 struct GeneratedFunction {
705 name: String,
706 params: Vec<ValueType>,
707 return_type: ValueType,
708 body: fn(&[Value], &ExecutionContext) -> Result<Value>,
709 }
710
711 impl CustomFunction for GeneratedFunction {
712 fn name(&self) -> &str {
713 &self.name
714 }
715 fn arity(&self) -> Option<usize> {
716 Some(self.params.len())
717 }
718 fn parameter_types(&self) -> Vec<ValueType> {
719 self.params.clone()
720 }
721 fn return_type(&self) -> ValueType {
722 self.return_type.clone()
723 }
724 fn documentation(&self) -> &str {
725 "Generated function"
726 }
727 fn clone_function(&self) -> Box<dyn CustomFunction> {
728 Box::new(self.clone())
729 }
730
731 fn execute(&self, args: &[Value], context: &ExecutionContext) -> Result<Value> {
732 (self.body)(args, context)
733 }
734 }
735
736 let func = GeneratedFunction {
737 name: $name.to_string(),
738 params: $params,
739 return_type: $return_type,
740 body: $body,
741 };
742
743 $registry.register_function(func)
744 }};
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750
751 #[derive(Debug, Clone)]
752 struct TestFunction;
753
754 impl CustomFunction for TestFunction {
755 fn name(&self) -> &str {
756 "http://example.org/test"
757 }
758 fn arity(&self) -> Option<usize> {
759 Some(2)
760 }
761 fn parameter_types(&self) -> Vec<ValueType> {
762 vec![ValueType::Integer, ValueType::Integer]
763 }
764 fn return_type(&self) -> ValueType {
765 ValueType::Integer
766 }
767 fn documentation(&self) -> &str {
768 "Test function that adds two integers"
769 }
770 fn clone_function(&self) -> Box<dyn CustomFunction> {
771 Box::new(self.clone())
772 }
773
774 fn execute(&self, args: &[Value], _context: &ExecutionContext) -> Result<Value> {
775 if args.len() != 2 {
776 bail!("Expected 2 arguments, got {}", args.len());
777 }
778
779 match (&args[0], &args[1]) {
780 (Value::Integer(a), Value::Integer(b)) => Ok(Value::Integer(a + b)),
781 _ => bail!("Expected integer arguments"),
782 }
783 }
784 }
785
786 #[test]
787 fn test_function_registration() {
788 let registry = ExtensionRegistry::new();
789 let func = TestFunction;
790
791 assert!(registry.register_function(func).is_ok());
792 assert!(registry
793 .get_function("http://example.org/test")
794 .unwrap()
795 .is_some());
796 }
797
798 #[test]
799 fn test_function_execution() {
800 let func = TestFunction;
801 let args = vec![Value::Integer(5), Value::Integer(3)];
802 let context = ExecutionContext {
803 variables: HashMap::new(),
804 namespaces: HashMap::new(),
805 base_iri: None,
806 dataset_context: None,
807 query_time: chrono::Utc::now(),
808 optimization_level: OptimizationLevel::Basic,
809 memory_limit: None,
810 time_limit: None,
811 };
812
813 let result = func.execute(&args, &context).unwrap();
814 assert_eq!(result, Value::Integer(8));
815 }
816}