1use crate::App;
11use anyhow::Result;
12use oxi_agent::{Agent, AgentEvent};
13use std::sync::Arc;
14use tokio::sync::mpsc;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum PrintMode {
19 Text,
21 Json,
23}
24
25#[derive(Debug)]
27pub struct PrintModeOptions {
28 pub mode: PrintMode,
30 pub messages: Vec<String>,
32 pub initial_message: Option<String>,
34}
35
36impl Default for PrintModeOptions {
37 fn default() -> Self {
38 Self {
39 mode: PrintMode::Text,
40 messages: Vec::new(),
41 initial_message: None,
42 }
43 }
44}
45
46pub async fn run_print_mode(app: &App, options: PrintModeOptions) -> Result<i32> {
50 let PrintModeOptions {
51 mode,
52 messages,
53 initial_message,
54 } = options;
55
56 let agent: Arc<Agent> = app.agent();
57 let mut exit_code = 0;
58
59 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
61 ctrlc_handler(shutdown_tx)?;
62
63 if let Some(prompt) = initial_message {
65 let result = run_single_prompt(&agent, &prompt, mode, &mut shutdown_rx).await;
66 match result {
67 Ok(()) => {}
68 Err(PromptError::AgentError(msg)) => {
69 if mode == PrintMode::Text {
70 eprintln!("Error: {}", msg);
71 }
72 exit_code = 1;
73 }
74 Err(PromptError::Shutdown) => {
75 exit_code = 130; return Ok(exit_code);
77 }
78 }
79 }
80
81 for message in messages {
83 if shutdown_rx.try_recv().is_ok() {
84 exit_code = 130;
85 return Ok(exit_code);
86 }
87
88 let result = run_single_prompt(&agent, &message, mode, &mut shutdown_rx).await;
89 match result {
90 Ok(()) => {}
91 Err(PromptError::AgentError(msg)) => {
92 if mode == PrintMode::Text {
93 eprintln!("Error: {}", msg);
94 }
95 exit_code = 1;
96 }
97 Err(PromptError::Shutdown) => {
98 exit_code = 130;
99 return Ok(exit_code);
100 }
101 }
102 }
103
104 Ok(exit_code)
105}
106
107enum PromptError {
109 AgentError(String),
110 Shutdown,
111}
112
113async fn run_single_prompt(
115 agent: &Arc<Agent>,
116 prompt: &str,
117 mode: PrintMode,
118 shutdown_rx: &mut mpsc::Receiver<()>,
119) -> Result<(), PromptError> {
120 let (event_tx, mut event_rx) = mpsc::channel::<AgentEvent>(256);
121
122 let agent_clone: Arc<Agent> = Arc::clone(agent);
124 let prompt_owned = prompt.to_string();
125
126 let agent_handle = tokio::task::spawn_blocking(move || {
127 let rt = tokio::runtime::Builder::new_current_thread()
128 .enable_all()
129 .build()
130 .expect("failed to build agent runtime");
131 rt.block_on(async {
132 let local = tokio::task::LocalSet::new();
133 local
134 .run_until(async {
135 let _ = agent_clone.run_with_channel(prompt_owned, event_tx).await;
136 })
137 .await;
138 });
139 });
140
141 let mut last_text = String::new();
143 let mut had_error = false;
144 let mut error_message = String::new();
145 let mut _stop_reason: Option<String> = None;
146
147 loop {
148 tokio::select! {
149 event = event_rx.recv() => {
150 match event {
151 Some(ev) => {
152 match &ev {
153 AgentEvent::TextChunk { text } => {
154 last_text.push_str(text);
155 }
156 AgentEvent::Complete { .. } => {
157 _stop_reason = Some("complete".to_string());
158 }
159 AgentEvent::Error { message } => {
160 had_error = true;
161 error_message = message.clone();
162 _stop_reason = Some("error".to_string());
163 }
164 _ => {}
165 }
166
167 if mode == PrintMode::Json {
168 if let Ok(json) = serde_json::to_string(&event_to_json(&ev)) {
169 println!("{}", json);
170 }
171 }
172 }
173 None => break,
174 }
175 }
176 _ = shutdown_rx.recv() => {
177 return Err(PromptError::Shutdown);
178 }
179 }
180 }
181
182 let _ = agent_handle.await;
184
185 if had_error {
186 return Err(PromptError::AgentError(error_message));
187 }
188
189 if mode == PrintMode::Text && !last_text.is_empty() {
191 println!("{}", last_text);
192 }
193
194 Ok(())
195}
196
197fn event_to_json(event: &AgentEvent) -> serde_json::Value {
199 match event {
200 AgentEvent::Start { .. } => serde_json::json!({
201 "type": "start"
202 }),
203 AgentEvent::Thinking => serde_json::json!({
204 "type": "thinking"
205 }),
206 AgentEvent::TextChunk { text } => serde_json::json!({
207 "type": "text_delta",
208 "text": text,
209 }),
210 AgentEvent::ToolCall { tool_call } => serde_json::json!({
211 "type": "tool_call",
212 "id": tool_call.id,
213 "name": tool_call.name,
214 "arguments": tool_call.arguments.to_string(),
215 }),
216 AgentEvent::ToolStart { tool_name, tool_call_id } => serde_json::json!({
217 "type": "tool_start",
218 "tool_name": tool_name,
219 "tool_call_id": tool_call_id,
220 }),
221 AgentEvent::ToolComplete { result } => serde_json::json!({
222 "type": "tool_complete",
223 "content": result.content.chars().take(2000).collect::<String>(),
224 "is_error": result.is_error(),
225 }),
226 AgentEvent::ToolError { error, tool_call_id } => serde_json::json!({
227 "type": "tool_error",
228 "error": error,
229 "tool_call_id": tool_call_id,
230 }),
231 AgentEvent::Complete { .. } => serde_json::json!({
232 "type": "complete"
233 }),
234 AgentEvent::Error { message } => serde_json::json!({
235 "type": "error",
236 "message": message,
237 }),
238 _ => serde_json::json!({
239 "type": "unknown"
240 }),
241 }
242}
243
244fn ctrlc_handler(shutdown_tx: mpsc::Sender<()>) -> Result<()> {
246 std::thread::spawn(move || {
248 let _ = tokio::runtime::Builder::new_current_thread()
249 .enable_all()
250 .build()
251 .map(|rt| {
252 rt.block_on(async {
253 tokio::signal::ctrl_c().await.ok();
254 let _ = shutdown_tx.try_send(());
255 });
256 });
257 });
258 Ok(())
259}
260
261pub fn read_stdin_prompt() -> Result<String> {
263 use std::io::{self, Read};
264 let mut buffer = String::new();
265 io::stdin().read_to_string(&mut buffer)?;
266 Ok(buffer.trim().to_string())
267}
268
269#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_event_to_json_start() {
277 let event = AgentEvent::Start {
278 prompt: "test".to_string(),
279 };
280 let json = event_to_json(&event);
281 assert_eq!(json["type"], "start");
282 }
283
284 #[test]
285 fn test_event_to_json_thinking() {
286 let json = event_to_json(&AgentEvent::Thinking);
287 assert_eq!(json["type"], "thinking");
288 }
289
290 #[test]
291 fn test_event_to_json_text_chunk() {
292 let event = AgentEvent::TextChunk {
293 text: "Hello world".to_string(),
294 };
295 let json = event_to_json(&event);
296 assert_eq!(json["type"], "text_delta");
297 assert_eq!(json["text"], "Hello world");
298 }
299
300 #[test]
301 fn test_event_to_json_tool_call() {
302 let event = AgentEvent::ToolCall {
303 tool_call: oxi_ai::ToolCall {
304 content_type: oxi_ai::ToolCallType::ToolCall,
305 id: "tc-1".to_string(),
306 name: "read_file".to_string(),
307 arguments: serde_json::json!({"path": "/tmp/test.rs"}),
308 thought_signature: None,
309 },
310 };
311 let json = event_to_json(&event);
312 assert_eq!(json["type"], "tool_call");
313 assert_eq!(json["name"], "read_file");
314 assert_eq!(json["id"], "tc-1");
315 }
316
317 #[test]
318 fn test_event_to_json_error() {
319 let event = AgentEvent::Error {
320 message: "Something went wrong".to_string(),
321 };
322 let json = event_to_json(&event);
323 assert_eq!(json["type"], "error");
324 assert_eq!(json["message"], "Something went wrong");
325 }
326
327 #[test]
328 fn test_event_to_json_complete() {
329 let event = AgentEvent::Complete {
330 content: "done".to_string(),
331 stop_reason: "end_turn".to_string(),
332 };
333 let json = event_to_json(&event);
334 assert_eq!(json["type"], "complete");
335 }
336
337 #[test]
338 fn test_event_to_json_tool_complete() {
339 let event = AgentEvent::ToolComplete {
340 result: oxi_ai::ToolResult {
341 tool_call_id: "tc-1".to_string(),
342 content: "file contents here".to_string(),
343 status: "success".to_string(),
344 },
345 };
346 let json = event_to_json(&event);
347 assert_eq!(json["type"], "tool_complete");
348 assert_eq!(json["is_error"], false);
349 }
350
351 #[test]
352 fn test_print_mode_default_options() {
353 let opts = PrintModeOptions::default();
354 assert_eq!(opts.mode, PrintMode::Text);
355 assert!(opts.messages.is_empty());
356 assert!(opts.initial_message.is_none());
357 }
358
359 #[test]
360 fn test_print_mode_equality() {
361 assert_eq!(PrintMode::Text, PrintMode::Text);
362 assert_eq!(PrintMode::Json, PrintMode::Json);
363 assert_ne!(PrintMode::Text, PrintMode::Json);
364 }
365}