Skip to main content

synaptic_graph/
compiled.rs

1use std::collections::{HashMap, HashSet};
2use std::hash::{Hash, Hasher};
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use futures::Stream;
8use serde_json::Value;
9use synaptic_core::SynapticError;
10use tokio::sync::RwLock;
11
12use crate::checkpoint::{Checkpoint, CheckpointConfig, Checkpointer};
13use crate::command::{CommandGoto, GraphResult, NodeOutput};
14use crate::edge::{ConditionalEdge, Edge};
15use crate::node::Node;
16use crate::state::State;
17use crate::END;
18
19/// Cache policy for node-level caching.
20#[derive(Debug, Clone)]
21pub struct CachePolicy {
22    /// Time-to-live for cached entries.
23    pub ttl: Duration,
24}
25
26impl CachePolicy {
27    /// Create a new cache policy with the given TTL.
28    pub fn new(ttl: Duration) -> Self {
29        Self { ttl }
30    }
31}
32
33/// Cached node output with expiry.
34pub(crate) struct CachedEntry<S: State> {
35    output: NodeOutput<S>,
36    created: Instant,
37    ttl: Duration,
38}
39
40impl<S: State> CachedEntry<S> {
41    fn is_valid(&self) -> bool {
42        self.created.elapsed() < self.ttl
43    }
44}
45
46/// Hash a serializable state to use as a cache key.
47fn hash_state(value: &Value) -> u64 {
48    let mut hasher = std::collections::hash_map::DefaultHasher::new();
49    let canonical = value.to_string();
50    canonical.hash(&mut hasher);
51    hasher.finish()
52}
53
54/// Controls what is yielded during graph streaming.
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum StreamMode {
57    /// Yield full state after each node executes.
58    Values,
59    /// Yield only the delta (state before merge vs after, keyed by node name).
60    Updates,
61    /// Yield only AI messages from the state (useful for chat UIs).
62    Messages,
63    /// Yield detailed debug information including node timing.
64    Debug,
65    /// Yield custom events emitted via StreamWriter.
66    Custom,
67}
68
69/// An event yielded during graph streaming.
70#[derive(Debug, Clone)]
71pub struct GraphEvent<S> {
72    /// The node that just executed.
73    pub node: String,
74    /// The state snapshot (full state for Values mode, post-node state for Updates).
75    pub state: S,
76}
77
78/// An event yielded during multi-mode streaming, tagged with its stream mode.
79#[derive(Debug, Clone)]
80pub struct MultiGraphEvent<S> {
81    /// Which stream mode produced this event.
82    pub mode: StreamMode,
83    /// The underlying graph event.
84    pub event: GraphEvent<S>,
85}
86
87/// A stream of graph events.
88pub type GraphStream<'a, S> =
89    Pin<Box<dyn Stream<Item = Result<GraphEvent<S>, SynapticError>> + Send + 'a>>;
90
91/// A stream of multi-mode graph events.
92pub type MultiGraphStream<'a, S> =
93    Pin<Box<dyn Stream<Item = Result<MultiGraphEvent<S>, SynapticError>> + Send + 'a>>;
94
95/// The compiled, executable graph.
96pub struct CompiledGraph<S: State> {
97    pub(crate) nodes: HashMap<String, Box<dyn Node<S>>>,
98    pub(crate) edges: Vec<Edge>,
99    pub(crate) conditional_edges: Vec<ConditionalEdge<S>>,
100    pub(crate) entry_point: String,
101    pub(crate) interrupt_before: HashSet<String>,
102    pub(crate) interrupt_after: HashSet<String>,
103    pub(crate) checkpointer: Option<Arc<dyn Checkpointer>>,
104    /// Cache policies keyed by node name.
105    pub(crate) cache_policies: HashMap<String, CachePolicy>,
106    /// Node-level cache: node_name -> (state_hash -> cached_output).
107    #[expect(clippy::type_complexity)]
108    pub(crate) cache: Arc<RwLock<HashMap<String, HashMap<u64, CachedEntry<S>>>>>,
109    /// Nodes marked as deferred (wait for all incoming edges).
110    pub(crate) deferred: HashSet<String>,
111}
112
113impl<S: State> std::fmt::Debug for CompiledGraph<S> {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("CompiledGraph")
116            .field("entry_point", &self.entry_point)
117            .field("node_count", &self.nodes.len())
118            .field("edge_count", &self.edges.len())
119            .field("conditional_edge_count", &self.conditional_edges.len())
120            .finish()
121    }
122}
123
124/// Internal helper: process a `NodeOutput` and return the next node to visit.
125/// Returns `(next_node_or_none, interrupt_value_or_none)`.
126fn handle_node_output<S: State>(
127    output: NodeOutput<S>,
128    state: &mut S,
129    current_node: &str,
130    find_next: impl Fn(&str, &S) -> String,
131) -> (Option<String>, Option<serde_json::Value>) {
132    match output {
133        NodeOutput::State(new_state) => {
134            *state = new_state;
135            (None, None) // use normal routing
136        }
137        NodeOutput::Command(cmd) => {
138            // Apply state update if present
139            if let Some(update) = cmd.update {
140                state.merge(update);
141            }
142
143            // Check for interrupt
144            if let Some(interrupt_value) = cmd.interrupt_value {
145                return (None, Some(interrupt_value));
146            }
147
148            // Determine routing
149            match cmd.goto {
150                Some(CommandGoto::One(target)) => (Some(target), None),
151                Some(CommandGoto::Many(_sends)) => {
152                    // Fan-out: for now, execute Send targets sequentially
153                    // Full parallel execution is handled in the main loop
154                    (Some("__fanout__".to_string()), None)
155                }
156                None => {
157                    let next = find_next(current_node, state);
158                    (Some(next), None)
159                }
160            }
161        }
162    }
163}
164
165/// Helper to serialize state into a checkpoint.
166fn make_checkpoint<S: serde::Serialize>(
167    state: &S,
168    next_node: Option<String>,
169    node_name: &str,
170) -> Result<Checkpoint, SynapticError> {
171    let state_val = serde_json::to_value(state)
172        .map_err(|e| SynapticError::Graph(format!("serialize state: {e}")))?;
173    Ok(Checkpoint::new(state_val, next_node).with_metadata("source", serde_json::json!(node_name)))
174}
175
176impl<S: State> CompiledGraph<S> {
177    /// Set a checkpointer for state persistence.
178    pub fn with_checkpointer(mut self, checkpointer: Arc<dyn Checkpointer>) -> Self {
179        self.checkpointer = Some(checkpointer);
180        self
181    }
182
183    /// Set a `StoreCheckpointer` backed by the given store.
184    ///
185    /// Convenience method equivalent to:
186    /// ```ignore
187    /// graph.with_checkpointer(Arc::new(StoreCheckpointer::new(store)))
188    /// ```
189    pub fn with_store_checkpointer(self, store: Arc<dyn synaptic_core::Store>) -> Self {
190        self.with_checkpointer(Arc::new(crate::StoreCheckpointer::new(store)))
191    }
192
193    /// Execute the graph with initial state.
194    pub async fn invoke(&self, state: S) -> Result<GraphResult<S>, SynapticError>
195    where
196        S: serde::Serialize + serde::de::DeserializeOwned,
197    {
198        self.invoke_with_config(state, None).await
199    }
200
201    /// Execute with optional checkpoint config for resumption.
202    pub async fn invoke_with_config(
203        &self,
204        mut state: S,
205        config: Option<CheckpointConfig>,
206    ) -> Result<GraphResult<S>, SynapticError>
207    where
208        S: serde::Serialize + serde::de::DeserializeOwned,
209    {
210        // If there's a checkpoint, try to resume from it
211        let mut resume_from: Option<String> = None;
212        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
213            if let Some(checkpoint) = checkpointer.get(cfg).await? {
214                state = serde_json::from_value(checkpoint.state).map_err(|e| {
215                    SynapticError::Graph(format!("failed to deserialize checkpoint state: {e}"))
216                })?;
217                resume_from = checkpoint.next_node;
218            }
219        }
220
221        let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
222        let mut max_iterations = 100; // safety guard
223
224        loop {
225            if current_node == END {
226                break;
227            }
228            if max_iterations == 0 {
229                return Err(SynapticError::Graph(
230                    "max iterations (100) exceeded — possible infinite loop".to_string(),
231                ));
232            }
233            max_iterations -= 1;
234
235            // Check interrupt_before
236            if self.interrupt_before.contains(&current_node) {
237                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
238                    let checkpoint =
239                        make_checkpoint(&state, Some(current_node.clone()), &current_node)?;
240                    checkpointer.put(cfg, &checkpoint).await?;
241                }
242                return Ok(GraphResult::Interrupted {
243                    state,
244                    interrupt_value: serde_json::json!({
245                        "reason": format!("interrupted before node '{current_node}'")
246                    }),
247                });
248            }
249
250            // Execute node (with optional cache)
251            let node = self
252                .nodes
253                .get(&current_node)
254                .ok_or_else(|| SynapticError::Graph(format!("node '{current_node}' not found")))?;
255            let output = self
256                .execute_with_cache(&current_node, node.as_ref(), state.clone())
257                .await?;
258
259            // Handle the output
260            let (next_override, interrupt_value) =
261                handle_node_output(output, &mut state, &current_node, |cur, s| {
262                    self.find_next_node(cur, s)
263                });
264
265            // Check for interrupt from Command
266            if let Some(interrupt_val) = interrupt_value {
267                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
268                    let next = self.find_next_node(&current_node, &state);
269                    let checkpoint = make_checkpoint(&state, Some(next), &current_node)?;
270                    checkpointer.put(cfg, &checkpoint).await?;
271                }
272                return Ok(GraphResult::Interrupted {
273                    state,
274                    interrupt_value: interrupt_val,
275                });
276            }
277
278            // Handle fan-out (Send)
279            if next_override.as_deref() == Some("__fanout__") {
280                // TODO: full parallel fan-out
281                break;
282            }
283
284            let next = if let Some(target) = next_override {
285                target
286            } else {
287                // Check interrupt_after (only when no command override)
288                if self.interrupt_after.contains(&current_node) {
289                    let next = self.find_next_node(&current_node, &state);
290                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
291                        let checkpoint = make_checkpoint(&state, Some(next), &current_node)?;
292                        checkpointer.put(cfg, &checkpoint).await?;
293                    }
294                    return Ok(GraphResult::Interrupted {
295                        state,
296                        interrupt_value: serde_json::json!({
297                            "reason": format!("interrupted after node '{current_node}'")
298                        }),
299                    });
300                }
301
302                // Normal routing
303                self.find_next_node(&current_node, &state)
304            };
305
306            // Save checkpoint after each node
307            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
308                let checkpoint = make_checkpoint(&state, Some(next.clone()), &current_node)?;
309                checkpointer.put(cfg, &checkpoint).await?;
310            }
311
312            current_node = next;
313        }
314
315        Ok(GraphResult::Complete(state))
316    }
317
318    /// Stream graph execution, yielding a `GraphEvent` after each node.
319    pub fn stream(&self, state: S, mode: StreamMode) -> GraphStream<'_, S>
320    where
321        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
322    {
323        self.stream_with_config(state, mode, None)
324    }
325
326    /// Stream graph execution with optional checkpoint config.
327    pub fn stream_with_config(
328        &self,
329        state: S,
330        _mode: StreamMode,
331        config: Option<CheckpointConfig>,
332    ) -> GraphStream<'_, S>
333    where
334        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
335    {
336        Box::pin(async_stream::stream! {
337            let mut state = state;
338
339            // If there's a checkpoint, try to resume from it
340            let mut resume_from: Option<String> = None;
341            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
342                match checkpointer.get(cfg).await {
343                    Ok(Some(checkpoint)) => {
344                        match serde_json::from_value(checkpoint.state) {
345                            Ok(s) => {
346                                state = s;
347                                resume_from = checkpoint.next_node;
348                            }
349                            Err(e) => {
350                                yield Err(SynapticError::Graph(format!(
351                                    "failed to deserialize checkpoint state: {e}"
352                                )));
353                                return;
354                            }
355                        }
356                    }
357                    Ok(None) => {}
358                    Err(e) => {
359                        yield Err(e);
360                        return;
361                    }
362                }
363            }
364
365            let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
366            let mut max_iterations = 100;
367
368            loop {
369                if current_node == END {
370                    break;
371                }
372                if max_iterations == 0 {
373                    yield Err(SynapticError::Graph(
374                        "max iterations (100) exceeded — possible infinite loop".to_string(),
375                    ));
376                    return;
377                }
378                max_iterations -= 1;
379
380                // Check interrupt_before
381                if self.interrupt_before.contains(&current_node) {
382                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
383                        match make_checkpoint(&state, Some(current_node.clone()), &current_node) {
384                            Ok(checkpoint) => {
385                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
386                                    yield Err(e);
387                                    return;
388                                }
389                            }
390                            Err(e) => {
391                                yield Err(e);
392                                return;
393                            }
394                        }
395                    }
396                    yield Err(SynapticError::Graph(format!(
397                        "interrupted before node '{current_node}'"
398                    )));
399                    return;
400                }
401
402                // Execute node
403                let node = match self.nodes.get(&current_node) {
404                    Some(n) => n,
405                    None => {
406                        yield Err(SynapticError::Graph(format!("node '{current_node}' not found")));
407                        return;
408                    }
409                };
410
411                let output = match node.process(state.clone()).await {
412                    Ok(o) => o,
413                    Err(e) => {
414                        yield Err(e);
415                        return;
416                    }
417                };
418
419                // Handle the node output
420                let mut interrupt_val = None;
421                let next_override = match output {
422                    NodeOutput::State(new_state) => {
423                        state = new_state;
424                        None
425                    }
426                    NodeOutput::Command(cmd) => {
427                        if let Some(update) = cmd.update {
428                            state.merge(update);
429                        }
430
431                        if let Some(iv) = cmd.interrupt_value {
432                            interrupt_val = Some(iv);
433                            None
434                        } else {
435                            match cmd.goto {
436                                Some(CommandGoto::One(target)) => Some(target),
437                                Some(CommandGoto::Many(_)) => Some(END.to_string()),
438                                None => None,
439                            }
440                        }
441                    }
442                };
443
444                // Yield event
445                let event = GraphEvent {
446                    node: current_node.clone(),
447                    state: state.clone(),
448                };
449                yield Ok(event);
450
451                // Check for interrupt from Command
452                if let Some(iv) = interrupt_val {
453                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
454                        let next = self.find_next_node(&current_node, &state);
455                        match make_checkpoint(&state, Some(next), &current_node) {
456                            Ok(checkpoint) => {
457                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
458                                    yield Err(e);
459                                    return;
460                                }
461                            }
462                            Err(e) => {
463                                yield Err(e);
464                                return;
465                            }
466                        }
467                    }
468                    yield Err(SynapticError::Graph(format!(
469                        "interrupted by node '{current_node}': {iv}"
470                    )));
471                    return;
472                }
473
474                let next = if let Some(target) = next_override {
475                    target
476                } else {
477                    // Check interrupt_after (only when no command override)
478                    if self.interrupt_after.contains(&current_node) {
479                        let next = self.find_next_node(&current_node, &state);
480                        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
481                            match make_checkpoint(&state, Some(next), &current_node) {
482                                Ok(checkpoint) => {
483                                    if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
484                                        yield Err(e);
485                                        return;
486                                    }
487                                }
488                                Err(e) => {
489                                    yield Err(e);
490                                    return;
491                                }
492                            }
493                        }
494                        yield Err(SynapticError::Graph(format!(
495                            "interrupted after node '{current_node}'"
496                        )));
497                        return;
498                    }
499
500                    // Find next node via normal edge routing
501                    self.find_next_node(&current_node, &state)
502                };
503
504                // Save checkpoint
505                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
506                    match make_checkpoint(&state, Some(next.clone()), &current_node) {
507                        Ok(checkpoint) => {
508                            if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
509                                yield Err(e);
510                                return;
511                            }
512                        }
513                        Err(e) => {
514                            yield Err(e);
515                            return;
516                        }
517                    }
518                }
519
520                current_node = next;
521            }
522        })
523    }
524
525    /// Stream graph execution with multiple stream modes.
526    ///
527    /// Each event is tagged with the `StreamMode` that produced it.
528    /// For a single node execution, one event per requested mode is emitted.
529    pub fn stream_modes(&self, state: S, modes: Vec<StreamMode>) -> MultiGraphStream<'_, S>
530    where
531        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
532    {
533        self.stream_modes_with_config(state, modes, None)
534    }
535
536    /// Stream graph execution with multiple stream modes and optional checkpoint config.
537    pub fn stream_modes_with_config(
538        &self,
539        state: S,
540        modes: Vec<StreamMode>,
541        config: Option<CheckpointConfig>,
542    ) -> MultiGraphStream<'_, S>
543    where
544        S: serde::Serialize + serde::de::DeserializeOwned + Clone,
545    {
546        Box::pin(async_stream::stream! {
547            let mut state = state;
548
549            // If there's a checkpoint, try to resume from it
550            let mut resume_from: Option<String> = None;
551            if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
552                match checkpointer.get(cfg).await {
553                    Ok(Some(checkpoint)) => {
554                        match serde_json::from_value(checkpoint.state) {
555                            Ok(s) => {
556                                state = s;
557                                resume_from = checkpoint.next_node;
558                            }
559                            Err(e) => {
560                                yield Err(SynapticError::Graph(format!(
561                                    "failed to deserialize checkpoint state: {e}"
562                                )));
563                                return;
564                            }
565                        }
566                    }
567                    Ok(None) => {}
568                    Err(e) => {
569                        yield Err(e);
570                        return;
571                    }
572                }
573            }
574
575            let mut current_node = resume_from.unwrap_or_else(|| self.entry_point.clone());
576            let mut max_iterations = 100;
577
578            loop {
579                if current_node == END {
580                    break;
581                }
582                if max_iterations == 0 {
583                    yield Err(SynapticError::Graph(
584                        "max iterations (100) exceeded — possible infinite loop".to_string(),
585                    ));
586                    return;
587                }
588                max_iterations -= 1;
589
590                // Check interrupt_before
591                if self.interrupt_before.contains(&current_node) {
592                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
593                        match make_checkpoint(&state, Some(current_node.clone()), &current_node) {
594                            Ok(checkpoint) => {
595                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
596                                    yield Err(e);
597                                    return;
598                                }
599                            }
600                            Err(e) => {
601                                yield Err(e);
602                                return;
603                            }
604                        }
605                    }
606                    yield Err(SynapticError::Graph(format!(
607                        "interrupted before node '{current_node}'"
608                    )));
609                    return;
610                }
611
612                // Snapshot state before node execution (for Updates mode diff)
613                let state_before = state.clone();
614
615                // Execute node
616                let node = match self.nodes.get(&current_node) {
617                    Some(n) => n,
618                    None => {
619                        yield Err(SynapticError::Graph(format!("node '{current_node}' not found")));
620                        return;
621                    }
622                };
623
624                let output = match node.process(state.clone()).await {
625                    Ok(o) => o,
626                    Err(e) => {
627                        yield Err(e);
628                        return;
629                    }
630                };
631
632                // Handle the node output
633                let mut interrupt_val = None;
634                let next_override = match output {
635                    NodeOutput::State(new_state) => {
636                        state = new_state;
637                        None
638                    }
639                    NodeOutput::Command(cmd) => {
640                        if let Some(update) = cmd.update {
641                            state.merge(update);
642                        }
643
644                        if let Some(iv) = cmd.interrupt_value {
645                            interrupt_val = Some(iv);
646                            None
647                        } else {
648                            match cmd.goto {
649                                Some(CommandGoto::One(target)) => Some(target),
650                                Some(CommandGoto::Many(_)) => Some(END.to_string()),
651                                None => None,
652                            }
653                        }
654                    }
655                };
656
657                // Yield events for each requested mode
658                for mode in &modes {
659                    let event = match mode {
660                        StreamMode::Values | StreamMode::Debug | StreamMode::Custom => {
661                            // Full state after node execution
662                            GraphEvent {
663                                node: current_node.clone(),
664                                state: state.clone(),
665                            }
666                        }
667                        StreamMode::Updates => {
668                            // State before node (the "delta" is the difference)
669                            // For Updates, we yield the pre-node state so callers
670                            // can diff against the full Values event
671                            GraphEvent {
672                                node: current_node.clone(),
673                                state: state_before.clone(),
674                            }
675                        }
676                        StreamMode::Messages => {
677                            // Same as Values — callers filter for AI messages
678                            GraphEvent {
679                                node: current_node.clone(),
680                                state: state.clone(),
681                            }
682                        }
683                    };
684                    yield Ok(MultiGraphEvent {
685                        mode: *mode,
686                        event,
687                    });
688                }
689
690                // Check for interrupt from Command
691                if let Some(iv) = interrupt_val {
692                    if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
693                        let next = self.find_next_node(&current_node, &state);
694                        match make_checkpoint(&state, Some(next), &current_node) {
695                            Ok(checkpoint) => {
696                                if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
697                                    yield Err(e);
698                                    return;
699                                }
700                            }
701                            Err(e) => {
702                                yield Err(e);
703                                return;
704                            }
705                        }
706                    }
707                    yield Err(SynapticError::Graph(format!(
708                        "interrupted by node '{current_node}': {iv}"
709                    )));
710                    return;
711                }
712
713                let next = if let Some(target) = next_override {
714                    target
715                } else {
716                    // Check interrupt_after (only when no command override)
717                    if self.interrupt_after.contains(&current_node) {
718                        let next = self.find_next_node(&current_node, &state);
719                        if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
720                            match make_checkpoint(&state, Some(next), &current_node) {
721                                Ok(checkpoint) => {
722                                    if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
723                                        yield Err(e);
724                                        return;
725                                    }
726                                }
727                                Err(e) => {
728                                    yield Err(e);
729                                    return;
730                                }
731                            }
732                        }
733                        yield Err(SynapticError::Graph(format!(
734                            "interrupted after node '{current_node}'"
735                        )));
736                        return;
737                    }
738
739                    // Find next node via normal edge routing
740                    self.find_next_node(&current_node, &state)
741                };
742
743                // Save checkpoint
744                if let (Some(ref checkpointer), Some(ref cfg)) = (&self.checkpointer, &config) {
745                    match make_checkpoint(&state, Some(next.clone()), &current_node) {
746                        Ok(checkpoint) => {
747                            if let Err(e) = checkpointer.put(cfg, &checkpoint).await {
748                                yield Err(e);
749                                return;
750                            }
751                        }
752                        Err(e) => {
753                            yield Err(e);
754                            return;
755                        }
756                    }
757                }
758
759                current_node = next;
760            }
761        })
762    }
763
764    /// Update state on an interrupted graph (for human-in-the-loop).
765    pub async fn update_state(
766        &self,
767        config: &CheckpointConfig,
768        update: S,
769    ) -> Result<(), SynapticError>
770    where
771        S: serde::Serialize + serde::de::DeserializeOwned,
772    {
773        let checkpointer = self
774            .checkpointer
775            .as_ref()
776            .ok_or_else(|| SynapticError::Graph("no checkpointer configured".to_string()))?;
777
778        let checkpoint = checkpointer
779            .get(config)
780            .await?
781            .ok_or_else(|| SynapticError::Graph("no checkpoint found".to_string()))?;
782
783        let mut current_state: S = serde_json::from_value(checkpoint.state)
784            .map_err(|e| SynapticError::Graph(format!("deserialize: {e}")))?;
785
786        current_state.merge(update);
787
788        let updated = Checkpoint::new(
789            serde_json::to_value(&current_state)
790                .map_err(|e| SynapticError::Graph(format!("serialize: {e}")))?,
791            checkpoint.next_node,
792        )
793        .with_metadata("source", serde_json::json!("update_state"));
794        checkpointer.put(config, &updated).await?;
795
796        Ok(())
797    }
798
799    /// Get the current state for a thread from the checkpointer.
800    ///
801    /// Returns `None` if no checkpoint exists for the given thread.
802    pub async fn get_state(&self, config: &CheckpointConfig) -> Result<Option<S>, SynapticError>
803    where
804        S: serde::de::DeserializeOwned,
805    {
806        let checkpointer = self
807            .checkpointer
808            .as_ref()
809            .ok_or_else(|| SynapticError::Graph("no checkpointer configured".to_string()))?;
810
811        match checkpointer.get(config).await? {
812            Some(checkpoint) => {
813                let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
814                    SynapticError::Graph(format!("failed to deserialize checkpoint state: {e}"))
815                })?;
816                Ok(Some(state))
817            }
818            None => Ok(None),
819        }
820    }
821
822    /// Get the state history for a thread (all checkpoints).
823    ///
824    /// Returns a list of `(state, next_node)` pairs, ordered from oldest to newest.
825    pub async fn get_state_history(
826        &self,
827        config: &CheckpointConfig,
828    ) -> Result<Vec<(S, Option<String>)>, SynapticError>
829    where
830        S: serde::de::DeserializeOwned,
831    {
832        let checkpointer = self
833            .checkpointer
834            .as_ref()
835            .ok_or_else(|| SynapticError::Graph("no checkpointer configured".to_string()))?;
836
837        let checkpoints = checkpointer.list(config).await?;
838        let mut history = Vec::with_capacity(checkpoints.len());
839
840        for checkpoint in checkpoints {
841            let state: S = serde_json::from_value(checkpoint.state).map_err(|e| {
842                SynapticError::Graph(format!("failed to deserialize checkpoint state: {e}"))
843            })?;
844            history.push((state, checkpoint.next_node));
845        }
846
847        Ok(history)
848    }
849
850    /// Execute a node, using cache if a CachePolicy is set for it.
851    async fn execute_with_cache(
852        &self,
853        node_name: &str,
854        node: &dyn Node<S>,
855        state: S,
856    ) -> Result<NodeOutput<S>, SynapticError>
857    where
858        S: serde::Serialize,
859    {
860        let policy = self.cache_policies.get(node_name);
861        if policy.is_none() {
862            return node.process(state).await;
863        }
864        let policy = policy.unwrap();
865
866        // Compute state hash for cache key
867        let state_val = serde_json::to_value(&state)
868            .map_err(|e| SynapticError::Graph(format!("cache: serialize state: {e}")))?;
869        let key = hash_state(&state_val);
870
871        // Check cache
872        {
873            let cache = self.cache.read().await;
874            if let Some(node_cache) = cache.get(node_name) {
875                if let Some(entry) = node_cache.get(&key) {
876                    if entry.is_valid() {
877                        return Ok(entry.output.clone());
878                    }
879                }
880            }
881        }
882
883        // Cache miss — execute the node
884        let output = node.process(state).await?;
885
886        // Store in cache
887        {
888            let mut cache = self.cache.write().await;
889            let node_cache = cache.entry(node_name.to_string()).or_default();
890            node_cache.insert(
891                key,
892                CachedEntry {
893                    output: output.clone(),
894                    created: Instant::now(),
895                    ttl: policy.ttl,
896                },
897            );
898        }
899
900        Ok(output)
901    }
902
903    /// Returns true if the given node is deferred (waits for all incoming paths).
904    pub fn is_deferred(&self, node_name: &str) -> bool {
905        self.deferred.contains(node_name)
906    }
907
908    /// Returns the number of incoming edges (fixed + conditional) for a node.
909    pub fn incoming_edge_count(&self, node_name: &str) -> usize {
910        let fixed = self.edges.iter().filter(|e| e.target == node_name).count();
911        // Conditional edges may route to this node but we can't statically count them,
912        // so we count the path_map entries that reference this node.
913        let conditional = self
914            .conditional_edges
915            .iter()
916            .filter_map(|ce| ce.path_map.as_ref())
917            .flat_map(|pm| pm.values())
918            .filter(|target| *target == node_name)
919            .count();
920        fixed + conditional
921    }
922
923    fn find_next_node(&self, current: &str, state: &S) -> String {
924        // Check conditional edges first
925        for ce in &self.conditional_edges {
926            if ce.source == current {
927                return (ce.router)(state);
928            }
929        }
930
931        // Check fixed edges
932        for edge in &self.edges {
933            if edge.source == current {
934                return edge.target.clone();
935            }
936        }
937
938        // No outgoing edge means END
939        END.to_string()
940    }
941}