1use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8
9use crate::kernel::identity::RunId;
10
11pub type InterruptId = String;
13
14#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
16pub enum InterruptKind {
17 HumanInTheLoop,
19 ApprovalRequired,
21 ToolCallWaiting,
23 Custom(String),
25}
26
27#[derive(Clone, Debug, Serialize, Deserialize)]
29pub struct Interrupt {
30 pub id: InterruptId,
32 pub thread_id: RunId,
34 pub kind: InterruptKind,
36 pub payload_schema: serde_json::Value,
38 pub created_at: DateTime<Utc>,
40 pub step_id: Option<String>,
42}
43
44impl Interrupt {
45 pub fn new(
47 id: InterruptId,
48 thread_id: RunId,
49 kind: InterruptKind,
50 payload_schema: serde_json::Value,
51 ) -> Self {
52 Self {
53 id,
54 thread_id,
55 kind,
56 payload_schema,
57 created_at: Utc::now(),
58 step_id: None,
59 }
60 }
61
62 pub fn with_step(
64 id: InterruptId,
65 thread_id: RunId,
66 kind: InterruptKind,
67 payload_schema: serde_json::Value,
68 step_id: String,
69 ) -> Self {
70 Self {
71 id,
72 thread_id,
73 kind,
74 payload_schema,
75 created_at: Utc::now(),
76 step_id: Some(step_id),
77 }
78 }
79}
80
81pub trait InterruptStore: Send + Sync {
83 fn save(&self, interrupt: &Interrupt) -> Result<(), InterruptError>;
85
86 fn load(&self, id: &InterruptId) -> Result<Option<Interrupt>, InterruptError>;
88
89 fn load_for_run(&self, thread_id: &RunId) -> Result<Vec<Interrupt>, InterruptError>;
91
92 fn delete(&self, id: &InterruptId) -> Result<(), InterruptError>;
94}
95
96#[derive(Debug, thiserror::Error)]
98pub enum InterruptError {
99 #[error("Interrupt store error: {0}")]
100 Store(String),
101 #[error("Interrupt not found: {0}")]
102 NotFound(InterruptId),
103}
104
105#[derive(Debug, Default)]
107pub struct InMemoryInterruptStore {
108 by_id: std::sync::RwLock<std::collections::HashMap<InterruptId, Interrupt>>,
109}
110
111impl InMemoryInterruptStore {
112 pub fn new() -> Self {
113 Self::default()
114 }
115}
116
117impl InterruptStore for InMemoryInterruptStore {
118 fn save(&self, interrupt: &Interrupt) -> Result<(), InterruptError> {
119 let mut guard = self
120 .by_id
121 .write()
122 .map_err(|e| InterruptError::Store(e.to_string()))?;
123 guard.insert(interrupt.id.clone(), interrupt.clone());
124 Ok(())
125 }
126
127 fn load(&self, id: &InterruptId) -> Result<Option<Interrupt>, InterruptError> {
128 let guard = self
129 .by_id
130 .read()
131 .map_err(|e| InterruptError::Store(e.to_string()))?;
132 Ok(guard.get(id).cloned())
133 }
134
135 fn load_for_run(&self, thread_id: &RunId) -> Result<Vec<Interrupt>, InterruptError> {
136 let guard = self
137 .by_id
138 .read()
139 .map_err(|e| InterruptError::Store(e.to_string()))?;
140 Ok(guard
141 .values()
142 .filter(|i| i.thread_id == *thread_id)
143 .cloned()
144 .collect())
145 }
146
147 fn delete(&self, id: &InterruptId) -> Result<(), InterruptError> {
148 let mut guard = self
149 .by_id
150 .write()
151 .map_err(|e| InterruptError::Store(e.to_string()))?;
152 guard.remove(id);
153 Ok(())
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn save_and_load_interrupt() {
163 let store = InMemoryInterruptStore::new();
164 let interrupt = Interrupt::new(
165 "intr-1".into(),
166 "run-1".into(),
167 InterruptKind::HumanInTheLoop,
168 serde_json::json!({"type": "string"}),
169 );
170 store.save(&interrupt).unwrap();
171
172 let loaded = store.load(&"intr-1".into()).unwrap();
173 assert!(loaded.is_some());
174 assert_eq!(loaded.unwrap().id, "intr-1");
175 }
176
177 #[test]
178 fn load_for_run_filters() {
179 let store = InMemoryInterruptStore::new();
180 store
181 .save(&Interrupt::new(
182 "i1".into(),
183 "run-a".into(),
184 InterruptKind::ApprovalRequired,
185 serde_json::json!({}),
186 ))
187 .unwrap();
188 store
189 .save(&Interrupt::new(
190 "i2".into(),
191 "run-b".into(),
192 InterruptKind::HumanInTheLoop,
193 serde_json::json!({}),
194 ))
195 .unwrap();
196 store
197 .save(&Interrupt::new(
198 "i3".into(),
199 "run-a".into(),
200 InterruptKind::ToolCallWaiting,
201 serde_json::json!({}),
202 ))
203 .unwrap();
204
205 let run_a = store.load_for_run(&"run-a".into()).unwrap();
206 assert_eq!(run_a.len(), 2);
207 }
208
209 #[test]
210 fn delete_removes_interrupt() {
211 let store = InMemoryInterruptStore::new();
212 store
213 .save(&Interrupt::new(
214 "i1".into(),
215 "run-1".into(),
216 InterruptKind::Custom("test".into()),
217 serde_json::json!({}),
218 ))
219 .unwrap();
220 store.delete(&"i1".into()).unwrap();
221 let loaded = store.load(&"i1".into()).unwrap();
222 assert!(loaded.is_none());
223 }
224}