1use std::collections::HashSet;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9
10use cognis_core::{Result, Runnable, RunnableConfig};
11
12use crate::builder::Graph;
13use crate::checkpoint::Checkpointer;
14use crate::durability::Durability;
15use crate::engine;
16use crate::state::GraphState;
17use crate::stream_mode::StreamModes;
18
19#[derive(Clone)]
22pub struct CompiledGraph<S: GraphState> {
23 pub(crate) graph: Graph<S>,
24 pub(crate) checkpointer: Option<Arc<dyn Checkpointer<S>>>,
25 pub(crate) interrupt_before: HashSet<String>,
26 pub(crate) interrupt_after: HashSet<String>,
27 pub(crate) durability: Durability,
28}
29
30impl<S: GraphState> std::fmt::Debug for CompiledGraph<S> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("CompiledGraph")
33 .field("node_count", &self.graph.nodes.len())
34 .field("has_checkpointer", &self.checkpointer.is_some())
35 .field("interrupt_before", &self.interrupt_before)
36 .field("interrupt_after", &self.interrupt_after)
37 .finish()
38 }
39}
40
41impl<S: GraphState> CompiledGraph<S> {
42 pub(crate) fn new(graph: Graph<S>) -> Self {
43 Self {
44 graph,
45 checkpointer: None,
46 interrupt_before: HashSet::new(),
47 interrupt_after: HashSet::new(),
48 durability: Durability::default(),
49 }
50 }
51
52 pub fn with_durability(mut self, d: Durability) -> Self {
55 self.durability = d;
56 self
57 }
58
59 pub fn durability(&self) -> &Durability {
61 &self.durability
62 }
63
64 pub fn node_count(&self) -> usize {
66 self.graph.nodes.len()
67 }
68
69 pub fn node_names(&self) -> Vec<&str> {
71 self.graph.nodes.keys().map(|s| s.as_str()).collect()
72 }
73
74 pub fn version(&self) -> Option<&str> {
76 self.graph.version.as_deref()
77 }
78
79 pub fn annotations(
82 &self,
83 node_name: &str,
84 ) -> &std::collections::HashMap<String, serde_json::Value> {
85 static EMPTY: std::sync::OnceLock<std::collections::HashMap<String, serde_json::Value>> =
86 std::sync::OnceLock::new();
87 self.graph
88 .annotations
89 .get(node_name)
90 .unwrap_or_else(|| EMPTY.get_or_init(std::collections::HashMap::new))
91 }
92
93 pub fn annotation(&self, node_name: &str, key: &str) -> Option<&serde_json::Value> {
95 self.graph
96 .annotations
97 .get(node_name)
98 .and_then(|m| m.get(key))
99 }
100}
101
102impl<S: GraphState + Clone + Send + 'static> CompiledGraph<S> {
103 pub fn with_checkpointer(mut self, cp: Arc<dyn Checkpointer<S>>) -> Self {
105 self.checkpointer = Some(cp);
106 self
107 }
108
109 pub fn with_interrupt_before<I, N>(mut self, names: I) -> Self
114 where
115 I: IntoIterator<Item = N>,
116 N: Into<String>,
117 {
118 self.interrupt_before
119 .extend(names.into_iter().map(Into::into));
120 self
121 }
122
123 pub fn with_interrupt_after<I, N>(mut self, names: I) -> Self
126 where
127 I: IntoIterator<Item = N>,
128 N: Into<String>,
129 {
130 self.interrupt_after
131 .extend(names.into_iter().map(Into::into));
132 self
133 }
134
135 pub async fn resume(
142 &self,
143 run_id: uuid::Uuid,
144 step: u64,
145 state: S,
146 config: RunnableConfig,
147 ) -> Result<S>
148 where
149 S::Update: Clone,
150 {
151 let mut cfg = config;
152 cfg.run_id = run_id;
153 engine::resume(self, state, cfg, step).await
154 }
155
156 pub async fn get_state(&self, run_id: uuid::Uuid) -> Result<Option<S>> {
163 match &self.checkpointer {
164 Some(cp) => cp.load(run_id, None).await,
165 None => Ok(None),
166 }
167 }
168
169 pub async fn get_state_at(&self, run_id: uuid::Uuid, step: u64) -> Result<Option<S>> {
171 match &self.checkpointer {
172 Some(cp) => cp.load(run_id, Some(step)).await,
173 None => Ok(None),
174 }
175 }
176
177 pub async fn get_state_history(&self, run_id: uuid::Uuid) -> Result<Vec<(u64, S)>> {
180 let cp = match &self.checkpointer {
181 Some(cp) => cp,
182 None => return Ok(Vec::new()),
183 };
184 let steps = cp.list(run_id).await?;
185 let mut out = Vec::with_capacity(steps.len());
186 for s in steps {
187 if let Some(state) = cp.load(run_id, Some(s)).await? {
188 out.push((s, state));
189 }
190 }
191 Ok(out)
192 }
193
194 pub async fn update_state(&self, run_id: uuid::Uuid, step: u64, state: &S) -> Result<()> {
199 match &self.checkpointer {
200 Some(cp) => cp.save(run_id, step, state).await,
201 None => Err(cognis_core::CognisError::Configuration(
202 "update_state requires a checkpointer; attach via .with_checkpointer(...)".into(),
203 )),
204 }
205 }
206}
207
208impl<S> CompiledGraph<S>
209where
210 S: GraphState + Clone + Send + 'static,
211 <S as GraphState>::Update: Clone,
212{
213 pub async fn stream_mode(
216 &self,
217 input: S,
218 modes: StreamModes,
219 config: RunnableConfig,
220 ) -> Result<cognis_core::EventStream> {
221 use cognis_core::Observer;
222 use futures::StreamExt;
223 use tokio::sync::mpsc;
224 use tokio_stream::wrappers::UnboundedReceiverStream;
225
226 struct ChannelObserver(mpsc::UnboundedSender<cognis_core::Event>);
227 impl Observer for ChannelObserver {
228 fn on_event(&self, event: &cognis_core::Event) {
229 let _ = self.0.send(event.clone());
230 }
231 }
232
233 let (tx, rx) = mpsc::unbounded_channel::<cognis_core::Event>();
234 let observer: Arc<dyn Observer> = Arc::new(ChannelObserver(tx));
235 let mut cfg = config;
236 cfg.observers.push(observer);
237
238 let this = self.clone();
239 tokio::spawn(async move {
240 let _ = engine::run(&this, input, cfg).await;
241 });
242
243 let filtered = UnboundedReceiverStream::new(rx).filter(move |e| {
244 let keep = modes.matches(e);
245 async move { keep }
246 });
247
248 Ok(cognis_core::EventStream::new(filtered))
249 }
250}
251
252#[async_trait]
253impl<S> Runnable<S, S> for CompiledGraph<S>
254where
255 S: GraphState + Clone + Send + 'static,
256 <S as GraphState>::Update: Clone,
257{
258 async fn invoke(&self, input: S, config: RunnableConfig) -> Result<S> {
259 engine::run(self, input, config).await
260 }
261
262 fn name(&self) -> &str {
263 "CompiledGraph"
264 }
265
266 async fn stream_events(
270 &self,
271 input: S,
272 config: RunnableConfig,
273 ) -> Result<cognis_core::EventStream> {
274 use cognis_core::Observer;
275 use tokio::sync::mpsc;
276 use tokio_stream::wrappers::UnboundedReceiverStream;
277
278 struct ChannelObserver(mpsc::UnboundedSender<cognis_core::Event>);
279 impl Observer for ChannelObserver {
280 fn on_event(&self, event: &cognis_core::Event) {
281 let _ = self.0.send(event.clone());
282 }
283 }
284
285 let (tx, rx) = mpsc::unbounded_channel::<cognis_core::Event>();
286 let observer: Arc<dyn Observer> = Arc::new(ChannelObserver(tx));
287 let mut cfg = config;
288 cfg.observers.push(observer);
289
290 let this = self.clone();
291 tokio::spawn(async move {
292 let _ = engine::run(&this, input, cfg).await;
293 });
294
295 Ok(cognis_core::EventStream::new(UnboundedReceiverStream::new(
296 rx,
297 )))
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::goto::Goto;
305 use crate::node::{node_fn, NodeOut};
306
307 #[derive(Default, Clone, Debug, PartialEq, serde::Serialize)]
308 struct Counter {
309 n: u32,
310 }
311
312 #[derive(Default, Clone)]
313 struct CounterUpdate {
314 n: u32,
315 }
316
317 impl GraphState for Counter {
318 type Update = CounterUpdate;
319 fn apply(&mut self, u: Self::Update) {
320 self.n += u.n;
321 }
322 }
323
324 #[tokio::test]
325 async fn linear_two_nodes_runs_to_end() {
326 let g = Graph::<Counter>::new()
327 .node(
328 "a",
329 node_fn::<Counter, _, _>("a", |_s, _c| async move {
330 Ok(NodeOut {
331 update: CounterUpdate { n: 1 },
332 goto: Goto::node("b"),
333 })
334 }),
335 )
336 .node(
337 "b",
338 node_fn::<Counter, _, _>("b", |_s, _c| async move {
339 Ok(NodeOut {
340 update: CounterUpdate { n: 10 },
341 goto: Goto::end(),
342 })
343 }),
344 )
345 .start_at("a")
346 .compile()
347 .unwrap();
348
349 let out = g
350 .invoke(Counter::default(), RunnableConfig::default())
351 .await
352 .unwrap();
353 assert_eq!(out, Counter { n: 11 });
354 }
355
356 #[tokio::test]
357 async fn cycle_terminates_via_state_check() {
358 let g = Graph::<Counter>::new()
360 .node(
361 "tick",
362 node_fn::<Counter, _, _>("tick", |s, _c| {
363 let cur = s.n;
364 async move {
365 if cur >= 5 {
366 Ok(NodeOut {
367 update: CounterUpdate { n: 0 },
368 goto: Goto::end(),
369 })
370 } else {
371 Ok(NodeOut {
372 update: CounterUpdate { n: 1 },
373 goto: Goto::node("tick"),
374 })
375 }
376 }
377 }),
378 )
379 .start_at("tick")
380 .compile()
381 .unwrap();
382
383 let out = g
384 .invoke(Counter::default(), RunnableConfig::default())
385 .await
386 .unwrap();
387 assert_eq!(out, Counter { n: 5 });
388 }
389
390 #[tokio::test]
391 async fn recursion_limit_is_honored() {
392 let g = Graph::<Counter>::new()
394 .node(
395 "loop",
396 node_fn::<Counter, _, _>("loop", |_s, _c| async move {
397 Ok(NodeOut {
398 update: CounterUpdate { n: 1 },
399 goto: Goto::node("loop"),
400 })
401 }),
402 )
403 .start_at("loop")
404 .compile()
405 .unwrap();
406
407 let cfg = RunnableConfig::default().with_recursion_limit(3);
408 let err = g.invoke(Counter::default(), cfg).await.unwrap_err();
409 assert!(matches!(
410 err,
411 cognis_core::CognisError::RecursionLimit { limit: 3 }
412 ));
413 }
414
415 #[tokio::test]
416 async fn compiled_graph_clones_and_runs() {
417 let g = Graph::<Counter>::new()
418 .node(
419 "a",
420 node_fn::<Counter, _, _>("a", |_s, _c| async move {
421 Ok(NodeOut {
422 update: CounterUpdate { n: 1 },
423 goto: Goto::end(),
424 })
425 }),
426 )
427 .start_at("a")
428 .compile()
429 .unwrap();
430 let g2 = g.clone();
431 let r1 = g
432 .invoke(Counter::default(), RunnableConfig::default())
433 .await
434 .unwrap();
435 let r2 = g2
436 .invoke(Counter::default(), RunnableConfig::default())
437 .await
438 .unwrap();
439 assert_eq!(r1.n, 1);
440 assert_eq!(r2.n, 1);
441 }
442
443 #[tokio::test]
444 async fn route_to_unknown_node_errors() {
445 let g = Graph::<Counter>::new()
446 .node(
447 "bad",
448 node_fn::<Counter, _, _>("bad", |_s, _c| async move {
449 Ok(NodeOut {
450 update: CounterUpdate { n: 0 },
451 goto: Goto::node("ghost"),
452 })
453 }),
454 )
455 .start_at("bad")
456 .compile()
457 .unwrap();
458 let err = g
459 .invoke(Counter::default(), RunnableConfig::default())
460 .await
461 .unwrap_err();
462 assert!(format!("{err}").contains("ghost"));
463 }
464
465 #[tokio::test]
466 async fn stream_events_emits_per_node() {
467 use cognis_core::Event;
468 use futures::StreamExt;
469
470 let g = Graph::<Counter>::new()
471 .node(
472 "a",
473 node_fn::<Counter, _, _>("a", |_, _| async move {
474 Ok(NodeOut {
475 update: CounterUpdate { n: 1 },
476 goto: Goto::node("b"),
477 })
478 }),
479 )
480 .node(
481 "b",
482 node_fn::<Counter, _, _>("b", |_, _| async move {
483 Ok(NodeOut {
484 update: CounterUpdate { n: 1 },
485 goto: Goto::end(),
486 })
487 }),
488 )
489 .start_at("a")
490 .compile()
491 .unwrap();
492
493 let mut s = g
494 .stream_events(Counter::default(), RunnableConfig::default())
495 .await
496 .unwrap();
497 let mut events = Vec::new();
498 while let Some(e) = s.next().await {
499 events.push(e);
500 }
501 assert!(events
502 .iter()
503 .any(|e| matches!(e, Event::OnNodeStart { node, .. } if node == "a")));
504 assert!(events
505 .iter()
506 .any(|e| matches!(e, Event::OnNodeStart { node, .. } if node == "b")));
507 assert!(events.iter().any(|e| matches!(e, Event::OnEnd { .. })));
508 }
509}