1use std::sync::{Arc, Mutex};
2
3use async_trait::async_trait;
4use serde_json::json;
5use tracing::warn;
6
7use super::traits::{Tool, ToolResult};
8use crate::sop::types::{SopRunAction, SopStepResult, SopStepStatus};
9use crate::sop::{SopAuditLogger, SopEngine, SopMetricsCollector};
10
11pub struct SopAdvanceTool {
13 engine: Arc<Mutex<SopEngine>>,
14 audit: Option<Arc<SopAuditLogger>>,
15 collector: Option<Arc<SopMetricsCollector>>,
16}
17
18impl SopAdvanceTool {
19 pub fn new(engine: Arc<Mutex<SopEngine>>) -> Self {
20 Self {
21 engine,
22 audit: None,
23 collector: None,
24 }
25 }
26
27 pub fn with_audit(mut self, audit: Arc<SopAuditLogger>) -> Self {
28 self.audit = Some(audit);
29 self
30 }
31
32 pub fn with_collector(mut self, collector: Arc<SopMetricsCollector>) -> Self {
33 self.collector = Some(collector);
34 self
35 }
36}
37
38#[async_trait]
39impl Tool for SopAdvanceTool {
40 fn name(&self) -> &str {
41 "sop_advance"
42 }
43
44 fn description(&self) -> &str {
45 "Report the result of the current SOP step and advance to the next step. Provide the run_id, whether the step succeeded or failed, and a brief output summary."
46 }
47
48 fn parameters_schema(&self) -> serde_json::Value {
49 json!({
50 "type": "object",
51 "properties": {
52 "run_id": {
53 "type": "string",
54 "description": "The run ID to advance"
55 },
56 "status": {
57 "type": "string",
58 "enum": ["completed", "failed", "skipped"],
59 "description": "Result status of the current step"
60 },
61 "output": {
62 "type": "string",
63 "description": "Brief summary of what happened in this step"
64 }
65 },
66 "required": ["run_id", "status", "output"]
67 })
68 }
69
70 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
71 let run_id = args
72 .get("run_id")
73 .and_then(|v| v.as_str())
74 .ok_or_else(|| anyhow::anyhow!("Missing 'run_id' parameter"))?;
75
76 let status_str = args
77 .get("status")
78 .and_then(|v| v.as_str())
79 .ok_or_else(|| anyhow::anyhow!("Missing 'status' parameter"))?;
80
81 let output = args
82 .get("output")
83 .and_then(|v| v.as_str())
84 .ok_or_else(|| anyhow::anyhow!("Missing 'output' parameter"))?;
85
86 let step_status = match status_str {
87 "completed" => SopStepStatus::Completed,
88 "failed" => SopStepStatus::Failed,
89 "skipped" => SopStepStatus::Skipped,
90 other => {
91 return Ok(ToolResult {
92 success: false,
93 output: String::new(),
94 error: Some(format!(
95 "Invalid status '{other}'. Must be: completed, failed, or skipped"
96 )),
97 });
98 }
99 };
100
101 let (action, step_result_ok, finished_run) = {
103 let mut engine = self
104 .engine
105 .lock()
106 .map_err(|e| anyhow::anyhow!("Engine lock poisoned: {e}"))?;
107
108 let current_step = engine
109 .get_run(run_id)
110 .map(|r| r.current_step)
111 .ok_or_else(|| anyhow::anyhow!("Run not found: {run_id}"))?;
112
113 let now = now_iso8601();
114 let step_result = SopStepResult {
115 step_number: current_step,
116 status: step_status,
117 output: output.to_string(),
118 started_at: now.clone(),
119 completed_at: Some(now),
120 };
121 let step_result_clone = step_result.clone();
122
123 match engine.advance_step(run_id, step_result) {
124 Ok(action) => {
125 let finished = match &action {
127 SopRunAction::Completed { run_id, .. }
128 | SopRunAction::Failed { run_id, .. } => engine.get_run(run_id).cloned(),
129 _ => None,
130 };
131 (Ok(action), Some(step_result_clone), finished)
133 }
134 Err(e) => (Err(e), None, None),
135 }
136 };
137
138 if let Some(ref audit) = self.audit {
140 if let Some(ref sr) = step_result_ok {
141 if let Err(e) = audit.log_step_result(run_id, sr).await {
142 warn!("SOP audit log_step_result failed: {e}");
143 }
144 }
145 if let Some(ref run) = finished_run {
146 if let Err(e) = audit.log_run_complete(run).await {
147 warn!("SOP audit log_run_complete failed: {e}");
148 }
149 }
150 }
151
152 if let Some(ref collector) = self.collector {
154 if let Some(ref run) = finished_run {
155 collector.record_run_complete(run);
156 }
157 }
158
159 match action {
160 Ok(action) => {
161 let result_output = match action {
162 SopRunAction::ExecuteStep {
163 run_id, context, ..
164 } => {
165 format!("Step recorded. Next step for run {run_id}:\n\n{context}")
166 }
167 SopRunAction::WaitApproval {
168 run_id, context, ..
169 } => {
170 format!(
171 "Step recorded. Next step for run {run_id} (waiting for approval):\n\n{context}"
172 )
173 }
174 SopRunAction::Completed { run_id, sop_name } => {
175 format!("SOP '{sop_name}' run {run_id} completed successfully.")
176 }
177 SopRunAction::Failed {
178 run_id,
179 sop_name,
180 reason,
181 } => {
182 format!("SOP '{sop_name}' run {run_id} failed: {reason}")
183 }
184 SopRunAction::DeterministicStep { run_id, step, .. } => {
185 format!(
186 "Step recorded. Next deterministic step for run {run_id}: {}",
187 step.title
188 )
189 }
190 SopRunAction::CheckpointWait { run_id, step, .. } => {
191 format!(
192 "Step recorded. Run {run_id} paused at checkpoint: {}",
193 step.title
194 )
195 }
196 };
197 Ok(ToolResult {
198 success: true,
199 output: result_output,
200 error: None,
201 })
202 }
203 Err(e) => Ok(ToolResult {
204 success: false,
205 output: String::new(),
206 error: Some(format!("Failed to advance step: {e}")),
207 }),
208 }
209 }
210}
211
212use crate::sop::engine::now_iso8601;
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use crate::config::SopConfig;
218 use crate::memory::Memory;
219 use crate::sop::engine::SopEngine;
220 use crate::sop::types::*;
221
222 fn test_sop() -> Sop {
223 Sop {
224 name: "test-sop".into(),
225 description: "Test SOP".into(),
226 version: "1.0.0".into(),
227 priority: SopPriority::Normal,
228 execution_mode: SopExecutionMode::Auto,
229 triggers: vec![SopTrigger::Manual],
230 steps: vec![
231 SopStep {
232 number: 1,
233 title: "Step one".into(),
234 body: "Do step one".into(),
235 suggested_tools: vec![],
236 requires_confirmation: false,
237 kind: SopStepKind::default(),
238 schema: None,
239 },
240 SopStep {
241 number: 2,
242 title: "Step two".into(),
243 body: "Do step two".into(),
244 suggested_tools: vec![],
245 requires_confirmation: false,
246 kind: SopStepKind::default(),
247 schema: None,
248 },
249 ],
250 cooldown_secs: 0,
251 max_concurrent: 1,
252 location: None,
253 deterministic: false,
254 }
255 }
256
257 fn engine_with_active_run() -> (Arc<Mutex<SopEngine>>, String) {
258 let mut engine = SopEngine::new(SopConfig::default());
259 engine.set_sops_for_test(vec![test_sop()]);
260 let event = SopEvent {
261 source: SopTriggerSource::Manual,
262 topic: None,
263 payload: None,
264 timestamp: "2026-02-19T12:00:00Z".into(),
265 };
266 engine.start_run("test-sop", event).unwrap();
267 let run_id = engine
268 .active_runs()
269 .keys()
270 .next()
271 .expect("expected active run")
272 .clone();
273 (Arc::new(Mutex::new(engine)), run_id)
274 }
275
276 #[tokio::test]
277 async fn advance_to_next_step() {
278 let (engine, run_id) = engine_with_active_run();
279 let tool = SopAdvanceTool::new(engine);
280 let result = tool
281 .execute(json!({
282 "run_id": run_id,
283 "status": "completed",
284 "output": "Step 1 done successfully"
285 }))
286 .await
287 .unwrap();
288 assert!(result.success);
289 assert!(result.output.contains("Next step"));
290 assert!(result.output.contains("Step two"));
291 }
292
293 #[tokio::test]
294 async fn advance_to_completion() {
295 let (engine, run_id) = engine_with_active_run();
296 let tool = SopAdvanceTool::new(engine.clone());
297
298 tool.execute(json!({
300 "run_id": run_id,
301 "status": "completed",
302 "output": "Step 1 done"
303 }))
304 .await
305 .unwrap();
306
307 let result = tool
309 .execute(json!({
310 "run_id": run_id,
311 "status": "completed",
312 "output": "Step 2 done"
313 }))
314 .await
315 .unwrap();
316 assert!(result.success);
317 assert!(result.output.contains("completed successfully"));
318 }
319
320 #[tokio::test]
321 async fn advance_with_failure() {
322 let (engine, run_id) = engine_with_active_run();
323 let tool = SopAdvanceTool::new(engine);
324 let result = tool
325 .execute(json!({
326 "run_id": run_id,
327 "status": "failed",
328 "output": "Valve stuck open"
329 }))
330 .await
331 .unwrap();
332 assert!(result.success); assert!(result.output.contains("failed"));
334 assert!(result.output.contains("Valve stuck open"));
335 }
336
337 #[tokio::test]
338 async fn advance_invalid_status() {
339 let (engine, run_id) = engine_with_active_run();
340 let tool = SopAdvanceTool::new(engine);
341 let result = tool
342 .execute(json!({
343 "run_id": run_id,
344 "status": "invalid",
345 "output": "whatever"
346 }))
347 .await
348 .unwrap();
349 assert!(!result.success);
350 assert!(result.error.unwrap().contains("Invalid status"));
351 }
352
353 #[tokio::test]
354 async fn advance_unknown_run() {
355 let engine = Arc::new(Mutex::new(SopEngine::new(SopConfig::default())));
356 let tool = SopAdvanceTool::new(engine);
357 let result = tool
358 .execute(json!({
359 "run_id": "nonexistent",
360 "status": "completed",
361 "output": "done"
362 }))
363 .await;
364 assert!(result.is_err());
365 }
366
367 #[test]
368 fn name_and_schema() {
369 let engine = Arc::new(Mutex::new(SopEngine::new(SopConfig::default())));
370 let tool = SopAdvanceTool::new(engine);
371 assert_eq!(tool.name(), "sop_advance");
372 let schema = tool.parameters_schema();
373 assert!(schema["properties"]["run_id"].is_object());
374 assert!(schema["properties"]["status"]["enum"].is_array());
375 }
376
377 #[tokio::test]
378 async fn advance_error_does_not_write_step_audit() {
379 let engine = Arc::new(Mutex::new(SopEngine::new(SopConfig::default())));
381 let memory: Arc<dyn Memory> = Arc::new(crate::memory::test_memory::TestMemory::new());
382 let audit = Arc::new(SopAuditLogger::new(memory.clone()));
383
384 let tool = SopAdvanceTool::new(engine).with_audit(audit.clone());
385 let result = tool
386 .execute(json!({
387 "run_id": "nonexistent",
388 "status": "completed",
389 "output": "done"
390 }))
391 .await;
392 assert!(result.is_err());
394
395 let runs = audit.list_runs().await.unwrap();
397 assert!(
398 runs.is_empty(),
399 "no audit entries should exist after advance error"
400 );
401 }
402
403 #[tokio::test]
404 async fn advance_success_writes_step_audit() {
405 let (engine, run_id) = engine_with_active_run();
406 let memory: Arc<dyn Memory> = Arc::new(crate::memory::test_memory::TestMemory::new());
407 let audit = Arc::new(SopAuditLogger::new(memory.clone()));
408
409 let tool = SopAdvanceTool::new(engine).with_audit(audit.clone());
410 let result = tool
411 .execute(json!({
412 "run_id": run_id,
413 "status": "completed",
414 "output": "Step 1 done"
415 }))
416 .await
417 .unwrap();
418 assert!(result.success);
419
420 let entries = memory
422 .list(
423 Some(&crate::memory::traits::MemoryCategory::Custom("sop".into())),
424 None,
425 )
426 .await
427 .unwrap();
428 let step_keys: Vec<_> = entries
429 .iter()
430 .filter(|e| e.key.starts_with("sop_step_"))
431 .collect();
432 assert!(
433 !step_keys.is_empty(),
434 "step audit should be written on success"
435 );
436 }
437}