Skip to main content

claude_agent/agent/
task_output.rs

1//! TaskOutputTool - retrieves results from running or completed tasks.
2
3use std::time::Duration;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use super::task_registry::TaskRegistry;
10use crate::session::SessionState;
11use crate::tools::{ExecutionContext, SchemaTool};
12use crate::types::ToolResult;
13
14pub struct TaskOutputTool {
15    registry: TaskRegistry,
16}
17
18impl TaskOutputTool {
19    pub fn new(registry: TaskRegistry) -> Self {
20        Self { registry }
21    }
22}
23
24impl Clone for TaskOutputTool {
25    fn clone(&self) -> Self {
26        Self {
27            registry: self.registry.clone(),
28        }
29    }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
33#[schemars(deny_unknown_fields)]
34pub struct TaskOutputInput {
35    /// The task ID to get output from
36    pub task_id: String,
37    /// Whether to wait for completion
38    #[serde(default = "default_block")]
39    pub block: bool,
40    /// Max wait time in ms
41    #[serde(default = "default_timeout")]
42    #[schemars(range(min = 0, max = 600000))]
43    pub timeout: u64,
44}
45
46fn default_block() -> bool {
47    true
48}
49
50fn default_timeout() -> u64 {
51    30000
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum TaskStatus {
57    Running,
58    Completed,
59    Failed,
60    Cancelled,
61    NotFound,
62}
63
64impl From<SessionState> for TaskStatus {
65    fn from(state: SessionState) -> Self {
66        match state {
67            SessionState::Created | SessionState::Active | SessionState::WaitingForTools => {
68                TaskStatus::Running
69            }
70            SessionState::Completed => TaskStatus::Completed,
71            SessionState::Failed => TaskStatus::Failed,
72            SessionState::Cancelled => TaskStatus::Cancelled,
73        }
74    }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TaskOutputResult {
79    pub task_id: String,
80    pub status: TaskStatus,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub output: Option<String>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub error: Option<String>,
85}
86
87#[async_trait]
88impl SchemaTool for TaskOutputTool {
89    type Input = TaskOutputInput;
90
91    const NAME: &'static str = "TaskOutput";
92    const DESCRIPTION: &'static str = r#"
93- Retrieves output from a running or completed task (background shell, agent, or remote session)
94- Takes a task_id parameter identifying the task
95- Returns the task output along with status information
96- Use block=true (default) to wait for task completion
97- Use block=false for non-blocking check of current status
98- Task IDs can be found using the Task tool response
99- Works with all task types: background shells, async agents, and remote sessions
100- Output is limited to prevent excessive memory usage; for larger outputs, consider streaming
101- Important: task_id is the Task tool's returned ID, NOT a process PID"#;
102
103    async fn handle(&self, input: TaskOutputInput, _context: &ExecutionContext) -> ToolResult {
104        let timeout = Duration::from_millis(input.timeout.min(600000));
105
106        let result = if input.block {
107            self.registry
108                .wait_for_completion(&input.task_id, timeout)
109                .await
110        } else {
111            self.registry.get_result(&input.task_id).await
112        };
113
114        let output = match result {
115            Some((status, output, error)) => TaskOutputResult {
116                task_id: input.task_id,
117                status: status.into(),
118                output,
119                error,
120            },
121            None => TaskOutputResult {
122                task_id: input.task_id,
123                status: TaskStatus::NotFound,
124                output: None,
125                error: Some("Task not found".to_string()),
126            },
127        };
128
129        ToolResult::success(
130            serde_json::to_string_pretty(&output)
131                .unwrap_or_else(|_| format!("Task status: {:?}", output.status)),
132        )
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::agent::{AgentMetrics, AgentResult, AgentState};
140    use crate::session::MemoryPersistence;
141    use crate::tools::Tool;
142    use crate::types::{StopReason, ToolOutput, Usage};
143    use std::sync::Arc;
144
145    // Use valid UUIDs for tests to ensure consistent session IDs
146    const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000011";
147    const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000012";
148    const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000013";
149
150    fn test_registry() -> TaskRegistry {
151        TaskRegistry::new(Arc::new(MemoryPersistence::new()))
152    }
153
154    fn mock_result() -> AgentResult {
155        AgentResult {
156            text: "Completed successfully".to_string(),
157            usage: Usage::default(),
158            tool_calls: 0,
159            iterations: 1,
160            stop_reason: StopReason::EndTurn,
161            state: AgentState::Completed,
162            metrics: AgentMetrics::default(),
163            session_id: "test-session".to_string(),
164            structured_output: None,
165            messages: Vec::new(),
166            uuid: "test-uuid".to_string(),
167        }
168    }
169
170    #[tokio::test]
171    async fn test_task_output_completed() {
172        let registry = test_registry();
173        registry
174            .register(TASK_1_UUID.into(), "Explore".into(), "Test".into())
175            .await;
176        registry.complete(TASK_1_UUID, mock_result()).await;
177
178        let tool = TaskOutputTool::new(registry);
179        let context = crate::tools::ExecutionContext::default();
180        let result = tool
181            .execute(
182                serde_json::json!({
183                    "task_id": TASK_1_UUID
184                }),
185                &context,
186            )
187            .await;
188
189        assert!(!result.is_error());
190        if let ToolOutput::Success(content) = &result.output {
191            assert!(content.contains("completed"));
192        }
193    }
194
195    #[tokio::test]
196    async fn test_task_output_not_found() {
197        let registry = test_registry();
198        let tool = TaskOutputTool::new(registry);
199        let context = crate::tools::ExecutionContext::default();
200
201        let result = tool
202            .execute(
203                serde_json::json!({
204                    "task_id": "nonexistent"
205                }),
206                &context,
207            )
208            .await;
209
210        if let ToolOutput::Success(content) = &result.output {
211            assert!(content.contains("not_found"));
212        }
213    }
214
215    #[tokio::test]
216    async fn test_task_output_non_blocking() {
217        let registry = test_registry();
218        registry
219            .register(TASK_2_UUID.into(), "Explore".into(), "Running".into())
220            .await;
221
222        let tool = TaskOutputTool::new(registry);
223        let context = crate::tools::ExecutionContext::default();
224        let result = tool
225            .execute(
226                serde_json::json!({
227                    "task_id": TASK_2_UUID,
228                    "block": false
229                }),
230                &context,
231            )
232            .await;
233
234        if let ToolOutput::Success(content) = &result.output {
235            assert!(content.contains("running"));
236        }
237    }
238
239    #[tokio::test]
240    async fn test_task_output_failed() {
241        let registry = test_registry();
242        registry
243            .register(TASK_3_UUID.into(), "Explore".into(), "Failing".into())
244            .await;
245        registry
246            .fail(TASK_3_UUID, "Something went wrong".into())
247            .await;
248
249        let tool = TaskOutputTool::new(registry);
250        let context = crate::tools::ExecutionContext::default();
251        let result = tool
252            .execute(
253                serde_json::json!({
254                    "task_id": TASK_3_UUID
255                }),
256                &context,
257            )
258            .await;
259
260        if let ToolOutput::Success(content) = &result.output {
261            assert!(content.contains("failed"));
262            assert!(content.contains("Something went wrong"));
263        }
264    }
265}