Skip to main content

cognis_core/
runnable.rs

1//! The unified `Runnable<I, O>` trait + per-call configuration.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use async_trait::async_trait;
7use futures::stream::{self, StreamExt};
8use uuid::Uuid;
9
10use crate::extensions::Extensions;
11use crate::stream::Observer;
12
13/// Per-invocation configuration. Defaults are sensible; override only
14/// what you need.
15#[derive(Clone)]
16pub struct RunnableConfig {
17    /// Maximum number of graph supersteps / chain depth before erroring with
18    /// `CognisError::RecursionLimit`.
19    pub recursion_limit: u32,
20    /// Maximum concurrent in-flight tasks (used by `batch` and parallel nodes).
21    pub max_concurrency: usize,
22    /// Free-form telemetry tags (e.g. ["production", "feature/foo"]).
23    pub tags: Vec<String>,
24    /// User-supplied metadata, attached to every emitted Event.
25    pub metadata: serde_json::Value,
26    /// Event subscribers. Multiple are allowed; each receives every event.
27    pub observers: Vec<Arc<dyn Observer>>,
28    /// Correlation ID. Defaults to a fresh UUID per `Default::default()`.
29    pub run_id: Uuid,
30    /// Cooperative cancellation token.
31    pub cancel_token: Option<tokio_util::sync::CancellationToken>,
32    /// Hard deadline (if set, framework checks it at every superstep boundary).
33    pub deadline: Option<Instant>,
34    /// Plugin-supplied typed payloads.
35    pub extras: Extensions,
36    /// Parent observation id for trace nesting. Set by composition sites
37    /// (Pipe, batch, graph engine) when invoking a sub-runnable. Defaults
38    /// to `None` for top-level invocations.
39    pub parent_run_id: Option<Uuid>,
40}
41
42impl Default for RunnableConfig {
43    fn default() -> Self {
44        Self {
45            recursion_limit: 25,
46            max_concurrency: num_cpus::get().max(1),
47            tags: Vec::new(),
48            metadata: serde_json::Value::Null,
49            observers: Vec::new(),
50            run_id: Uuid::new_v4(),
51            cancel_token: None,
52            deadline: None,
53            extras: Extensions::new(),
54            parent_run_id: None,
55        }
56    }
57}
58
59impl RunnableConfig {
60    /// Create with defaults. Equivalent to `RunnableConfig::default()`.
61    pub fn new() -> Self {
62        Self::default()
63    }
64
65    /// Set the recursion limit (builder-style).
66    pub fn with_recursion_limit(mut self, n: u32) -> Self {
67        self.recursion_limit = n;
68        self
69    }
70
71    /// Set the max concurrency (builder-style).
72    pub fn with_max_concurrency(mut self, n: usize) -> Self {
73        self.max_concurrency = n;
74        self
75    }
76
77    /// Add a single observer (builder-style).
78    pub fn with_observer(mut self, o: Arc<dyn Observer>) -> Self {
79        self.observers.push(o);
80        self
81    }
82
83    /// Add a tag (builder-style).
84    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
85        self.tags.push(tag.into());
86        self
87    }
88
89    /// Set the cancellation token (builder-style).
90    pub fn with_cancel_token(mut self, t: tokio_util::sync::CancellationToken) -> Self {
91        self.cancel_token = Some(t);
92        self
93    }
94
95    /// Set the parent run id (builder-style). Used by composition sites
96    /// to thread trace nesting down to children.
97    pub fn with_parent_run_id(mut self, id: Uuid) -> Self {
98        self.parent_run_id = Some(id);
99        self
100    }
101
102    /// Notify every registered observer of an event.
103    pub fn emit(&self, event: &crate::stream::Event) {
104        for o in &self.observers {
105            o.on_event(event);
106        }
107    }
108
109    /// True if the cancel token has been triggered.
110    pub fn is_cancelled(&self) -> bool {
111        self.cancel_token
112            .as_ref()
113            .map(|t| t.is_cancelled())
114            .unwrap_or(false)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn defaults_sane() {
124        let c = RunnableConfig::default();
125        assert_eq!(c.recursion_limit, 25);
126        assert!(c.max_concurrency >= 1);
127        assert!(c.observers.is_empty());
128    }
129
130    #[test]
131    fn builder_chains() {
132        let c = RunnableConfig::new()
133            .with_recursion_limit(10)
134            .with_max_concurrency(4)
135            .with_tag("prod");
136        assert_eq!(c.recursion_limit, 10);
137        assert_eq!(c.max_concurrency, 4);
138        assert_eq!(c.tags, vec!["prod"]);
139    }
140
141    #[test]
142    fn cancel_default_false() {
143        let c = RunnableConfig::default();
144        assert!(!c.is_cancelled());
145    }
146
147    #[test]
148    fn config_clones_with_extras_emptied() {
149        let mut c = RunnableConfig::default()
150            .with_recursion_limit(50)
151            .with_max_concurrency(8)
152            .with_tag("test");
153        c.extras.insert(42u32);
154        assert!(c.extras.contains::<u32>());
155
156        let cloned = c.clone();
157        assert_eq!(cloned.recursion_limit, 50);
158        assert_eq!(cloned.max_concurrency, 8);
159        assert_eq!(cloned.tags, vec!["test"]);
160        // Per the Extensions::clone contract (Plan #2), extras don't deep-clone.
161        assert!(cloned.extras.is_empty());
162    }
163
164    #[test]
165    fn parent_run_id_default_is_none() {
166        assert!(RunnableConfig::default().parent_run_id.is_none());
167    }
168
169    #[test]
170    fn clone_for_subcall_sets_parent_run_id_to_self() {
171        use std::sync::Arc;
172        let parent = Arc::new(RunnableConfig::default());
173        let child = RunnableConfig::clone_for_subcall(&parent);
174        assert_eq!(child.parent_run_id, Some(parent.run_id));
175        assert_ne!(child.run_id, parent.run_id);
176    }
177
178    #[test]
179    fn with_parent_run_id_builder() {
180        let id = Uuid::new_v4();
181        let cfg = RunnableConfig::default().with_parent_run_id(id);
182        assert_eq!(cfg.parent_run_id, Some(id));
183    }
184}
185
186/// The unified contract every cognis primitive implements.
187///
188/// Generic over `I` (input) and `O` (output). One required method (`invoke`);
189/// `batch`, `stream`, and `stream_events` have sensible defaults that
190/// implementations override only when they can do better.
191#[async_trait]
192pub trait Runnable<I, O>: Send + Sync
193where
194    I: Send + 'static,
195    O: Send + 'static,
196{
197    /// One-shot invocation. The hot path.
198    async fn invoke(&self, input: I, config: RunnableConfig) -> crate::Result<O>;
199
200    /// Run multiple inputs in parallel. Defaults to `buffer_unordered`
201    /// honouring `config.max_concurrency`.
202    async fn batch(&self, inputs: Vec<I>, config: RunnableConfig) -> crate::Result<Vec<O>>
203    where
204        I: 'static,
205        O: 'static,
206        Self: Sized + Sync,
207    {
208        let concurrency = config.max_concurrency.max(1);
209        let cfg = Arc::new(config);
210        stream::iter(inputs)
211            .map(|input| {
212                let cfg = cfg.clone();
213                async move {
214                    self.invoke(input, RunnableConfig::clone_for_subcall(&cfg))
215                        .await
216                }
217            })
218            .buffer_unordered(concurrency)
219            .collect::<Vec<_>>()
220            .await
221            .into_iter()
222            .collect()
223    }
224
225    /// Stream the final output (chunks of `O`). Default emits one item via
226    /// `invoke` — non-streaming runnables are correct without override.
227    async fn stream(&self, input: I, config: RunnableConfig) -> crate::Result<RunnableStream<O>>
228    where
229        Self: Sized + Sync,
230    {
231        let result = self.invoke(input, config).await;
232        Ok(RunnableStream::once(result))
233    }
234
235    /// Stream structured events. Default emits OnStart + OnEnd around an
236    /// `invoke` call. Graph engines override to surface per-node events.
237    async fn stream_events(&self, input: I, config: RunnableConfig) -> crate::Result<EventStream>
238    where
239        I: serde::Serialize,
240        O: serde::Serialize,
241        Self: Sized + Sync,
242    {
243        let runnable = self.name().to_string();
244        let run_id = config.run_id;
245        let input_json = serde_json::to_value(&input).unwrap_or(serde_json::Value::Null);
246
247        let on_start = Event::OnStart {
248            runnable: runnable.clone(),
249            run_id,
250            input: input_json,
251        };
252        let result = self.invoke(input, config).await;
253        let on_end_or_err = match &result {
254            Ok(o) => Event::OnEnd {
255                runnable,
256                run_id,
257                output: serde_json::to_value(o).unwrap_or(serde_json::Value::Null),
258            },
259            Err(e) => Event::OnError {
260                error: e.to_string(),
261                run_id,
262            },
263        };
264
265        Ok(EventStream::new(stream::iter(vec![
266            on_start,
267            on_end_or_err,
268        ])))
269    }
270
271    /// Friendly name for telemetry / introspection.
272    fn name(&self) -> &str {
273        std::any::type_name::<Self>()
274    }
275
276    /// JSON Schema for the input type, if known.
277    fn input_schema(&self) -> Option<serde_json::Value> {
278        None
279    }
280
281    /// JSON Schema for the output type, if known.
282    fn output_schema(&self) -> Option<serde_json::Value> {
283        None
284    }
285}
286
287use crate::stream::{Event, EventStream, RunnableStream};
288
289impl RunnableConfig {
290    /// Build a child config for a sub-call (batch / fan-out).
291    /// Reuses `tags`, `metadata`, `observers`, `cancel_token`, `deadline`
292    /// — everything except a fresh `run_id` and an empty `extras`.
293    pub fn clone_for_subcall(parent: &Arc<RunnableConfig>) -> RunnableConfig {
294        RunnableConfig {
295            recursion_limit: parent.recursion_limit,
296            max_concurrency: parent.max_concurrency,
297            tags: parent.tags.clone(),
298            metadata: parent.metadata.clone(),
299            observers: parent.observers.clone(),
300            run_id: Uuid::new_v4(),
301            parent_run_id: Some(parent.run_id),
302            cancel_token: parent.cancel_token.clone(),
303            deadline: parent.deadline,
304            extras: Extensions::new(),
305        }
306    }
307}
308
309#[cfg(test)]
310mod runnable_tests {
311    use super::*;
312    use async_trait::async_trait;
313
314    struct Doubler;
315
316    #[async_trait]
317    impl Runnable<u32, u32> for Doubler {
318        async fn invoke(&self, input: u32, _: RunnableConfig) -> crate::Result<u32> {
319            Ok(input * 2)
320        }
321    }
322
323    #[tokio::test]
324    async fn invoke_works() {
325        let d = Doubler;
326        let out = d.invoke(5, RunnableConfig::default()).await.unwrap();
327        assert_eq!(out, 10);
328    }
329
330    #[tokio::test]
331    async fn default_batch_runs_each() {
332        let d = Doubler;
333        let out = d
334            .batch(vec![1, 2, 3, 4], RunnableConfig::default())
335            .await
336            .unwrap();
337        let mut sorted = out;
338        sorted.sort();
339        assert_eq!(sorted, vec![2, 4, 6, 8]);
340    }
341
342    #[tokio::test]
343    async fn default_stream_emits_one_item() {
344        let d = Doubler;
345        let s = d.stream(7, RunnableConfig::default()).await.unwrap();
346        let v = s.collect_into_vec().await.unwrap();
347        assert_eq!(v, vec![14]);
348    }
349
350    #[tokio::test]
351    async fn default_stream_events_emits_start_end() {
352        use futures::StreamExt;
353        let d = Doubler;
354        let mut s = d.stream_events(3, RunnableConfig::default()).await.unwrap();
355        let mut events = Vec::new();
356        while let Some(e) = s.next().await {
357            events.push(e);
358        }
359        assert_eq!(events.len(), 2);
360        assert!(matches!(events[0], Event::OnStart { .. }));
361        assert!(matches!(events[1], Event::OnEnd { .. }));
362    }
363
364    #[tokio::test]
365    async fn batch_respects_max_concurrency() {
366        let d = Doubler;
367        let cfg = RunnableConfig::default().with_max_concurrency(1);
368        let out = d.batch(vec![1, 2, 3], cfg).await.unwrap();
369        let mut sorted = out;
370        sorted.sort();
371        assert_eq!(sorted, vec![2, 4, 6]);
372    }
373}