1#![warn(missing_docs)]
52#![warn(rust_2018_idioms)]
53#![warn(missing_debug_implementations)]
54
55mod aggregate;
56mod decode;
57mod eval;
58mod scalar;
59
60pub mod persistence;
61
62use std::sync::Arc;
63
64use serde::{Deserialize, Serialize};
65use thiserror::Error;
66use uni_plugin::PluginRegistry;
67
68pub use crate::aggregate::{DeclaredAggregateFn, install_aggregate_into_registry};
69pub use crate::persistence::{JsonFilePersistence, NullPersistence, Persistence, PersistenceError};
70pub use crate::scalar::DeclaredScalarFn;
71
72#[derive(Debug, Error)]
74#[non_exhaustive]
75pub enum CustomError {
76 #[error("declared plugin body parse failure: {0}")]
78 BodyParse(String),
79
80 #[error("declared qname `{0}` is shadowed by a native plugin registration")]
82 NativeShadow(String),
83
84 #[error("declared plugin `{dependent}` depends on missing `{dep}`")]
86 DependencyMissing {
87 dependent: String,
89 dep: String,
91 },
92
93 #[error("dependency cycle in declared plugins: {0:?}")]
95 DependencyCycle(Vec<String>),
96
97 #[error("declared-plugin persistence: {0}")]
99 Persistence(#[from] PersistenceError),
100
101 #[error("declared-plugin registration: {0}")]
103 Registration(String),
104
105 #[error("declared-plugin capability denied: caller is missing `{0}`")]
107 CapabilityDenied(String),
108}
109
110#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
118pub struct DeclaredPlugin {
119 pub qname: String,
121 pub kind: String,
123 pub body: String,
125 pub signature_json: String,
127 pub dependencies: Vec<String>,
129 pub declared_by: String,
131 pub active: bool,
134}
135
136pub struct CustomPlugin {
154 store: Arc<DeclaredPluginStore>,
155 registry: Arc<PluginRegistry>,
156 persistence: Arc<dyn Persistence>,
157 procedure_synthesizer: Option<Arc<dyn ProcedureBodySynthesizer>>,
163 manifest: std::sync::OnceLock<uni_plugin::PluginManifest>,
164}
165
166pub trait ProcedureBodySynthesizer: Send + Sync + std::fmt::Debug {
177 fn synthesize(
187 &self,
188 decl: &DeclaredPlugin,
189 ) -> Result<Arc<dyn uni_plugin::traits::procedure::ProcedurePlugin>, String>;
190}
191
192impl std::fmt::Debug for CustomPlugin {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 f.debug_struct("CustomPlugin")
195 .field("store", &self.store)
196 .field("declared_count", &self.store.list().len())
197 .finish_non_exhaustive()
198 }
199}
200
201impl CustomPlugin {
202 pub const ID: &'static str = "custom";
204
205 pub fn new(
219 registry: Arc<PluginRegistry>,
220 persistence: Arc<dyn Persistence>,
221 ) -> Result<Self, CustomError> {
222 let store = Arc::new(DeclaredPluginStore::new());
223 let initial = persistence.load_all()?;
224 for plugin in initial {
225 store.declare_unchecked(plugin);
230 }
231 Ok(Self {
232 store,
233 registry,
234 persistence,
235 procedure_synthesizer: None,
236 manifest: std::sync::OnceLock::new(),
237 })
238 }
239
240 #[must_use]
248 pub fn with_procedure_synthesizer(
249 mut self,
250 synthesizer: Arc<dyn ProcedureBodySynthesizer>,
251 ) -> Self {
252 self.procedure_synthesizer = Some(synthesizer);
253 self
254 }
255
256 #[must_use]
261 pub fn new_in_memory() -> Self {
262 Self::new(Arc::new(PluginRegistry::new()), Arc::new(NullPersistence))
263 .expect("NullPersistence cannot fail")
264 }
265
266 #[must_use]
268 pub fn store(&self) -> &Arc<DeclaredPluginStore> {
269 &self.store
270 }
271
272 #[must_use]
274 pub fn registry(&self) -> &Arc<PluginRegistry> {
275 &self.registry
276 }
277
278 pub fn reactivate_into_registry(&self) -> Result<(), CustomError> {
291 let mut records = self.store.list();
292 records.sort_by_key(|a| a.dependencies.len());
293 for record in records {
294 let install_result = match record.kind.as_str() {
295 "function" => procedures::install_function_into_registry(&self.registry, &record),
296 "aggregate" => {
297 crate::aggregate::install_aggregate_into_registry(&self.registry, &record)
298 }
299 "procedure" | "trigger" => {
300 match self.procedure_synthesizer.as_ref() {
306 Some(synth) => procedures::install_synthesized_procedure(
307 &self.registry,
308 &record,
309 synth.as_ref(),
310 ),
311 None => continue,
312 }
313 }
314 _ => continue,
315 };
316 let mut record = record;
317 match install_result {
318 Ok(()) => {}
319 Err(CustomError::NativeShadow(_)) => {
320 record.active = false;
321 self.store.replace(record.clone());
322 let _ = self.persistence.save(&record);
323 }
324 Err(e) => return Err(e),
325 }
326 }
327 Ok(())
328 }
329
330 fn manifest_value() -> uni_plugin::PluginManifest {
331 use semver::Version;
332 use uni_plugin::{
333 AbiRange, Capability, CapabilitySet, Determinism, PluginId, PluginManifest,
334 ProvidedSurfaces, Scope, SideEffects,
335 };
336 PluginManifest {
337 id: PluginId::new(Self::ID),
338 version: env!("CARGO_PKG_VERSION")
339 .parse::<Version>()
340 .unwrap_or_else(|_| Version::new(0, 0, 0)),
341 abi: AbiRange::parse("^1").expect("manifest ABI range is valid"),
342 depends_on: vec![],
343 capabilities: CapabilitySet::from_iter_of([
344 Capability::Procedure,
345 Capability::ProcedureWrites,
346 Capability::PluginDeclare,
347 ]),
348 determinism: Determinism::Nondeterministic,
349 side_effects: SideEffects::ReadOnly,
350 scope: Scope::Instance,
351 hash: None,
352 signature: None,
353 provides: ProvidedSurfaces::default(),
354 docs: "apoc.custom-style meta-plugin: declare procedures / functions / aggregates / triggers from Cypher."
355 .to_owned(),
356 metadata: std::collections::BTreeMap::new(),
357 }
358 }
359}
360
361impl uni_plugin::Plugin for CustomPlugin {
362 fn manifest(&self) -> &uni_plugin::PluginManifest {
363 self.manifest.get_or_init(Self::manifest_value)
364 }
365
366 fn register(
367 &self,
368 r: &mut uni_plugin::PluginRegistrar<'_>,
369 ) -> Result<(), uni_plugin::PluginError> {
370 use uni_plugin::QName;
371
372 r.procedure(
373 QName::new(Self::ID, "plugin.listDeclared"),
374 procedures::list_declared_signature(),
375 std::sync::Arc::new(procedures::ListDeclaredProcedure::new(Arc::clone(
376 &self.store,
377 ))),
378 )?;
379 r.procedure(
380 QName::new(Self::ID, "plugin.dropDeclared"),
381 procedures::drop_declared_signature(),
382 std::sync::Arc::new(procedures::DropDeclaredProcedure::new(
383 Arc::clone(&self.store),
384 Arc::clone(&self.persistence),
385 Arc::clone(&self.registry),
386 )),
387 )?;
388 r.procedure(
389 QName::new(Self::ID, "plugin.declareFunction"),
390 procedures::declare_function_signature(),
391 std::sync::Arc::new(procedures::DeclareFunctionProcedure::new(
392 Arc::clone(&self.store),
393 Arc::clone(&self.persistence),
394 Arc::clone(&self.registry),
395 )),
396 )?;
397 r.procedure(
398 QName::new(Self::ID, "plugin.declareProcedure"),
399 procedures::declare_procedure_signature(),
400 std::sync::Arc::new(match self.procedure_synthesizer.as_ref() {
401 Some(synth) => procedures::DeclareProcedureProcedure::new_with_synthesis(
402 Arc::clone(&self.store),
403 Arc::clone(&self.persistence),
404 Arc::clone(&self.registry),
405 Arc::clone(synth),
406 ),
407 None => procedures::DeclareProcedureProcedure::new(
408 Arc::clone(&self.store),
409 Arc::clone(&self.persistence),
410 ),
411 }),
412 )?;
413 r.procedure(
414 QName::new(Self::ID, "plugin.declareAggregate"),
415 procedures::declare_aggregate_signature(),
416 std::sync::Arc::new(procedures::DeclareAggregateProcedure::new(
417 Arc::clone(&self.store),
418 Arc::clone(&self.persistence),
419 Arc::clone(&self.registry),
420 )),
421 )?;
422 r.procedure(
423 QName::new(Self::ID, "plugin.declareTrigger"),
424 procedures::declare_trigger_signature(),
425 std::sync::Arc::new(match self.procedure_synthesizer.as_ref() {
426 Some(synth) => procedures::DeclareTriggerProcedure::new_with_synthesis(
427 Arc::clone(&self.store),
428 Arc::clone(&self.persistence),
429 Arc::clone(&self.registry),
430 Arc::clone(synth),
431 ),
432 None => procedures::DeclareTriggerProcedure::new(
433 Arc::clone(&self.store),
434 Arc::clone(&self.persistence),
435 ),
436 }),
437 )?;
438 Ok(())
439 }
440}
441
442pub mod procedures {
444 use std::sync::Arc;
445
446 use arrow_array::builder::{BooleanBuilder, StringBuilder};
447 use arrow_array::{Array, BooleanArray, RecordBatch, StringArray};
448 use arrow_schema::{DataType, Field, Schema, SchemaRef};
449 use datafusion::execution::SendableRecordBatchStream;
450 use datafusion::logical_expr::ColumnarValue;
451 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
452 use datafusion::scalar::ScalarValue;
453 use futures::stream;
454 use semver::Version;
455 use uni_cypher::parse_expression;
456 use uni_plugin::traits::procedure::{
457 NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
458 };
459 use uni_plugin::traits::scalar::{ArgType, ScalarPluginFn};
460 use uni_plugin::{
461 AbiRange, Capability, CapabilitySet, Determinism, FnError, Plugin, PluginError, PluginId,
462 PluginManifest, PluginRegistrar, PluginRegistry, ProvidedSurfaces, QName, Scope,
463 SideEffects,
464 };
465
466 use super::{CustomError, DeclaredPlugin, DeclaredPluginStore, DeclaredScalarFn, Persistence};
467 use crate::decode::{declared_plugin_id, local_part, map_plugin_error, type_str_to_arrow};
468
469 #[must_use]
475 pub fn list_declared_signature() -> ProcedureSignature {
476 ProcedureSignature {
477 args: vec![],
478 yields: vec![
479 Field::new("qname", DataType::Utf8, false),
480 Field::new("kind", DataType::Utf8, false),
481 Field::new("declared_by", DataType::Utf8, false),
482 Field::new("active", DataType::Boolean, false),
483 ],
484 mode: ProcedureMode::Read,
485 side_effects: SideEffects::ReadOnly,
486 retry_contract: None,
487 batch_input: None,
488 docs: "List every declared plugin (apoc.custom analogue) with its kind, declarer, and active state.".to_owned(),
489 }
490 }
491
492 #[must_use]
494 pub fn drop_declared_signature() -> ProcedureSignature {
495 write_signature(
496 vec![named_arg(
497 "qname",
498 DataType::Utf8,
499 "Qualified name of the declared plugin to drop.",
500 )],
501 "removed",
502 "Drop a previously declared plugin. Errors if other declared plugins depend on it.",
503 )
504 }
505
506 fn named_arg(name: &str, ty: DataType, doc: &str) -> NamedArgType {
507 NamedArgType {
508 name: smol_str::SmolStr::new(name),
509 ty: ArgType::Primitive(ty),
510 default: None,
511 doc: doc.to_owned(),
512 }
513 }
514
515 fn named_arg_default(name: &str, ty: DataType, doc: &str, default: &str) -> NamedArgType {
525 NamedArgType {
526 name: smol_str::SmolStr::new(name),
527 ty: ArgType::Primitive(ty),
528 default: Some(ScalarValue::Utf8(Some(default.to_owned()))),
529 doc: doc.to_owned(),
530 }
531 }
532
533 const DEPS_JSON_DOC: &str =
536 "JSON array of qualified names this declaration depends on (empty by default).";
537
538 fn deps_arg() -> NamedArgType {
539 named_arg_default("deps_json", DataType::Utf8, DEPS_JSON_DOC, "[]")
540 }
541
542 fn write_signature(args: Vec<NamedArgType>, yield_col: &str, docs: &str) -> ProcedureSignature {
549 ProcedureSignature {
550 args,
551 yields: vec![Field::new(yield_col, DataType::Boolean, false)],
552 mode: ProcedureMode::Write,
553 side_effects: SideEffects::ReadOnly,
554 retry_contract: None,
555 batch_input: None,
556 docs: docs.to_owned(),
557 }
558 }
559
560 #[must_use]
562 pub fn declare_function_signature() -> ProcedureSignature {
563 write_signature(
564 vec![
565 named_arg("qname", DataType::Utf8, "Qualified name to register."),
566 named_arg("body", DataType::Utf8, "Cypher expression body."),
567 named_arg(
568 "return_type",
569 DataType::Utf8,
570 "Return type ('string', 'int', 'float', 'bool').",
571 ),
572 named_arg(
573 "arg_names_json",
574 DataType::Utf8,
575 "JSON array of argument names, in positional order.",
576 ),
577 deps_arg(),
578 ],
579 "registered",
580 "Declare a new scalar function. Body is a Cypher expression; arguments are bound by name (positional).",
581 )
582 }
583
584 #[must_use]
586 pub fn declare_procedure_signature() -> ProcedureSignature {
587 write_signature(
588 vec![
589 named_arg("qname", DataType::Utf8, "Qualified name to register."),
590 named_arg("body", DataType::Utf8, "Cypher query body."),
591 named_arg("mode", DataType::Utf8, "'READ' or 'WRITE'."),
592 named_arg(
593 "yield_json",
594 DataType::Utf8,
595 "JSON array describing yielded columns.",
596 ),
597 deps_arg(),
598 ],
599 "registered",
600 "Declare a new procedure. The body is a full Cypher query; arguments are bound by name.",
601 )
602 }
603
604 #[must_use]
606 pub fn declare_aggregate_signature() -> ProcedureSignature {
607 write_signature(
608 vec![
609 named_arg("qname", DataType::Utf8, "Qualified name to register."),
610 named_arg(
611 "init_expr",
612 DataType::Utf8,
613 "Init state expression (no parameters).",
614 ),
615 named_arg(
616 "update_expr",
617 DataType::Utf8,
618 "Update step expression; binds `$state` plus per-row args.",
619 ),
620 named_arg(
621 "finalize_expr",
622 DataType::Utf8,
623 "Finalize expression; binds `$state`.",
624 ),
625 named_arg_default(
626 "return_type",
627 DataType::Utf8,
628 "Return type ('string', 'int', 'float', 'bool').",
629 "float",
630 ),
631 named_arg_default(
632 "arg_names_json",
633 DataType::Utf8,
634 "JSON array of update-arg names, in positional order.",
635 "[]",
636 ),
637 deps_arg(),
638 ],
639 "registered",
640 "Declare a new aggregate function from Cypher init / update / finalize expressions.",
641 )
642 }
643
644 #[must_use]
646 pub fn declare_trigger_signature() -> ProcedureSignature {
647 write_signature(
648 vec![
649 named_arg("qname", DataType::Utf8, "Qualified name to register."),
650 named_arg(
651 "event_filter",
652 DataType::Utf8,
653 "Event filter (label or relationship pattern).",
654 ),
655 named_arg(
656 "body",
657 DataType::Utf8,
658 "Cypher body to execute when the trigger fires.",
659 ),
660 deps_arg(),
661 ],
662 "registered",
663 "Declare a new trigger that fires the given Cypher body on matched mutation events.",
664 )
665 }
666
667 #[derive(Debug)]
673 pub struct ListDeclaredProcedure {
674 store: Arc<DeclaredPluginStore>,
675 }
676
677 impl ListDeclaredProcedure {
678 #[must_use]
680 pub fn new(store: Arc<DeclaredPluginStore>) -> Self {
681 Self { store }
682 }
683 }
684
685 impl ProcedurePlugin for ListDeclaredProcedure {
686 fn signature(&self) -> &ProcedureSignature {
687 static SIG: std::sync::OnceLock<ProcedureSignature> = std::sync::OnceLock::new();
688 SIG.get_or_init(list_declared_signature)
689 }
690
691 fn invoke(
692 &self,
693 _ctx: ProcedureContext<'_>,
694 _args: &[ColumnarValue],
695 ) -> Result<SendableRecordBatchStream, FnError> {
696 let rows = self.store.list();
697 let mut qname = StringBuilder::new();
698 let mut kind = StringBuilder::new();
699 let mut declared_by = StringBuilder::new();
700 let mut active = BooleanBuilder::new();
701 for r in rows {
702 qname.append_value(&r.qname);
703 kind.append_value(&r.kind);
704 declared_by.append_value(&r.declared_by);
705 active.append_value(r.active);
706 }
707 let schema: SchemaRef = Arc::new(Schema::new(vec![
708 Field::new("qname", DataType::Utf8, false),
709 Field::new("kind", DataType::Utf8, false),
710 Field::new("declared_by", DataType::Utf8, false),
711 Field::new("active", DataType::Boolean, false),
712 ]));
713 let cols: Vec<Arc<dyn Array>> = vec![
714 Arc::new(qname.finish()),
715 Arc::new(kind.finish()),
716 Arc::new(declared_by.finish()),
717 Arc::new(active.finish()),
718 ];
719 let batch = RecordBatch::try_new(Arc::clone(&schema), cols)
720 .map_err(|e| FnError::new(0xB00, format!("listDeclared: {e}")))?;
721 Ok(Box::pin(RecordBatchStreamAdapter::new(
722 schema,
723 stream::iter(vec![Ok(batch)]),
724 )))
725 }
726 }
727
728 #[derive(Debug)]
730 pub struct DropDeclaredProcedure {
731 store: Arc<DeclaredPluginStore>,
732 persistence: Arc<dyn Persistence>,
733 registry: Arc<PluginRegistry>,
734 }
735
736 impl DropDeclaredProcedure {
737 #[must_use]
739 pub fn new(
740 store: Arc<DeclaredPluginStore>,
741 persistence: Arc<dyn Persistence>,
742 registry: Arc<PluginRegistry>,
743 ) -> Self {
744 Self {
745 store,
746 persistence,
747 registry,
748 }
749 }
750 }
751
752 impl ProcedurePlugin for DropDeclaredProcedure {
753 fn signature(&self) -> &ProcedureSignature {
754 static SIG: std::sync::OnceLock<ProcedureSignature> = std::sync::OnceLock::new();
755 SIG.get_or_init(drop_declared_signature)
756 }
757
758 fn invoke(
759 &self,
760 _ctx: ProcedureContext<'_>,
761 args: &[ColumnarValue],
762 ) -> Result<SendableRecordBatchStream, FnError> {
763 let qname = extract_string(args, 0, "qname")?;
764 let existed = self
765 .store
766 .drop_declared(&qname)
767 .map_err(|e| FnError::new(0xB01, format!("dropDeclared: {e}")))?;
768 if existed {
769 let pid = PluginId::new(declared_plugin_id(&qname));
773 self.registry.remove_plugin(&pid);
774 self.persistence
775 .delete(&qname)
776 .map_err(|e| FnError::new(0xB01, format!("dropDeclared persist: {e}")))?;
777 }
778 single_bool("removed", existed)
779 }
780 }
781
782 #[derive(Debug)]
788 pub struct DeclareFunctionProcedure {
789 store: Arc<DeclaredPluginStore>,
790 persistence: Arc<dyn Persistence>,
791 registry: Arc<PluginRegistry>,
792 }
793
794 impl DeclareFunctionProcedure {
795 #[must_use]
797 pub fn new(
798 store: Arc<DeclaredPluginStore>,
799 persistence: Arc<dyn Persistence>,
800 registry: Arc<PluginRegistry>,
801 ) -> Self {
802 Self {
803 store,
804 persistence,
805 registry,
806 }
807 }
808 }
809
810 impl ProcedurePlugin for DeclareFunctionProcedure {
811 fn signature(&self) -> &ProcedureSignature {
812 static SIG: std::sync::OnceLock<ProcedureSignature> = std::sync::OnceLock::new();
813 SIG.get_or_init(declare_function_signature)
814 }
815
816 fn invoke(
817 &self,
818 ctx: ProcedureContext<'_>,
819 args: &[ColumnarValue],
820 ) -> Result<SendableRecordBatchStream, FnError> {
821 let qname = extract_string(args, 0, "qname")?;
822 let body = extract_string(args, 1, "body")?;
823 let return_type = extract_string(args, 2, "return_type")?;
824 let arg_names_json = extract_string(args, 3, "arg_names_json")?;
825 let arg_names: Vec<String> = serde_json::from_str(&arg_names_json).map_err(|e| {
826 FnError::new(
827 FnError::CODE_TYPE_COERCION,
828 format!("declareFunction: arg_names_json parse: {e}"),
829 )
830 })?;
831 let dependencies = parse_deps(args, 4)?;
832 let declared_by = ctx
833 .principal
834 .map(|p| p.id.clone())
835 .unwrap_or_else(|| "anonymous".to_owned());
836
837 let record = DeclaredPlugin {
838 qname: qname.clone(),
839 kind: "function".to_owned(),
840 body,
841 signature_json: serde_json::to_string(&serde_json::json!({
842 "return_type": return_type,
843 "arg_names": arg_names,
844 }))
845 .unwrap_or_else(|_| "{}".to_owned()),
846 dependencies,
847 declared_by,
848 active: true,
849 };
850
851 self.store
852 .declare(record.clone())
853 .map_err(custom_to_fn_err)?;
854
855 match install_function_into_registry(&self.registry, &record) {
856 Ok(()) => {}
857 Err(CustomError::NativeShadow(_)) => {
858 let mut record = record.clone();
859 record.active = false;
860 self.store.replace(record.clone());
861 self.persistence.save(&record).map_err(|e| {
862 FnError::new(0xB20, format!("declareFunction persist: {e}"))
863 })?;
864 return single_bool("registered", false);
865 }
866 Err(e) => {
867 let _ = self.store.drop_declared(&qname);
869 return Err(custom_to_fn_err(e));
870 }
871 }
872
873 self.persistence
874 .save(&record)
875 .map_err(|e| FnError::new(0xB20, format!("declareFunction persist: {e}")))?;
876
877 single_bool("registered", true)
878 }
879 }
880
881 pub fn install_function_into_registry(
892 registry: &Arc<PluginRegistry>,
893 record: &DeclaredPlugin,
894 ) -> Result<(), CustomError> {
895 let parsed_body =
896 parse_expression(&record.body).map_err(|e| CustomError::BodyParse(format!("{e:?}")))?;
897 let sig_meta: serde_json::Value = serde_json::from_str(&record.signature_json)
898 .map_err(|e| CustomError::BodyParse(format!("signature_json: {e}")))?;
899 let return_type_str = sig_meta
900 .get("return_type")
901 .and_then(|v| v.as_str())
902 .unwrap_or("string");
903 let arg_names: Vec<String> = sig_meta
904 .get("arg_names")
905 .and_then(|v| v.as_array())
906 .map(|arr| {
907 arr.iter()
908 .filter_map(|v| v.as_str().map(str::to_owned))
909 .collect()
910 })
911 .unwrap_or_default();
912
913 let return_dt = type_str_to_arrow(return_type_str).ok_or_else(|| {
914 CustomError::BodyParse(format!("unknown return type `{return_type_str}`"))
915 })?;
916 let arg_pairs: Vec<(String, DataType)> = arg_names
917 .iter()
918 .map(|n| (n.clone(), DataType::Utf8))
919 .collect();
920 let signature = DeclaredScalarFn::build_signature(return_dt, &arg_pairs);
921 let scalar_fn = DeclaredScalarFn::new(parsed_body, arg_names, signature.clone());
922
923 let qname = QName::new(
927 declared_plugin_id(&record.qname),
928 local_part(&record.qname).to_ascii_lowercase(),
929 );
930 let plugin = SyntheticScalarPlugin {
931 plugin_id: PluginId::new(declared_plugin_id(&record.qname)),
932 qname,
933 signature,
934 function: Arc::new(scalar_fn) as Arc<dyn ScalarPluginFn>,
935 manifest: std::sync::OnceLock::new(),
936 };
937 let manifest = plugin.manifest().clone();
938 let caps = manifest.capabilities.clone();
939 let mut r = PluginRegistrar::new(manifest.id, &caps, registry);
940 plugin
941 .register(&mut r)
942 .map_err(|e| map_plugin_error(e, &record.qname))?;
943 r.commit_to_registry()
944 .map_err(|e| map_plugin_error(e, &record.qname))?;
945 Ok(())
946 }
947
948 pub(super) fn install_synthesized_procedure(
955 registry: &Arc<PluginRegistry>,
956 record: &DeclaredPlugin,
957 synthesizer: &dyn crate::ProcedureBodySynthesizer,
958 ) -> Result<(), CustomError> {
959 let plugin = synthesizer
960 .synthesize(record)
961 .map_err(CustomError::Registration)?;
962 let qname = QName::new(
963 declared_plugin_id(&record.qname),
964 local_part(&record.qname).to_ascii_lowercase(),
965 );
966 let signature = plugin.signature().clone();
967 let caps = {
968 let mut s = uni_plugin::CapabilitySet::new();
969 s.insert(uni_plugin::Capability::Procedure);
970 match signature.mode {
974 uni_plugin::traits::procedure::ProcedureMode::Write => {
975 s.insert(uni_plugin::Capability::ProcedureWrites);
976 }
977 uni_plugin::traits::procedure::ProcedureMode::Schema => {
978 s.insert(uni_plugin::Capability::ProcedureSchema);
979 }
980 uni_plugin::traits::procedure::ProcedureMode::Dbms => {
981 s.insert(uni_plugin::Capability::ProcedureDbms);
982 }
983 _ => {}
986 }
987 s
988 };
989 let plugin_id = uni_plugin::PluginId::new(declared_plugin_id(&record.qname));
990 let mut r = PluginRegistrar::new(plugin_id, &caps, registry);
991 r.procedure(qname, signature, plugin)
992 .map_err(|e| map_plugin_error(e, &record.qname))?;
993 r.commit_to_registry()
994 .map_err(|e| map_plugin_error(e, &record.qname))?;
995 Ok(())
996 }
997
998 struct SyntheticScalarPlugin {
1000 plugin_id: PluginId,
1001 qname: QName,
1002 signature: uni_plugin::traits::scalar::FnSignature,
1003 function: Arc<dyn ScalarPluginFn>,
1004 manifest: std::sync::OnceLock<PluginManifest>,
1009 }
1010
1011 impl std::fmt::Debug for SyntheticScalarPlugin {
1012 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1013 f.debug_struct("SyntheticScalarPlugin")
1014 .field("plugin_id", &self.plugin_id)
1015 .field("qname", &self.qname)
1016 .finish_non_exhaustive()
1017 }
1018 }
1019
1020 impl SyntheticScalarPlugin {
1021 fn build_manifest(&self) -> PluginManifest {
1022 PluginManifest {
1023 id: self.plugin_id.clone(),
1024 version: Version::new(0, 0, 1),
1025 abi: AbiRange::parse("^1").expect("manifest ABI range is valid"),
1026 depends_on: vec![],
1027 capabilities: CapabilitySet::from_iter_of([Capability::ScalarFn]),
1028 determinism: Determinism::Pure,
1029 side_effects: SideEffects::ReadOnly,
1030 scope: Scope::Instance,
1031 hash: None,
1032 signature: None,
1033 provides: ProvidedSurfaces::default(),
1034 docs: "Declared scalar function (apoc.custom analogue).".to_owned(),
1035 metadata: std::collections::BTreeMap::new(),
1036 }
1037 }
1038 }
1039
1040 impl Plugin for SyntheticScalarPlugin {
1041 fn manifest(&self) -> &PluginManifest {
1042 self.manifest.get_or_init(|| self.build_manifest())
1043 }
1044
1045 fn register(&self, r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
1046 r.scalar_fn(
1047 self.qname.clone(),
1048 self.signature.clone(),
1049 Arc::clone(&self.function),
1050 )?;
1051 Ok(())
1052 }
1053 }
1054
1055 #[derive(Debug)]
1071 pub struct DeclareAggregateProcedure {
1072 store: Arc<DeclaredPluginStore>,
1073 persistence: Arc<dyn Persistence>,
1074 registry: Arc<PluginRegistry>,
1075 }
1076
1077 impl DeclareAggregateProcedure {
1078 #[must_use]
1080 pub fn new(
1081 store: Arc<DeclaredPluginStore>,
1082 persistence: Arc<dyn Persistence>,
1083 registry: Arc<PluginRegistry>,
1084 ) -> Self {
1085 Self {
1086 store,
1087 persistence,
1088 registry,
1089 }
1090 }
1091 }
1092
1093 impl ProcedurePlugin for DeclareAggregateProcedure {
1094 fn signature(&self) -> &ProcedureSignature {
1095 static SIG: std::sync::OnceLock<ProcedureSignature> = std::sync::OnceLock::new();
1096 SIG.get_or_init(declare_aggregate_signature)
1097 }
1098
1099 fn invoke(
1100 &self,
1101 ctx: ProcedureContext<'_>,
1102 args: &[ColumnarValue],
1103 ) -> Result<SendableRecordBatchStream, FnError> {
1104 let qname = extract_string(args, 0, "qname")?;
1105 let init_src = extract_string(args, 1, "init_expr")?;
1106 let update_src = extract_string(args, 2, "update_expr")?;
1107 let finalize_src = extract_string(args, 3, "finalize_expr")?;
1108 let return_type = extract_string_or(args, 4, "float");
1109 let arg_names_json = extract_string_or(args, 5, "[]");
1110 let arg_names: Vec<String> = serde_json::from_str(&arg_names_json).map_err(|e| {
1111 FnError::new(
1112 FnError::CODE_TYPE_COERCION,
1113 format!("declareAggregate: arg_names_json parse: {e}"),
1114 )
1115 })?;
1116 let dependencies = parse_deps(args, 6)?;
1117 let declared_by = ctx
1118 .principal
1119 .map(|p| p.id.clone())
1120 .unwrap_or_else(|| "anonymous".to_owned());
1121
1122 let record = DeclaredPlugin {
1123 qname: qname.clone(),
1124 kind: "aggregate".to_owned(),
1125 body: update_src.clone(),
1129 signature_json: serde_json::to_string(&serde_json::json!({
1130 "init": init_src,
1131 "update": update_src,
1132 "finalize": finalize_src,
1133 "return_type": return_type,
1134 "arg_names": arg_names,
1135 }))
1136 .unwrap_or_else(|_| "{}".to_owned()),
1137 dependencies,
1138 declared_by,
1139 active: true,
1140 };
1141
1142 self.store
1143 .declare(record.clone())
1144 .map_err(custom_to_fn_err)?;
1145
1146 match crate::aggregate::install_aggregate_into_registry(&self.registry, &record) {
1147 Ok(()) => {}
1148 Err(CustomError::NativeShadow(_)) => {
1149 let mut record = record.clone();
1150 record.active = false;
1151 self.store.replace(record.clone());
1152 self.persistence.save(&record).map_err(|e| {
1153 FnError::new(0xB21, format!("declareAggregate persist: {e}"))
1154 })?;
1155 return single_bool("registered", false);
1156 }
1157 Err(e) => {
1158 let _ = self.store.drop_declared(&qname);
1159 return Err(custom_to_fn_err(e));
1160 }
1161 }
1162
1163 self.persistence
1164 .save(&record)
1165 .map_err(|e| FnError::new(0xB21, format!("declareAggregate persist: {e}")))?;
1166
1167 single_bool("registered", true)
1168 }
1169 }
1170
1171 macro_rules! declare_kind_procedure {
1178 ($name:ident, $sig_fn:ident, $kind:literal, $field_count:literal) => {
1179 #[derive(Debug)]
1187 pub struct $name {
1188 store: Arc<DeclaredPluginStore>,
1189 persistence: Arc<dyn Persistence>,
1190 registry: Arc<uni_plugin::PluginRegistry>,
1191 synthesizer:
1192 Option<Arc<dyn crate::ProcedureBodySynthesizer>>,
1193 }
1194
1195 impl $name {
1196 #[must_use]
1198 pub fn new(
1199 store: Arc<DeclaredPluginStore>,
1200 persistence: Arc<dyn Persistence>,
1201 ) -> Self {
1202 Self {
1203 store,
1204 persistence,
1205 registry: Arc::new(uni_plugin::PluginRegistry::new()),
1206 synthesizer: None,
1207 }
1208 }
1209
1210 #[must_use]
1214 pub fn new_with_synthesis(
1215 store: Arc<DeclaredPluginStore>,
1216 persistence: Arc<dyn Persistence>,
1217 registry: Arc<uni_plugin::PluginRegistry>,
1218 synthesizer: Arc<dyn crate::ProcedureBodySynthesizer>,
1219 ) -> Self {
1220 Self {
1221 store,
1222 persistence,
1223 registry,
1224 synthesizer: Some(synthesizer),
1225 }
1226 }
1227 }
1228
1229 impl ProcedurePlugin for $name {
1230 fn signature(&self) -> &ProcedureSignature {
1231 static SIG: std::sync::OnceLock<ProcedureSignature> =
1232 std::sync::OnceLock::new();
1233 SIG.get_or_init($sig_fn)
1234 }
1235
1236 fn invoke(
1237 &self,
1238 ctx: ProcedureContext<'_>,
1239 args: &[ColumnarValue],
1240 ) -> Result<SendableRecordBatchStream, FnError> {
1241 let qname = extract_string(args, 0, "qname")?;
1242 let sig_args = $sig_fn().args;
1247 let mut sig = serde_json::Map::new();
1248 for i in 1..($field_count - 1) {
1251 let key = sig_args
1252 .get(i)
1253 .map(|a| a.name.to_string())
1254 .unwrap_or_else(|| format!("arg{i}"));
1255 let v = extract_string(args, i, &key)?;
1256 sig.insert(key, serde_json::Value::String(v));
1257 }
1258 if $kind == "procedure" {
1266 if let Ok(mode_str) = extract_string(args, 2, "mode") {
1267 let mode_uc = mode_str.to_ascii_uppercase();
1268 if mode_uc == "WRITE" {
1269 let has_writes = ctx
1270 .principal
1271 .map(|p| {
1272 p.capabilities.contains_variant(
1273 &uni_plugin::Capability::ProcedureWrites,
1274 )
1275 })
1276 .unwrap_or(false);
1277 if !has_writes {
1278 return Err(FnError::new(
1279 0xB09,
1280 format!(
1281 "declareProcedure WRITE for `{qname}` denied: \
1282 principal lacks `Capability::ProcedureWrites`"
1283 ),
1284 ));
1285 }
1286 }
1287 sig.insert(
1288 "mode".to_owned(),
1289 serde_json::Value::String(mode_uc),
1290 );
1291 }
1292 }
1293 let dependencies = parse_deps(args, $field_count - 1)?;
1294 let declared_by = ctx
1295 .principal
1296 .map(|p| p.id.clone())
1297 .unwrap_or_else(|| "anonymous".to_owned());
1298 let body = sig_args
1302 .get(1)
1303 .map(|a| a.name.to_string())
1304 .and_then(|key| sig.get(&key))
1305 .and_then(|v| v.as_str())
1306 .unwrap_or("")
1307 .to_owned();
1308 let record = DeclaredPlugin {
1309 qname: qname.clone(),
1310 kind: $kind.to_owned(),
1311 body,
1312 signature_json: serde_json::to_string(&sig).unwrap_or_default(),
1313 dependencies,
1314 declared_by,
1315 active: true,
1316 };
1317 self.store
1318 .declare(record.clone())
1319 .map_err(custom_to_fn_err)?;
1320 self.persistence
1321 .save(&record)
1322 .map_err(|e| FnError::new(0xB30, format!("declare persist: {e}")))?;
1323 if let Some(synth) = self.synthesizer.as_ref() {
1328 if let Err(e) = crate::procedures::install_synthesized_procedure(
1329 &self.registry,
1330 &record,
1331 synth.as_ref(),
1332 ) {
1333 match e {
1338 CustomError::NativeShadow(_) => {
1339 let mut shadowed = record.clone();
1340 shadowed.active = false;
1341 self.store.replace(shadowed.clone());
1342 let _ = self.persistence.save(&shadowed);
1343 }
1344 other => {
1345 return Err(FnError::new(
1346 0xB31,
1347 format!("declare synthesize: {other}"),
1348 ));
1349 }
1350 }
1351 }
1352 }
1353 single_bool("registered", true)
1354 }
1355 }
1356 };
1357 }
1358
1359 declare_kind_procedure!(
1360 DeclareProcedureProcedure,
1361 declare_procedure_signature,
1362 "procedure",
1363 5
1364 );
1365 declare_kind_procedure!(
1366 DeclareTriggerProcedure,
1367 declare_trigger_signature,
1368 "trigger",
1369 4
1370 );
1371
1372 fn columnar_utf8(cv: &ColumnarValue) -> Option<String> {
1383 match cv {
1384 ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => Some(s.clone()),
1385 ColumnarValue::Array(arr) => arr
1386 .as_any()
1387 .downcast_ref::<StringArray>()
1388 .and_then(|a| a.iter().next().flatten().map(|s| s.to_owned())),
1389 _ => None,
1390 }
1391 }
1392
1393 fn extract_string_or(args: &[ColumnarValue], i: usize, default: &str) -> String {
1399 args.get(i)
1400 .and_then(columnar_utf8)
1401 .unwrap_or_else(|| default.to_owned())
1402 }
1403
1404 fn parse_deps(args: &[ColumnarValue], i: usize) -> Result<Vec<String>, FnError> {
1407 let raw = extract_string_or(args, i, "[]");
1408 serde_json::from_str::<Vec<String>>(&raw).map_err(|e| {
1409 FnError::new(
1410 FnError::CODE_TYPE_COERCION,
1411 format!("declare: deps_json parse: {e}"),
1412 )
1413 })
1414 }
1415
1416 fn extract_string(args: &[ColumnarValue], i: usize, name: &str) -> Result<String, FnError> {
1417 let cv = args.get(i).ok_or_else(|| {
1418 FnError::new(
1419 FnError::CODE_TYPE_COERCION,
1420 format!("declare procedure missing arg `{name}` at position {i}"),
1421 )
1422 })?;
1423 if let Some(s) = columnar_utf8(cv) {
1424 return Ok(s);
1425 }
1426 let msg = match cv {
1430 ColumnarValue::Scalar(ScalarValue::Utf8(None) | ScalarValue::Null) => {
1431 format!("declare procedure arg `{name}` was null")
1432 }
1433 _ => format!("declare procedure arg `{name}` not Utf8"),
1434 };
1435 Err(FnError::new(FnError::CODE_TYPE_COERCION, msg))
1436 }
1437
1438 fn single_bool(col: &str, v: bool) -> Result<SendableRecordBatchStream, FnError> {
1439 let schema: SchemaRef =
1440 Arc::new(Schema::new(vec![Field::new(col, DataType::Boolean, false)]));
1441 let arr: Arc<dyn Array> = Arc::new(BooleanArray::from(vec![v]));
1442 let batch = RecordBatch::try_new(Arc::clone(&schema), vec![arr])
1443 .map_err(|e| FnError::new(0xB02, format!("single bool: {e}")))?;
1444 Ok(Box::pin(RecordBatchStreamAdapter::new(
1445 schema,
1446 stream::iter(vec![Ok(batch)]),
1447 )))
1448 }
1449
1450 fn custom_to_fn_err(e: CustomError) -> FnError {
1451 let code = match &e {
1452 CustomError::DependencyCycle(_) => 0xB03,
1453 CustomError::DependencyMissing { .. } => 0xB04,
1454 CustomError::NativeShadow(_) => 0xB05,
1455 CustomError::BodyParse(_) => 0xB06,
1456 CustomError::Persistence(_) => 0xB07,
1457 CustomError::Registration(_) => 0xB08,
1458 CustomError::CapabilityDenied(_) => 0xB09,
1459 };
1460 FnError::new(code, e.to_string())
1461 }
1462}
1463
1464#[derive(Debug, Default)]
1474pub struct DeclaredPluginStore {
1475 by_qname: std::sync::RwLock<std::collections::BTreeMap<String, DeclaredPlugin>>,
1476}
1477
1478impl DeclaredPluginStore {
1479 #[must_use]
1481 pub fn new() -> Self {
1482 Self::default()
1483 }
1484
1485 pub fn declare(&self, plugin: DeclaredPlugin) -> Result<(), CustomError> {
1495 {
1496 let map = self.by_qname.read().expect("declared-plugin lock poisoned");
1497 for dep in &plugin.dependencies {
1498 if !map.contains_key(dep) {
1499 return Err(CustomError::DependencyMissing {
1500 dependent: plugin.qname.clone(),
1501 dep: dep.clone(),
1502 });
1503 }
1504 }
1505 if would_introduce_cycle(&map, &plugin) {
1506 return Err(CustomError::DependencyCycle(chain_starting_at(
1507 &map, &plugin,
1508 )));
1509 }
1510 }
1511 self.by_qname
1512 .write()
1513 .expect("declared-plugin lock poisoned")
1514 .insert(plugin.qname.clone(), plugin);
1515 Ok(())
1516 }
1517
1518 pub fn declare_unchecked(&self, plugin: DeclaredPlugin) {
1521 self.by_qname
1522 .write()
1523 .expect("declared-plugin lock poisoned")
1524 .insert(plugin.qname.clone(), plugin);
1525 }
1526
1527 #[must_use]
1529 pub fn get(&self, qname: &str) -> Option<DeclaredPlugin> {
1530 self.by_qname
1531 .read()
1532 .expect("declared-plugin lock poisoned")
1533 .get(qname)
1534 .cloned()
1535 }
1536
1537 pub fn drop_declared(&self, qname: &str) -> Result<bool, CustomError> {
1547 let mut map = self
1548 .by_qname
1549 .write()
1550 .expect("declared-plugin lock poisoned");
1551 for other in map.values() {
1552 if other.dependencies.iter().any(|d| d == qname) {
1553 return Err(CustomError::DependencyMissing {
1554 dependent: other.qname.clone(),
1555 dep: qname.to_owned(),
1556 });
1557 }
1558 }
1559 Ok(map.remove(qname).is_some())
1560 }
1561
1562 pub fn drop_cascade(&self, qname: &str) -> Vec<String> {
1567 let mut removed = Vec::new();
1568 let mut map = self
1569 .by_qname
1570 .write()
1571 .expect("declared-plugin lock poisoned");
1572 let mut stack = vec![qname.to_owned()];
1573 while let Some(target) = stack.pop() {
1574 let dependents: Vec<String> = map
1575 .iter()
1576 .filter(|(_, p)| p.dependencies.iter().any(|d| d == &target))
1577 .map(|(k, _)| k.clone())
1578 .collect();
1579 if dependents.is_empty() {
1580 if map.remove(&target).is_some() {
1581 removed.push(target);
1582 }
1583 } else {
1584 stack.push(target);
1585 for d in dependents {
1586 stack.push(d);
1587 }
1588 }
1589 }
1590 removed
1591 }
1592
1593 pub fn replace(&self, plugin: DeclaredPlugin) {
1596 self.declare_unchecked(plugin);
1597 }
1598
1599 #[must_use]
1601 pub fn list(&self) -> Vec<DeclaredPlugin> {
1602 self.by_qname
1603 .read()
1604 .expect("declared-plugin lock poisoned")
1605 .values()
1606 .cloned()
1607 .collect()
1608 }
1609}
1610
1611fn would_introduce_cycle(
1612 map: &std::collections::BTreeMap<String, DeclaredPlugin>,
1613 candidate: &DeclaredPlugin,
1614) -> bool {
1615 fn reachable(
1616 map: &std::collections::BTreeMap<String, DeclaredPlugin>,
1617 start: &str,
1618 target: &str,
1619 visited: &mut std::collections::BTreeSet<String>,
1620 ) -> bool {
1621 if start == target {
1622 return true;
1623 }
1624 if !visited.insert(start.to_owned()) {
1625 return false;
1626 }
1627 if let Some(node) = map.get(start) {
1628 for d in &node.dependencies {
1629 if reachable(map, d, target, visited) {
1630 return true;
1631 }
1632 }
1633 }
1634 false
1635 }
1636 let mut visited = std::collections::BTreeSet::new();
1637 candidate
1638 .dependencies
1639 .iter()
1640 .any(|d| reachable(map, d, &candidate.qname, &mut visited))
1641}
1642
1643fn chain_starting_at(
1652 map: &std::collections::BTreeMap<String, DeclaredPlugin>,
1653 candidate: &DeclaredPlugin,
1654) -> Vec<String> {
1655 fn dfs(
1656 map: &std::collections::BTreeMap<String, DeclaredPlugin>,
1657 node: &str,
1658 target: &str,
1659 stack: &mut Vec<String>,
1660 visited: &mut std::collections::BTreeSet<String>,
1661 ) -> bool {
1662 stack.push(node.to_owned());
1663 if node == target {
1664 return true;
1665 }
1666 if !visited.insert(node.to_owned()) {
1667 stack.pop();
1668 return false;
1669 }
1670 if let Some(declared) = map.get(node) {
1671 for dep in &declared.dependencies {
1672 if dfs(map, dep, target, stack, visited) {
1673 return true;
1674 }
1675 }
1676 }
1677 stack.pop();
1678 false
1679 }
1680
1681 let mut visited = std::collections::BTreeSet::new();
1682 for dep in &candidate.dependencies {
1683 let mut stack = vec![candidate.qname.clone()];
1684 if dfs(map, dep, &candidate.qname, &mut stack, &mut visited) {
1685 return stack;
1686 }
1687 }
1688 vec![candidate.qname.clone()]
1689}
1690
1691#[cfg(test)]
1692mod tests {
1693 use super::*;
1694
1695 #[test]
1696 fn declared_plugin_round_trip_json() {
1697 let d = DeclaredPlugin {
1698 qname: "mycorp.fullName".to_owned(),
1699 kind: "function".to_owned(),
1700 body: "$first + ' ' + $last".to_owned(),
1701 signature_json: r#"{"args":["string","string"],"returns":"string"}"#.to_owned(),
1702 dependencies: vec![],
1703 declared_by: "alice".to_owned(),
1704 active: true,
1705 };
1706 let s = serde_json::to_string(&d).unwrap();
1707 let parsed: DeclaredPlugin = serde_json::from_str(&s).unwrap();
1708 assert_eq!(d, parsed);
1709 }
1710
1711 #[test]
1712 fn custom_plugin_constructs_in_memory() {
1713 let _ = CustomPlugin::new_in_memory();
1714 }
1715
1716 #[derive(Debug)]
1722 struct StubSynthesizer {
1723 synthesized_count: std::sync::atomic::AtomicUsize,
1724 }
1725
1726 impl StubSynthesizer {
1727 fn new() -> Self {
1728 Self {
1729 synthesized_count: std::sync::atomic::AtomicUsize::new(0),
1730 }
1731 }
1732
1733 fn count(&self) -> usize {
1734 self.synthesized_count
1735 .load(std::sync::atomic::Ordering::SeqCst)
1736 }
1737 }
1738
1739 impl crate::ProcedureBodySynthesizer for StubSynthesizer {
1740 fn synthesize(
1741 &self,
1742 _decl: &DeclaredPlugin,
1743 ) -> Result<Arc<dyn uni_plugin::traits::procedure::ProcedurePlugin>, String> {
1744 self.synthesized_count
1745 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
1746 Ok(Arc::new(StubProcedure {
1747 signature: stub_signature(),
1748 }))
1749 }
1750 }
1751
1752 #[derive(Debug)]
1753 struct StubProcedure {
1754 signature: uni_plugin::traits::procedure::ProcedureSignature,
1755 }
1756
1757 fn stub_signature() -> uni_plugin::traits::procedure::ProcedureSignature {
1758 use arrow_schema::{DataType, Field};
1759 uni_plugin::traits::procedure::ProcedureSignature {
1760 args: vec![],
1761 yields: vec![Field::new("ok", DataType::Boolean, false)],
1762 mode: uni_plugin::traits::procedure::ProcedureMode::Read,
1763 side_effects: uni_plugin::SideEffects::ReadOnly,
1764 retry_contract: None,
1765 batch_input: None,
1766 docs: "stub".to_owned(),
1767 }
1768 }
1769
1770 impl uni_plugin::traits::procedure::ProcedurePlugin for StubProcedure {
1771 fn signature(&self) -> &uni_plugin::traits::procedure::ProcedureSignature {
1772 &self.signature
1773 }
1774
1775 fn invoke(
1776 &self,
1777 _ctx: uni_plugin::traits::procedure::ProcedureContext<'_>,
1778 _args: &[datafusion::logical_expr::ColumnarValue],
1779 ) -> Result<datafusion::execution::SendableRecordBatchStream, uni_plugin::FnError> {
1780 unimplemented!(
1781 "StubProcedure does not execute; the synthesizer test only checks registration"
1782 )
1783 }
1784 }
1785
1786 #[test]
1787 fn synthesizer_synthesize_called_on_reactivate() {
1788 let synth = Arc::new(StubSynthesizer::new());
1789 let store = Arc::new(DeclaredPluginStore::new());
1790 store
1792 .declare(DeclaredPlugin {
1793 qname: "mycorp.findFriends".to_owned(),
1794 kind: "procedure".to_owned(),
1795 body: "MATCH (p)-[:KNOWS]->(f) RETURN f".to_owned(),
1796 signature_json: "{}".to_owned(),
1797 dependencies: vec![],
1798 declared_by: "test".to_owned(),
1799 active: true,
1800 })
1801 .unwrap();
1802
1803 let registry = Arc::new(uni_plugin::PluginRegistry::new());
1804 let plugin = CustomPlugin {
1808 store: Arc::clone(&store),
1809 registry: Arc::clone(®istry),
1810 persistence: Arc::new(NullPersistence),
1811 procedure_synthesizer: Some(synth.clone()),
1812 manifest: std::sync::OnceLock::new(),
1813 };
1814 plugin
1815 .reactivate_into_registry()
1816 .expect("reactivate must call synthesizer for procedure-kind records");
1817 assert_eq!(
1818 synth.count(),
1819 1,
1820 "synthesizer should have been called for the one procedure declaration"
1821 );
1822 }
1823
1824 #[test]
1825 fn reactivate_skips_procedure_when_no_synthesizer() {
1826 let store = Arc::new(DeclaredPluginStore::new());
1827 store
1828 .declare(DeclaredPlugin {
1829 qname: "mycorp.findFriends".to_owned(),
1830 kind: "procedure".to_owned(),
1831 body: "MATCH (p)-[:KNOWS]->(f) RETURN f".to_owned(),
1832 signature_json: "{}".to_owned(),
1833 dependencies: vec![],
1834 declared_by: "test".to_owned(),
1835 active: true,
1836 })
1837 .unwrap();
1838
1839 let registry = Arc::new(uni_plugin::PluginRegistry::new());
1840 let plugin = CustomPlugin {
1841 store,
1842 registry,
1843 persistence: Arc::new(NullPersistence),
1844 procedure_synthesizer: None, manifest: std::sync::OnceLock::new(),
1846 };
1847 plugin
1848 .reactivate_into_registry()
1849 .expect("reactivate must succeed even with procedure records when no synthesizer");
1850 }
1853
1854 fn utf8_scalar(s: &str) -> datafusion::logical_expr::ColumnarValue {
1857 datafusion::logical_expr::ColumnarValue::Scalar(datafusion::scalar::ScalarValue::Utf8(
1858 Some(s.to_owned()),
1859 ))
1860 }
1861
1862 fn drive_declare_procedure(
1863 args: &[datafusion::logical_expr::ColumnarValue],
1864 principal: Option<&uni_plugin::traits::connector::Principal>,
1865 ) -> Result<(), uni_plugin::FnError> {
1866 let store = Arc::new(DeclaredPluginStore::new());
1867 let decl = procedures::DeclareProcedureProcedure::new(store, Arc::new(NullPersistence));
1868 let mut ctx = uni_plugin::traits::procedure::ProcedureContext::new();
1869 if let Some(p) = principal {
1870 ctx = ctx.with_principal(p);
1871 }
1872 use uni_plugin::traits::procedure::ProcedurePlugin;
1873 decl.invoke(ctx, args).map(|_| ())
1874 }
1875
1876 #[test]
1877 fn declare_procedure_write_rejected_without_procedure_writes() {
1878 let args = vec![
1879 utf8_scalar("mycorp.deleteAll"),
1880 utf8_scalar("MATCH (n) DETACH DELETE n"),
1881 utf8_scalar("WRITE"),
1882 utf8_scalar("[]"),
1883 utf8_scalar("[]"),
1884 ];
1885 let p = uni_plugin::traits::connector::Principal {
1886 id: "alice".to_owned(),
1887 groups: vec![],
1888 capabilities: uni_plugin::CapabilitySet::new(),
1889 };
1890 let err = drive_declare_procedure(&args, Some(&p))
1891 .expect_err("WRITE without ProcedureWrites must fail");
1892 assert_eq!(err.code, 0xB09, "expected capability-denied code 0xB09");
1893 }
1894
1895 #[test]
1896 fn declare_procedure_write_allowed_with_procedure_writes() {
1897 let args = vec![
1898 utf8_scalar("mycorp.deleteAll"),
1899 utf8_scalar("MATCH (n) DETACH DELETE n"),
1900 utf8_scalar("WRITE"),
1901 utf8_scalar("[]"),
1902 utf8_scalar("[]"),
1903 ];
1904 let mut caps = uni_plugin::CapabilitySet::new();
1905 caps.insert(uni_plugin::Capability::ProcedureWrites);
1906 let p = uni_plugin::traits::connector::Principal {
1907 id: "admin".to_owned(),
1908 groups: vec!["admin".to_owned()],
1909 capabilities: caps,
1910 };
1911 drive_declare_procedure(&args, Some(&p)).expect("WRITE with ProcedureWrites must succeed");
1912 }
1913
1914 #[test]
1915 fn declare_procedure_read_does_not_require_procedure_writes() {
1916 let args = vec![
1917 utf8_scalar("mycorp.findFriends"),
1918 utf8_scalar("MATCH (p)-[:KNOWS]->(f) RETURN f"),
1919 utf8_scalar("READ"),
1920 utf8_scalar("[]"),
1921 utf8_scalar("[]"),
1922 ];
1923 let p = uni_plugin::traits::connector::Principal::anonymous();
1924 drive_declare_procedure(&args, Some(&p))
1925 .expect("READ mode declaration must not require ProcedureWrites");
1926 }
1927
1928 fn make(qname: &str, deps: &[&str]) -> DeclaredPlugin {
1929 DeclaredPlugin {
1930 qname: qname.to_owned(),
1931 kind: "function".to_owned(),
1932 body: String::new(),
1933 signature_json: "{}".to_owned(),
1934 dependencies: deps.iter().map(|s| s.to_string()).collect(),
1935 declared_by: "test".to_owned(),
1936 active: true,
1937 }
1938 }
1939
1940 #[test]
1941 fn store_declare_and_get() {
1942 let s = DeclaredPluginStore::new();
1943 s.declare(make("a.foo", &[])).unwrap();
1944 assert_eq!(s.get("a.foo").unwrap().qname, "a.foo");
1945 }
1946
1947 #[test]
1948 fn store_rejects_missing_dependency() {
1949 let s = DeclaredPluginStore::new();
1950 match s.declare(make("a.foo", &["a.bar"])) {
1951 Err(CustomError::DependencyMissing { dependent, dep }) => {
1952 assert_eq!(dependent, "a.foo");
1953 assert_eq!(dep, "a.bar");
1954 }
1955 other => panic!("expected DependencyMissing, got {other:?}"),
1956 }
1957 }
1958
1959 #[test]
1960 fn store_detects_cycle() {
1961 let s = DeclaredPluginStore::new();
1962 s.declare(make("a", &[])).unwrap();
1963 s.declare(make("b", &["a"])).unwrap();
1964 match s.declare(make("a", &["b"])) {
1965 Err(CustomError::DependencyCycle(_)) => {}
1966 other => panic!("expected DependencyCycle, got {other:?}"),
1967 }
1968 }
1969
1970 #[test]
1971 fn store_protects_against_drop_with_dependents() {
1972 let s = DeclaredPluginStore::new();
1973 s.declare(make("a", &[])).unwrap();
1974 s.declare(make("b", &["a"])).unwrap();
1975 assert!(s.drop_declared("a").is_err());
1976 assert!(s.drop_declared("b").unwrap());
1977 assert!(s.drop_declared("a").unwrap());
1978 }
1979
1980 #[test]
1981 fn store_cascade_removes_dependents() {
1982 let s = DeclaredPluginStore::new();
1983 s.declare(make("a", &[])).unwrap();
1984 s.declare(make("b", &["a"])).unwrap();
1985 s.declare(make("c", &["b"])).unwrap();
1986 let removed = s.drop_cascade("a");
1987 assert_eq!(removed.len(), 3);
1988 assert!(removed.iter().any(|q| q == "a"));
1989 assert!(removed.iter().any(|q| q == "b"));
1990 assert!(removed.iter().any(|q| q == "c"));
1991 assert!(s.list().is_empty());
1992 }
1993
1994 #[test]
1995 fn store_list_returns_all_declared() {
1996 let s = DeclaredPluginStore::new();
1997 s.declare(make("x", &[])).unwrap();
1998 s.declare(make("y", &[])).unwrap();
1999 assert_eq!(s.list().len(), 2);
2000 }
2001}