use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ActionType {
ToolCall,
StateWrite,
StateRead,
Assertion,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum FailureBehavior {
#[default]
Abort,
Retry,
Skip,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ActionStatus {
Proposed,
Validated,
Rejected,
Executing,
Succeeded,
Failed,
Skipped,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Precondition {
pub key: String,
#[serde(default = "default_operator")]
pub operator: String,
#[serde(default)]
pub value: Value,
#[serde(default)]
pub description: String,
}
fn default_operator() -> String {
"eq".to_string()
}
fn short_id() -> String {
Uuid::new_v4().simple().to_string()[..12].to_string()
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Action {
#[serde(default = "short_id")]
pub id: String,
#[serde(rename = "type")]
pub action_type: ActionType,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool: Option<String>,
#[serde(default)]
pub parameters: HashMap<String, Value>,
#[serde(default)]
pub preconditions: Vec<Precondition>,
#[serde(default)]
pub expected_effects: HashMap<String, Value>,
#[serde(default)]
pub state_dependencies: Vec<String>,
#[serde(default)]
pub idempotent: bool,
#[serde(default = "default_max_retries")]
pub max_retries: u32,
#[serde(default)]
pub failure_behavior: FailureBehavior,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout_ms: Option<u64>,
#[serde(default)]
pub metadata: HashMap<String, Value>,
}
fn default_max_retries() -> u32 {
3
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ActionProposal {
#[serde(default = "short_id")]
pub id: String,
#[serde(default = "default_source")]
pub source: String,
pub actions: Vec<Action>,
#[serde(default = "Utc::now")]
pub timestamp: DateTime<Utc>,
#[serde(default)]
pub context: HashMap<String, Value>,
}
fn default_source() -> String {
"unknown".to_string()
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ActionResult {
pub action_id: String,
pub status: ActionStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub output: Option<Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default)]
pub state_changes: HashMap<String, Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<f64>,
#[serde(default = "Utc::now")]
pub timestamp: DateTime<Utc>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolRateLimit {
pub max_calls: u32,
pub interval_secs: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
#[serde(default)]
pub description: String,
#[serde(default = "default_parameters_schema")]
pub parameters: Value,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub returns: Option<Value>,
#[serde(default)]
pub idempotent: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cache_ttl_secs: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rate_limit: Option<ToolRateLimit>,
}
fn default_parameters_schema() -> Value {
Value::Object(Default::default())
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct CostSummary {
pub tool_calls: u32,
pub actions_executed: u32,
pub actions_skipped: u32,
pub total_duration_ms: f64,
pub retries: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostTarget {
pub target_tool_calls: u32,
pub target_duration_ms: f64,
pub target_actions: u32,
pub cost_weight: f64,
}
impl Default for CostTarget {
fn default() -> Self {
Self {
target_tool_calls: 5,
target_duration_ms: 5000.0,
target_actions: 10,
cost_weight: 0.2,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ProposalResult {
pub proposal_id: String,
#[serde(default)]
pub results: Vec<ActionResult>,
#[serde(default)]
pub cost: CostSummary,
}
impl ProposalResult {
pub fn all_succeeded(&self) -> bool {
self.results
.iter()
.all(|r| r.status == ActionStatus::Succeeded)
}
pub fn summary(&self) -> HashMap<ActionStatus, usize> {
let mut counts = HashMap::new();
for r in &self.results {
*counts.entry(r.status.clone()).or_insert(0) += 1;
}
counts
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn action_type_serializes_snake_case() {
assert_eq!(
serde_json::to_string(&ActionType::ToolCall).unwrap(),
"\"tool_call\""
);
assert_eq!(
serde_json::to_string(&ActionType::StateWrite).unwrap(),
"\"state_write\""
);
}
#[test]
fn failure_behavior_serializes_snake_case() {
assert_eq!(
serde_json::to_string(&FailureBehavior::Abort).unwrap(),
"\"abort\""
);
assert_eq!(
serde_json::to_string(&FailureBehavior::Retry).unwrap(),
"\"retry\""
);
}
#[test]
fn action_roundtrip_json() {
let action = Action {
id: "abc123".to_string(),
action_type: ActionType::ToolCall,
tool: Some("add".to_string()),
parameters: [("a".to_string(), Value::from(1)), ("b".to_string(), Value::from(2))]
.into(),
preconditions: vec![Precondition {
key: "auth".to_string(),
operator: "eq".to_string(),
value: Value::Bool(true),
description: String::new(),
}],
expected_effects: [("sum".to_string(), Value::from(3))].into(),
state_dependencies: vec!["auth".to_string()],
idempotent: true,
max_retries: 3,
failure_behavior: FailureBehavior::Retry,
timeout_ms: Some(5000),
metadata: HashMap::new(),
};
let json = serde_json::to_string_pretty(&action).unwrap();
let roundtripped: Action = serde_json::from_str(&json).unwrap();
assert_eq!(action.id, roundtripped.id);
assert_eq!(action.action_type, roundtripped.action_type);
assert_eq!(action.tool, roundtripped.tool);
assert_eq!(action.idempotent, roundtripped.idempotent);
assert_eq!(action.failure_behavior, roundtripped.failure_behavior);
assert_eq!(action.timeout_ms, roundtripped.timeout_ms);
}
#[test]
fn proposal_roundtrip_json() {
let proposal = ActionProposal {
id: "prop1".to_string(),
source: "test".to_string(),
actions: vec![Action {
id: "a1".to_string(),
action_type: ActionType::StateWrite,
tool: None,
parameters: [
("key".to_string(), Value::from("x")),
("value".to_string(), Value::from(42)),
]
.into(),
preconditions: vec![],
expected_effects: HashMap::new(),
state_dependencies: vec![],
idempotent: false,
max_retries: 3,
failure_behavior: FailureBehavior::Abort,
timeout_ms: None,
metadata: HashMap::new(),
}],
timestamp: Utc::now(),
context: HashMap::new(),
};
let json = serde_json::to_string(&proposal).unwrap();
let roundtripped: ActionProposal = serde_json::from_str(&json).unwrap();
assert_eq!(proposal.id, roundtripped.id);
assert_eq!(proposal.source, roundtripped.source);
assert_eq!(proposal.actions.len(), roundtripped.actions.len());
}
#[test]
fn action_result_serializes() {
let result = ActionResult {
action_id: "a1".to_string(),
status: ActionStatus::Succeeded,
output: Some(Value::from(42)),
error: None,
state_changes: HashMap::new(),
duration_ms: Some(1.5),
timestamp: Utc::now(),
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("\"succeeded\""));
}
#[test]
fn proposal_result_all_succeeded() {
let pr = ProposalResult {
proposal_id: "p1".to_string(),
results: vec![
ActionResult {
action_id: "a1".to_string(),
status: ActionStatus::Succeeded,
output: None,
error: None,
state_changes: HashMap::new(),
duration_ms: None,
timestamp: Utc::now(),
},
ActionResult {
action_id: "a2".to_string(),
status: ActionStatus::Succeeded,
output: None,
error: None,
state_changes: HashMap::new(),
duration_ms: None,
timestamp: Utc::now(),
},
],
cost: CostSummary::default(),
};
assert!(pr.all_succeeded());
}
#[test]
fn proposal_result_not_all_succeeded() {
let pr = ProposalResult {
proposal_id: "p1".to_string(),
results: vec![
ActionResult {
action_id: "a1".to_string(),
status: ActionStatus::Succeeded,
output: None,
error: None,
state_changes: HashMap::new(),
duration_ms: None,
timestamp: Utc::now(),
},
ActionResult {
action_id: "a2".to_string(),
status: ActionStatus::Failed,
output: None,
error: Some("boom".to_string()),
state_changes: HashMap::new(),
duration_ms: None,
timestamp: Utc::now(),
},
],
cost: CostSummary::default(),
};
assert!(!pr.all_succeeded());
}
#[test]
fn cost_summary_default_is_zero() {
let cost = CostSummary::default();
assert_eq!(cost.tool_calls, 0);
assert_eq!(cost.actions_executed, 0);
assert_eq!(cost.actions_skipped, 0);
assert_eq!(cost.total_duration_ms, 0.0);
assert_eq!(cost.retries, 0);
}
#[test]
fn cost_summary_serde_roundtrip() {
let cost = CostSummary {
tool_calls: 3,
actions_executed: 5,
actions_skipped: 1,
total_duration_ms: 42.5,
retries: 2,
};
let json = serde_json::to_string(&cost).unwrap();
let roundtripped: CostSummary = serde_json::from_str(&json).unwrap();
assert_eq!(cost, roundtripped);
}
#[test]
fn proposal_result_deserializes_without_cost() {
let json = r#"{"proposal_id": "p1", "results": []}"#;
let pr: ProposalResult = serde_json::from_str(json).unwrap();
assert_eq!(pr.cost, CostSummary::default());
}
#[test]
fn deserialize_from_python_compatible_json() {
let json = r#"{
"id": "test123",
"type": "tool_call",
"tool": "add",
"parameters": {"a": 1, "b": 2},
"preconditions": [],
"expected_effects": {"sum": 3},
"state_dependencies": [],
"idempotent": true,
"max_retries": 3,
"failure_behavior": "retry",
"timeout_ms": 5000,
"metadata": {}
}"#;
let action: Action = serde_json::from_str(json).unwrap();
assert_eq!(action.id, "test123");
assert_eq!(action.action_type, ActionType::ToolCall);
assert_eq!(action.tool, Some("add".to_string()));
assert!(action.idempotent);
assert_eq!(action.failure_behavior, FailureBehavior::Retry);
assert_eq!(action.timeout_ms, Some(5000));
}
}