1mod context;
7
8use std::collections::{HashMap, HashSet};
9use std::hash::Hash;
10use xxhash_rust::xxh3::Xxh3;
11
12use chrono::{DateTime, Utc};
13
14pub use context::InterruptContext;
15
16tokio::task_local! {
20 pub static INTERRUPT_CONTEXT: std::sync::Arc<InterruptContext>;
21}
22
23#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
29pub struct InterruptSignal {
30 pub index: usize,
32
33 pub id: Option<String>,
35
36 pub payload: serde_json::Value,
38
39 #[serde(default = "InterruptSignal::current_timestamp")]
41 pub timestamp: DateTime<Utc>,
42}
43
44impl InterruptSignal {
45 #[must_use]
47 fn current_timestamp() -> DateTime<Utc> {
48 Utc::now()
49 }
50}
51
52#[derive(Clone, Debug)]
56pub enum ResumeValue {
57 Single(serde_json::Value),
59
60 ById(std::collections::HashMap<String, serde_json::Value>),
63
64 ByNamespace(std::collections::HashMap<String, serde_json::Value>),
68}
69
70#[allow(
72 clippy::fallible_impl_from,
73 reason = "empty Vec is converted to Null, which is a valid value"
74)]
75impl From<Vec<serde_json::Value>> for ResumeValue {
76 fn from(values: Vec<serde_json::Value>) -> Self {
77 if values.is_empty() {
79 Self::Single(serde_json::Value::Null)
80 } else if values.len() == 1 {
81 Self::Single(values.into_iter().next().unwrap())
82 } else {
83 let map: std::collections::HashMap<String, serde_json::Value> = values
85 .into_iter()
86 .enumerate()
87 .map(|(i, v)| (i.to_string(), v))
88 .collect();
89 Self::ByNamespace(map)
90 }
91 }
92}
93
94#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
98pub struct InterruptRecord {
99 pub id: String,
101
102 pub node: String,
104
105 pub payload: serde_json::Value,
107
108 pub timestamp: DateTime<Utc>,
110
111 pub resumed_at: Option<DateTime<Utc>>,
113
114 pub resume_value: Option<serde_json::Value>,
116}
117
118#[must_use]
142pub fn extract_namespace(interrupt_id: &str) -> Option<&str> {
143 if let Some(colon_pos) = interrupt_id.find(':') {
145 if colon_pos > 0 {
147 return Some(&interrupt_id[..colon_pos]);
148 }
149 }
150 None
151}
152
153#[expect(
196 clippy::implicit_hasher,
197 reason = "accepting standard HashMap is fine for this use case"
198)]
199pub fn validate_resume_coverage(
200 pending: &[InterruptSignal],
201 resume_values: &HashMap<String, serde_json::Value>,
202) -> Result<(), Vec<String>> {
203 let mut uncovered = Vec::new();
204
205 for signal in pending {
206 if let Some(ref id) = signal.id {
207 if !resume_values.contains_key(id) {
208 uncovered.push(id.clone());
209 }
210 }
211 }
212
213 if uncovered.is_empty() {
214 Ok(())
215 } else {
216 Err(uncovered)
217 }
218}
219
220pub const HIDDEN_TAG: &str = "__hidden__";
227
228#[must_use]
262pub fn is_hidden_node(node_name: &str, tags: &[String]) -> bool {
263 let is_hidden_by_name =
264 node_name.starts_with("__") && node_name.ends_with("__") && node_name.len() > 4;
265 let is_hidden_by_tag = tags.iter().any(|tag| tag == HIDDEN_TAG);
266 is_hidden_by_name || is_hidden_by_tag
267}
268
269#[must_use]
299pub fn generate_interrupt_id(node_name: &str, index: usize) -> String {
300 let mut hasher = Xxh3::new();
301 node_name.hash(&mut hasher);
302 index.hash(&mut hasher);
303 let hash = hasher.digest128();
304 format!("{hash:032x}")
305}
306
307#[allow(
331 clippy::implicit_hasher,
332 reason = "accepting standard HashSet is fine for this use case"
333)]
334#[must_use]
335pub fn should_interrupt<S: crate::State>(
336 pending_tasks: &[crate::PendingTask<S>],
337 interrupt_before: &HashSet<String>,
338 interrupt_after: &HashSet<String>,
339 channel_versions: &HashMap<String, u64>,
340 versions_seen_for_interrupt: &HashMap<String, u64>,
341) -> Option<Vec<InterruptSignal>> {
342 let any_updates = channel_versions
344 .iter()
345 .any(|(chan, ver)| ver > versions_seen_for_interrupt.get(chan).unwrap_or(&0));
346
347 if !any_updates && !versions_seen_for_interrupt.is_empty() {
348 return None;
349 }
350
351 let mut signals = Vec::new();
353
354 for task in pending_tasks {
355 let node_name = &task.node_name;
356 let tags: &[String] = &[];
358
359 if is_hidden_node(node_name, tags) {
362 continue;
363 }
364
365 if interrupt_before.contains(node_name) {
366 let timestamp = Utc::now();
367 signals.push(InterruptSignal {
368 index: signals.len(),
369 id: Some(generate_interrupt_id(node_name, signals.len())),
370 payload: serde_json::json!({
371 "node": node_name,
372 "reason": "interrupt_before",
373 }),
374 timestamp,
375 });
376 }
377
378 if interrupt_after.contains(node_name) {
379 let timestamp = Utc::now();
380 signals.push(InterruptSignal {
381 index: signals.len(),
382 id: Some(generate_interrupt_id(node_name, signals.len())),
383 payload: serde_json::json!({
384 "node": node_name,
385 "reason": "interrupt_after",
386 }),
387 timestamp,
388 });
389 }
390 }
391
392 if signals.is_empty() {
393 None
394 } else {
395 Some(signals)
396 }
397}
398
399#[expect(
416 clippy::unused_async,
417 reason = "async is required by the interrupt! macro's .await expansion"
418)]
419pub async fn __interrupt_impl(
420 ctx: &crate::interrupt::InterruptContext,
421 payload: serde_json::Value,
422 id: Option<&str>,
423) -> Result<serde_json::Value, crate::JunctureError> {
424 let index = ctx.next_index();
425
426 let interrupt_id = id.map_or_else(
427 || {
428 generate_interrupt_id("current_node", index)
430 },
431 std::string::ToString::to_string,
432 );
433
434 if let Some(value) = ctx.get_resume_value(index) {
435 return Ok(value);
436 }
437
438 ctx.send_interrupt(InterruptSignal {
439 index,
440 id: Some(interrupt_id),
441 payload,
442 timestamp: Utc::now(),
443 })
444 .map_err(|_err| crate::JunctureError::execution("interrupt channel closed"))?;
445
446 Err(crate::JunctureError::interrupted(index))
447}
448
449#[derive(Clone, Debug, Default)]
454pub struct Scratchpad {
455 processed_interrupts: HashSet<String>,
457
458 data: HashMap<String, serde_json::Value>,
460
461 interrupt_history: Vec<InterruptRecord>,
463}
464
465impl Scratchpad {
466 #[must_use]
468 pub fn new() -> Self {
469 Self {
470 processed_interrupts: HashSet::new(),
471 data: HashMap::new(),
472 interrupt_history: Vec::new(),
473 }
474 }
475
476 #[must_use]
486 pub fn is_interrupt_processed(&self, id: &str) -> bool {
487 self.processed_interrupts.contains(id)
488 }
489
490 #[must_use]
495 pub fn get_null_resume(&self, interrupt_id: &str) -> bool {
496 self.is_interrupt_processed(interrupt_id)
497 }
498
499 pub fn mark_interrupt_processed(&mut self, id: &str) {
505 self.processed_interrupts.insert(id.to_string());
506 }
507
508 #[must_use]
518 pub fn get_data(&self, key: &str) -> Option<&serde_json::Value> {
519 self.data.get(key)
520 }
521
522 pub fn set_data(&mut self, key: String, value: serde_json::Value) {
529 self.data.insert(key, value);
530 }
531
532 pub fn record_interrupt(&mut self, id: String, node: String, payload: serde_json::Value) {
543 let record = InterruptRecord {
544 id,
545 node,
546 payload,
547 timestamp: Utc::now(),
548 resumed_at: None,
549 resume_value: None,
550 };
551 self.interrupt_history.push(record);
552 }
553
554 pub fn record_resume(&mut self, id: &str, value: serde_json::Value) {
564 if let Some(record) = self.interrupt_history.iter_mut().find(|r| r.id == id) {
565 record.resumed_at = Some(Utc::now());
566 record.resume_value = Some(value);
567 }
568 }
569
570 #[must_use]
578 pub fn interrupt_history(&self) -> &[InterruptRecord] {
579 &self.interrupt_history
580 }
581
582 pub fn clear_transient(&mut self) {
587 self.data
588 .retain(|key, _value| key.starts_with("null_resume:"));
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
599 fn scratchpad_get_null_resume() {
600 let mut pad = Scratchpad::new();
601 assert!(!pad.get_null_resume("int-1"));
602 pad.mark_interrupt_processed("int-1");
603 assert!(pad.get_null_resume("int-1"));
604 assert!(!pad.get_null_resume("int-2"));
605 }
606
607 #[test]
608 fn scratchpad_record_interrupt() {
609 let mut pad = Scratchpad::new();
610 pad.record_interrupt(
611 "int-1".to_string(),
612 "node_a".to_string(),
613 serde_json::json!({"reason": "test"}),
614 );
615
616 let history = pad.interrupt_history();
617 assert_eq!(history.len(), 1);
618 assert_eq!(history[0].id, "int-1");
619 assert_eq!(history[0].node, "node_a");
620 assert_eq!(history[0].payload["reason"], "test");
621 assert!(history[0].resumed_at.is_none());
622 assert!(history[0].resume_value.is_none());
623 }
624
625 #[test]
626 fn scratchpad_record_resume() {
627 let mut pad = Scratchpad::new();
628 pad.record_interrupt(
629 "int-1".to_string(),
630 "node_a".to_string(),
631 serde_json::json!({}),
632 );
633
634 pad.record_resume("int-1", serde_json::json!("approved"));
635
636 let history = pad.interrupt_history();
637 assert_eq!(history.len(), 1);
638 assert!(history[0].resumed_at.is_some());
639 assert_eq!(history[0].resume_value, Some(serde_json::json!("approved")));
640 }
641
642 #[test]
643 fn scratchpad_interrupt_history_order() {
644 let mut pad = Scratchpad::new();
645
646 pad.record_interrupt(
647 "int-1".to_string(),
648 "node_a".to_string(),
649 serde_json::json!({}),
650 );
651 std::thread::sleep(std::time::Duration::from_millis(10));
652 pad.record_interrupt(
653 "int-2".to_string(),
654 "node_b".to_string(),
655 serde_json::json!({}),
656 );
657
658 let history = pad.interrupt_history();
659 assert_eq!(history.len(), 2);
660 assert!(history[0].timestamp < history[1].timestamp);
661 }
662
663 #[test]
664 fn scratchpad_clear_transient() {
665 let mut pad = Scratchpad::new();
666 pad.set_data("temp_key".to_string(), serde_json::json!("temp"));
667 pad.set_data(
668 "null_resume:persistent".to_string(),
669 serde_json::json!("keep"),
670 );
671
672 pad.clear_transient();
673
674 assert!(pad.get_data("temp_key").is_none());
675 assert_eq!(
676 pad.get_data("null_resume:persistent"),
677 Some(&serde_json::json!("keep"))
678 );
679 }
680
681 #[test]
682 fn scratchpad_clear_transient_empty() {
683 let mut pad = Scratchpad::new();
684 pad.clear_transient();
685 assert!(pad.data.is_empty());
686 }
687
688 #[test]
689 fn scratchpad_record_resume_nonexistent() {
690 let mut pad = Scratchpad::new();
691 pad.record_resume("nonexistent", serde_json::json!("value"));
693 assert_eq!(pad.interrupt_history().len(), 0);
694 }
695
696 #[test]
699 fn extract_namespace_with_namespace() {
700 assert_eq!(extract_namespace("agent:review#0"), Some("agent"));
701 assert_eq!(extract_namespace("namespace:node#index"), Some("namespace"));
702 }
703
704 #[test]
705 fn extract_namespace_without_namespace() {
706 assert_eq!(extract_namespace("node_name#index"), None);
707 assert_eq!(extract_namespace("simple_id"), None);
708 assert_eq!(extract_namespace("no_colon"), None);
709 }
710
711 #[test]
712 fn extract_namespace_empty_namespace() {
713 assert_eq!(extract_namespace(":node#index"), None);
714 assert_eq!(extract_namespace(":only_colon"), None);
715 }
716
717 #[test]
720 fn validate_resume_coverage_complete() {
721 let pending = vec![InterruptSignal {
722 index: 0,
723 id: Some("int-1".to_string()),
724 payload: serde_json::json!({}),
725 timestamp: Utc::now(),
726 }];
727
728 let mut resume_values = HashMap::new();
729 resume_values.insert("int-1".to_string(), serde_json::json!("value"));
730
731 validate_resume_coverage(&pending, &resume_values).unwrap();
732 }
733
734 #[test]
735 fn validate_resume_coverage_incomplete() {
736 let pending = vec![
737 InterruptSignal {
738 index: 0,
739 id: Some("int-1".to_string()),
740 payload: serde_json::json!({}),
741 timestamp: Utc::now(),
742 },
743 InterruptSignal {
744 index: 1,
745 id: Some("int-2".to_string()),
746 payload: serde_json::json!({}),
747 timestamp: Utc::now(),
748 },
749 ];
750
751 let mut resume_values = HashMap::new();
752 resume_values.insert("int-1".to_string(), serde_json::json!("value"));
753
754 let result = validate_resume_coverage(&pending, &resume_values);
755 assert!(result.is_err());
756 assert_eq!(result.unwrap_err(), vec!["int-2".to_string()]);
757 }
758
759 #[test]
760 fn validate_resume_coverage_empty_pending() {
761 let pending = vec![];
762 let resume_values = HashMap::new();
763
764 validate_resume_coverage(&pending, &resume_values).unwrap();
765 }
766
767 #[test]
768 fn validate_resume_coverage_no_id() {
769 let pending = vec![InterruptSignal {
770 index: 0,
771 id: None,
772 payload: serde_json::json!({}),
773 timestamp: Utc::now(),
774 }];
775
776 let resume_values = HashMap::new();
777
778 validate_resume_coverage(&pending, &resume_values).unwrap();
780 }
781
782 #[test]
783 fn validate_resume_coverage_multiple_uncovered() {
784 let pending = vec![
785 InterruptSignal {
786 index: 0,
787 id: Some("int-1".to_string()),
788 payload: serde_json::json!({}),
789 timestamp: Utc::now(),
790 },
791 InterruptSignal {
792 index: 1,
793 id: Some("int-2".to_string()),
794 payload: serde_json::json!({}),
795 timestamp: Utc::now(),
796 },
797 InterruptSignal {
798 index: 2,
799 id: Some("int-3".to_string()),
800 payload: serde_json::json!({}),
801 timestamp: Utc::now(),
802 },
803 ];
804
805 let resume_values = HashMap::new();
806
807 let result = validate_resume_coverage(&pending, &resume_values);
808 assert!(result.is_err());
809 let uncovered = result.unwrap_err();
810 assert_eq!(uncovered.len(), 3);
811 assert!(uncovered.contains(&"int-1".to_string()));
812 assert!(uncovered.contains(&"int-2".to_string()));
813 assert!(uncovered.contains(&"int-3".to_string()));
814 }
815
816 #[test]
819 fn hidden_node_double_underscore_prefix_and_suffix() {
820 assert!(is_hidden_node("__route__", &[]));
821 assert!(is_hidden_node("__internal__", &[]));
822 assert!(is_hidden_node("__error_handler__", &[]));
823 }
824
825 #[test]
826 fn normal_nodes_are_not_hidden() {
827 assert!(!is_hidden_node("my_node", &[]));
828 assert!(!is_hidden_node("agent", &[]));
829 assert!(!is_hidden_node("review", &[]));
830 }
831
832 #[test]
833 fn partial_underscore_prefix_is_not_hidden() {
834 assert!(!is_hidden_node("__incomplete", &[]));
835 assert!(!is_hidden_node("__only_start", &[]));
836 }
837
838 #[test]
839 fn partial_underscore_suffix_is_not_hidden() {
840 assert!(!is_hidden_node("only_end__", &[]));
841 assert!(!is_hidden_node("incomplete__", &[]));
842 }
843
844 #[test]
845 fn bare_double_underscore_is_not_hidden() {
846 assert!(!is_hidden_node("____", &[]));
848 }
849
850 #[test]
851 fn hidden_tag_constant_value() {
852 assert_eq!(HIDDEN_TAG, "__hidden__");
853 }
854
855 #[test]
856 fn hidden_node_by_tag() {
857 assert!(is_hidden_node("my_node", &["__hidden__".to_string()]));
859 assert!(is_hidden_node(
860 "agent",
861 &["__hidden__".to_string(), "other".to_string()]
862 ));
863 }
864
865 #[test]
866 fn hidden_node_by_tag_only_when_exact_match() {
867 assert!(!is_hidden_node("my_node", &["_hidden_".to_string()]));
869 assert!(!is_hidden_node("my_node", &["hidden".to_string()]));
870 assert!(!is_hidden_node("my_node", &["__hidden".to_string()]));
871 assert!(!is_hidden_node("my_node", &["hidden__".to_string()]));
872 }
873
874 #[test]
875 fn hidden_node_by_name_or_tag() {
876 assert!(is_hidden_node("__internal__", &[])); assert!(is_hidden_node("normal_node", &["__hidden__".to_string()])); assert!(is_hidden_node("__internal__", &["__hidden__".to_string()])); }
881
882 #[test]
883 fn normal_node_without_tag_not_hidden() {
884 assert!(!is_hidden_node("my_node", &[]));
885 assert!(!is_hidden_node("my_node", &["other_tag".to_string()]));
886 assert!(!is_hidden_node(
887 "my_node",
888 &["tag1".to_string(), "tag2".to_string()]
889 ));
890 }
891
892 #[derive(Clone, Debug, Default, serde::Serialize)]
896 struct TestState;
897
898 impl crate::State for TestState {
899 type Update = TestUpdate;
900 type FieldVersions = crate::state::FieldVersions;
901
902 fn apply(&mut self, _: Self::Update) -> crate::FieldsChanged {
903 crate::FieldsChanged(0)
904 }
905 fn reset_ephemeral(&mut self) {}
906 }
907
908 #[derive(Clone, Debug, Default, serde::Serialize)]
909 struct TestUpdate;
910
911 fn make_task(node_name: &str) -> crate::PendingTask<TestState> {
912 crate::PendingTask::pull(uuid::Uuid::new_v4().to_string(), node_name.to_string())
913 }
914
915 #[test]
916 fn hidden_nodes_filtered_from_interrupt_before() {
917 let tasks = vec![
918 make_task("agent"),
919 make_task("__route__"),
920 make_task("review"),
921 ];
922
923 let mut interrupt_before = HashSet::new();
924 interrupt_before.insert("agent".to_string());
925 interrupt_before.insert("__route__".to_string());
926 interrupt_before.insert("review".to_string());
927
928 let channel_versions: HashMap<String, u64> =
929 std::iter::once(("field_0".to_string(), 1u64)).collect();
930 let versions_seen = HashMap::new();
931
932 let result = should_interrupt(
933 &tasks,
934 &interrupt_before,
935 &HashSet::new(),
936 &channel_versions,
937 &versions_seen,
938 );
939
940 let signals = result.expect("should return signals");
941 assert_eq!(signals.len(), 2, "hidden node __route__ should be filtered");
943 let nodes: Vec<&str> = signals
944 .iter()
945 .filter_map(|s| s.payload.get("node").and_then(|v| v.as_str()))
946 .collect();
947 assert!(nodes.contains(&"agent"), "agent should be present");
948 assert!(nodes.contains(&"review"), "review should be present");
949 assert!(
950 !nodes.contains(&"__route__"),
951 "__route__ should be filtered"
952 );
953 }
954
955 #[test]
956 fn hidden_nodes_filtered_from_interrupt_after() {
957 let tasks = vec![make_task("agent"), make_task("__internal_router__")];
958
959 let mut interrupt_after = HashSet::new();
960 interrupt_after.insert("agent".to_string());
961 interrupt_after.insert("__internal_router__".to_string());
962
963 let channel_versions: HashMap<String, u64> =
964 std::iter::once(("field_0".to_string(), 1u64)).collect();
965 let versions_seen = HashMap::new();
966
967 let result = should_interrupt(
968 &tasks,
969 &HashSet::new(),
970 &interrupt_after,
971 &channel_versions,
972 &versions_seen,
973 );
974
975 let signals = result.expect("should return signals");
976 assert_eq!(
977 signals.len(),
978 1,
979 "only agent should produce a signal, __internal_router__ filtered"
980 );
981 let node = signals[0]
982 .payload
983 .get("node")
984 .and_then(|v| v.as_str())
985 .expect("should have node");
986 assert_eq!(node, "agent");
987 }
988
989 #[test]
990 fn all_hidden_nodes_produces_no_signals() {
991 let tasks = vec![make_task("__route__"), make_task("__handler__")];
992
993 let mut interrupt_before = HashSet::new();
994 interrupt_before.insert("__route__".to_string());
995 interrupt_before.insert("__handler__".to_string());
996
997 let channel_versions: HashMap<String, u64> =
998 std::iter::once(("field_0".to_string(), 1u64)).collect();
999 let versions_seen = HashMap::new();
1000
1001 let result = should_interrupt(
1002 &tasks,
1003 &interrupt_before,
1004 &HashSet::new(),
1005 &channel_versions,
1006 &versions_seen,
1007 );
1008
1009 assert!(
1010 result.is_none(),
1011 "all-hidden-node tasks should produce no interrupt signals"
1012 );
1013 }
1014}
1015
1016