1use std::time::Duration;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use super::task_registry::TaskRegistry;
10use crate::session::SessionState;
11use crate::tools::{ExecutionContext, SchemaTool};
12use crate::types::ToolResult;
13
14pub struct TaskOutputTool {
15 registry: TaskRegistry,
16}
17
18impl TaskOutputTool {
19 pub fn new(registry: TaskRegistry) -> Self {
20 Self { registry }
21 }
22}
23
24impl Clone for TaskOutputTool {
25 fn clone(&self) -> Self {
26 Self {
27 registry: self.registry.clone(),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
33#[schemars(deny_unknown_fields)]
34pub struct TaskOutputInput {
35 pub task_id: String,
37 #[serde(default = "default_block")]
39 pub block: bool,
40 #[serde(default = "default_timeout")]
42 #[schemars(range(min = 0, max = 600000))]
43 pub timeout: u64,
44}
45
46fn default_block() -> bool {
47 true
48}
49
50fn default_timeout() -> u64 {
51 30000
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum TaskStatus {
57 Running,
58 Completed,
59 Failed,
60 Cancelled,
61 NotFound,
62}
63
64impl From<SessionState> for TaskStatus {
65 fn from(state: SessionState) -> Self {
66 match state {
67 SessionState::Active | SessionState::WaitingForTools | SessionState::WaitingForUser => {
68 TaskStatus::Running
69 }
70 SessionState::Completed => TaskStatus::Completed,
71 SessionState::Failed => TaskStatus::Failed,
72 SessionState::Cancelled => TaskStatus::Cancelled,
73 SessionState::Created | SessionState::Paused => TaskStatus::Running,
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct TaskOutputResult {
80 pub task_id: String,
81 pub status: TaskStatus,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub output: Option<String>,
84 #[serde(skip_serializing_if = "Option::is_none")]
85 pub error: Option<String>,
86}
87
88#[async_trait]
89impl SchemaTool for TaskOutputTool {
90 type Input = TaskOutputInput;
91
92 const NAME: &'static str = "TaskOutput";
93 const DESCRIPTION: &'static str = r#"
94- Retrieves output from a running or completed task (background shell, agent, or remote session)
95- Takes a task_id parameter identifying the task
96- Returns the task output along with status information
97- Use block=true (default) to wait for task completion
98- Use block=false for non-blocking check of current status
99- Task IDs can be found using the Task tool response
100- Works with all task types: background shells, async agents, and remote sessions
101- Output is limited to prevent excessive memory usage; for larger outputs, consider streaming
102- Important: task_id is the Task tool's returned ID, NOT a process PID"#;
103
104 async fn handle(&self, input: TaskOutputInput, _context: &ExecutionContext) -> ToolResult {
105 let timeout = Duration::from_millis(input.timeout.min(600000));
106
107 let result = if input.block {
108 self.registry
109 .wait_for_completion(&input.task_id, timeout)
110 .await
111 } else {
112 self.registry.get_result(&input.task_id).await
113 };
114
115 let output = match result {
116 Some((status, output, error)) => TaskOutputResult {
117 task_id: input.task_id,
118 status: status.into(),
119 output,
120 error,
121 },
122 None => TaskOutputResult {
123 task_id: input.task_id,
124 status: TaskStatus::NotFound,
125 output: None,
126 error: Some("Task not found".to_string()),
127 },
128 };
129
130 ToolResult::success(
131 serde_json::to_string_pretty(&output)
132 .unwrap_or_else(|_| format!("Task status: {:?}", output.status)),
133 )
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::agent::{AgentMetrics, AgentResult, AgentState};
141 use crate::session::MemoryPersistence;
142 use crate::tools::Tool;
143 use crate::types::{StopReason, ToolOutput, Usage};
144 use std::sync::Arc;
145
146 const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000011";
148 const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000012";
149 const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000013";
150
151 fn test_registry() -> TaskRegistry {
152 TaskRegistry::new(Arc::new(MemoryPersistence::new()))
153 }
154
155 fn mock_result() -> AgentResult {
156 AgentResult {
157 text: "Completed successfully".to_string(),
158 usage: Usage::default(),
159 tool_calls: 0,
160 iterations: 1,
161 stop_reason: StopReason::EndTurn,
162 state: AgentState::Completed,
163 metrics: AgentMetrics::default(),
164 session_id: "test-session".to_string(),
165 structured_output: None,
166 messages: Vec::new(),
167 uuid: "test-uuid".to_string(),
168 }
169 }
170
171 #[tokio::test]
172 async fn test_task_output_completed() {
173 let registry = test_registry();
174 registry
175 .register(TASK_1_UUID.into(), "explore".into(), "Test".into())
176 .await;
177 registry.complete(TASK_1_UUID, mock_result()).await;
178
179 let tool = TaskOutputTool::new(registry);
180 let context = crate::tools::ExecutionContext::default();
181 let result = tool
182 .execute(
183 serde_json::json!({
184 "task_id": TASK_1_UUID
185 }),
186 &context,
187 )
188 .await;
189
190 assert!(!result.is_error());
191 if let ToolOutput::Success(content) = &result.output {
192 assert!(content.contains("completed"));
193 }
194 }
195
196 #[tokio::test]
197 async fn test_task_output_not_found() {
198 let registry = test_registry();
199 let tool = TaskOutputTool::new(registry);
200 let context = crate::tools::ExecutionContext::default();
201
202 let result = tool
203 .execute(
204 serde_json::json!({
205 "task_id": "nonexistent"
206 }),
207 &context,
208 )
209 .await;
210
211 if let ToolOutput::Success(content) = &result.output {
212 assert!(content.contains("not_found"));
213 }
214 }
215
216 #[tokio::test]
217 async fn test_task_output_non_blocking() {
218 let registry = test_registry();
219 registry
220 .register(TASK_2_UUID.into(), "explore".into(), "Running".into())
221 .await;
222
223 let tool = TaskOutputTool::new(registry);
224 let context = crate::tools::ExecutionContext::default();
225 let result = tool
226 .execute(
227 serde_json::json!({
228 "task_id": TASK_2_UUID,
229 "block": false
230 }),
231 &context,
232 )
233 .await;
234
235 if let ToolOutput::Success(content) = &result.output {
236 assert!(content.contains("running"));
237 }
238 }
239
240 #[tokio::test]
241 async fn test_task_output_failed() {
242 let registry = test_registry();
243 registry
244 .register(TASK_3_UUID.into(), "explore".into(), "Failing".into())
245 .await;
246 registry
247 .fail(TASK_3_UUID, "Something went wrong".into())
248 .await;
249
250 let tool = TaskOutputTool::new(registry);
251 let context = crate::tools::ExecutionContext::default();
252 let result = tool
253 .execute(
254 serde_json::json!({
255 "task_id": TASK_3_UUID
256 }),
257 &context,
258 )
259 .await;
260
261 if let ToolOutput::Success(content) = &result.output {
262 assert!(content.contains("failed"));
263 assert!(content.contains("Something went wrong"));
264 }
265 }
266}