1use std::time::Duration;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use super::task_registry::TaskRegistry;
10use crate::session::SessionState;
11use crate::tools::{ExecutionContext, SchemaTool};
12use crate::types::ToolResult;
13
14pub struct TaskOutputTool {
15 registry: TaskRegistry,
16}
17
18impl TaskOutputTool {
19 pub fn new(registry: TaskRegistry) -> Self {
20 Self { registry }
21 }
22}
23
24impl Clone for TaskOutputTool {
25 fn clone(&self) -> Self {
26 Self {
27 registry: self.registry.clone(),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
33#[schemars(deny_unknown_fields)]
34pub struct TaskOutputInput {
35 pub task_id: String,
37 #[serde(default = "default_block")]
39 pub block: bool,
40 #[serde(default = "default_timeout")]
42 #[schemars(range(min = 0, max = 600000))]
43 pub timeout: u64,
44}
45
46fn default_block() -> bool {
47 true
48}
49
50fn default_timeout() -> u64 {
51 30000
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum TaskStatus {
57 Running,
58 Completed,
59 Failed,
60 Cancelled,
61 NotFound,
62}
63
64impl From<SessionState> for TaskStatus {
65 fn from(state: SessionState) -> Self {
66 match state {
67 SessionState::Created | SessionState::Active | SessionState::WaitingForTools => {
68 TaskStatus::Running
69 }
70 SessionState::Completed => TaskStatus::Completed,
71 SessionState::Failed => TaskStatus::Failed,
72 SessionState::Cancelled => TaskStatus::Cancelled,
73 }
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct TaskOutputResult {
79 pub task_id: String,
80 pub status: TaskStatus,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 pub output: Option<String>,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 pub error: Option<String>,
85}
86
87#[async_trait]
88impl SchemaTool for TaskOutputTool {
89 type Input = TaskOutputInput;
90
91 const NAME: &'static str = "TaskOutput";
92 const DESCRIPTION: &'static str = r#"
93- Retrieves output from a running or completed task (background shell, agent, or remote session)
94- Takes a task_id parameter identifying the task
95- Returns the task output along with status information
96- Use block=true (default) to wait for task completion
97- Use block=false for non-blocking check of current status
98- Task IDs can be found using the Task tool response
99- Works with all task types: background shells, async agents, and remote sessions
100- Output is limited to prevent excessive memory usage; for larger outputs, consider streaming
101- Important: task_id is the Task tool's returned ID, NOT a process PID"#;
102
103 async fn handle(&self, input: TaskOutputInput, _context: &ExecutionContext) -> ToolResult {
104 let timeout = Duration::from_millis(input.timeout.min(600000));
105
106 let result = if input.block {
107 self.registry
108 .wait_for_completion(&input.task_id, timeout)
109 .await
110 } else {
111 self.registry.get_result(&input.task_id).await
112 };
113
114 let output = match result {
115 Some((status, output, error)) => TaskOutputResult {
116 task_id: input.task_id,
117 status: status.into(),
118 output,
119 error,
120 },
121 None => TaskOutputResult {
122 task_id: input.task_id,
123 status: TaskStatus::NotFound,
124 output: None,
125 error: Some("Task not found".to_string()),
126 },
127 };
128
129 ToolResult::success(
130 serde_json::to_string_pretty(&output)
131 .unwrap_or_else(|_| format!("Task status: {:?}", output.status)),
132 )
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::agent::{AgentMetrics, AgentResult, AgentState};
140 use crate::session::MemoryPersistence;
141 use crate::tools::Tool;
142 use crate::types::{StopReason, ToolOutput, Usage};
143 use std::sync::Arc;
144
145 const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000011";
147 const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000012";
148 const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000013";
149
150 fn test_registry() -> TaskRegistry {
151 TaskRegistry::new(Arc::new(MemoryPersistence::new()))
152 }
153
154 fn mock_result() -> AgentResult {
155 AgentResult {
156 text: "Completed successfully".to_string(),
157 usage: Usage::default(),
158 tool_calls: 0,
159 iterations: 1,
160 stop_reason: StopReason::EndTurn,
161 state: AgentState::Completed,
162 metrics: AgentMetrics::default(),
163 session_id: "test-session".to_string(),
164 structured_output: None,
165 messages: Vec::new(),
166 uuid: "test-uuid".to_string(),
167 }
168 }
169
170 #[tokio::test]
171 async fn test_task_output_completed() {
172 let registry = test_registry();
173 registry
174 .register(TASK_1_UUID.into(), "Explore".into(), "Test".into())
175 .await;
176 registry.complete(TASK_1_UUID, mock_result()).await;
177
178 let tool = TaskOutputTool::new(registry);
179 let context = crate::tools::ExecutionContext::default();
180 let result = tool
181 .execute(
182 serde_json::json!({
183 "task_id": TASK_1_UUID
184 }),
185 &context,
186 )
187 .await;
188
189 assert!(!result.is_error());
190 if let ToolOutput::Success(content) = &result.output {
191 assert!(content.contains("completed"));
192 }
193 }
194
195 #[tokio::test]
196 async fn test_task_output_not_found() {
197 let registry = test_registry();
198 let tool = TaskOutputTool::new(registry);
199 let context = crate::tools::ExecutionContext::default();
200
201 let result = tool
202 .execute(
203 serde_json::json!({
204 "task_id": "nonexistent"
205 }),
206 &context,
207 )
208 .await;
209
210 if let ToolOutput::Success(content) = &result.output {
211 assert!(content.contains("not_found"));
212 }
213 }
214
215 #[tokio::test]
216 async fn test_task_output_non_blocking() {
217 let registry = test_registry();
218 registry
219 .register(TASK_2_UUID.into(), "Explore".into(), "Running".into())
220 .await;
221
222 let tool = TaskOutputTool::new(registry);
223 let context = crate::tools::ExecutionContext::default();
224 let result = tool
225 .execute(
226 serde_json::json!({
227 "task_id": TASK_2_UUID,
228 "block": false
229 }),
230 &context,
231 )
232 .await;
233
234 if let ToolOutput::Success(content) = &result.output {
235 assert!(content.contains("running"));
236 }
237 }
238
239 #[tokio::test]
240 async fn test_task_output_failed() {
241 let registry = test_registry();
242 registry
243 .register(TASK_3_UUID.into(), "Explore".into(), "Failing".into())
244 .await;
245 registry
246 .fail(TASK_3_UUID, "Something went wrong".into())
247 .await;
248
249 let tool = TaskOutputTool::new(registry);
250 let context = crate::tools::ExecutionContext::default();
251 let result = tool
252 .execute(
253 serde_json::json!({
254 "task_id": TASK_3_UUID
255 }),
256 &context,
257 )
258 .await;
259
260 if let ToolOutput::Success(content) = &result.output {
261 assert!(content.contains("failed"));
262 assert!(content.contains("Something went wrong"));
263 }
264 }
265}