use std::collections::HashMap;
use selene_core::{DbString, GraphId, Value, db_string};
use crate::ProcedureContext;
use crate::procedure_registry::{
ProcedureError, ProcedureHandle, ProcedureMetadata, ProcedureRegistry, ProcedureResult,
};
use crate::runtime::builtins::{BUILTIN_SPECS, BuiltinKind};
use crate::runtime::native_algorithms::{ALGO_SPECS, AlgoKind, AlgorithmCatalogs, forget_graph};
const REGISTRY_VERSION: u64 = 0;
#[derive(Clone, Copy, Debug)]
enum Dispatch {
Algo(AlgoKind),
Builtin(BuiltinKind),
}
#[derive(Debug)]
pub struct BuiltinProcedureRegistry {
by_name: HashMap<Box<[DbString]>, ProcedureMetadata>,
by_handle: HashMap<ProcedureHandle, Dispatch>,
ordered: Vec<(Vec<DbString>, ProcedureMetadata)>,
catalogs: AlgorithmCatalogs,
}
impl BuiltinProcedureRegistry {
#[must_use]
pub fn new() -> Self {
let mut by_name = HashMap::new();
let mut by_handle = HashMap::new();
let mut ordered = Vec::new();
let mut next_handle = 1_u64;
for spec in &ALGO_SPECS {
let handle = ProcedureHandle::new(next_handle);
next_handle += 1;
let name = procedure_name_segments(spec.name);
let metadata = spec.kind.metadata(handle, spec.description);
by_handle.insert(handle, Dispatch::Algo(spec.kind));
by_name.insert(name.clone().into_boxed_slice(), metadata.clone());
ordered.push((name, metadata));
}
for spec in &BUILTIN_SPECS {
let handle = ProcedureHandle::new(next_handle);
next_handle += 1;
let name = procedure_name_segments(spec.name);
let metadata = spec
.kind
.metadata(handle, spec.description, spec.since_version);
by_handle.insert(handle, Dispatch::Builtin(spec.kind));
by_name.insert(name.clone().into_boxed_slice(), metadata.clone());
ordered.push((name, metadata));
}
Self {
by_name,
by_handle,
ordered,
catalogs: AlgorithmCatalogs::default(),
}
}
pub fn forget_graph(&self, graph_id: GraphId) -> bool {
forget_graph(&self.catalogs, graph_id)
}
}
impl Default for BuiltinProcedureRegistry {
fn default() -> Self {
Self::new()
}
}
impl ProcedureRegistry for BuiltinProcedureRegistry {
fn lookup(&self, name: &[DbString]) -> Option<ProcedureMetadata> {
self.by_name.get(name).cloned()
}
fn registry_version(&self) -> u64 {
REGISTRY_VERSION
}
fn iter_handles(&self) -> Box<dyn Iterator<Item = (Vec<DbString>, ProcedureMetadata)> + '_> {
Box::new(
self.ordered
.iter()
.map(|(name, metadata)| (name.clone(), metadata.clone())),
)
}
fn execute(
&self,
handle: ProcedureHandle,
args: &[Value],
ctx: &mut ProcedureContext<'_, '_>,
) -> Result<ProcedureResult, ProcedureError> {
let _span = tracing::span!(
tracing::Level::INFO,
"selene.procedure.dispatch",
procedure = tracing::field::Empty
)
.entered();
let Some(dispatch) = self.by_handle.get(&handle).copied() else {
return Err(ProcedureError::UnknownProcedure { name: Box::new([]) });
};
match dispatch {
Dispatch::Algo(kind) => {
tracing::Span::current()
.record("procedure", tracing::field::display(procedure_name(kind)));
let ProcedureContext::Graph(graph_ctx) = ctx else {
return Err(ProcedureError::TierMismatch {
expected: crate::ProcedureTier::Graph,
actual: ctx.tier(),
});
};
kind.execute(&self.catalogs, graph_ctx, args)
}
Dispatch::Builtin(kind) => {
tracing::Span::current()
.record("procedure", tracing::field::display(builtin_name(kind)));
match (kind.tier(), ctx) {
(crate::ProcedureTier::Graph, ProcedureContext::Graph(graph_ctx)) => {
kind.execute_graph(graph_ctx, args)
}
(crate::ProcedureTier::Mutation, ProcedureContext::Mutation(mut_ctx)) => {
kind.execute_mutation(mut_ctx, args)
}
(
crate::ProcedureTier::Maintenance,
ProcedureContext::Maintenance(maintenance_ctx),
) => kind.execute_maintenance(maintenance_ctx, args),
(expected, ctx) => Err(ProcedureError::TierMismatch {
expected,
actual: ctx.tier(),
}),
}
}
}
}
}
fn procedure_name(kind: AlgoKind) -> String {
ALGO_SPECS
.iter()
.find(|spec| spec.kind == kind)
.map_or_else(String::new, |spec| spec.name.join("."))
}
fn builtin_name(kind: BuiltinKind) -> String {
BUILTIN_SPECS
.iter()
.find(|spec| spec.kind == kind)
.map_or_else(String::new, |spec| spec.name.join("."))
}
fn procedure_name_segments(raw: &'static [&'static str]) -> Vec<DbString> {
raw.iter()
.map(|segment| {
db_string(segment).expect("static procedure name segment fits DB string cap")
})
.collect()
}
#[cfg(test)]
#[path = "builtin_registry_surface_tests.rs"]
mod surface_tests;
#[cfg(test)]
#[path = "builtin_registry_tests.rs"]
mod tests;