use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Send {
pub node: String,
pub arg: serde_json::Value,
}
impl Send {
pub fn new(node: impl Into<String>, arg: serde_json::Value) -> Self {
Self {
node: node.into(),
arg,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Command {
pub update: Option<HashMap<String, serde_json::Value>>,
pub goto: Option<Vec<String>>,
pub resume: Option<serde_json::Value>,
}
impl Command {
pub fn new() -> Self {
Self::default()
}
pub fn with_update(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.update
.get_or_insert_with(HashMap::new)
.insert(key.into(), value);
self
}
pub fn with_goto(mut self, nodes: Vec<impl Into<String>>) -> Self {
self.goto = Some(nodes.into_iter().map(|n| n.into()).collect());
self
}
pub fn with_resume(mut self, value: serde_json::Value) -> Self {
self.resume = Some(value);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum StreamMode {
Values,
Updates,
Checkpoints,
Tasks,
Debug,
Messages,
Custom,
}
impl Default for StreamMode {
fn default() -> Self {
StreamMode::Values
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event", content = "data")]
pub enum StreamEvent {
Values {
ns: Vec<String>,
data: serde_json::Value,
interrupts: Vec<Interrupt>,
},
Updates {
ns: Vec<String>,
data: serde_json::Value,
node: String,
},
Checkpoint {
ns: Vec<String>,
checkpoint_id: String,
step: usize,
},
TaskStart {
task_id: String,
node: String,
},
TaskEnd {
task_id: String,
node: String,
result: serde_json::Value,
},
Debug {
message: String,
context: HashMap<String, serde_json::Value>,
},
Message {
content: String,
metadata: HashMap<String, serde_json::Value>,
},
Custom {
event_type: String,
data: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Interrupt {
pub value: serde_json::Value,
pub when: InterruptType,
pub node: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum InterruptType {
Before,
During,
After,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 100,
max_delay_ms: 10_000,
backoff_multiplier: 2.0,
}
}
}
impl RetryPolicy {
pub fn new(max_attempts: usize) -> Self {
Self {
max_attempts,
..Default::default()
}
}
pub fn delay_for_attempt(&self, attempt: usize) -> u64 {
let delay = self.initial_delay_ms as f64 * self.backoff_multiplier.powi(attempt as i32);
delay.min(self.max_delay_ms as f64) as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_send_creation() {
let send = Send::new("test_node", serde_json::json!({"key": "value"}));
assert_eq!(send.node, "test_node");
assert!(send.arg.is_object());
}
#[test]
fn test_command_builder() {
let cmd = Command::new()
.with_update("key1", serde_json::json!("val1"))
.with_update("key2", serde_json::json!(42))
.with_goto(vec!["node1", "node2"]);
assert_eq!(cmd.update.as_ref().unwrap().len(), 2);
assert_eq!(cmd.goto.as_ref().unwrap().len(), 2);
}
#[test]
fn test_retry_policy_delay() {
let policy = RetryPolicy::default();
assert_eq!(policy.delay_for_attempt(0), 100);
assert_eq!(policy.delay_for_attempt(1), 200);
assert_eq!(policy.delay_for_attempt(2), 400);
assert_eq!(policy.delay_for_attempt(10), 10_000);
}
#[test]
fn test_stream_mode_serialization() {
let mode = StreamMode::Values;
let json = serde_json::to_string(&mode).unwrap();
assert_eq!(json, "\"values\"");
let deserialized: StreamMode = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, StreamMode::Values);
}
}