1use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8
9use opi_agent::event::AgentEvent;
10use opi_agent::hooks::{
11 AfterToolCallContext, AfterToolCallResult, AgentHooks, BeforeToolCallContext,
12 BeforeToolCallResult, PrepareNextTurnContext, ShouldStopAfterTurnContext,
13};
14use opi_agent::loop_types::AgentError;
15use opi_agent::message::AgentMessage;
16use opi_agent::session_event::{AgentSessionEvent, SessionCostTotals, SessionTokenTotals};
17use opi_ai::message::Message;
18use opi_ai::provider::Provider;
19use opi_ai::stream::AssistantStreamEvent;
20
21use crate::config::OpiConfig;
22use crate::harness::{CodingHarness, ResumeInfo};
23use crate::policy::is_mutating_tool;
24
25pub const NDJSON_SCHEMA_VERSION: u32 = 1;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(i32)]
35pub enum ExitCode {
36 Success = 0,
37 RuntimeFailure = 1,
38 ConfigError = 2,
39 AuthFailure = 3,
40 ProviderFailure = 4,
41 ToolFailure = 5,
42 Interrupted = 130,
43}
44
45#[derive(Debug, Clone)]
51pub struct NonInteractiveResult {
52 pub stdout: String,
53 pub stderr: String,
54 pub exit_code: i32,
55}
56
57pub struct NonInteractiveRunner {
63 harness: CodingHarness,
64}
65
66impl NonInteractiveRunner {
67 pub fn new(
69 provider: Box<dyn Provider>,
70 model: String,
71 config: OpiConfig,
72 workspace_root: PathBuf,
73 allow_mutating: bool,
74 user_system_prompt: Option<String>,
75 initial_messages: Vec<AgentMessage>,
76 ) -> Self {
77 Self::new_with_resume(
78 provider,
79 model,
80 config,
81 workspace_root,
82 allow_mutating,
83 user_system_prompt,
84 initial_messages,
85 None,
86 )
87 }
88
89 #[allow(clippy::too_many_arguments)]
92 pub fn new_with_resume(
93 provider: Box<dyn Provider>,
94 model: String,
95 config: OpiConfig,
96 workspace_root: PathBuf,
97 allow_mutating: bool,
98 user_system_prompt: Option<String>,
99 initial_messages: Vec<AgentMessage>,
100 resume_info: Option<ResumeInfo>,
101 ) -> Self {
102 let hooks = Box::new(NonInteractiveHooks { allow_mutating });
103 let harness = CodingHarness::new_with_hooks_and_resume(
104 provider,
105 model,
106 config,
107 workspace_root,
108 hooks,
109 user_system_prompt,
110 initial_messages,
111 resume_info,
112 );
113 Self { harness }
114 }
115
116 pub async fn run_json(&mut self, prompt: &str) -> NonInteractiveResult {
118 let output: Arc<Mutex<String>> = Arc::new(Mutex::new(String::new()));
119
120 {
122 let header = serde_json::json!({
123 "type": "session_header",
124 "schema_version": NDJSON_SCHEMA_VERSION,
125 });
126 let mut out = output.lock().unwrap();
127 out.push_str(&header.to_string());
128 out.push('\n');
129 }
130
131 let out = output.clone();
132 self.harness.subscribe(Box::new(move |event| {
133 let session_event = match event {
134 AgentEvent::AutoRetryStart {
135 attempt,
136 max_attempts,
137 delay_ms,
138 error_message,
139 } => AgentSessionEvent::AutoRetryStart {
140 attempt: *attempt,
141 max_attempts: *max_attempts,
142 delay_ms: *delay_ms,
143 error_message: error_message.clone(),
144 },
145 AgentEvent::AutoRetryEnd {
146 success,
147 attempt,
148 final_error,
149 } => AgentSessionEvent::AutoRetryEnd {
150 success: *success,
151 attempt: *attempt,
152 final_error: final_error.clone(),
153 },
154 AgentEvent::CompactionStart { reason } => {
155 AgentSessionEvent::CompactionStart { reason: *reason }
156 }
157 AgentEvent::CompactionEnd {
158 reason,
159 result,
160 aborted,
161 error_message,
162 } => AgentSessionEvent::CompactionEnd {
163 reason: *reason,
164 result: result.clone(),
165 aborted: *aborted,
166 will_retry: false,
167 error_message: error_message.clone(),
168 },
169 _ => AgentSessionEvent::Agent {
170 event: event.clone(),
171 },
172 };
173 if let Ok(json) = serde_json::to_string(&session_event)
174 && let Ok(mut guard) = out.lock()
175 {
176 guard.push_str(&json);
177 guard.push('\n');
178 }
179 }));
180
181 let prompt_result = self.harness.prompt(prompt).await;
182
183 if let Some(session) = self.harness.session() {
187 let usage = session.usage();
188 let cost = session.cost_summary().map(|c| SessionCostTotals {
189 input: c.input_cost,
190 output: c.output_cost,
191 cache_read: c.cache_read_cost,
192 cache_write: c.cache_write_cost,
193 total: c.total_cost(),
194 });
195 let summary_event = AgentSessionEvent::SessionSummary {
196 session_id: session.session_id().to_owned(),
197 model: session.model().to_owned(),
198 turns: usage.turn_count(),
199 tokens: SessionTokenTotals {
200 input: usage.total_input_tokens(),
201 output: usage.total_output_tokens(),
202 cache_read: usage.total_cache_read_tokens(),
203 cache_write: usage.total_cache_write_tokens(),
204 },
205 cost_usd: cost,
206 };
207 if let Ok(json) = serde_json::to_string(&summary_event)
208 && let Ok(mut guard) = output.lock()
209 {
210 guard.push_str(&json);
211 guard.push('\n');
212 }
213 }
214
215 match prompt_result {
216 Ok(messages) => {
217 if let Some(error) = find_error_message(&messages) {
218 return NonInteractiveResult {
219 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
220 stderr: error,
221 exit_code: ExitCode::ProviderFailure as i32,
222 };
223 }
224 NonInteractiveResult {
225 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
226 stderr: String::new(),
227 exit_code: ExitCode::Success as i32,
228 }
229 }
230 Err(AgentError::Cancelled) => NonInteractiveResult {
231 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
232 stderr: "cancelled".into(),
233 exit_code: ExitCode::Interrupted as i32,
234 },
235 Err(AgentError::AuthFailed(e)) => NonInteractiveResult {
236 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
237 stderr: format!("authentication error: {e}"),
238 exit_code: ExitCode::AuthFailure as i32,
239 },
240 Err(AgentError::Provider(e)) => NonInteractiveResult {
241 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
242 stderr: format!("provider error: {e}"),
243 exit_code: ExitCode::ProviderFailure as i32,
244 },
245 Err(AgentError::Tool(e)) => NonInteractiveResult {
246 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
247 stderr: format!("tool error: {e}"),
248 exit_code: ExitCode::ToolFailure as i32,
249 },
250 Err(AgentError::Hook(e)) => NonInteractiveResult {
251 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
252 stderr: format!("hook error: {e}"),
253 exit_code: ExitCode::RuntimeFailure as i32,
254 },
255 Err(AgentError::MaxTurnsExceeded(n)) => NonInteractiveResult {
256 stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
257 stderr: format!("max turns exceeded ({n})"),
258 exit_code: ExitCode::RuntimeFailure as i32,
259 },
260 }
261 }
262
263 pub fn cancel(&self) {
265 self.harness.cancel();
266 }
267
268 pub async fn run(&mut self, prompt: &str) -> NonInteractiveResult {
270 let text_parts: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
272 let persist_errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
273 let tp = text_parts.clone();
274 let pe = persist_errors.clone();
275 self.harness.subscribe(Box::new(move |event| match event {
276 AgentEvent::MessageUpdate {
277 assistant_event, ..
278 } => {
279 if let AssistantStreamEvent::TextDelta { delta, .. } = assistant_event.as_ref()
280 && let Ok(mut guard) = tp.lock()
281 {
282 guard.push(delta.clone());
283 }
284 }
285 AgentEvent::SessionPersistError { message } => {
286 if let Ok(mut guard) = pe.lock() {
287 guard.push(message.clone());
288 }
289 }
290 _ => {}
291 }));
292
293 let prompt_result = self.harness.prompt(prompt).await;
294
295 let persist_stderr = format_persist_errors(&persist_errors);
298
299 match prompt_result {
300 Ok(messages) => {
301 if let Some(error) = find_error_message(&messages) {
303 let mut stderr = error;
304 stderr.push_str(&persist_stderr);
305 return NonInteractiveResult {
306 stdout: String::new(),
307 stderr,
308 exit_code: ExitCode::ProviderFailure as i32,
309 };
310 }
311
312 let stdout = text_parts.lock().map(|g| g.join("")).unwrap_or_default();
313 NonInteractiveResult {
314 stdout,
315 stderr: persist_stderr,
316 exit_code: ExitCode::Success as i32,
317 }
318 }
319 Err(AgentError::Cancelled) => NonInteractiveResult {
320 stdout: String::new(),
321 stderr: format!("cancelled{persist_stderr}"),
322 exit_code: ExitCode::Interrupted as i32,
323 },
324 Err(AgentError::AuthFailed(e)) => NonInteractiveResult {
325 stdout: String::new(),
326 stderr: format!("authentication error: {e}{persist_stderr}"),
327 exit_code: ExitCode::AuthFailure as i32,
328 },
329 Err(AgentError::Provider(e)) => NonInteractiveResult {
330 stdout: String::new(),
331 stderr: format!("provider error: {e}{persist_stderr}"),
332 exit_code: ExitCode::ProviderFailure as i32,
333 },
334 Err(AgentError::Tool(e)) => NonInteractiveResult {
335 stdout: String::new(),
336 stderr: format!("tool error: {e}{persist_stderr}"),
337 exit_code: ExitCode::ToolFailure as i32,
338 },
339 Err(AgentError::Hook(e)) => NonInteractiveResult {
340 stdout: String::new(),
341 stderr: format!("hook error: {e}{persist_stderr}"),
342 exit_code: ExitCode::RuntimeFailure as i32,
343 },
344 Err(AgentError::MaxTurnsExceeded(n)) => NonInteractiveResult {
345 stdout: String::new(),
346 stderr: format!("max turns exceeded ({n}){persist_stderr}"),
347 exit_code: ExitCode::RuntimeFailure as i32,
348 },
349 }
350 }
351}
352
353fn find_error_message(messages: &[AgentMessage]) -> Option<String> {
359 for msg in messages {
360 if let AgentMessage::Llm(Message::Assistant(asst)) = msg
361 && let Some(err) = &asst.error_message
362 {
363 return Some(err.clone());
364 }
365 }
366 None
367}
368
369pub fn format_persist_errors(errors: &Arc<Mutex<Vec<String>>>) -> String {
371 let guard = errors.lock().unwrap();
372 if guard.is_empty() {
373 return String::new();
374 }
375 let mut out = String::new();
376 for e in guard.iter() {
377 out.push_str("\nsession persist error: ");
378 out.push_str(e);
379 }
380 out
381}
382
383struct NonInteractiveHooks {
389 allow_mutating: bool,
390}
391
392impl AgentHooks for NonInteractiveHooks {
393 fn convert_to_llm(&self, messages: &[AgentMessage]) -> Result<Vec<Message>, AgentError> {
394 Ok(crate::harness::agent_messages_to_llm(messages))
395 }
396
397 fn before_tool_call(
398 &self,
399 ctx: BeforeToolCallContext,
400 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = BeforeToolCallResult> + Send>> {
401 let allowed = self.allow_mutating;
402 let tool_name = ctx.tool_name.clone();
403 Box::pin(async move {
404 if !allowed && is_mutating_tool(&tool_name) {
405 return BeforeToolCallResult::Deny {
406 reason: format!(
407 "tool '{}' is not allowed in non-interactive mode without --allow-mutating",
408 tool_name
409 ),
410 };
411 }
412 BeforeToolCallResult::Allow
413 })
414 }
415
416 fn after_tool_call(
417 &self,
418 _ctx: AfterToolCallContext,
419 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AfterToolCallResult> + Send>> {
420 Box::pin(async { AfterToolCallResult::Keep })
421 }
422
423 fn should_stop_after_turn(
424 &self,
425 _ctx: ShouldStopAfterTurnContext,
426 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
427 Box::pin(async { false })
428 }
429
430 fn prepare_next_turn(
431 &self,
432 _ctx: PrepareNextTurnContext,
433 ) -> std::pin::Pin<
434 Box<
435 dyn std::future::Future<Output = Option<opi_agent::loop_types::AgentLoopTurnUpdate>>
436 + Send,
437 >,
438 > {
439 Box::pin(async { None })
440 }
441}