use crate::server::events::HitlEventEmitter;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterruptData {
pub node_id: String,
pub message: String,
pub data: Value,
pub thread_id: String,
pub checkpoint_id: String,
pub step: usize,
}
impl InterruptData {
pub fn from_interrupted(
thread_id: String,
checkpoint_id: String,
node_id: String,
message: String,
data: Value,
step: usize,
) -> Self {
Self {
node_id,
message,
data,
thread_id,
checkpoint_id,
step,
}
}
pub fn to_json(&self) -> String {
serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterruptedSessionState {
pub session_id: String,
pub thread_id: String,
pub checkpoint_id: String,
pub node_id: String,
pub message: String,
pub data: Value,
pub state: HashMap<String, Value>,
pub step: usize,
pub interrupted_at: u64,
}
impl InterruptedSessionState {
pub fn new(
session_id: String,
interrupt_data: InterruptData,
state: HashMap<String, Value>,
) -> Self {
Self {
session_id,
thread_id: interrupt_data.thread_id,
checkpoint_id: interrupt_data.checkpoint_id,
node_id: interrupt_data.node_id,
message: interrupt_data.message,
data: interrupt_data.data,
state,
step: interrupt_data.step,
interrupted_at: current_timestamp_ms(),
}
}
}
#[derive(Debug, Default)]
pub struct InterruptedSessionStore {
sessions: RwLock<HashMap<String, InterruptedSessionState>>,
}
impl InterruptedSessionStore {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
pub async fn store(&self, session_id: &str, state: InterruptedSessionState) {
let mut sessions = self.sessions.write().await;
sessions.insert(session_id.to_string(), state);
}
pub async fn get(&self, session_id: &str) -> Option<InterruptedSessionState> {
let sessions = self.sessions.read().await;
sessions.get(session_id).cloned()
}
pub async fn remove(&self, session_id: &str) -> Option<InterruptedSessionState> {
let mut sessions = self.sessions.write().await;
sessions.remove(session_id)
}
pub async fn is_interrupted(&self, session_id: &str) -> bool {
let sessions = self.sessions.read().await;
sessions.contains_key(session_id)
}
pub async fn list_interrupted(&self) -> Vec<String> {
let sessions = self.sessions.read().await;
sessions.keys().cloned().collect()
}
pub async fn cleanup_old(&self, max_age_ms: u64) {
let now = current_timestamp_ms();
let mut sessions = self.sessions.write().await;
sessions.retain(|_, state| now - state.interrupted_at < max_age_ms);
}
}
pub struct GraphInterruptHandler {
store: Arc<InterruptedSessionStore>,
emitter: Option<HitlEventEmitter>,
}
impl GraphInterruptHandler {
pub fn new(store: Arc<InterruptedSessionStore>) -> Self {
Self {
store,
emitter: None,
}
}
pub fn with_emitter(store: Arc<InterruptedSessionStore>, emitter: HitlEventEmitter) -> Self {
Self {
store,
emitter: Some(emitter),
}
}
#[allow(clippy::too_many_arguments)]
pub async fn handle_interrupt(
&self,
session_id: &str,
thread_id: String,
checkpoint_id: String,
node_id: String,
message: String,
data: Value,
state: HashMap<String, Value>,
step: usize,
) {
let interrupt_data = InterruptData::from_interrupted(
thread_id,
checkpoint_id,
node_id.clone(),
message.clone(),
data.clone(),
step,
);
if let Some(emitter) = &self.emitter {
emitter
.emit_interrupt(&node_id, &message, data.clone())
.await;
}
let session_state =
InterruptedSessionState::new(session_id.to_string(), interrupt_data, state);
self.store.store(session_id, session_state).await;
}
#[allow(dead_code, clippy::too_many_arguments)]
pub async fn handle_graph_interrupt_direct(
&self,
session_id: &str,
thread_id: String,
checkpoint_id: String,
interrupt_type: &str,
interrupt_message: String,
interrupt_data: Option<Value>,
state: HashMap<String, Value>,
step: usize,
) {
let (node_id, message, data) = match interrupt_type {
"before" => {
let node = interrupt_message.clone();
(
node.clone(),
format!("Interrupt before '{}'", node),
Value::Null,
)
}
"after" => {
let node = interrupt_message.clone();
(
node.clone(),
format!("Interrupt after '{}'", node),
Value::Null,
)
}
_ => {
let node_id = "dynamic".to_string();
(
node_id,
interrupt_message,
interrupt_data.unwrap_or(Value::Null),
)
}
};
self.handle_interrupt(
session_id,
thread_id,
checkpoint_id,
node_id,
message,
data,
state,
step,
)
.await;
}
pub async fn get_interrupted_state(&self, session_id: &str) -> Option<InterruptedSessionState> {
self.store.get(session_id).await
}
pub async fn is_interrupted(&self, session_id: &str) -> bool {
self.store.is_interrupted(session_id).await
}
pub async fn clear_interrupted_state(
&self,
session_id: &str,
) -> Option<InterruptedSessionState> {
self.store.remove(session_id).await
}
}
pub fn serialize_interrupt_data(node_id: &str, message: &str, data: Value) -> Value {
serde_json::json!({
"nodeId": node_id,
"message": message,
"data": data,
"timestamp": current_timestamp_ms()
})
}
pub fn deserialize_interrupt_response(response: Value) -> HashMap<String, Value> {
match response {
Value::Object(map) => map.into_iter().collect(),
_ => {
let mut updates = HashMap::new();
updates.insert("response".to_string(), response);
updates
}
}
}
fn current_timestamp_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
lazy_static::lazy_static! {
pub static ref INTERRUPTED_SESSIONS: Arc<InterruptedSessionStore> =
Arc::new(InterruptedSessionStore::new());
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_interrupt_data_creation() {
let data = InterruptData::from_interrupted(
"thread-123".to_string(),
"checkpoint-456".to_string(),
"review".to_string(),
"Human approval required".to_string(),
serde_json::json!({"risk_level": "high"}),
5,
);
assert_eq!(data.node_id, "review");
assert_eq!(data.message, "Human approval required");
assert_eq!(data.thread_id, "thread-123");
assert_eq!(data.checkpoint_id, "checkpoint-456");
assert_eq!(data.step, 5);
}
#[test]
fn test_interrupt_data_serialization() {
let data = InterruptData::from_interrupted(
"thread-123".to_string(),
"checkpoint-456".to_string(),
"review".to_string(),
"Human approval required".to_string(),
serde_json::json!({"risk_level": "high"}),
5,
);
let json = data.to_json();
assert!(json.contains("\"node_id\":\"review\""));
assert!(json.contains("\"message\":\"Human approval required\""));
assert!(json.contains("\"risk_level\":\"high\""));
}
#[tokio::test]
async fn test_interrupted_session_store() {
let store = InterruptedSessionStore::new();
let interrupt_data = InterruptData::from_interrupted(
"thread-123".to_string(),
"checkpoint-456".to_string(),
"review".to_string(),
"Human approval required".to_string(),
serde_json::json!({"risk_level": "high"}),
5,
);
let mut state = HashMap::new();
state.insert("task".to_string(), serde_json::json!("Delete files"));
let session_state =
InterruptedSessionState::new("session-789".to_string(), interrupt_data, state);
store.store("session-789", session_state.clone()).await;
assert!(store.is_interrupted("session-789").await);
let retrieved = store.get("session-789").await;
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.node_id, "review");
assert_eq!(retrieved.thread_id, "thread-123");
let removed = store.remove("session-789").await;
assert!(removed.is_some());
assert!(!store.is_interrupted("session-789").await);
}
#[test]
fn test_serialize_interrupt_data() {
let data = serialize_interrupt_data(
"review",
"Human approval required",
serde_json::json!({"risk_level": "high"}),
);
assert_eq!(data["nodeId"], "review");
assert_eq!(data["message"], "Human approval required");
assert_eq!(data["data"]["risk_level"], "high");
assert!(data["timestamp"].is_number());
}
#[test]
fn test_deserialize_interrupt_response_object() {
let response = serde_json::json!({
"approved": true,
"comment": "Looks good"
});
let updates = deserialize_interrupt_response(response);
assert_eq!(updates.get("approved"), Some(&serde_json::json!(true)));
assert_eq!(
updates.get("comment"),
Some(&serde_json::json!("Looks good"))
);
}
#[test]
fn test_deserialize_interrupt_response_non_object() {
let response = serde_json::json!("approve");
let updates = deserialize_interrupt_response(response);
assert_eq!(updates.get("response"), Some(&serde_json::json!("approve")));
}
#[tokio::test]
async fn test_graph_interrupt_handler() {
let store = Arc::new(InterruptedSessionStore::new());
let handler = GraphInterruptHandler::new(store.clone());
let mut state = HashMap::new();
state.insert("task".to_string(), serde_json::json!("Delete files"));
handler
.handle_interrupt(
"session-123",
"thread-456".to_string(),
"checkpoint-789".to_string(),
"review".to_string(),
"Human approval required".to_string(),
serde_json::json!({"risk_level": "high"}),
state,
5,
)
.await;
assert!(handler.is_interrupted("session-123").await);
let interrupted_state = handler.get_interrupted_state("session-123").await;
assert!(interrupted_state.is_some());
let interrupted_state = interrupted_state.unwrap();
assert_eq!(interrupted_state.node_id, "review");
assert_eq!(interrupted_state.thread_id, "thread-456");
assert_eq!(interrupted_state.checkpoint_id, "checkpoint-789");
let cleared = handler.clear_interrupted_state("session-123").await;
assert!(cleared.is_some());
assert!(!handler.is_interrupted("session-123").await);
}
#[tokio::test]
async fn test_cleanup_old_sessions() {
let store = InterruptedSessionStore::new();
let interrupt_data = InterruptData::from_interrupted(
"thread-123".to_string(),
"checkpoint-456".to_string(),
"review".to_string(),
"Old interrupt".to_string(),
Value::Null,
1,
);
let mut session_state =
InterruptedSessionState::new("old-session".to_string(), interrupt_data, HashMap::new());
session_state.interrupted_at = current_timestamp_ms() - (2 * 60 * 60 * 1000);
store.store("old-session", session_state).await;
let interrupt_data = InterruptData::from_interrupted(
"thread-789".to_string(),
"checkpoint-012".to_string(),
"review".to_string(),
"Recent interrupt".to_string(),
Value::Null,
1,
);
let session_state = InterruptedSessionState::new(
"recent-session".to_string(),
interrupt_data,
HashMap::new(),
);
store.store("recent-session", session_state).await;
store.cleanup_old(60 * 60 * 1000).await;
assert!(!store.is_interrupted("old-session").await);
assert!(store.is_interrupted("recent-session").await);
}
}