1use crate::App;
11use anyhow::Result;
12use oxi_agent::{Agent, AgentEvent};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::mpsc;
16
17pub fn format_error(msg: &str) -> String {
22 format!("Error: {}", msg)
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum PrintMode {
28 Text,
30 Json,
32}
33
34#[derive(Debug)]
36pub struct PrintModeOptions {
37 pub mode: PrintMode,
39 pub messages: Vec<String>,
41 pub initial_message: Option<String>,
43 pub no_stdin: bool,
46 pub no_session: bool,
48 pub quiet: bool,
50 pub timeout: Option<u64>,
52}
53
54impl Default for PrintModeOptions {
55 fn default() -> Self {
56 Self {
57 mode: PrintMode::Text,
58 messages: Vec::new(),
59 initial_message: None,
60 no_stdin: false,
61 no_session: false,
62 quiet: false,
63 timeout: None,
64 }
65 }
66}
67
68pub async fn run_print_mode(app: &App, options: PrintModeOptions) -> Result<i32> {
72 let PrintModeOptions {
73 mode,
74 messages,
75 initial_message,
76 no_stdin,
77 no_session: _,
78 quiet,
79 timeout,
80 } = options;
81
82 let _ = no_stdin;
84
85 let agent: Arc<Agent> = app.agent();
86 let mut exit_code = 0;
87
88 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
90 ctrlc_handler(shutdown_tx)?;
91
92 let work = async {
94 if let Some(prompt) = initial_message {
96 let result = run_single_prompt(&agent, &prompt, mode, quiet, &mut shutdown_rx).await;
97 match result {
98 Ok(()) => {}
99 Err(PromptError::AgentError(msg)) => {
100 if mode == PrintMode::Text && !quiet {
101 eprintln!("{}", format_error(&msg));
102 }
103 exit_code = 1;
104 }
105 Err(PromptError::Shutdown) => {
106 exit_code = 130;
107 return;
108 }
109 }
110 }
111
112 for message in messages {
114 if shutdown_rx.try_recv().is_ok() {
115 exit_code = 130;
116 return;
117 }
118
119 let result = run_single_prompt(&agent, &message, mode, quiet, &mut shutdown_rx).await;
120 match result {
121 Ok(()) => {}
122 Err(PromptError::AgentError(msg)) => {
123 if mode == PrintMode::Text && !quiet {
124 eprintln!("{}", format_error(&msg));
125 }
126 exit_code = 1;
127 }
128 Err(PromptError::Shutdown) => {
129 exit_code = 130;
130 return;
131 }
132 }
133 }
134 };
135
136 if let Some(secs) = timeout {
138 match tokio::time::timeout(Duration::from_secs(secs), work).await {
139 Ok(()) => {}
140 Err(_) => {
141 if !quiet {
142 eprintln!("Timed out after {} seconds", secs);
143 }
144 exit_code = 124; }
146 }
147 } else {
148 work.await;
149 }
150
151 Ok(exit_code)
152}
153
154enum PromptError {
156 AgentError(String),
157 Shutdown,
158}
159
160async fn run_single_prompt(
162 agent: &Arc<Agent>,
163 prompt: &str,
164 mode: PrintMode,
165 quiet: bool,
166 shutdown_rx: &mut mpsc::Receiver<()>,
167) -> Result<(), PromptError> {
168 let _ = quiet; let (event_tx, event_rx) = std::sync::mpsc::channel::<AgentEvent>();
172 let (async_tx, mut async_rx) = mpsc::channel::<AgentEvent>(256);
173
174 let agent_clone: Arc<Agent> = Arc::clone(agent);
176 let prompt_owned = prompt.to_string();
177
178 let bridge_handle = std::thread::spawn(move || {
185 while let Ok(event) = event_rx.recv() {
186 let _ = async_tx.try_send(event);
188 }
189 });
190
191 let agent_handle = tokio::task::spawn_blocking(move || {
192 let rt = tokio::runtime::Builder::new_current_thread()
193 .enable_all()
194 .build()
195 .expect("failed to build agent runtime");
196 rt.block_on(async {
197 let local = tokio::task::LocalSet::new();
198 local
199 .run_until(async {
200 let _ = agent_clone.run_with_channel(prompt_owned, event_tx).await;
201 })
202 .await;
203 });
204 });
205
206 let mut last_text = String::new();
208 let mut had_error = false;
209 let mut error_message = String::new();
210 let mut _stop_reason: Option<String> = None;
211
212 loop {
213 tokio::select! {
214 event = async_rx.recv() => {
215 match event {
216 Some(ev) => {
217 match &ev {
218 AgentEvent::TextChunk { text } => {
220 last_text.push_str(text);
221 }
222 AgentEvent::MessageUpdate { message: oxi_sdk::Message::Assistant(asst), .. } => {
224 let text_only: String = asst.content.iter()
229 .filter_map(|b| match b {
230 oxi_sdk::ContentBlock::Text(t) => Some(t.text.as_str()),
231 _ => None,
232 })
233 .collect();
234 if !text_only.is_empty() {
235 last_text = text_only;
236 } else {
237 let thinking_text: String = asst.content.iter()
239 .filter_map(|b| match b {
240 oxi_sdk::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
241 _ => None,
242 })
243 .collect();
244 if !thinking_text.is_empty() {
245 last_text = thinking_text;
246 }
247 }
248 }
249 AgentEvent::MessageEnd { message: oxi_sdk::Message::Assistant(asst) } => {
250 let text_only: String = asst.content.iter()
252 .filter_map(|b| match b {
253 oxi_sdk::ContentBlock::Text(t) => Some(t.text.as_str()),
254 _ => None,
255 })
256 .collect();
257 if !text_only.is_empty() {
258 last_text = text_only;
259 } else {
260 let thinking_text: String = asst.content.iter()
262 .filter_map(|b| match b {
263 oxi_sdk::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
264 _ => None,
265 })
266 .collect();
267 if !thinking_text.is_empty() {
268 last_text = thinking_text;
269 }
270 }
271 }
272 AgentEvent::Complete { .. } => {
273 _stop_reason = Some("complete".to_string());
274 }
275 AgentEvent::Error { message, .. } => {
276 had_error = true;
277 error_message = message.clone();
278 _stop_reason = Some("error".to_string());
279 }
280 _ => {}
281 }
282
283 if mode == PrintMode::Json
284 && let Ok(json) = serde_json::to_string(&event_to_json(&ev)) {
285 println!("{}", json);
286 use std::io::Write;
287 std::io::stdout().flush().ok();
288 }
289 }
290 None => break,
291 }
292 }
293 _ = shutdown_rx.recv() => {
294 return Err(PromptError::Shutdown);
295 }
296 }
297 }
298
299 let _ = agent_handle.await;
301 let _ = bridge_handle.join();
302
303 if had_error {
304 return Err(PromptError::AgentError(error_message));
305 }
306
307 if mode == PrintMode::Text && !last_text.is_empty() {
309 println!("{}", last_text);
310 use std::io::Write;
311 std::io::stdout().flush().ok();
312 }
313
314 Ok(())
315}
316
317fn extract_text_from_message(msg: &oxi_sdk::Message) -> String {
321 match msg {
322 oxi_sdk::Message::Assistant(asst) => {
323 let text_only: String = asst
324 .content
325 .iter()
326 .filter_map(|b| match b {
327 oxi_sdk::ContentBlock::Text(t) => Some(t.text.as_str()),
328 _ => None,
329 })
330 .collect::<Vec<_>>()
331 .join("");
332 if !text_only.is_empty() {
333 return text_only;
334 }
335 asst.content
337 .iter()
338 .filter_map(|b| match b {
339 oxi_sdk::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
340 _ => None,
341 })
342 .collect::<Vec<_>>()
343 .join("")
344 }
345 _ => String::new(),
346 }
347}
348
349fn event_to_json(event: &AgentEvent) -> serde_json::Value {
351 match event {
352 AgentEvent::Start { .. } => serde_json::json!({
353 "type": "start"
354 }),
355 AgentEvent::Thinking => serde_json::json!({
356 "type": "thinking"
357 }),
358 AgentEvent::TextChunk { text } => serde_json::json!({
359 "type": "text_delta",
360 "text": text,
361 }),
362 AgentEvent::ToolCall { tool_call } => serde_json::json!({
363 "type": "tool_call",
364 "id": tool_call.id,
365 "name": tool_call.name,
366 "arguments": tool_call.arguments.to_string(),
367 }),
368 AgentEvent::ToolStart {
369 tool_name,
370 tool_call_id,
371 arguments: _,
372 } => serde_json::json!({
373 "type": "tool_start",
374 "tool_name": tool_name,
375 "tool_call_id": tool_call_id,
376 }),
377 AgentEvent::ToolComplete { result } => serde_json::json!({
378 "type": "tool_complete",
379 "content": result.content.chars().take(2000).collect::<String>(),
380 "is_error": result.is_error(),
381 }),
382 AgentEvent::ToolError {
383 error,
384 tool_call_id,
385 } => serde_json::json!({
386 "type": "tool_error",
387 "error": error,
388 "tool_call_id": tool_call_id,
389 }),
390 AgentEvent::Complete { .. } => serde_json::json!({
391 "type": "complete"
392 }),
393 AgentEvent::Error { message, .. } => serde_json::json!({
394 "type": "error",
395 "message": message,
396 }),
397 AgentEvent::Usage {
398 input_tokens,
399 output_tokens,
400 } => serde_json::json!({
401 "type": "usage",
402 "input_tokens": input_tokens,
403 "output_tokens": output_tokens,
404 }),
405
406 AgentEvent::AgentStart { .. } => serde_json::json!({
408 "type": "agent_start"
409 }),
410 AgentEvent::AgentEnd { .. } => serde_json::json!({
411 "type": "agent_end"
412 }),
413 AgentEvent::TurnStart { turn_number } => serde_json::json!({
414 "type": "turn_start",
415 "turn_number": turn_number,
416 }),
417 AgentEvent::TurnEnd { turn_number, .. } => serde_json::json!({
418 "type": "turn_end",
419 "turn_number": turn_number,
420 }),
421 AgentEvent::MessageStart { message } => {
422 let text = extract_text_from_message(message);
423 serde_json::json!({
424 "type": "message_start",
425 "text": text,
426 })
427 }
428 AgentEvent::MessageUpdate { message, delta } => {
429 let text = extract_text_from_message(message);
430 serde_json::json!({
431 "type": "message_update",
432 "text": text,
433 "delta": delta,
434 })
435 }
436 AgentEvent::MessageEnd { message } => {
437 let text = extract_text_from_message(message);
438 serde_json::json!({
439 "type": "message_end",
440 "text": text,
441 })
442 }
443 AgentEvent::ToolExecutionStart {
444 tool_call_id,
445 tool_name,
446 args,
447 ..
448 } => serde_json::json!({
449 "type": "tool_execution_start",
450 "tool_call_id": tool_call_id,
451 "tool_name": tool_name,
452 "args": args.to_string(),
453 }),
454 AgentEvent::ToolExecutionEnd {
455 tool_call_id,
456 tool_name,
457 result,
458 is_error,
459 } => serde_json::json!({
460 "type": "tool_execution_end",
461 "tool_call_id": tool_call_id,
462 "tool_name": tool_name,
463 "result": result.content.chars().take(2000).collect::<String>(),
464 "is_error": is_error,
465 }),
466
467 _ => serde_json::json!({
469 "type": "unknown"
470 }),
471 }
472}
473
474fn ctrlc_handler(shutdown_tx: mpsc::Sender<()>) -> Result<()> {
476 std::thread::spawn(move || {
478 let _ = tokio::runtime::Builder::new_current_thread()
479 .enable_all()
480 .build()
481 .map(|rt| {
482 rt.block_on(async {
483 tokio::signal::ctrl_c().await.ok();
484 let _ = shutdown_tx.try_send(());
485 });
486 });
487 });
488 Ok(())
489}
490
491pub fn read_stdin_prompt() -> Result<String> {
493 use std::io::{self, Read};
494 let mut buffer = String::new();
495 io::stdin().read_to_string(&mut buffer)?;
496 Ok(buffer.trim().to_string())
497}
498
499#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn test_event_to_json_start() {
507 let event = AgentEvent::Start {
508 prompt: "test".to_string(),
509 };
510 let json = event_to_json(&event);
511 assert_eq!(json["type"], "start");
512 }
513
514 #[test]
515 fn test_event_to_json_thinking() {
516 let json = event_to_json(&AgentEvent::Thinking);
517 assert_eq!(json["type"], "thinking");
518 }
519
520 #[test]
521 fn test_event_to_json_text_chunk() {
522 let event = AgentEvent::TextChunk {
523 text: "Hello world".to_string(),
524 };
525 let json = event_to_json(&event);
526 assert_eq!(json["type"], "text_delta");
527 assert_eq!(json["text"], "Hello world");
528 }
529
530 #[test]
531 fn test_event_to_json_tool_call() {
532 let event = AgentEvent::ToolCall {
533 tool_call: oxi_sdk::ToolCall {
534 content_type: oxi_sdk::ToolCallType::ToolCall,
535 id: "tc-1".to_string(),
536 name: "read_file".to_string(),
537 arguments: serde_json::json!({"path": "/tmp/test.rs"}),
538 thought_signature: None,
539 },
540 };
541 let json = event_to_json(&event);
542 assert_eq!(json["type"], "tool_call");
543 assert_eq!(json["name"], "read_file");
544 assert_eq!(json["id"], "tc-1");
545 }
546
547 #[test]
548 fn test_event_to_json_error() {
549 let event = AgentEvent::Error {
550 message: "Something went wrong".to_string(),
551 session_id: None,
552 };
553 let json = event_to_json(&event);
554 assert_eq!(json["type"], "error");
555 assert_eq!(json["message"], "Something went wrong");
556 }
557
558 #[test]
559 fn test_event_to_json_complete() {
560 let event = AgentEvent::Complete {
561 content: "done".to_string(),
562 stop_reason: "end_turn".to_string(),
563 };
564 let json = event_to_json(&event);
565 assert_eq!(json["type"], "complete");
566 }
567
568 #[test]
569 fn test_event_to_json_tool_complete() {
570 let event = AgentEvent::ToolComplete {
571 result: oxi_sdk::ToolResult {
572 tool_call_id: "tc-1".to_string(),
573 content: "file contents here".to_string(),
574 status: "success".to_string(),
575 },
576 };
577 let json = event_to_json(&event);
578 assert_eq!(json["type"], "tool_complete");
579 assert_eq!(json["is_error"], false);
580 }
581
582 #[test]
583 fn test_print_mode_default_options() {
584 let opts = PrintModeOptions::default();
585 assert_eq!(opts.mode, PrintMode::Text);
586 assert!(opts.messages.is_empty());
587 assert!(opts.initial_message.is_none());
588 assert!(!opts.no_stdin);
589 assert!(!opts.no_session);
590 assert!(!opts.quiet);
591 assert!(opts.timeout.is_none());
592 }
593
594 #[test]
595 fn test_print_mode_equality() {
596 assert_eq!(PrintMode::Text, PrintMode::Text);
597 assert_eq!(PrintMode::Json, PrintMode::Json);
598 assert_ne!(PrintMode::Text, PrintMode::Json);
599 }
600}