claude_agent/agent/
task_output.rs

1//! TaskOutputTool - retrieves results from running or completed tasks.
2
3use std::time::Duration;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use super::task_registry::TaskRegistry;
10use crate::session::SessionState;
11use crate::tools::{ExecutionContext, SchemaTool};
12use crate::types::ToolResult;
13
14pub struct TaskOutputTool {
15    registry: TaskRegistry,
16}
17
18impl TaskOutputTool {
19    pub fn new(registry: TaskRegistry) -> Self {
20        Self { registry }
21    }
22}
23
24impl Clone for TaskOutputTool {
25    fn clone(&self) -> Self {
26        Self {
27            registry: self.registry.clone(),
28        }
29    }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
33#[schemars(deny_unknown_fields)]
34pub struct TaskOutputInput {
35    /// The task ID to get output from
36    pub task_id: String,
37    /// Whether to wait for completion
38    #[serde(default = "default_block")]
39    pub block: bool,
40    /// Max wait time in ms
41    #[serde(default = "default_timeout")]
42    #[schemars(range(min = 0, max = 600000))]
43    pub timeout: u64,
44}
45
46fn default_block() -> bool {
47    true
48}
49
50fn default_timeout() -> u64 {
51    30000
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum TaskStatus {
57    Running,
58    Completed,
59    Failed,
60    Cancelled,
61    NotFound,
62}
63
64impl From<SessionState> for TaskStatus {
65    fn from(state: SessionState) -> Self {
66        match state {
67            SessionState::Active | SessionState::WaitingForTools | SessionState::WaitingForUser => {
68                TaskStatus::Running
69            }
70            SessionState::Completed => TaskStatus::Completed,
71            SessionState::Failed => TaskStatus::Failed,
72            SessionState::Cancelled => TaskStatus::Cancelled,
73            SessionState::Created | SessionState::Paused => TaskStatus::Running,
74        }
75    }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TaskOutputResult {
80    pub task_id: String,
81    pub status: TaskStatus,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub output: Option<String>,
84    #[serde(skip_serializing_if = "Option::is_none")]
85    pub error: Option<String>,
86}
87
88#[async_trait]
89impl SchemaTool for TaskOutputTool {
90    type Input = TaskOutputInput;
91
92    const NAME: &'static str = "TaskOutput";
93    const DESCRIPTION: &'static str = r#"
94- Retrieves output from a running or completed task (background shell, agent, or remote session)
95- Takes a task_id parameter identifying the task
96- Returns the task output along with status information
97- Use block=true (default) to wait for task completion
98- Use block=false for non-blocking check of current status
99- Task IDs can be found using the Task tool response
100- Works with all task types: background shells, async agents, and remote sessions
101- Output is limited to prevent excessive memory usage; for larger outputs, consider streaming
102- Important: task_id is the Task tool's returned ID, NOT a process PID"#;
103
104    async fn handle(&self, input: TaskOutputInput, _context: &ExecutionContext) -> ToolResult {
105        let timeout = Duration::from_millis(input.timeout.min(600000));
106
107        let result = if input.block {
108            self.registry
109                .wait_for_completion(&input.task_id, timeout)
110                .await
111        } else {
112            self.registry.get_result(&input.task_id).await
113        };
114
115        let output = match result {
116            Some((status, output, error)) => TaskOutputResult {
117                task_id: input.task_id,
118                status: status.into(),
119                output,
120                error,
121            },
122            None => TaskOutputResult {
123                task_id: input.task_id,
124                status: TaskStatus::NotFound,
125                output: None,
126                error: Some("Task not found".to_string()),
127            },
128        };
129
130        ToolResult::success(
131            serde_json::to_string_pretty(&output)
132                .unwrap_or_else(|_| format!("Task status: {:?}", output.status)),
133        )
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::agent::{AgentMetrics, AgentResult, AgentState};
141    use crate::session::MemoryPersistence;
142    use crate::tools::Tool;
143    use crate::types::{StopReason, ToolOutput, Usage};
144    use std::sync::Arc;
145
146    // Use valid UUIDs for tests to ensure consistent session IDs
147    const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000011";
148    const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000012";
149    const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000013";
150
151    fn test_registry() -> TaskRegistry {
152        TaskRegistry::new(Arc::new(MemoryPersistence::new()))
153    }
154
155    fn mock_result() -> AgentResult {
156        AgentResult {
157            text: "Completed successfully".to_string(),
158            usage: Usage::default(),
159            tool_calls: 0,
160            iterations: 1,
161            stop_reason: StopReason::EndTurn,
162            state: AgentState::Completed,
163            metrics: AgentMetrics::default(),
164            session_id: "test-session".to_string(),
165            structured_output: None,
166            messages: Vec::new(),
167            uuid: "test-uuid".to_string(),
168        }
169    }
170
171    #[tokio::test]
172    async fn test_task_output_completed() {
173        let registry = test_registry();
174        registry
175            .register(TASK_1_UUID.into(), "explore".into(), "Test".into())
176            .await;
177        registry.complete(TASK_1_UUID, mock_result()).await;
178
179        let tool = TaskOutputTool::new(registry);
180        let context = crate::tools::ExecutionContext::default();
181        let result = tool
182            .execute(
183                serde_json::json!({
184                    "task_id": TASK_1_UUID
185                }),
186                &context,
187            )
188            .await;
189
190        assert!(!result.is_error());
191        if let ToolOutput::Success(content) = &result.output {
192            assert!(content.contains("completed"));
193        }
194    }
195
196    #[tokio::test]
197    async fn test_task_output_not_found() {
198        let registry = test_registry();
199        let tool = TaskOutputTool::new(registry);
200        let context = crate::tools::ExecutionContext::default();
201
202        let result = tool
203            .execute(
204                serde_json::json!({
205                    "task_id": "nonexistent"
206                }),
207                &context,
208            )
209            .await;
210
211        if let ToolOutput::Success(content) = &result.output {
212            assert!(content.contains("not_found"));
213        }
214    }
215
216    #[tokio::test]
217    async fn test_task_output_non_blocking() {
218        let registry = test_registry();
219        registry
220            .register(TASK_2_UUID.into(), "explore".into(), "Running".into())
221            .await;
222
223        let tool = TaskOutputTool::new(registry);
224        let context = crate::tools::ExecutionContext::default();
225        let result = tool
226            .execute(
227                serde_json::json!({
228                    "task_id": TASK_2_UUID,
229                    "block": false
230                }),
231                &context,
232            )
233            .await;
234
235        if let ToolOutput::Success(content) = &result.output {
236            assert!(content.contains("running"));
237        }
238    }
239
240    #[tokio::test]
241    async fn test_task_output_failed() {
242        let registry = test_registry();
243        registry
244            .register(TASK_3_UUID.into(), "explore".into(), "Failing".into())
245            .await;
246        registry
247            .fail(TASK_3_UUID, "Something went wrong".into())
248            .await;
249
250        let tool = TaskOutputTool::new(registry);
251        let context = crate::tools::ExecutionContext::default();
252        let result = tool
253            .execute(
254                serde_json::json!({
255                    "task_id": TASK_3_UUID
256                }),
257                &context,
258            )
259            .await;
260
261        if let ToolOutput::Success(content) = &result.output {
262            assert!(content.contains("failed"));
263            assert!(content.contains("Something went wrong"));
264        }
265    }
266}