Skip to main content

synaptic_graph/
compiled.rs

1use std::collections::{HashMap, HashSet};
2use std::hash::{Hash, Hasher};
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use futures::Stream;
8use serde_json::Value;
9use synaptic_core::SynapticError;
10use tokio::sync::RwLock;
11
12use crate::checkpoint::{Checkpoint, CheckpointConfig, Checkpointer};
13use crate::command::{CommandGoto, GraphResult, NodeOutput};
14use crate::edge::{ConditionalEdge, Edge};
15use crate::node::Node;
16use crate::state::State;
17use crate::END;
18
19/// Cache policy for node-level caching.
20#[derive(Debug, Clone)]
21pub struct CachePolicy {
22    /// Time-to-live for cached entries.
23    pub ttl: Duration,
24}
25
26impl CachePolicy {
27    /// Create a new cache policy with the given TTL.
28    pub fn new(ttl: Duration) -> Self {
29        Self { ttl }
30    }
31}
32
33/// Cached node output with expiry.
34pub(crate) struct CachedEntry<S: State> {
35    output: NodeOutput<S>,
36    created: Instant,
37    ttl: Duration,
38}
39
40impl<S: State> CachedEntry<S> {
41    fn is_valid(&self) -> bool {
42        self.created.elapsed() < self.ttl
43    }
44}
45
46/// Hash a serializable state to use as a cache key.
47fn hash_state(value: &Value) -> u64 {
48    let mut hasher = std::collections::hash_map::DefaultHasher::new();
49    let canonical = value.to_string();
50    canonical.hash(&mut hasher);
51    hasher.finish()
52}
53
54/// Controls what is yielded during graph streaming.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum StreamMode {
57    /// Yield full state after each node executes.
58    Values,
59    /// Yield only the delta (state before merge vs after, keyed by node name).
60    Updates,
61    /// Yield only AI messages from the state (useful for chat UIs).
62    Messages,
63    /// Yield detailed debug information including node timing.
64    Debug,
65    /// Yield custom events emitted via StreamWriter.
66    Custom,
67}
68
69/// An event yielded during graph streaming.
70#[derive(Debug, Clone)]
71pub struct GraphEvent<S> {
72    /// The node that just executed.
73    pub node: String,
74    /// The state snapshot (full state for Values mode, post-node state for Updates).
75    pub state: S,
76}
77
78/// An event yielded during multi-mode streaming, tagged with its stream mode.
79#[derive(Debug, Clone)]
80pub struct MultiGraphEvent<S> {
81    /// Which stream mode produced this event.
82    pub mode: StreamMode,
83    /// The underlying graph event.
84    pub event: GraphEvent<S>,
85}
86
87/// A stream of graph events.
88pub type GraphStream<'a, S> =
89    Pin<Box<dyn Stream<Item = Result<GraphEvent<S>, SynapticError>> + Send + 'a>>;
90
91/// A stream of multi-mode graph events.
92pub type MultiGraphStream<'a, S> =
93    Pin<Box<dyn Stream<Item = Result<MultiGraphEvent<S>, SynapticError>> + Send + 'a>>;
94
95/// The compiled, executable graph.
96pub struct CompiledGraph<S: State> {
97    pub(crate) nodes: HashMap<String, Box<dyn Node<S>>>,
98    pub(crate) edges: Vec<Edge>,
99    pub(crate) conditional_edges: Vec<ConditionalEdge<S>>,
100    pub(crate) entry_point: String,
101    pub(crate) interrupt_before: HashSet<String>,
102    pub(crate) interrupt_after: HashSet<String>,
103    pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
104    /// Cache policies keyed by node name.
105    pub(crate) cache_policies: HashMap<String, CachePolicy>,
106    /// Node-level cache: node_name -> (state_hash -> cached_output).
107    pub(crate) cache: Arc<RwLock<HashMap<String, HashMap<u64, CachedEntry<S>>>>>,
108    /// Nodes marked as deferred (wait for all incoming edges).
109    pub(crate) deferred: HashSet<String>,
110}
111
112impl<S: State> std::fmt::Debug for CompiledGraph<S> {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("CompiledGraph")
115            .field("entry_point", &self.entry_point)
116            .field("node_count", &self.nodes.len())
117            .field("edge_count", &self.edges.len())
118            .field("conditional_edge_count", &self.conditional_edges.len())
119            .finish()
120    }
121}
122
123/// Internal helper: process a `NodeOutput` and return the next node to visit.
124/// Returns `(next_node_or_none, interrupt_value_or_none)`.
125fn handle_node_output<S: State>(
126    output: NodeOutput<S>,
127    state: &mut S,
128    current_node: &str,
129    find_next: impl Fn(&str, &S) -> String,
130) -> (Option<String>, Option<serde_json::Value>) {
131    match output {
132        NodeOutput::State(new_state) => {
133            *state = new_state;
134            (None, None) // use normal routing
135        }
136        NodeOutput::Command(cmd) => {
137            // Apply state update if present
138            if let Some(update) = cmd.update {
139                state.merge(update);
140            }
141
142            // Check for interrupt
143            if let Some(interrupt_value) = cmd.interrupt_value {
144                return (None, Some(interrupt_value));
145            }
146
147            // Determine routing
148            match cmd.goto {
149                Some(CommandGoto::One(target)) => (Some(target), None),
150                Some(CommandGoto::Many(_sends)) => {
151                    // Fan-out: for now, execute Send targets sequentially
152                    // Full parallel execution is handled in the main loop
153                    (Some("__fanout__".to_string()), None)
154                }
155                None => {
156                    let next = find_next(current_node, state);
157                    (Some(next), None)
158                }
159            }
160        }
161    }
162}
163
164/// Helper to serialize state into a checkpoint.
165fn make_checkpoint<S: serde::Serialize>(
166    state: &S,
167    next_node: Option<String>,
168    node_name: &str,
169) -> Result<Checkpoint, SynapticError> {
170    let state_val = serde_json::to_value(state)
171        .map_err(|e| SynapticError::Graph(format!("serialize state: {e}")))?;
172    Ok(Checkpoint::new(state_val, next_node).with_metadata("source", serde_json::json!(node_name)))
173}
174
175impl<S: State> CompiledGraph<S> {
176    /// Set a checkpointer for state persistence.
177    pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
178        self.checkpointer = Some(checkpointer);
179        self
180    }
181
182    /// Execute the graph with initial state.
183    pub async fn invoke(&self, state: S) -> Result<GraphResult<S>, SynapticError>
184    where
185        S: serde::Serialize + serde::de::DeserializeOwned,
186    {
187        self.invoke_with_config(state, None).await
188    }
189
190    /// Execute with optional checkpoint config for resumption.
191    pub async fn invoke_with_config(
192        &self,
193        mut state: S,
194        config: Option<CheckpointConfig>,
195    ) -> Result<GraphResult<S>, SynapticError>
196    where
197        S: serde::Serialize + serde::de::DeserializeOwned,
198    {
199        // If there's a checkpoint, try to resume from it
200        let mut resume_from: Option<String> = None;
201        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
202            if let Some(checkpoint) = checkpointer.get(cfg).await? {
203                state = serde_json::from_value(checkpoint.state).map_err(|e| {
204                    SynapticError::Graph(format!("failed to deserialize checkpoint state: {e}"))
205                })?;
206                resume_from = checkpoint.next_node;
207            }
208        }
209
210        let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
211        let mut max_iterations = 100; // safety guard
212
213        loop {
214            if current_node == END {
215                break;
216            }
217            if max_iterations == 0 {
218                return Err(SynapticError::Graph(
219                    "max iterations (100) exceeded — possible infinite loop".to_string(),
220                ));
221            }
222            max_iterations -= 1;
223
224            // Check interrupt_before
225            if self.interrupt_before.contains(&current_node) {
226                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
227                    let checkpoint =
228                        make_checkpoint(&state, Some(current_node.clone()), &current_node)?;
229                    checkpointer.put(cfg, &checkpoint).await?;
230                }
231                return Ok(GraphResult::Interrupted {
232                    state,
233                    interrupt_value: serde_json::json!({
234                        "reason": format!("interrupted before node '{current_node}'")
235                    }),
236                });
237            }
238
239            // Execute node (with optional cache)
240            let node = self
241                .nodes
242                .get(&current_node)
243                .ok_or_else(|| SynapticError::Graph(format!("node '{current_node}' not found")))?;
244            let output = self
245                .execute_with_cache(&current_node, node.as_ref(), state.clone())
246                .await?;
247
248            // Handle the output
249            let (next_override, interrupt_value) =
250                handle_node_output(output, &mut state, &current_node, |cur, s| {
251                    self.find_next_node(cur, s)
252                });
253
254            // Check for interrupt from Command
255            if let Some(interrupt_val) = interrupt_value {
256                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
257                    let next = self.find_next_node(&current_node, &state);
258                    let checkpoint = make_checkpoint(&state, Some(next), &current_node)?;
259                    checkpointer.put(cfg, &checkpoint).await?;
260                }
261                return Ok(GraphResult::Interrupted {
262                    state,
263                    interrupt_value: interrupt_val,
264                });
265            }
266
267            // Handle fan-out (Send)
268            if next_override.as_deref() == Some("__fanout__") {
269                // TODO: full parallel fan-out
270                break;
271            }
272
273            let next = if let Some(target) = next_override {
274                target
275            } else {
276                // Check interrupt_after (only when no command override)
277                if self.interrupt_after.contains(&current_node) {
278                    let next = self.find_next_node(&current_node, &state);
279                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
280                        let checkpoint = make_checkpoint(&state, Some(next), &current_node)?;
281                        checkpointer.put(cfg, &checkpoint).await?;
282                    }
283                    return Ok(GraphResult::Interrupted {
284                        state,
285                        interrupt_value: serde_json::json!({
286                            "reason": format!("interrupted after node '{current_node}'")
287                        }),
288                    });
289                }
290
291                // Normal routing
292                self.find_next_node(&current_node, &state)
293            };
294
295            // Save checkpoint after each node
296            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
297                let checkpoint = make_checkpoint(&state, Some(next.clone()), &current_node)?;
298                checkpointer.put(cfg, &checkpoint).await?;
299            }
300
301            current_node = next;
302        }
303
304        Ok(GraphResult::Complete(state))
305    }
306
307    /// Stream graph execution, yielding a `GraphEvent` after each node.
308    pub fn stream(&self, state: S, mode: StreamMode) -> GraphStream<'_, S>
309    where
310        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
311    {
312        self.stream_with_config(state, mode, None)
313    }
314
315    /// Stream graph execution with optional checkpoint config.
316    pub fn stream_with_config(
317        &self,
318        state: S,
319        _mode: StreamMode,
320        config: Option<CheckpointConfig>,
321    ) -> GraphStream<'_, S>
322    where
323        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
324    {
325        Box::pin(async_stream::stream! {
326            let mut state = state;
327
328            // If there's a checkpoint, try to resume from it
329            let mut resume_from: Option<String> = None;
330            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
331                match checkpointer.get(cfg).await {
332                    Ok(Some(checkpoint)) => {
333                        match serde_json::from_value(checkpoint.state) {
334                            Ok(s) => {
335                                state = s;
336                                resume_from = checkpoint.next_node;
337                            }
338                            Err(e) => {
339                                yield Err(SynapticError::Graph(format!(
340                                    "failed to deserialize checkpoint state: {e}"
341                                )));
342                                return;
343                            }
344                        }
345                    }
346                    Ok(None) => {}
347                    Err(e) => {
348                        yield Err(e);
349                        return;
350                    }
351                }
352            }
353
354            let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
355            let mut max_iterations = 100;
356
357            loop {
358                if current_node == END {
359                    break;
360                }
361                if max_iterations == 0 {
362                    yield Err(SynapticError::Graph(
363                        "max iterations (100) exceeded — possible infinite loop".to_string(),
364                    ));
365                    return;
366                }
367                max_iterations -= 1;
368
369                // Check interrupt_before
370                if self.interrupt_before.contains(&current_node) {
371                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
372                        match make_checkpoint(&state, Some(current_node.clone()), &current_node) {
373                            Ok(checkpoint) => {
374                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
375                                    yield Err(e);
376                                    return;
377                                }
378                            }
379                            Err(e) => {
380                                yield Err(e);
381                                return;
382                            }
383                        }
384                    }
385                    yield Err(SynapticError::Graph(format!(
386                        "interrupted before node '{current_node}'"
387                    )));
388                    return;
389                }
390
391                // Execute node
392                let node = match self.nodes.get(&current_node) {
393                    Some(n) => n,
394                    None => {
395                        yield Err(SynapticError::Graph(format!("node '{current_node}' not found")));
396                        return;
397                    }
398                };
399
400                let output = match node.process(state.clone()).await {
401                    Ok(o) => o,
402                    Err(e) => {
403                        yield Err(e);
404                        return;
405                    }
406                };
407
408                // Handle the node output
409                let mut interrupt_val = None;
410                let next_override = match output {
411                    NodeOutput::State(new_state) => {
412                        state = new_state;
413                        None
414                    }
415                    NodeOutput::Command(cmd) => {
416                        if let Some(update) = cmd.update {
417                            state.merge(update);
418                        }
419
420                        if let Some(iv) = cmd.interrupt_value {
421                            interrupt_val = Some(iv);
422                            None
423                        } else {
424                            match cmd.goto {
425                                Some(CommandGoto::One(target)) => Some(target),
426                                Some(CommandGoto::Many(_)) => Some(END.to_string()),
427                                None => None,
428                            }
429                        }
430                    }
431                };
432
433                // Yield event
434                let event = GraphEvent {
435                    node: current_node.clone(),
436                    state: state.clone(),
437                };
438                yield Ok(event);
439
440                // Check for interrupt from Command
441                if let Some(iv) = interrupt_val {
442                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
443                        let next = self.find_next_node(&current_node, &state);
444                        match make_checkpoint(&state, Some(next), &current_node) {
445                            Ok(checkpoint) => {
446                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
447                                    yield Err(e);
448                                    return;
449                                }
450                            }
451                            Err(e) => {
452                                yield Err(e);
453                                return;
454                            }
455                        }
456                    }
457                    yield Err(SynapticError::Graph(format!(
458                        "interrupted by node '{current_node}': {iv}"
459                    )));
460                    return;
461                }
462
463                let next = if let Some(target) = next_override {
464                    target
465                } else {
466                    // Check interrupt_after (only when no command override)
467                    if self.interrupt_after.contains(&current_node) {
468                        let next = self.find_next_node(&current_node, &state);
469                        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
470                            match make_checkpoint(&state, Some(next), &current_node) {
471                                Ok(checkpoint) => {
472                                    if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
473                                        yield Err(e);
474                                        return;
475                                    }
476                                }
477                                Err(e) => {
478                                    yield Err(e);
479                                    return;
480                                }
481                            }
482                        }
483                        yield Err(SynapticError::Graph(format!(
484                            "interrupted after node '{current_node}'"
485                        )));
486                        return;
487                    }
488
489                    // Find next node via normal edge routing
490                    self.find_next_node(&current_node, &state)
491                };
492
493                // Save checkpoint
494                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
495                    match make_checkpoint(&state, Some(next.clone()), &current_node) {
496                        Ok(checkpoint) => {
497                            if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
498                                yield Err(e);
499                                return;
500                            }
501                        }
502                        Err(e) => {
503                            yield Err(e);
504                            return;
505                        }
506                    }
507                }
508
509                current_node = next;
510            }
511        })
512    }
513
514    /// Stream graph execution with multiple stream modes.
515    ///
516    /// Each event is tagged with the `StreamMode` that produced it.
517    /// For a single node execution, one event per requested mode is emitted.
518    pub fn stream_modes(&self, state: S, modes: Vec<StreamMode>) -> MultiGraphStream<'_, S>
519    where
520        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
521    {
522        self.stream_modes_with_config(state, modes, None)
523    }
524
525    /// Stream graph execution with multiple stream modes and optional checkpoint config.
526    pub fn stream_modes_with_config(
527        &self,
528        state: S,
529        modes: Vec<StreamMode>,
530        config: Option<CheckpointConfig>,
531    ) -> MultiGraphStream<'_, S>
532    where
533        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
534    {
535        Box::pin(async_stream::stream! {
536            let mut state = state;
537
538            // If there's a checkpoint, try to resume from it
539            let mut resume_from: Option<String> = None;
540            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
541                match checkpointer.get(cfg).await {
542                    Ok(Some(checkpoint)) => {
543                        match serde_json::from_value(checkpoint.state) {
544                            Ok(s) => {
545                                state = s;
546                                resume_from = checkpoint.next_node;
547                            }
548                            Err(e) => {
549                                yield Err(SynapticError::Graph(format!(
550                                    "failed to deserialize checkpoint state: {e}"
551                                )));
552                                return;
553                            }
554                        }
555                    }
556                    Ok(None) => {}
557                    Err(e) => {
558                        yield Err(e);
559                        return;
560                    }
561                }
562            }
563
564            let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
565            let mut max_iterations = 100;
566
567            loop {
568                if current_node == END {
569                    break;
570                }
571                if max_iterations == 0 {
572                    yield Err(SynapticError::Graph(
573                        "max iterations (100) exceeded — possible infinite loop".to_string(),
574                    ));
575                    return;
576                }
577                max_iterations -= 1;
578
579                // Check interrupt_before
580                if self.interrupt_before.contains(&current_node) {
581                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
582                        match make_checkpoint(&state, Some(current_node.clone()), &current_node) {
583                            Ok(checkpoint) => {
584                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
585                                    yield Err(e);
586                                    return;
587                                }
588                            }
589                            Err(e) => {
590                                yield Err(e);
591                                return;
592                            }
593                        }
594                    }
595                    yield Err(SynapticError::Graph(format!(
596                        "interrupted before node '{current_node}'"
597                    )));
598                    return;
599                }
600
601                // Snapshot state before node execution (for Updates mode diff)
602                let state_before = state.clone();
603
604                // Execute node
605                let node = match self.nodes.get(&current_node) {
606                    Some(n) => n,
607                    None => {
608                        yield Err(SynapticError::Graph(format!("node '{current_node}' not found")));
609                        return;
610                    }
611                };
612
613                let output = match node.process(state.clone()).await {
614                    Ok(o) => o,
615                    Err(e) => {
616                        yield Err(e);
617                        return;
618                    }
619                };
620
621                // Handle the node output
622                let mut interrupt_val = None;
623                let next_override = match output {
624                    NodeOutput::State(new_state) => {
625                        state = new_state;
626                        None
627                    }
628                    NodeOutput::Command(cmd) => {
629                        if let Some(update) = cmd.update {
630                            state.merge(update);
631                        }
632
633                        if let Some(iv) = cmd.interrupt_value {
634                            interrupt_val = Some(iv);
635                            None
636                        } else {
637                            match cmd.goto {
638                                Some(CommandGoto::One(target)) => Some(target),
639                                Some(CommandGoto::Many(_)) => Some(END.to_string()),
640                                None => None,
641                            }
642                        }
643                    }
644                };
645
646                // Yield events for each requested mode
647                for mode in &modes {
648                    let event = match mode {
649                        StreamMode::Values | StreamMode::Debug | StreamMode::Custom => {
650                            // Full state after node execution
651                            GraphEvent {
652                                node: current_node.clone(),
653                                state: state.clone(),
654                            }
655                        }
656                        StreamMode::Updates => {
657                            // State before node (the "delta" is the difference)
658                            // For Updates, we yield the pre-node state so callers
659                            // can diff against the full Values event
660                            GraphEvent {
661                                node: current_node.clone(),
662                                state: state_before.clone(),
663                            }
664                        }
665                        StreamMode::Messages => {
666                            // Same as Values — callers filter for AI messages
667                            GraphEvent {
668                                node: current_node.clone(),
669                                state: state.clone(),
670                            }
671                        }
672                    };
673                    yield Ok(MultiGraphEvent {
674                        mode: *mode,
675                        event,
676                    });
677                }
678
679                // Check for interrupt from Command
680                if let Some(iv) = interrupt_val {
681                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
682                        let next = self.find_next_node(&current_node, &state);
683                        match make_checkpoint(&state, Some(next), &current_node) {
684                            Ok(checkpoint) => {
685                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
686                                    yield Err(e);
687                                    return;
688                                }
689                            }
690                            Err(e) => {
691                                yield Err(e);
692                                return;
693                            }
694                        }
695                    }
696                    yield Err(SynapticError::Graph(format!(
697                        "interrupted by node '{current_node}': {iv}"
698                    )));
699                    return;
700                }
701
702                let next = if let Some(target) = next_override {
703                    target
704                } else {
705                    // Check interrupt_after (only when no command override)
706                    if self.interrupt_after.contains(&current_node) {
707                        let next = self.find_next_node(&current_node, &state);
708                        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
709                            match make_checkpoint(&state, Some(next), &current_node) {
710                                Ok(checkpoint) => {
711                                    if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
712                                        yield Err(e);
713                                        return;
714                                    }
715                                }
716                                Err(e) => {
717                                    yield Err(e);
718                                    return;
719                                }
720                            }
721                        }
722                        yield Err(SynapticError::Graph(format!(
723                            "interrupted after node '{current_node}'"
724                        )));
725                        return;
726                    }
727
728                    // Find next node via normal edge routing
729                    self.find_next_node(&current_node, &state)
730                };
731
732                // Save checkpoint
733                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
734                    match make_checkpoint(&state, Some(next.clone()), &current_node) {
735                        Ok(checkpoint) => {
736                            if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
737                                yield Err(e);
738                                return;
739                            }
740                        }
741                        Err(e) => {
742                            yield Err(e);
743                            return;
744                        }
745                    }
746                }
747
748                current_node = next;
749            }
750        })
751    }
752
753    /// Update state on an interrupted graph (for human-in-the-loop).
754    pub async fn update_state(
755        &self,
756        config: &CheckpointConfig,
757        update: S,
758    ) -> Result<(), SynapticError>
759    where
760        S: serde::Serialize + serde::de::DeserializeOwned,
761    {
762        let checkpointer = self
763            .checkpointer
764            .as_ref()
765            .ok_or_else(|| SynapticError::Graph("no checkpointer configured".to_string()))?;
766
767        let checkpoint = checkpointer
768            .get(config)
769            .await?
770            .ok_or_else(|| SynapticError::Graph("no checkpoint found".to_string()))?;
771
772        let mut current_state: S = serde_json::from_value(checkpoint.state)
773            .map_err(|e| SynapticError::Graph(format!("deserialize: {e}")))?;
774
775        current_state.merge(update);
776
777        let updated = Checkpoint::new(
778            serde_json::to_value(&current_state)
779                .map_err(|e| SynapticError::Graph(format!("serialize: {e}")))?,
780            checkpoint.next_node,
781        )
782        .with_metadata("source", serde_json::json!("update_state"));
783        checkpointer.put(config, &updated).await?;
784
785        Ok(())
786    }
787
788    /// Get the current state for a thread from the checkpointer.
789    ///
790    /// Returns `None` if no checkpoint exists for the given thread.
791    pub async fn get_state(&self, config: &CheckpointConfig) -> Result<Option<S>, SynapticError>
792    where
793        S: serde::de::DeserializeOwned,
794    {
795        let checkpointer = self
796            .checkpointer
797            .as_ref()
798            .ok_or_else(|| SynapticError::Graph("no checkpointer configured".to_string()))?;
799
800        match checkpointer.get(config).await? {
801            Some(checkpoint) => {
802                let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
803                    SynapticError::Graph(format!("failed to deserialize checkpoint state: {e}"))
804                })?;
805                Ok(Some(state))
806            }
807            None => Ok(None),
808        }
809    }
810
811    /// Get the state history for a thread (all checkpoints).
812    ///
813    /// Returns a list of `(state, next_node)` pairs, ordered from oldest to newest.
814    pub async fn get_state_history(
815        &self,
816        config: &CheckpointConfig,
817    ) -> Result<Vec<(S, Option<String>)>, SynapticError>
818    where
819        S: serde::de::DeserializeOwned,
820    {
821        let checkpointer = self
822            .checkpointer
823            .as_ref()
824            .ok_or_else(|| SynapticError::Graph("no checkpointer configured".to_string()))?;
825
826        let checkpoints = checkpointer.list(config).await?;
827        let mut history = Vec::with_capacity(checkpoints.len());
828
829        for checkpoint in checkpoints {
830            let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
831                SynapticError::Graph(format!("failed to deserialize checkpoint state: {e}"))
832            })?;
833            history.push((state, checkpoint.next_node));
834        }
835
836        Ok(history)
837    }
838
839    /// Execute a node, using cache if a CachePolicy is set for it.
840    async fn execute_with_cache(
841        &self,
842        node_name: &str,
843        node: &dyn Node<S>,
844        state: S,
845    ) -> Result<NodeOutput<S>, SynapticError>
846    where
847        S: serde::Serialize,
848    {
849        let policy = self.cache_policies.get(node_name);
850        if policy.is_none() {
851            return node.process(state).await;
852        }
853        let policy = policy.unwrap();
854
855        // Compute state hash for cache key
856        let state_val = serde_json::to_value(&state)
857            .map_err(|e| SynapticError::Graph(format!("cache: serialize state: {e}")))?;
858        let key = hash_state(&state_val);
859
860        // Check cache
861        {
862            let cache = self.cache.read().await;
863            if let Some(node_cache) = cache.get(node_name) {
864                if let Some(entry) = node_cache.get(&key) {
865                    if entry.is_valid() {
866                        return Ok(entry.output.clone());
867                    }
868                }
869            }
870        }
871
872        // Cache miss — execute the node
873        let output = node.process(state).await?;
874
875        // Store in cache
876        {
877            let mut cache = self.cache.write().await;
878            let node_cache = cache.entry(node_name.to_string()).or_default();
879            node_cache.insert(
880                key,
881                CachedEntry {
882                    output: output.clone(),
883                    created: Instant::now(),
884                    ttl: policy.ttl,
885                },
886            );
887        }
888
889        Ok(output)
890    }
891
892    /// Returns true if the given node is deferred (waits for all incoming paths).
893    pub fn is_deferred(&self, node_name: &str) -> bool {
894        self.deferred.contains(node_name)
895    }
896
897    /// Returns the number of incoming edges (fixed + conditional) for a node.
898    pub fn incoming_edge_count(&self, node_name: &str) -> usize {
899        let fixed = self.edges.iter().filter(|e| e.target == node_name).count();
900        // Conditional edges may route to this node but we can't statically count them,
901        // so we count the path_map entries that reference this node.
902        let conditional = self
903            .conditional_edges
904            .iter()
905            .filter_map(|ce| ce.path_map.as_ref())
906            .flat_map(|pm| pm.values())
907            .filter(|target| *target == node_name)
908            .count();
909        fixed + conditional
910    }
911
912    fn find_next_node(&self, current: &str, state: &S) -> String {
913        // Check conditional edges first
914        for ce in &self.conditional_edges {
915            if ce.source == current {
916                return (ce.router)(state);
917            }
918        }
919
920        // Check fixed edges
921        for edge in &self.edges {
922            if edge.source == current {
923                return edge.target.clone();
924            }
925        }
926
927        // No outgoing edge means END
928        END.to_string()
929    }
930}