1use crate::App;
11use anyhow::Result;
12use oxi_agent::{Agent, AgentEvent};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::mpsc;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum PrintMode {
20 Text,
22 Json,
24}
25
26#[derive(Debug)]
28pub struct PrintModeOptions {
29 pub mode: PrintMode,
31 pub messages: Vec<String>,
33 pub initial_message: Option<String>,
35 pub no_stdin: bool,
38 pub no_session: bool,
40 pub quiet: bool,
42 pub timeout: Option<u64>,
44}
45
46impl Default for PrintModeOptions {
47 fn default() -> Self {
48 Self {
49 mode: PrintMode::Text,
50 messages: Vec::new(),
51 initial_message: None,
52 no_stdin: false,
53 no_session: false,
54 quiet: false,
55 timeout: None,
56 }
57 }
58}
59
60pub async fn run_print_mode(app: &App, options: PrintModeOptions) -> Result<i32> {
64 let PrintModeOptions {
65 mode,
66 messages,
67 initial_message,
68 no_stdin,
69 no_session: _,
70 quiet,
71 timeout,
72 } = options;
73
74 let _ = no_stdin;
76
77 let agent: Arc<Agent> = app.agent();
78 let mut exit_code = 0;
79
80 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
82 ctrlc_handler(shutdown_tx)?;
83
84 let work = async {
86 if let Some(prompt) = initial_message {
88 let result = run_single_prompt(&agent, &prompt, mode, quiet, &mut shutdown_rx).await;
89 match result {
90 Ok(()) => {}
91 Err(PromptError::AgentError(msg)) => {
92 if mode == PrintMode::Text && !quiet {
93 eprintln!("Error: {}", msg);
94 }
95 exit_code = 1;
96 }
97 Err(PromptError::Shutdown) => {
98 exit_code = 130;
99 return;
100 }
101 }
102 }
103
104 for message in messages {
106 if shutdown_rx.try_recv().is_ok() {
107 exit_code = 130;
108 return;
109 }
110
111 let result = run_single_prompt(&agent, &message, mode, quiet, &mut shutdown_rx).await;
112 match result {
113 Ok(()) => {}
114 Err(PromptError::AgentError(msg)) => {
115 if mode == PrintMode::Text && !quiet {
116 eprintln!("Error: {}", msg);
117 }
118 exit_code = 1;
119 }
120 Err(PromptError::Shutdown) => {
121 exit_code = 130;
122 return;
123 }
124 }
125 }
126 };
127
128 if let Some(secs) = timeout {
130 match tokio::time::timeout(Duration::from_secs(secs), work).await {
131 Ok(()) => {}
132 Err(_) => {
133 if !quiet {
134 eprintln!("Timed out after {} seconds", secs);
135 }
136 exit_code = 124; }
138 }
139 } else {
140 work.await;
141 }
142
143 Ok(exit_code)
144}
145
146enum PromptError {
148 AgentError(String),
149 Shutdown,
150}
151
152async fn run_single_prompt(
154 agent: &Arc<Agent>,
155 prompt: &str,
156 mode: PrintMode,
157 quiet: bool,
158 shutdown_rx: &mut mpsc::Receiver<()>,
159) -> Result<(), PromptError> {
160 let _ = quiet; let (event_tx, event_rx) = std::sync::mpsc::channel::<AgentEvent>();
164 let (async_tx, mut async_rx) = mpsc::channel::<AgentEvent>(256);
165
166 let agent_clone: Arc<Agent> = Arc::clone(agent);
168 let prompt_owned = prompt.to_string();
169
170 let bridge_handle = std::thread::spawn(move || {
177 while let Ok(event) = event_rx.recv() {
178 let _ = async_tx.try_send(event);
180 }
181 });
182
183 let agent_handle = tokio::task::spawn_blocking(move || {
184 let rt = tokio::runtime::Builder::new_current_thread()
185 .enable_all()
186 .build()
187 .expect("failed to build agent runtime");
188 rt.block_on(async {
189 let local = tokio::task::LocalSet::new();
190 local
191 .run_until(async {
192 let _ = agent_clone.run_with_channel(prompt_owned, event_tx).await;
193 })
194 .await;
195 });
196 });
197
198 let mut last_text = String::new();
200 let mut had_error = false;
201 let mut error_message = String::new();
202 let mut _stop_reason: Option<String> = None;
203
204 loop {
205 tokio::select! {
206 event = async_rx.recv() => {
207 match event {
208 Some(ev) => {
209 match &ev {
210 AgentEvent::TextChunk { text } => {
212 last_text.push_str(text);
213 }
214 AgentEvent::MessageUpdate { message: oxi_ai::Message::Assistant(asst), .. } => {
216 let text_only: String = asst.content.iter()
221 .filter_map(|b| match b {
222 oxi_ai::ContentBlock::Text(t) => Some(t.text.as_str()),
223 _ => None,
224 })
225 .collect();
226 if !text_only.is_empty() {
227 last_text = text_only;
228 } else {
229 let thinking_text: String = asst.content.iter()
231 .filter_map(|b| match b {
232 oxi_ai::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
233 _ => None,
234 })
235 .collect();
236 if !thinking_text.is_empty() {
237 last_text = thinking_text;
238 }
239 }
240 }
241 AgentEvent::MessageEnd { message: oxi_ai::Message::Assistant(asst) } => {
242 let text_only: String = asst.content.iter()
244 .filter_map(|b| match b {
245 oxi_ai::ContentBlock::Text(t) => Some(t.text.as_str()),
246 _ => None,
247 })
248 .collect();
249 if !text_only.is_empty() {
250 last_text = text_only;
251 } else {
252 let thinking_text: String = asst.content.iter()
254 .filter_map(|b| match b {
255 oxi_ai::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
256 _ => None,
257 })
258 .collect();
259 if !thinking_text.is_empty() {
260 last_text = thinking_text;
261 }
262 }
263 }
264 AgentEvent::Complete { .. } => {
265 _stop_reason = Some("complete".to_string());
266 }
267 AgentEvent::Error { message, .. } => {
268 had_error = true;
269 error_message = message.clone();
270 _stop_reason = Some("error".to_string());
271 }
272 _ => {}
273 }
274
275 if mode == PrintMode::Json {
276 if let Ok(json) = serde_json::to_string(&event_to_json(&ev)) {
277 println!("{}", json);
278 use std::io::Write;
279 std::io::stdout().flush().ok();
280 }
281 }
282 }
283 None => break,
284 }
285 }
286 _ = shutdown_rx.recv() => {
287 return Err(PromptError::Shutdown);
288 }
289 }
290 }
291
292 let _ = agent_handle.await;
294 let _ = bridge_handle.join();
295
296 if had_error {
297 return Err(PromptError::AgentError(error_message));
298 }
299
300 if mode == PrintMode::Text && !last_text.is_empty() {
302 println!("{}", last_text);
303 use std::io::Write;
304 std::io::stdout().flush().ok();
305 }
306
307 Ok(())
308}
309
310fn extract_text_from_message(msg: &oxi_ai::Message) -> String {
314 match msg {
315 oxi_ai::Message::Assistant(asst) => {
316 let text_only: String = asst
317 .content
318 .iter()
319 .filter_map(|b| match b {
320 oxi_ai::ContentBlock::Text(t) => Some(t.text.as_str()),
321 _ => None,
322 })
323 .collect::<Vec<_>>()
324 .join("");
325 if !text_only.is_empty() {
326 return text_only;
327 }
328 asst.content
330 .iter()
331 .filter_map(|b| match b {
332 oxi_ai::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
333 _ => None,
334 })
335 .collect::<Vec<_>>()
336 .join("")
337 }
338 _ => String::new(),
339 }
340}
341
342fn event_to_json(event: &AgentEvent) -> serde_json::Value {
344 match event {
345 AgentEvent::Start { .. } => serde_json::json!({
346 "type": "start"
347 }),
348 AgentEvent::Thinking => serde_json::json!({
349 "type": "thinking"
350 }),
351 AgentEvent::TextChunk { text } => serde_json::json!({
352 "type": "text_delta",
353 "text": text,
354 }),
355 AgentEvent::ToolCall { tool_call } => serde_json::json!({
356 "type": "tool_call",
357 "id": tool_call.id,
358 "name": tool_call.name,
359 "arguments": tool_call.arguments.to_string(),
360 }),
361 AgentEvent::ToolStart {
362 tool_name,
363 tool_call_id,
364 arguments: _,
365 } => serde_json::json!({
366 "type": "tool_start",
367 "tool_name": tool_name,
368 "tool_call_id": tool_call_id,
369 }),
370 AgentEvent::ToolComplete { result } => serde_json::json!({
371 "type": "tool_complete",
372 "content": result.content.chars().take(2000).collect::<String>(),
373 "is_error": result.is_error(),
374 }),
375 AgentEvent::ToolError {
376 error,
377 tool_call_id,
378 } => serde_json::json!({
379 "type": "tool_error",
380 "error": error,
381 "tool_call_id": tool_call_id,
382 }),
383 AgentEvent::Complete { .. } => serde_json::json!({
384 "type": "complete"
385 }),
386 AgentEvent::Error { message, .. } => serde_json::json!({
387 "type": "error",
388 "message": message,
389 }),
390 AgentEvent::Usage {
391 input_tokens,
392 output_tokens,
393 } => serde_json::json!({
394 "type": "usage",
395 "input_tokens": input_tokens,
396 "output_tokens": output_tokens,
397 }),
398
399 AgentEvent::AgentStart { .. } => serde_json::json!({
401 "type": "agent_start"
402 }),
403 AgentEvent::AgentEnd { .. } => serde_json::json!({
404 "type": "agent_end"
405 }),
406 AgentEvent::TurnStart { turn_number } => serde_json::json!({
407 "type": "turn_start",
408 "turn_number": turn_number,
409 }),
410 AgentEvent::TurnEnd { turn_number, .. } => serde_json::json!({
411 "type": "turn_end",
412 "turn_number": turn_number,
413 }),
414 AgentEvent::MessageStart { message } => {
415 let text = extract_text_from_message(message);
416 serde_json::json!({
417 "type": "message_start",
418 "text": text,
419 })
420 }
421 AgentEvent::MessageUpdate { message, delta } => {
422 let text = extract_text_from_message(message);
423 serde_json::json!({
424 "type": "message_update",
425 "text": text,
426 "delta": delta,
427 })
428 }
429 AgentEvent::MessageEnd { message } => {
430 let text = extract_text_from_message(message);
431 serde_json::json!({
432 "type": "message_end",
433 "text": text,
434 })
435 }
436 AgentEvent::ToolExecutionStart {
437 tool_call_id,
438 tool_name,
439 args,
440 } => serde_json::json!({
441 "type": "tool_execution_start",
442 "tool_call_id": tool_call_id,
443 "tool_name": tool_name,
444 "args": args.to_string(),
445 }),
446 AgentEvent::ToolExecutionEnd {
447 tool_call_id,
448 tool_name,
449 result,
450 is_error,
451 } => serde_json::json!({
452 "type": "tool_execution_end",
453 "tool_call_id": tool_call_id,
454 "tool_name": tool_name,
455 "result": result.content.chars().take(2000).collect::<String>(),
456 "is_error": is_error,
457 }),
458
459 _ => serde_json::json!({
461 "type": "unknown"
462 }),
463 }
464}
465
466fn ctrlc_handler(shutdown_tx: mpsc::Sender<()>) -> Result<()> {
468 std::thread::spawn(move || {
470 let _ = tokio::runtime::Builder::new_current_thread()
471 .enable_all()
472 .build()
473 .map(|rt| {
474 rt.block_on(async {
475 tokio::signal::ctrl_c().await.ok();
476 let _ = shutdown_tx.try_send(());
477 });
478 });
479 });
480 Ok(())
481}
482
483pub fn read_stdin_prompt() -> Result<String> {
485 use std::io::{self, Read};
486 let mut buffer = String::new();
487 io::stdin().read_to_string(&mut buffer)?;
488 Ok(buffer.trim().to_string())
489}
490
491#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_event_to_json_start() {
499 let event = AgentEvent::Start {
500 prompt: "test".to_string(),
501 };
502 let json = event_to_json(&event);
503 assert_eq!(json["type"], "start");
504 }
505
506 #[test]
507 fn test_event_to_json_thinking() {
508 let json = event_to_json(&AgentEvent::Thinking);
509 assert_eq!(json["type"], "thinking");
510 }
511
512 #[test]
513 fn test_event_to_json_text_chunk() {
514 let event = AgentEvent::TextChunk {
515 text: "Hello world".to_string(),
516 };
517 let json = event_to_json(&event);
518 assert_eq!(json["type"], "text_delta");
519 assert_eq!(json["text"], "Hello world");
520 }
521
522 #[test]
523 fn test_event_to_json_tool_call() {
524 let event = AgentEvent::ToolCall {
525 tool_call: oxi_ai::ToolCall {
526 content_type: oxi_ai::ToolCallType::ToolCall,
527 id: "tc-1".to_string(),
528 name: "read_file".to_string(),
529 arguments: serde_json::json!({"path": "/tmp/test.rs"}),
530 thought_signature: None,
531 },
532 };
533 let json = event_to_json(&event);
534 assert_eq!(json["type"], "tool_call");
535 assert_eq!(json["name"], "read_file");
536 assert_eq!(json["id"], "tc-1");
537 }
538
539 #[test]
540 fn test_event_to_json_error() {
541 let event = AgentEvent::Error {
542 message: "Something went wrong".to_string(),
543 session_id: None,
544 };
545 let json = event_to_json(&event);
546 assert_eq!(json["type"], "error");
547 assert_eq!(json["message"], "Something went wrong");
548 }
549
550 #[test]
551 fn test_event_to_json_complete() {
552 let event = AgentEvent::Complete {
553 content: "done".to_string(),
554 stop_reason: "end_turn".to_string(),
555 };
556 let json = event_to_json(&event);
557 assert_eq!(json["type"], "complete");
558 }
559
560 #[test]
561 fn test_event_to_json_tool_complete() {
562 let event = AgentEvent::ToolComplete {
563 result: oxi_ai::ToolResult {
564 tool_call_id: "tc-1".to_string(),
565 content: "file contents here".to_string(),
566 status: "success".to_string(),
567 },
568 };
569 let json = event_to_json(&event);
570 assert_eq!(json["type"], "tool_complete");
571 assert_eq!(json["is_error"], false);
572 }
573
574 #[test]
575 fn test_print_mode_default_options() {
576 let opts = PrintModeOptions::default();
577 assert_eq!(opts.mode, PrintMode::Text);
578 assert!(opts.messages.is_empty());
579 assert!(opts.initial_message.is_none());
580 assert!(!opts.no_stdin);
581 assert!(!opts.no_session);
582 assert!(!opts.quiet);
583 assert!(opts.timeout.is_none());
584 }
585
586 #[test]
587 fn test_print_mode_equality() {
588 assert_eq!(PrintMode::Text, PrintMode::Text);
589 assert_eq!(PrintMode::Json, PrintMode::Json);
590 assert_ne!(PrintMode::Text, PrintMode::Json);
591 }
592}