1use crate::App;
11use anyhow::Result;
12use oxi_agent::{Agent, AgentEvent};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::sync::mpsc;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum PrintMode {
20 Text,
22 Json,
24}
25
26#[derive(Debug)]
28pub struct PrintModeOptions {
29 pub mode: PrintMode,
31 pub messages: Vec<String>,
33 pub initial_message: Option<String>,
35 pub no_stdin: bool,
38 pub no_session: bool,
40 pub quiet: bool,
42 pub timeout: Option<u64>,
44}
45
46impl Default for PrintModeOptions {
47 fn default() -> Self {
48 Self {
49 mode: PrintMode::Text,
50 messages: Vec::new(),
51 initial_message: None,
52 no_stdin: false,
53 no_session: false,
54 quiet: false,
55 timeout: None,
56 }
57 }
58}
59
60pub async fn run_print_mode(app: &App, options: PrintModeOptions) -> Result<i32> {
64 let PrintModeOptions {
65 mode,
66 messages,
67 initial_message,
68 no_stdin,
69 no_session: _,
70 quiet,
71 timeout,
72 } = options;
73
74 let _ = no_stdin;
76
77 let agent: Arc<Agent> = app.agent();
78 let mut exit_code = 0;
79
80 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
82 ctrlc_handler(shutdown_tx)?;
83
84 let work = async {
86 if let Some(prompt) = initial_message {
88 let result = run_single_prompt(&agent, &prompt, mode, quiet, &mut shutdown_rx).await;
89 match result {
90 Ok(()) => {}
91 Err(PromptError::AgentError(msg)) => {
92 if mode == PrintMode::Text && !quiet {
93 eprintln!("Error: {}", msg);
94 }
95 exit_code = 1;
96 }
97 Err(PromptError::Shutdown) => {
98 exit_code = 130;
99 return;
100 }
101 }
102 }
103
104 for message in messages {
106 if shutdown_rx.try_recv().is_ok() {
107 exit_code = 130;
108 return;
109 }
110
111 let result = run_single_prompt(&agent, &message, mode, quiet, &mut shutdown_rx).await;
112 match result {
113 Ok(()) => {}
114 Err(PromptError::AgentError(msg)) => {
115 if mode == PrintMode::Text && !quiet {
116 eprintln!("Error: {}", msg);
117 }
118 exit_code = 1;
119 }
120 Err(PromptError::Shutdown) => {
121 exit_code = 130;
122 return;
123 }
124 }
125 }
126 };
127
128 if let Some(secs) = timeout {
130 match tokio::time::timeout(Duration::from_secs(secs), work).await {
131 Ok(()) => {}
132 Err(_) => {
133 if !quiet {
134 eprintln!("Timed out after {} seconds", secs);
135 }
136 exit_code = 124; }
138 }
139 } else {
140 work.await;
141 }
142
143 Ok(exit_code)
144}
145
146enum PromptError {
148 AgentError(String),
149 Shutdown,
150}
151
152async fn run_single_prompt(
154 agent: &Arc<Agent>,
155 prompt: &str,
156 mode: PrintMode,
157 quiet: bool,
158 shutdown_rx: &mut mpsc::Receiver<()>,
159) -> Result<(), PromptError> {
160 let _ = quiet; let (event_tx, event_rx) = std::sync::mpsc::channel::<AgentEvent>();
164 let (async_tx, mut async_rx) = mpsc::channel::<AgentEvent>(256);
165
166 let agent_clone: Arc<Agent> = Arc::clone(agent);
168 let prompt_owned = prompt.to_string();
169
170 let bridge_handle = std::thread::spawn(move || {
177 while let Ok(event) = event_rx.recv() {
178 let _ = async_tx.try_send(event);
180 }
181 });
182
183 let agent_handle = tokio::task::spawn_blocking(move || {
184 let rt = tokio::runtime::Builder::new_current_thread()
185 .enable_all()
186 .build()
187 .expect("failed to build agent runtime");
188 rt.block_on(async {
189 let local = tokio::task::LocalSet::new();
190 local
191 .run_until(async {
192 let _ = agent_clone.run_with_channel(prompt_owned, event_tx).await;
193 })
194 .await;
195 });
196 });
197
198 let mut last_text = String::new();
200 let mut had_error = false;
201 let mut error_message = String::new();
202 let mut _stop_reason: Option<String> = None;
203
204 loop {
205 tokio::select! {
206 event = async_rx.recv() => {
207 match event {
208 Some(ev) => {
209 match &ev {
210 AgentEvent::TextChunk { text } => {
212 last_text.push_str(text);
213 }
214 AgentEvent::MessageUpdate { message: oxi_sdk::Message::Assistant(asst), .. } => {
216 let text_only: String = asst.content.iter()
221 .filter_map(|b| match b {
222 oxi_sdk::ContentBlock::Text(t) => Some(t.text.as_str()),
223 _ => None,
224 })
225 .collect();
226 if !text_only.is_empty() {
227 last_text = text_only;
228 } else {
229 let thinking_text: String = asst.content.iter()
231 .filter_map(|b| match b {
232 oxi_sdk::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
233 _ => None,
234 })
235 .collect();
236 if !thinking_text.is_empty() {
237 last_text = thinking_text;
238 }
239 }
240 }
241 AgentEvent::MessageEnd { message: oxi_sdk::Message::Assistant(asst) } => {
242 let text_only: String = asst.content.iter()
244 .filter_map(|b| match b {
245 oxi_sdk::ContentBlock::Text(t) => Some(t.text.as_str()),
246 _ => None,
247 })
248 .collect();
249 if !text_only.is_empty() {
250 last_text = text_only;
251 } else {
252 let thinking_text: String = asst.content.iter()
254 .filter_map(|b| match b {
255 oxi_sdk::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
256 _ => None,
257 })
258 .collect();
259 if !thinking_text.is_empty() {
260 last_text = thinking_text;
261 }
262 }
263 }
264 AgentEvent::Complete { .. } => {
265 _stop_reason = Some("complete".to_string());
266 }
267 AgentEvent::Error { message, .. } => {
268 had_error = true;
269 error_message = message.clone();
270 _stop_reason = Some("error".to_string());
271 }
272 _ => {}
273 }
274
275 if mode == PrintMode::Json
276 && let Ok(json) = serde_json::to_string(&event_to_json(&ev)) {
277 println!("{}", json);
278 use std::io::Write;
279 std::io::stdout().flush().ok();
280 }
281 }
282 None => break,
283 }
284 }
285 _ = shutdown_rx.recv() => {
286 return Err(PromptError::Shutdown);
287 }
288 }
289 }
290
291 let _ = agent_handle.await;
293 let _ = bridge_handle.join();
294
295 if had_error {
296 return Err(PromptError::AgentError(error_message));
297 }
298
299 if mode == PrintMode::Text && !last_text.is_empty() {
301 println!("{}", last_text);
302 use std::io::Write;
303 std::io::stdout().flush().ok();
304 }
305
306 Ok(())
307}
308
309fn extract_text_from_message(msg: &oxi_sdk::Message) -> String {
313 match msg {
314 oxi_sdk::Message::Assistant(asst) => {
315 let text_only: String = asst
316 .content
317 .iter()
318 .filter_map(|b| match b {
319 oxi_sdk::ContentBlock::Text(t) => Some(t.text.as_str()),
320 _ => None,
321 })
322 .collect::<Vec<_>>()
323 .join("");
324 if !text_only.is_empty() {
325 return text_only;
326 }
327 asst.content
329 .iter()
330 .filter_map(|b| match b {
331 oxi_sdk::ContentBlock::Thinking(t) => Some(t.thinking.as_str()),
332 _ => None,
333 })
334 .collect::<Vec<_>>()
335 .join("")
336 }
337 _ => String::new(),
338 }
339}
340
341fn event_to_json(event: &AgentEvent) -> serde_json::Value {
343 match event {
344 AgentEvent::Start { .. } => serde_json::json!({
345 "type": "start"
346 }),
347 AgentEvent::Thinking => serde_json::json!({
348 "type": "thinking"
349 }),
350 AgentEvent::TextChunk { text } => serde_json::json!({
351 "type": "text_delta",
352 "text": text,
353 }),
354 AgentEvent::ToolCall { tool_call } => serde_json::json!({
355 "type": "tool_call",
356 "id": tool_call.id,
357 "name": tool_call.name,
358 "arguments": tool_call.arguments.to_string(),
359 }),
360 AgentEvent::ToolStart {
361 tool_name,
362 tool_call_id,
363 arguments: _,
364 } => serde_json::json!({
365 "type": "tool_start",
366 "tool_name": tool_name,
367 "tool_call_id": tool_call_id,
368 }),
369 AgentEvent::ToolComplete { result } => serde_json::json!({
370 "type": "tool_complete",
371 "content": result.content.chars().take(2000).collect::<String>(),
372 "is_error": result.is_error(),
373 }),
374 AgentEvent::ToolError {
375 error,
376 tool_call_id,
377 } => serde_json::json!({
378 "type": "tool_error",
379 "error": error,
380 "tool_call_id": tool_call_id,
381 }),
382 AgentEvent::Complete { .. } => serde_json::json!({
383 "type": "complete"
384 }),
385 AgentEvent::Error { message, .. } => serde_json::json!({
386 "type": "error",
387 "message": message,
388 }),
389 AgentEvent::Usage {
390 input_tokens,
391 output_tokens,
392 } => serde_json::json!({
393 "type": "usage",
394 "input_tokens": input_tokens,
395 "output_tokens": output_tokens,
396 }),
397
398 AgentEvent::AgentStart { .. } => serde_json::json!({
400 "type": "agent_start"
401 }),
402 AgentEvent::AgentEnd { .. } => serde_json::json!({
403 "type": "agent_end"
404 }),
405 AgentEvent::TurnStart { turn_number } => serde_json::json!({
406 "type": "turn_start",
407 "turn_number": turn_number,
408 }),
409 AgentEvent::TurnEnd { turn_number, .. } => serde_json::json!({
410 "type": "turn_end",
411 "turn_number": turn_number,
412 }),
413 AgentEvent::MessageStart { message } => {
414 let text = extract_text_from_message(message);
415 serde_json::json!({
416 "type": "message_start",
417 "text": text,
418 })
419 }
420 AgentEvent::MessageUpdate { message, delta } => {
421 let text = extract_text_from_message(message);
422 serde_json::json!({
423 "type": "message_update",
424 "text": text,
425 "delta": delta,
426 })
427 }
428 AgentEvent::MessageEnd { message } => {
429 let text = extract_text_from_message(message);
430 serde_json::json!({
431 "type": "message_end",
432 "text": text,
433 })
434 }
435 AgentEvent::ToolExecutionStart {
436 tool_call_id,
437 tool_name,
438 args,
439 ..
440 } => serde_json::json!({
441 "type": "tool_execution_start",
442 "tool_call_id": tool_call_id,
443 "tool_name": tool_name,
444 "args": args.to_string(),
445 }),
446 AgentEvent::ToolExecutionEnd {
447 tool_call_id,
448 tool_name,
449 result,
450 is_error,
451 } => serde_json::json!({
452 "type": "tool_execution_end",
453 "tool_call_id": tool_call_id,
454 "tool_name": tool_name,
455 "result": result.content.chars().take(2000).collect::<String>(),
456 "is_error": is_error,
457 }),
458
459 _ => serde_json::json!({
461 "type": "unknown"
462 }),
463 }
464}
465
466fn ctrlc_handler(shutdown_tx: mpsc::Sender<()>) -> Result<()> {
468 std::thread::spawn(move || {
470 let _ = tokio::runtime::Builder::new_current_thread()
471 .enable_all()
472 .build()
473 .map(|rt| {
474 rt.block_on(async {
475 tokio::signal::ctrl_c().await.ok();
476 let _ = shutdown_tx.try_send(());
477 });
478 });
479 });
480 Ok(())
481}
482
483pub fn read_stdin_prompt() -> Result<String> {
485 use std::io::{self, Read};
486 let mut buffer = String::new();
487 io::stdin().read_to_string(&mut buffer)?;
488 Ok(buffer.trim().to_string())
489}
490
491#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_event_to_json_start() {
499 let event = AgentEvent::Start {
500 prompt: "test".to_string(),
501 };
502 let json = event_to_json(&event);
503 assert_eq!(json["type"], "start");
504 }
505
506 #[test]
507 fn test_event_to_json_thinking() {
508 let json = event_to_json(&AgentEvent::Thinking);
509 assert_eq!(json["type"], "thinking");
510 }
511
512 #[test]
513 fn test_event_to_json_text_chunk() {
514 let event = AgentEvent::TextChunk {
515 text: "Hello world".to_string(),
516 };
517 let json = event_to_json(&event);
518 assert_eq!(json["type"], "text_delta");
519 assert_eq!(json["text"], "Hello world");
520 }
521
522 #[test]
523 fn test_event_to_json_tool_call() {
524 let event = AgentEvent::ToolCall {
525 tool_call: oxi_sdk::ToolCall {
526 content_type: oxi_sdk::ToolCallType::ToolCall,
527 id: "tc-1".to_string(),
528 name: "read_file".to_string(),
529 arguments: serde_json::json!({"path": "/tmp/test.rs"}),
530 thought_signature: None,
531 },
532 };
533 let json = event_to_json(&event);
534 assert_eq!(json["type"], "tool_call");
535 assert_eq!(json["name"], "read_file");
536 assert_eq!(json["id"], "tc-1");
537 }
538
539 #[test]
540 fn test_event_to_json_error() {
541 let event = AgentEvent::Error {
542 message: "Something went wrong".to_string(),
543 session_id: None,
544 };
545 let json = event_to_json(&event);
546 assert_eq!(json["type"], "error");
547 assert_eq!(json["message"], "Something went wrong");
548 }
549
550 #[test]
551 fn test_event_to_json_complete() {
552 let event = AgentEvent::Complete {
553 content: "done".to_string(),
554 stop_reason: "end_turn".to_string(),
555 };
556 let json = event_to_json(&event);
557 assert_eq!(json["type"], "complete");
558 }
559
560 #[test]
561 fn test_event_to_json_tool_complete() {
562 let event = AgentEvent::ToolComplete {
563 result: oxi_sdk::ToolResult {
564 tool_call_id: "tc-1".to_string(),
565 content: "file contents here".to_string(),
566 status: "success".to_string(),
567 },
568 };
569 let json = event_to_json(&event);
570 assert_eq!(json["type"], "tool_complete");
571 assert_eq!(json["is_error"], false);
572 }
573
574 #[test]
575 fn test_print_mode_default_options() {
576 let opts = PrintModeOptions::default();
577 assert_eq!(opts.mode, PrintMode::Text);
578 assert!(opts.messages.is_empty());
579 assert!(opts.initial_message.is_none());
580 assert!(!opts.no_stdin);
581 assert!(!opts.no_session);
582 assert!(!opts.quiet);
583 assert!(opts.timeout.is_none());
584 }
585
586 #[test]
587 fn test_print_mode_equality() {
588 assert_eq!(PrintMode::Text, PrintMode::Text);
589 assert_eq!(PrintMode::Json, PrintMode::Json);
590 assert_ne!(PrintMode::Text, PrintMode::Json);
591 }
592}