1use std::time::Duration;
4
5use async_trait::async_trait;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use super::task_registry::TaskRegistry;
10use crate::session::SessionState;
11use crate::tools::{ExecutionContext, SchemaTool};
12use crate::types::ToolResult;
13
14#[derive(Clone)]
15pub struct TaskOutputTool {
16 registry: TaskRegistry,
17}
18
19impl TaskOutputTool {
20 pub fn new(registry: TaskRegistry) -> Self {
21 Self { registry }
22 }
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
26#[schemars(deny_unknown_fields)]
27pub struct TaskOutputInput {
28 pub task_id: String,
30 #[serde(default = "default_block")]
32 pub block: bool,
33 #[serde(default = "default_timeout")]
35 #[schemars(range(min = 0, max = 600000))]
36 pub timeout: u64,
37}
38
39fn default_block() -> bool {
40 true
41}
42
43fn default_timeout() -> u64 {
44 30000
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48#[serde(rename_all = "snake_case")]
49pub enum TaskStatus {
50 Running,
51 Completed,
52 Failed,
53 Cancelled,
54 NotFound,
55}
56
57impl From<SessionState> for TaskStatus {
58 fn from(state: SessionState) -> Self {
59 match state {
60 SessionState::Created | SessionState::Active | SessionState::WaitingForTools => {
61 TaskStatus::Running
62 }
63 SessionState::Completed => TaskStatus::Completed,
64 SessionState::Failed => TaskStatus::Failed,
65 SessionState::Cancelled => TaskStatus::Cancelled,
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct TaskOutputResult {
72 pub task_id: String,
73 pub status: TaskStatus,
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub output: Option<String>,
76 #[serde(skip_serializing_if = "Option::is_none")]
77 pub error: Option<String>,
78}
79
80#[async_trait]
81impl SchemaTool for TaskOutputTool {
82 type Input = TaskOutputInput;
83
84 const NAME: &'static str = "TaskOutput";
85 const DESCRIPTION: &'static str = r#"
86- Retrieves output from a running or completed task (background shell, agent, or remote session)
87- Takes a task_id parameter identifying the task
88- Returns the task output along with status information
89- Use block=true (default) to wait for task completion
90- Use block=false for non-blocking check of current status
91- Task IDs can be found using the Task tool response
92- Works with all task types: background shells, async agents, and remote sessions
93- Output is limited to prevent excessive memory usage; for larger outputs, consider streaming
94- Important: task_id is the Task tool's returned ID, NOT a process PID"#;
95
96 async fn handle(&self, input: TaskOutputInput, _context: &ExecutionContext) -> ToolResult {
97 let timeout = Duration::from_millis(input.timeout.min(600000));
98
99 let result = if input.block {
100 self.registry
101 .wait_for_completion(&input.task_id, timeout)
102 .await
103 } else {
104 self.registry.get_result(&input.task_id).await
105 };
106
107 let output = match result {
108 Some((status, output, error)) => TaskOutputResult {
109 task_id: input.task_id,
110 status: status.into(),
111 output,
112 error,
113 },
114 None => TaskOutputResult {
115 task_id: input.task_id,
116 status: TaskStatus::NotFound,
117 output: None,
118 error: Some("Task not found".to_string()),
119 },
120 };
121
122 ToolResult::success(
123 serde_json::to_string_pretty(&output)
124 .unwrap_or_else(|_| format!("Task status: {:?}", output.status)),
125 )
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::agent::{AgentMetrics, AgentResult, AgentState};
133 use crate::session::MemoryPersistence;
134 use crate::tools::Tool;
135 use crate::types::{StopReason, ToolOutput, Usage};
136 use std::sync::Arc;
137
138 const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000011";
140 const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000012";
141 const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000013";
142
143 fn test_registry() -> TaskRegistry {
144 TaskRegistry::new(Arc::new(MemoryPersistence::new()))
145 }
146
147 fn mock_result() -> AgentResult {
148 AgentResult {
149 text: "Completed successfully".to_string(),
150 usage: Usage::default(),
151 tool_calls: 0,
152 iterations: 1,
153 stop_reason: StopReason::EndTurn,
154 state: AgentState::Completed,
155 metrics: AgentMetrics::default(),
156 session_id: "test-session".to_string(),
157 structured_output: None,
158 messages: Vec::new(),
159 uuid: "test-uuid".to_string(),
160 }
161 }
162
163 #[tokio::test]
164 async fn test_task_output_completed() {
165 let registry = test_registry();
166 registry
167 .register(TASK_1_UUID.into(), "Explore".into(), "Test".into())
168 .await;
169 registry.complete(TASK_1_UUID, mock_result()).await;
170
171 let tool = TaskOutputTool::new(registry);
172 let context = crate::tools::ExecutionContext::default();
173 let result = tool
174 .execute(
175 serde_json::json!({
176 "task_id": TASK_1_UUID
177 }),
178 &context,
179 )
180 .await;
181
182 assert!(!result.is_error());
183 if let ToolOutput::Success(content) = &result.output {
184 assert!(content.contains("completed"));
185 }
186 }
187
188 #[tokio::test]
189 async fn test_task_output_not_found() {
190 let registry = test_registry();
191 let tool = TaskOutputTool::new(registry);
192 let context = crate::tools::ExecutionContext::default();
193
194 let result = tool
195 .execute(
196 serde_json::json!({
197 "task_id": "nonexistent"
198 }),
199 &context,
200 )
201 .await;
202
203 if let ToolOutput::Success(content) = &result.output {
204 assert!(content.contains("not_found"));
205 }
206 }
207
208 #[tokio::test]
209 async fn test_task_output_non_blocking() {
210 let registry = test_registry();
211 registry
212 .register(TASK_2_UUID.into(), "Explore".into(), "Running".into())
213 .await;
214
215 let tool = TaskOutputTool::new(registry);
216 let context = crate::tools::ExecutionContext::default();
217 let result = tool
218 .execute(
219 serde_json::json!({
220 "task_id": TASK_2_UUID,
221 "block": false
222 }),
223 &context,
224 )
225 .await;
226
227 if let ToolOutput::Success(content) = &result.output {
228 assert!(content.contains("running"));
229 }
230 }
231
232 #[tokio::test]
233 async fn test_task_output_failed() {
234 let registry = test_registry();
235 registry
236 .register(TASK_3_UUID.into(), "Explore".into(), "Failing".into())
237 .await;
238 registry
239 .fail(TASK_3_UUID, "Something went wrong".into())
240 .await;
241
242 let tool = TaskOutputTool::new(registry);
243 let context = crate::tools::ExecutionContext::default();
244 let result = tool
245 .execute(
246 serde_json::json!({
247 "task_id": TASK_3_UUID
248 }),
249 &context,
250 )
251 .await;
252
253 if let ToolOutput::Success(content) = &result.output {
254 assert!(content.contains("failed"));
255 assert!(content.contains("Something went wrong"));
256 }
257 }
258}