Skip to main content

claude_agent/agent/
task_output.rs

1//! TaskOutputTool - retrieves results from running or completed tasks.
2
3use std::time::Duration;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use super::task_registry::TaskRegistry;
10use crate::session::SessionState;
11use crate::tools::{ExecutionContext, SchemaTool};
12use crate::types::ToolResult;
13
14#[derive(Clone)]
15pub struct TaskOutputTool {
16    registry: TaskRegistry,
17}
18
19impl TaskOutputTool {
20    pub fn new(registry: TaskRegistry) -> Self {
21        Self { registry }
22    }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
26#[schemars(deny_unknown_fields)]
27pub struct TaskOutputInput {
28    /// The task ID to get output from
29    pub task_id: String,
30    /// Whether to wait for completion
31    #[serde(default = "default_block")]
32    pub block: bool,
33    /// Max wait time in ms
34    #[serde(default = "default_timeout")]
35    #[schemars(range(min = 0, max = 600000))]
36    pub timeout: u64,
37}
38
39fn default_block() -> bool {
40    true
41}
42
43fn default_timeout() -> u64 {
44    30000
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum TaskStatus {
50    Running,
51    Completed,
52    Failed,
53    Cancelled,
54    NotFound,
55}
56
57impl From<SessionState> for TaskStatus {
58    fn from(state: SessionState) -> Self {
59        match state {
60            SessionState::Created | SessionState::Active | SessionState::WaitingForTools => {
61                TaskStatus::Running
62            }
63            SessionState::Completed => TaskStatus::Completed,
64            SessionState::Failed => TaskStatus::Failed,
65            SessionState::Cancelled => TaskStatus::Cancelled,
66        }
67    }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TaskOutputResult {
72    pub task_id: String,
73    pub status: TaskStatus,
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub output: Option<String>,
76    #[serde(skip_serializing_if = "Option::is_none")]
77    pub error: Option<String>,
78}
79
80#[async_trait]
81impl SchemaTool for TaskOutputTool {
82    type Input = TaskOutputInput;
83
84    const NAME: &'static str = "TaskOutput";
85    const DESCRIPTION: &'static str = r#"
86- Retrieves output from a running or completed task (background shell, agent, or remote session)
87- Takes a task_id parameter identifying the task
88- Returns the task output along with status information
89- Use block=true (default) to wait for task completion
90- Use block=false for non-blocking check of current status
91- Task IDs can be found using the Task tool response
92- Works with all task types: background shells, async agents, and remote sessions
93- Output is limited to prevent excessive memory usage; for larger outputs, consider streaming
94- Important: task_id is the Task tool's returned ID, NOT a process PID"#;
95
96    async fn handle(&self, input: TaskOutputInput, _context: &ExecutionContext) -> ToolResult {
97        let timeout = Duration::from_millis(input.timeout.min(600000));
98
99        let result = if input.block {
100            self.registry
101                .wait_for_completion(&input.task_id, timeout)
102                .await
103        } else {
104            self.registry.get_result(&input.task_id).await
105        };
106
107        let output = match result {
108            Some((status, output, error)) => TaskOutputResult {
109                task_id: input.task_id,
110                status: status.into(),
111                output,
112                error,
113            },
114            None => TaskOutputResult {
115                task_id: input.task_id,
116                status: TaskStatus::NotFound,
117                output: None,
118                error: Some("Task not found".to_string()),
119            },
120        };
121
122        ToolResult::success(
123            serde_json::to_string_pretty(&output)
124                .unwrap_or_else(|_| format!("Task status: {:?}", output.status)),
125        )
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::agent::{AgentMetrics, AgentResult, AgentState};
133    use crate::session::MemoryPersistence;
134    use crate::tools::Tool;
135    use crate::types::{StopReason, ToolOutput, Usage};
136    use std::sync::Arc;
137
138    // Use valid UUIDs for tests to ensure consistent session IDs
139    const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000011";
140    const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000012";
141    const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000013";
142
143    fn test_registry() -> TaskRegistry {
144        TaskRegistry::new(Arc::new(MemoryPersistence::new()))
145    }
146
147    fn mock_result() -> AgentResult {
148        AgentResult {
149            text: "Completed successfully".to_string(),
150            usage: Usage::default(),
151            tool_calls: 0,
152            iterations: 1,
153            stop_reason: StopReason::EndTurn,
154            state: AgentState::Completed,
155            metrics: AgentMetrics::default(),
156            session_id: "test-session".to_string(),
157            structured_output: None,
158            messages: Vec::new(),
159            uuid: "test-uuid".to_string(),
160        }
161    }
162
163    #[tokio::test]
164    async fn test_task_output_completed() {
165        let registry = test_registry();
166        registry
167            .register(TASK_1_UUID.into(), "Explore".into(), "Test".into())
168            .await;
169        registry.complete(TASK_1_UUID, mock_result()).await;
170
171        let tool = TaskOutputTool::new(registry);
172        let context = crate::tools::ExecutionContext::default();
173        let result = tool
174            .execute(
175                serde_json::json!({
176                    "task_id": TASK_1_UUID
177                }),
178                &context,
179            )
180            .await;
181
182        assert!(!result.is_error());
183        if let ToolOutput::Success(content) = &result.output {
184            assert!(content.contains("completed"));
185        }
186    }
187
188    #[tokio::test]
189    async fn test_task_output_not_found() {
190        let registry = test_registry();
191        let tool = TaskOutputTool::new(registry);
192        let context = crate::tools::ExecutionContext::default();
193
194        let result = tool
195            .execute(
196                serde_json::json!({
197                    "task_id": "nonexistent"
198                }),
199                &context,
200            )
201            .await;
202
203        if let ToolOutput::Success(content) = &result.output {
204            assert!(content.contains("not_found"));
205        }
206    }
207
208    #[tokio::test]
209    async fn test_task_output_non_blocking() {
210        let registry = test_registry();
211        registry
212            .register(TASK_2_UUID.into(), "Explore".into(), "Running".into())
213            .await;
214
215        let tool = TaskOutputTool::new(registry);
216        let context = crate::tools::ExecutionContext::default();
217        let result = tool
218            .execute(
219                serde_json::json!({
220                    "task_id": TASK_2_UUID,
221                    "block": false
222                }),
223                &context,
224            )
225            .await;
226
227        if let ToolOutput::Success(content) = &result.output {
228            assert!(content.contains("running"));
229        }
230    }
231
232    #[tokio::test]
233    async fn test_task_output_failed() {
234        let registry = test_registry();
235        registry
236            .register(TASK_3_UUID.into(), "Explore".into(), "Failing".into())
237            .await;
238        registry
239            .fail(TASK_3_UUID, "Something went wrong".into())
240            .await;
241
242        let tool = TaskOutputTool::new(registry);
243        let context = crate::tools::ExecutionContext::default();
244        let result = tool
245            .execute(
246                serde_json::json!({
247                    "task_id": TASK_3_UUID
248                }),
249                &context,
250            )
251            .await;
252
253        if let ToolOutput::Success(content) = &result.output {
254            assert!(content.contains("failed"));
255            assert!(content.contains("Something went wrong"));
256        }
257    }
258}