cognis_core/
serialization.rs1use std::collections::HashMap;
19use std::sync::Arc;
20
21use serde::{Deserialize, Serialize};
22
23use crate::runnable::Runnable;
24use crate::{CognisError, Result};
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "kind", rename_all = "snake_case")]
29pub enum RunnableDefinition {
30 Pipe {
32 a: Box<RunnableDefinition>,
34 b: Box<RunnableDefinition>,
36 },
37 Each {
39 inner: Box<RunnableDefinition>,
41 },
42 Passthrough,
44 Lambda {
47 name: String,
49 },
50 Opaque {
54 name: String,
56 params: serde_json::Value,
58 },
59}
60
61pub trait Serializable {
64 fn to_definition(&self) -> RunnableDefinition;
66}
67
68type Factory<I, O> = Box<dyn Fn(&serde_json::Value) -> Arc<dyn Runnable<I, O>> + Send + Sync>;
70
71pub struct Loader<I, O>
74where
75 I: Send + 'static,
76 O: Send + 'static,
77{
78 factories: HashMap<String, Factory<I, O>>,
79}
80
81impl<I, O> Default for Loader<I, O>
82where
83 I: Send + 'static,
84 O: Send + 'static,
85{
86 fn default() -> Self {
87 Self {
88 factories: HashMap::new(),
89 }
90 }
91}
92
93impl<I, O> Loader<I, O>
94where
95 I: Send + 'static,
96 O: Send + 'static,
97{
98 pub fn new() -> Self {
100 Self::default()
101 }
102
103 pub fn register<F>(&mut self, name: impl Into<String>, factory: F)
107 where
108 F: Fn(&serde_json::Value) -> Arc<dyn Runnable<I, O>> + Send + Sync + 'static,
109 {
110 self.factories.insert(name.into(), Box::new(factory));
111 }
112
113 pub fn load(&self, def: &RunnableDefinition) -> Result<Arc<dyn Runnable<I, O>>> {
115 match def {
116 RunnableDefinition::Lambda { name } | RunnableDefinition::Opaque { name, .. } => {
117 let params = match def {
118 RunnableDefinition::Opaque { params, .. } => params.clone(),
119 _ => serde_json::Value::Null,
120 };
121 self.factories.get(name).map(|f| f(¶ms)).ok_or_else(|| {
122 CognisError::Configuration(format!(
123 "Loader: no factory registered for `{name}`"
124 ))
125 })
126 }
127 other => Err(CognisError::Configuration(format!(
132 "Loader::load: composite definitions ({other:?}) must be reconstructed by caller — \
133 use the inner definitions to build the composition explicitly"
134 ))),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::runnable::RunnableConfig;
143 use async_trait::async_trait;
144
145 struct Const(u32);
146 #[async_trait]
147 impl Runnable<u32, u32> for Const {
148 async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
149 Ok(self.0)
150 }
151 }
152
153 #[test]
154 fn roundtrip_definition_serde() {
155 let d = RunnableDefinition::Pipe {
156 a: Box::new(RunnableDefinition::Lambda {
157 name: "step_a".into(),
158 }),
159 b: Box::new(RunnableDefinition::Passthrough),
160 };
161 let s = serde_json::to_string(&d).unwrap();
162 let back: RunnableDefinition = serde_json::from_str(&s).unwrap();
163 assert!(matches!(back, RunnableDefinition::Pipe { .. }));
164 }
165
166 #[tokio::test]
167 async fn loader_resolves_named_factory() {
168 let mut loader = Loader::<u32, u32>::new();
169 loader.register("k", |_| Arc::new(Const(7)));
170 let r = loader
171 .load(&RunnableDefinition::Lambda { name: "k".into() })
172 .unwrap();
173 assert_eq!(r.invoke(0, RunnableConfig::default()).await.unwrap(), 7);
174 }
175
176 #[tokio::test]
177 async fn loader_unknown_factory_errors() {
178 let loader = Loader::<u32, u32>::new();
179 assert!(loader
180 .load(&RunnableDefinition::Lambda {
181 name: "nope".into(),
182 })
183 .is_err());
184 }
185}