1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use uuid::Uuid;
8
9#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
11#[serde(rename_all = "snake_case")]
12pub enum ActionType {
13 ToolCall,
14 StateWrite,
15 StateRead,
16 Assertion,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
21#[serde(rename_all = "snake_case")]
22pub enum FailureBehavior {
23 #[default]
24 Abort,
25 Retry,
26 Skip,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ActionStatus {
33 Proposed,
34 Validated,
35 Rejected,
36 Executing,
37 Succeeded,
38 Failed,
39 Skipped,
40}
41
42#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
46pub struct Precondition {
47 pub key: String,
48 #[serde(default = "default_operator")]
50 pub operator: String,
51 #[serde(default)]
52 pub value: Value,
53 #[serde(default)]
54 pub description: String,
55}
56
57fn default_operator() -> String {
58 "eq".to_string()
59}
60
61fn short_id() -> String {
63 Uuid::new_v4().simple().to_string()[..12].to_string()
64}
65
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
71pub struct Action {
72 #[serde(default = "short_id")]
73 pub id: String,
74
75 #[serde(rename = "type")]
76 pub action_type: ActionType,
77
78 #[serde(default, skip_serializing_if = "Option::is_none")]
79 pub tool: Option<String>,
80
81 #[serde(default)]
82 pub parameters: HashMap<String, Value>,
83
84 #[serde(default)]
85 pub preconditions: Vec<Precondition>,
86
87 #[serde(default)]
88 pub expected_effects: HashMap<String, Value>,
89
90 #[serde(default)]
91 pub state_dependencies: Vec<String>,
92
93 #[serde(default)]
94 pub idempotent: bool,
95
96 #[serde(default = "default_max_retries")]
97 pub max_retries: u32,
98
99 #[serde(default)]
100 pub failure_behavior: FailureBehavior,
101
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub timeout_ms: Option<u64>,
104
105 #[serde(default)]
106 pub metadata: HashMap<String, Value>,
107}
108
109fn default_max_retries() -> u32 {
110 3
111}
112
113#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
115pub struct ActionProposal {
116 #[serde(default = "short_id")]
117 pub id: String,
118
119 #[serde(default = "default_source")]
120 pub source: String,
121
122 pub actions: Vec<Action>,
123
124 #[serde(default = "Utc::now")]
125 pub timestamp: DateTime<Utc>,
126
127 #[serde(default)]
128 pub context: HashMap<String, Value>,
129}
130
131fn default_source() -> String {
132 "unknown".to_string()
133}
134
135#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
137pub struct ActionResult {
138 pub action_id: String,
139 pub status: ActionStatus,
140
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub output: Option<Value>,
143
144 #[serde(default, skip_serializing_if = "Option::is_none")]
145 pub error: Option<String>,
146
147 #[serde(default)]
148 pub state_changes: HashMap<String, Value>,
149
150 #[serde(default, skip_serializing_if = "Option::is_none")]
151 pub duration_ms: Option<f64>,
152
153 #[serde(default = "Utc::now")]
154 pub timestamp: DateTime<Utc>,
155}
156
157#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
159pub struct ToolRateLimit {
160 pub max_calls: u32,
161 pub interval_secs: f64,
162}
163
164#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
169pub struct ToolSchema {
170 pub name: String,
171 #[serde(default)]
172 pub description: String,
173 #[serde(default = "default_parameters_schema")]
175 pub parameters: Value,
176 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub returns: Option<Value>,
179 #[serde(default)]
181 pub idempotent: bool,
182 #[serde(default, skip_serializing_if = "Option::is_none")]
184 pub cache_ttl_secs: Option<u64>,
185 #[serde(default, skip_serializing_if = "Option::is_none")]
187 pub rate_limit: Option<ToolRateLimit>,
188}
189
190fn default_parameters_schema() -> Value {
191 Value::Object(Default::default())
192}
193
194#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
196pub struct CostSummary {
197 pub tool_calls: u32,
198 pub actions_executed: u32,
199 pub actions_skipped: u32,
200 pub total_duration_ms: f64,
201 pub retries: u32,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct CostTarget {
213 pub target_tool_calls: u32,
215 pub target_duration_ms: f64,
217 pub target_actions: u32,
219 pub cost_weight: f64,
221}
222
223impl Default for CostTarget {
224 fn default() -> Self {
225 Self {
226 target_tool_calls: 5,
227 target_duration_ms: 5000.0,
228 target_actions: 10,
229 cost_weight: 0.2,
230 }
231 }
232}
233
234#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
236pub struct ProposalResult {
237 pub proposal_id: String,
238
239 #[serde(default)]
240 pub results: Vec<ActionResult>,
241
242 #[serde(default)]
243 pub cost: CostSummary,
244}
245
246impl ProposalResult {
247 pub fn all_succeeded(&self) -> bool {
248 self.results
249 .iter()
250 .all(|r| r.status == ActionStatus::Succeeded)
251 }
252
253 pub fn summary(&self) -> HashMap<ActionStatus, usize> {
254 let mut counts = HashMap::new();
255 for r in &self.results {
256 *counts.entry(r.status.clone()).or_insert(0) += 1;
257 }
258 counts
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use pretty_assertions::assert_eq;
266
267 #[test]
268 fn action_type_serializes_snake_case() {
269 assert_eq!(
270 serde_json::to_string(&ActionType::ToolCall).unwrap(),
271 "\"tool_call\""
272 );
273 assert_eq!(
274 serde_json::to_string(&ActionType::StateWrite).unwrap(),
275 "\"state_write\""
276 );
277 }
278
279 #[test]
280 fn failure_behavior_serializes_snake_case() {
281 assert_eq!(
282 serde_json::to_string(&FailureBehavior::Abort).unwrap(),
283 "\"abort\""
284 );
285 assert_eq!(
286 serde_json::to_string(&FailureBehavior::Retry).unwrap(),
287 "\"retry\""
288 );
289 }
290
291 #[test]
292 fn action_roundtrip_json() {
293 let action = Action {
294 id: "abc123".to_string(),
295 action_type: ActionType::ToolCall,
296 tool: Some("add".to_string()),
297 parameters: [
298 ("a".to_string(), Value::from(1)),
299 ("b".to_string(), Value::from(2)),
300 ]
301 .into(),
302 preconditions: vec![Precondition {
303 key: "auth".to_string(),
304 operator: "eq".to_string(),
305 value: Value::Bool(true),
306 description: String::new(),
307 }],
308 expected_effects: [("sum".to_string(), Value::from(3))].into(),
309 state_dependencies: vec!["auth".to_string()],
310 idempotent: true,
311 max_retries: 3,
312 failure_behavior: FailureBehavior::Retry,
313 timeout_ms: Some(5000),
314 metadata: HashMap::new(),
315 };
316
317 let json = serde_json::to_string_pretty(&action).unwrap();
318 let roundtripped: Action = serde_json::from_str(&json).unwrap();
319
320 assert_eq!(action.id, roundtripped.id);
321 assert_eq!(action.action_type, roundtripped.action_type);
322 assert_eq!(action.tool, roundtripped.tool);
323 assert_eq!(action.idempotent, roundtripped.idempotent);
324 assert_eq!(action.failure_behavior, roundtripped.failure_behavior);
325 assert_eq!(action.timeout_ms, roundtripped.timeout_ms);
326 }
327
328 #[test]
329 fn proposal_roundtrip_json() {
330 let proposal = ActionProposal {
331 id: "prop1".to_string(),
332 source: "test".to_string(),
333 actions: vec![Action {
334 id: "a1".to_string(),
335 action_type: ActionType::StateWrite,
336 tool: None,
337 parameters: [
338 ("key".to_string(), Value::from("x")),
339 ("value".to_string(), Value::from(42)),
340 ]
341 .into(),
342 preconditions: vec![],
343 expected_effects: HashMap::new(),
344 state_dependencies: vec![],
345 idempotent: false,
346 max_retries: 3,
347 failure_behavior: FailureBehavior::Abort,
348 timeout_ms: None,
349 metadata: HashMap::new(),
350 }],
351 timestamp: Utc::now(),
352 context: HashMap::new(),
353 };
354
355 let json = serde_json::to_string(&proposal).unwrap();
356 let roundtripped: ActionProposal = serde_json::from_str(&json).unwrap();
357
358 assert_eq!(proposal.id, roundtripped.id);
359 assert_eq!(proposal.source, roundtripped.source);
360 assert_eq!(proposal.actions.len(), roundtripped.actions.len());
361 }
362
363 #[test]
364 fn action_result_serializes() {
365 let result = ActionResult {
366 action_id: "a1".to_string(),
367 status: ActionStatus::Succeeded,
368 output: Some(Value::from(42)),
369 error: None,
370 state_changes: HashMap::new(),
371 duration_ms: Some(1.5),
372 timestamp: Utc::now(),
373 };
374
375 let json = serde_json::to_string(&result).unwrap();
376 assert!(json.contains("\"succeeded\""));
377 }
378
379 #[test]
380 fn proposal_result_all_succeeded() {
381 let pr = ProposalResult {
382 proposal_id: "p1".to_string(),
383 results: vec![
384 ActionResult {
385 action_id: "a1".to_string(),
386 status: ActionStatus::Succeeded,
387 output: None,
388 error: None,
389 state_changes: HashMap::new(),
390 duration_ms: None,
391 timestamp: Utc::now(),
392 },
393 ActionResult {
394 action_id: "a2".to_string(),
395 status: ActionStatus::Succeeded,
396 output: None,
397 error: None,
398 state_changes: HashMap::new(),
399 duration_ms: None,
400 timestamp: Utc::now(),
401 },
402 ],
403 cost: CostSummary::default(),
404 };
405 assert!(pr.all_succeeded());
406 }
407
408 #[test]
409 fn proposal_result_not_all_succeeded() {
410 let pr = ProposalResult {
411 proposal_id: "p1".to_string(),
412 results: vec![
413 ActionResult {
414 action_id: "a1".to_string(),
415 status: ActionStatus::Succeeded,
416 output: None,
417 error: None,
418 state_changes: HashMap::new(),
419 duration_ms: None,
420 timestamp: Utc::now(),
421 },
422 ActionResult {
423 action_id: "a2".to_string(),
424 status: ActionStatus::Failed,
425 output: None,
426 error: Some("boom".to_string()),
427 state_changes: HashMap::new(),
428 duration_ms: None,
429 timestamp: Utc::now(),
430 },
431 ],
432 cost: CostSummary::default(),
433 };
434 assert!(!pr.all_succeeded());
435 }
436
437 #[test]
438 fn cost_summary_default_is_zero() {
439 let cost = CostSummary::default();
440 assert_eq!(cost.tool_calls, 0);
441 assert_eq!(cost.actions_executed, 0);
442 assert_eq!(cost.actions_skipped, 0);
443 assert_eq!(cost.total_duration_ms, 0.0);
444 assert_eq!(cost.retries, 0);
445 }
446
447 #[test]
448 fn cost_summary_serde_roundtrip() {
449 let cost = CostSummary {
450 tool_calls: 3,
451 actions_executed: 5,
452 actions_skipped: 1,
453 total_duration_ms: 42.5,
454 retries: 2,
455 };
456 let json = serde_json::to_string(&cost).unwrap();
457 let roundtripped: CostSummary = serde_json::from_str(&json).unwrap();
458 assert_eq!(cost, roundtripped);
459 }
460
461 #[test]
462 fn proposal_result_deserializes_without_cost() {
463 let json = r#"{"proposal_id": "p1", "results": []}"#;
465 let pr: ProposalResult = serde_json::from_str(json).unwrap();
466 assert_eq!(pr.cost, CostSummary::default());
467 }
468
469 #[test]
470 fn deserialize_from_python_compatible_json() {
471 let json = r#"{
473 "id": "test123",
474 "type": "tool_call",
475 "tool": "add",
476 "parameters": {"a": 1, "b": 2},
477 "preconditions": [],
478 "expected_effects": {"sum": 3},
479 "state_dependencies": [],
480 "idempotent": true,
481 "max_retries": 3,
482 "failure_behavior": "retry",
483 "timeout_ms": 5000,
484 "metadata": {}
485 }"#;
486
487 let action: Action = serde_json::from_str(json).unwrap();
488 assert_eq!(action.id, "test123");
489 assert_eq!(action.action_type, ActionType::ToolCall);
490 assert_eq!(action.tool, Some("add".to_string()));
491 assert!(action.idempotent);
492 assert_eq!(action.failure_behavior, FailureBehavior::Retry);
493 assert_eq!(action.timeout_ms, Some(5000));
494 }
495}