use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::runnable::Runnable;
use crate::{CognisError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum RunnableDefinition {
Pipe {
a: Box<RunnableDefinition>,
b: Box<RunnableDefinition>,
},
Each {
inner: Box<RunnableDefinition>,
},
Passthrough,
Lambda {
name: String,
},
Opaque {
name: String,
params: serde_json::Value,
},
}
pub trait Serializable {
fn to_definition(&self) -> RunnableDefinition;
}
type Factory<I, O> = Box<dyn Fn(&serde_json::Value) -> Arc<dyn Runnable<I, O>> + Send + Sync>;
pub struct Loader<I, O>
where
I: Send + 'static,
O: Send + 'static,
{
factories: HashMap<String, Factory<I, O>>,
}
impl<I, O> Default for Loader<I, O>
where
I: Send + 'static,
O: Send + 'static,
{
fn default() -> Self {
Self {
factories: HashMap::new(),
}
}
}
impl<I, O> Loader<I, O>
where
I: Send + 'static,
O: Send + 'static,
{
pub fn new() -> Self {
Self::default()
}
pub fn register<F>(&mut self, name: impl Into<String>, factory: F)
where
F: Fn(&serde_json::Value) -> Arc<dyn Runnable<I, O>> + Send + Sync + 'static,
{
self.factories.insert(name.into(), Box::new(factory));
}
pub fn load(&self, def: &RunnableDefinition) -> Result<Arc<dyn Runnable<I, O>>> {
match def {
RunnableDefinition::Lambda { name } | RunnableDefinition::Opaque { name, .. } => {
let params = match def {
RunnableDefinition::Opaque { params, .. } => params.clone(),
_ => serde_json::Value::Null,
};
self.factories.get(name).map(|f| f(¶ms)).ok_or_else(|| {
CognisError::Configuration(format!(
"Loader: no factory registered for `{name}`"
))
})
}
other => Err(CognisError::Configuration(format!(
"Loader::load: composite definitions ({other:?}) must be reconstructed by caller — \
use the inner definitions to build the composition explicitly"
))),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runnable::RunnableConfig;
use async_trait::async_trait;
struct Const(u32);
#[async_trait]
impl Runnable<u32, u32> for Const {
async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
Ok(self.0)
}
}
#[test]
fn roundtrip_definition_serde() {
let d = RunnableDefinition::Pipe {
a: Box::new(RunnableDefinition::Lambda {
name: "step_a".into(),
}),
b: Box::new(RunnableDefinition::Passthrough),
};
let s = serde_json::to_string(&d).unwrap();
let back: RunnableDefinition = serde_json::from_str(&s).unwrap();
assert!(matches!(back, RunnableDefinition::Pipe { .. }));
}
#[tokio::test]
async fn loader_resolves_named_factory() {
let mut loader = Loader::<u32, u32>::new();
loader.register("k", |_| Arc::new(Const(7)));
let r = loader
.load(&RunnableDefinition::Lambda { name: "k".into() })
.unwrap();
assert_eq!(r.invoke(0, RunnableConfig::default()).await.unwrap(), 7);
}
#[tokio::test]
async fn loader_unknown_factory_errors() {
let loader = Loader::<u32, u32>::new();
assert!(loader
.load(&RunnableDefinition::Lambda {
name: "nope".into(),
})
.is_err());
}
}