Skip to main content

cognis_core/
serialization.rs

1//! Minimal Runnable serialization layer.
2//!
3//! V1 had a heavy "serializable runnable" hierarchy modeled after
4//! Python's serialize-everything pattern. V2 takes a smaller,
5//! Rust-native shape:
6//!
7//! - [`RunnableDefinition`] is a tagged enum describing a runnable's
8//!   shape (kind + config).
9//! - The [`Serializable`] trait (`Runnable + Serializable`) lets a
10//!   runnable emit its definition.
11//! - [`Loader`] reconstructs runnables from definitions via a registry
12//!   of named factories. Lambda closures are by name only — caller must
13//!   register the lambda factory under the same name used during dump.
14//!
15//! This is enough to ship/restore composed pipelines (Pipe/Each/Pipe of
16//! Pipe) without a full serialize-everything story.
17
18use std::collections::HashMap;
19use std::sync::Arc;
20
21use serde::{Deserialize, Serialize};
22
23use crate::runnable::Runnable;
24use crate::{CognisError, Result};
25
26/// Serializable description of a runnable's shape.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28#[serde(tag = "kind", rename_all = "snake_case")]
29pub enum RunnableDefinition {
30    /// Sequential `a.pipe(b)`.
31    Pipe {
32        /// Left-side definition.
33        a: Box<RunnableDefinition>,
34        /// Right-side definition.
35        b: Box<RunnableDefinition>,
36    },
37    /// Element-wise wrapper.
38    Each {
39        /// Inner definition.
40        inner: Box<RunnableDefinition>,
41    },
42    /// Identity runnable.
43    Passthrough,
44    /// Named lambda — the caller's `Loader` must register a factory
45    /// under `name`.
46    Lambda {
47        /// Registered factory name.
48        name: String,
49    },
50    /// Free-form opaque definition. Runnables that can't fully describe
51    /// themselves emit this; the caller must know how to rebuild from
52    /// `name` + `params`.
53    Opaque {
54        /// Type identifier for the loader.
55        name: String,
56        /// Arbitrary config payload.
57        params: serde_json::Value,
58    },
59}
60
61/// Trait runnables implement to participate in dump/load. Default is
62/// `Opaque { name = type name }` — override to capture useful structure.
63pub trait Serializable {
64    /// Emit a definition describing this runnable's shape.
65    fn to_definition(&self) -> RunnableDefinition;
66}
67
68/// One factory in a `Loader`: takes opaque params, produces a runnable.
69type Factory<I, O> = Box<dyn Fn(&serde_json::Value) -> Arc<dyn Runnable<I, O>> + Send + Sync>;
70
71/// Reconstructs runnables from definitions. Caller registers factories
72/// by name for each `Lambda` / `Opaque` kind they expect to load.
73pub struct Loader<I, O>
74where
75    I: Send + 'static,
76    O: Send + 'static,
77{
78    factories: HashMap<String, Factory<I, O>>,
79}
80
81impl<I, O> Default for Loader<I, O>
82where
83    I: Send + 'static,
84    O: Send + 'static,
85{
86    fn default() -> Self {
87        Self {
88            factories: HashMap::new(),
89        }
90    }
91}
92
93impl<I, O> Loader<I, O>
94where
95    I: Send + 'static,
96    O: Send + 'static,
97{
98    /// Empty loader.
99    pub fn new() -> Self {
100        Self::default()
101    }
102
103    /// Register a factory for a named lambda or opaque kind. The factory
104    /// receives the params payload (`serde_json::Value::Null` for
105    /// lambdas) and must produce an `Arc<dyn Runnable<I, O>>`.
106    pub fn register<F>(&mut self, name: impl Into<String>, factory: F)
107    where
108        F: Fn(&serde_json::Value) -> Arc<dyn Runnable<I, O>> + Send + Sync + 'static,
109    {
110        self.factories.insert(name.into(), Box::new(factory));
111    }
112
113    /// Reconstruct a `Runnable<I, O>` from its definition.
114    pub fn load(&self, def: &RunnableDefinition) -> Result<Arc<dyn Runnable<I, O>>> {
115        match def {
116            RunnableDefinition::Lambda { name } | RunnableDefinition::Opaque { name, .. } => {
117                let params = match def {
118                    RunnableDefinition::Opaque { params, .. } => params.clone(),
119                    _ => serde_json::Value::Null,
120                };
121                self.factories.get(name).map(|f| f(&params)).ok_or_else(|| {
122                    CognisError::Configuration(format!(
123                        "Loader: no factory registered for `{name}`"
124                    ))
125                })
126            }
127            // Composite kinds aren't directly buildable here — they need
128            // type-aligned inner definitions, which a generic loader
129            // can't enforce. Callers reconstruct composites via their
130            // own builder using the inner definitions.
131            other => Err(CognisError::Configuration(format!(
132                "Loader::load: composite definitions ({other:?}) must be reconstructed by caller — \
133                 use the inner definitions to build the composition explicitly"
134            ))),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::runnable::RunnableConfig;
143    use async_trait::async_trait;
144
145    struct Const(u32);
146    #[async_trait]
147    impl Runnable<u32, u32> for Const {
148        async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
149            Ok(self.0)
150        }
151    }
152
153    #[test]
154    fn roundtrip_definition_serde() {
155        let d = RunnableDefinition::Pipe {
156            a: Box::new(RunnableDefinition::Lambda {
157                name: "step_a".into(),
158            }),
159            b: Box::new(RunnableDefinition::Passthrough),
160        };
161        let s = serde_json::to_string(&d).unwrap();
162        let back: RunnableDefinition = serde_json::from_str(&s).unwrap();
163        assert!(matches!(back, RunnableDefinition::Pipe { .. }));
164    }
165
166    #[tokio::test]
167    async fn loader_resolves_named_factory() {
168        let mut loader = Loader::<u32, u32>::new();
169        loader.register("k", |_| Arc::new(Const(7)));
170        let r = loader
171            .load(&RunnableDefinition::Lambda { name: "k".into() })
172            .unwrap();
173        assert_eq!(r.invoke(0, RunnableConfig::default()).await.unwrap(), 7);
174    }
175
176    #[tokio::test]
177    async fn loader_unknown_factory_errors() {
178        let loader = Loader::<u32, u32>::new();
179        assert!(loader
180            .load(&RunnableDefinition::Lambda {
181                name: "nope".into(),
182            })
183            .is_err());
184    }
185}