Skip to main content

opi_coding_agent/
runner.rs

1//! Non-interactive runner (S10).
2//!
3//! Takes a single prompt, runs it through the agent, captures assistant text
4//! for stdout, diagnostics for stderr, and returns an exit code.
5
6use std::path::PathBuf;
7use std::sync::{Arc, Mutex};
8
9use opi_agent::event::AgentEvent;
10use opi_agent::hooks::{
11    AfterToolCallContext, AfterToolCallResult, AgentHooks, BeforeToolCallContext,
12    BeforeToolCallResult, PrepareNextTurnContext, ShouldStopAfterTurnContext,
13};
14use opi_agent::loop_types::AgentError;
15use opi_agent::message::AgentMessage;
16use opi_agent::session_event::{AgentSessionEvent, SessionCostTotals, SessionTokenTotals};
17use opi_ai::message::Message;
18use opi_ai::provider::Provider;
19use opi_ai::stream::AssistantStreamEvent;
20
21use crate::config::OpiConfig;
22use crate::harness::{CodingHarness, ResumeInfo};
23use crate::policy::is_mutating_tool;
24
25/// NDJSON output schema version.
26pub const NDJSON_SCHEMA_VERSION: u32 = 1;
27
28// ---------------------------------------------------------------------------
29// Exit codes (S10)
30// ---------------------------------------------------------------------------
31
32/// Exit codes for the non-interactive runner.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34#[repr(i32)]
35pub enum ExitCode {
36    Success = 0,
37    RuntimeFailure = 1,
38    ConfigError = 2,
39    AuthFailure = 3,
40    ProviderFailure = 4,
41    ToolFailure = 5,
42    Interrupted = 130,
43}
44
45// ---------------------------------------------------------------------------
46// Result
47// ---------------------------------------------------------------------------
48
49/// Captured output from a non-interactive run.
50#[derive(Debug, Clone)]
51pub struct NonInteractiveResult {
52    pub stdout: String,
53    pub stderr: String,
54    pub exit_code: i32,
55}
56
57// ---------------------------------------------------------------------------
58// Runner
59// ---------------------------------------------------------------------------
60
61/// Non-interactive runner that executes a single prompt and captures output.
62pub struct NonInteractiveRunner {
63    harness: CodingHarness,
64}
65
66impl NonInteractiveRunner {
67    /// Create a new non-interactive runner.
68    pub fn new(
69        provider: Box<dyn Provider>,
70        model: String,
71        config: OpiConfig,
72        workspace_root: PathBuf,
73        allow_mutating: bool,
74        user_system_prompt: Option<String>,
75        initial_messages: Vec<AgentMessage>,
76    ) -> Self {
77        Self::new_with_resume(
78            provider,
79            model,
80            config,
81            workspace_root,
82            allow_mutating,
83            user_system_prompt,
84            initial_messages,
85            None,
86        )
87    }
88
89    /// Create a new non-interactive runner, optionally adopting an existing
90    /// session (resume).
91    #[allow(clippy::too_many_arguments)]
92    pub fn new_with_resume(
93        provider: Box<dyn Provider>,
94        model: String,
95        config: OpiConfig,
96        workspace_root: PathBuf,
97        allow_mutating: bool,
98        user_system_prompt: Option<String>,
99        initial_messages: Vec<AgentMessage>,
100        resume_info: Option<ResumeInfo>,
101    ) -> Self {
102        let hooks = Box::new(NonInteractiveHooks { allow_mutating });
103        let harness = CodingHarness::new_with_hooks_and_resume(
104            provider,
105            model,
106            config,
107            workspace_root,
108            hooks,
109            user_system_prompt,
110            initial_messages,
111            resume_info,
112        );
113        Self { harness }
114    }
115
116    /// Run a single prompt in JSON mode, returning NDJSON output in stdout.
117    pub async fn run_json(&mut self, prompt: &str) -> NonInteractiveResult {
118        let output: Arc<Mutex<String>> = Arc::new(Mutex::new(String::new()));
119
120        // Schema version header line
121        {
122            let header = serde_json::json!({
123                "type": "session_header",
124                "schema_version": NDJSON_SCHEMA_VERSION,
125            });
126            let mut out = output.lock().unwrap();
127            out.push_str(&header.to_string());
128            out.push('\n');
129        }
130
131        let out = output.clone();
132        self.harness.subscribe(Box::new(move |event| {
133            let session_event = match event {
134                AgentEvent::AutoRetryStart {
135                    attempt,
136                    max_attempts,
137                    delay_ms,
138                    error_message,
139                } => AgentSessionEvent::AutoRetryStart {
140                    attempt: *attempt,
141                    max_attempts: *max_attempts,
142                    delay_ms: *delay_ms,
143                    error_message: error_message.clone(),
144                },
145                AgentEvent::AutoRetryEnd {
146                    success,
147                    attempt,
148                    final_error,
149                } => AgentSessionEvent::AutoRetryEnd {
150                    success: *success,
151                    attempt: *attempt,
152                    final_error: final_error.clone(),
153                },
154                AgentEvent::CompactionStart { reason } => {
155                    AgentSessionEvent::CompactionStart { reason: *reason }
156                }
157                AgentEvent::CompactionEnd {
158                    reason,
159                    result,
160                    aborted,
161                    error_message,
162                } => AgentSessionEvent::CompactionEnd {
163                    reason: *reason,
164                    result: result.clone(),
165                    aborted: *aborted,
166                    will_retry: false,
167                    error_message: error_message.clone(),
168                },
169                _ => AgentSessionEvent::Agent {
170                    event: event.clone(),
171                },
172            };
173            if let Ok(json) = serde_json::to_string(&session_event)
174                && let Ok(mut guard) = out.lock()
175            {
176                guard.push_str(&json);
177                guard.push('\n');
178            }
179        }));
180
181        let prompt_result = self.harness.prompt(prompt).await;
182
183        // Emit a final `SessionSummary` event with cumulative token totals
184        // and (when known) cost breakdown. Emitted before the result match so
185        // even error paths surface what the user spent before failing.
186        if let Some(session) = self.harness.session() {
187            let usage = session.usage();
188            let cost = session.cost_summary().map(|c| SessionCostTotals {
189                input: c.input_cost,
190                output: c.output_cost,
191                cache_read: c.cache_read_cost,
192                cache_write: c.cache_write_cost,
193                total: c.total_cost(),
194            });
195            let summary_event = AgentSessionEvent::SessionSummary {
196                session_id: session.session_id().to_owned(),
197                model: session.model().to_owned(),
198                turns: usage.turn_count(),
199                tokens: SessionTokenTotals {
200                    input: usage.total_input_tokens(),
201                    output: usage.total_output_tokens(),
202                    cache_read: usage.total_cache_read_tokens(),
203                    cache_write: usage.total_cache_write_tokens(),
204                },
205                cost_usd: cost,
206            };
207            if let Ok(json) = serde_json::to_string(&summary_event)
208                && let Ok(mut guard) = output.lock()
209            {
210                guard.push_str(&json);
211                guard.push('\n');
212            }
213        }
214
215        match prompt_result {
216            Ok(messages) => {
217                if let Some(error) = find_error_message(&messages) {
218                    return NonInteractiveResult {
219                        stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
220                        stderr: error,
221                        exit_code: ExitCode::ProviderFailure as i32,
222                    };
223                }
224                NonInteractiveResult {
225                    stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
226                    stderr: String::new(),
227                    exit_code: ExitCode::Success as i32,
228                }
229            }
230            Err(AgentError::Cancelled) => NonInteractiveResult {
231                stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
232                stderr: "cancelled".into(),
233                exit_code: ExitCode::Interrupted as i32,
234            },
235            Err(AgentError::AuthFailed(e)) => NonInteractiveResult {
236                stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
237                stderr: format!("authentication error: {e}"),
238                exit_code: ExitCode::AuthFailure as i32,
239            },
240            Err(AgentError::Provider(e)) => NonInteractiveResult {
241                stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
242                stderr: format!("provider error: {e}"),
243                exit_code: ExitCode::ProviderFailure as i32,
244            },
245            Err(AgentError::Tool(e)) => NonInteractiveResult {
246                stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
247                stderr: format!("tool error: {e}"),
248                exit_code: ExitCode::ToolFailure as i32,
249            },
250            Err(AgentError::Hook(e)) => NonInteractiveResult {
251                stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
252                stderr: format!("hook error: {e}"),
253                exit_code: ExitCode::RuntimeFailure as i32,
254            },
255            Err(AgentError::MaxTurnsExceeded(n)) => NonInteractiveResult {
256                stdout: output.lock().map(|g| g.clone()).unwrap_or_default(),
257                stderr: format!("max turns exceeded ({n})"),
258                exit_code: ExitCode::RuntimeFailure as i32,
259            },
260        }
261    }
262
263    /// Cancel the running operation.
264    pub fn cancel(&self) {
265        self.harness.cancel();
266    }
267
268    /// Run a single prompt and return captured output.
269    pub async fn run(&mut self, prompt: &str) -> NonInteractiveResult {
270        // Subscribe to capture text from TextDelta events and persist errors
271        let text_parts: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
272        let persist_errors: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
273        let tp = text_parts.clone();
274        let pe = persist_errors.clone();
275        self.harness.subscribe(Box::new(move |event| match event {
276            AgentEvent::MessageUpdate {
277                assistant_event, ..
278            } => {
279                if let AssistantStreamEvent::TextDelta { delta, .. } = assistant_event.as_ref()
280                    && let Ok(mut guard) = tp.lock()
281                {
282                    guard.push(delta.clone());
283                }
284            }
285            AgentEvent::SessionPersistError { message } => {
286                if let Ok(mut guard) = pe.lock() {
287                    guard.push(message.clone());
288                }
289            }
290            _ => {}
291        }));
292
293        let prompt_result = self.harness.prompt(prompt).await;
294
295        // Format persist errors AFTER prompt returns so events emitted
296        // during the run are captured.
297        let persist_stderr = format_persist_errors(&persist_errors);
298
299        match prompt_result {
300            Ok(messages) => {
301                // Check for provider errors in assistant messages
302                if let Some(error) = find_error_message(&messages) {
303                    let mut stderr = error;
304                    stderr.push_str(&persist_stderr);
305                    return NonInteractiveResult {
306                        stdout: String::new(),
307                        stderr,
308                        exit_code: ExitCode::ProviderFailure as i32,
309                    };
310                }
311
312                let stdout = text_parts.lock().map(|g| g.join("")).unwrap_or_default();
313                NonInteractiveResult {
314                    stdout,
315                    stderr: persist_stderr,
316                    exit_code: ExitCode::Success as i32,
317                }
318            }
319            Err(AgentError::Cancelled) => NonInteractiveResult {
320                stdout: String::new(),
321                stderr: format!("cancelled{persist_stderr}"),
322                exit_code: ExitCode::Interrupted as i32,
323            },
324            Err(AgentError::AuthFailed(e)) => NonInteractiveResult {
325                stdout: String::new(),
326                stderr: format!("authentication error: {e}{persist_stderr}"),
327                exit_code: ExitCode::AuthFailure as i32,
328            },
329            Err(AgentError::Provider(e)) => NonInteractiveResult {
330                stdout: String::new(),
331                stderr: format!("provider error: {e}{persist_stderr}"),
332                exit_code: ExitCode::ProviderFailure as i32,
333            },
334            Err(AgentError::Tool(e)) => NonInteractiveResult {
335                stdout: String::new(),
336                stderr: format!("tool error: {e}{persist_stderr}"),
337                exit_code: ExitCode::ToolFailure as i32,
338            },
339            Err(AgentError::Hook(e)) => NonInteractiveResult {
340                stdout: String::new(),
341                stderr: format!("hook error: {e}{persist_stderr}"),
342                exit_code: ExitCode::RuntimeFailure as i32,
343            },
344            Err(AgentError::MaxTurnsExceeded(n)) => NonInteractiveResult {
345                stdout: String::new(),
346                stderr: format!("max turns exceeded ({n}){persist_stderr}"),
347                exit_code: ExitCode::RuntimeFailure as i32,
348            },
349        }
350    }
351}
352
353// ---------------------------------------------------------------------------
354// Helpers
355// ---------------------------------------------------------------------------
356
357/// Find the first error_message in assistant messages.
358fn find_error_message(messages: &[AgentMessage]) -> Option<String> {
359    for msg in messages {
360        if let AgentMessage::Llm(Message::Assistant(asst)) = msg
361            && let Some(err) = &asst.error_message
362        {
363            return Some(err.clone());
364        }
365    }
366    None
367}
368
369/// Format any captured session persist errors into a stderr suffix.
370pub fn format_persist_errors(errors: &Arc<Mutex<Vec<String>>>) -> String {
371    let guard = errors.lock().unwrap();
372    if guard.is_empty() {
373        return String::new();
374    }
375    let mut out = String::new();
376    for e in guard.iter() {
377        out.push_str("\nsession persist error: ");
378        out.push_str(e);
379    }
380    out
381}
382
383// ---------------------------------------------------------------------------
384// Hooks
385// ---------------------------------------------------------------------------
386
387/// Hooks for non-interactive mode with tool safety policy.
388struct NonInteractiveHooks {
389    allow_mutating: bool,
390}
391
392impl AgentHooks for NonInteractiveHooks {
393    fn convert_to_llm(&self, messages: &[AgentMessage]) -> Result<Vec<Message>, AgentError> {
394        Ok(crate::harness::agent_messages_to_llm(messages))
395    }
396
397    fn before_tool_call(
398        &self,
399        ctx: BeforeToolCallContext,
400    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = BeforeToolCallResult> + Send>> {
401        let allowed = self.allow_mutating;
402        let tool_name = ctx.tool_name.clone();
403        Box::pin(async move {
404            if !allowed && is_mutating_tool(&tool_name) {
405                return BeforeToolCallResult::Deny {
406                    reason: format!(
407                        "tool '{}' is not allowed in non-interactive mode without --allow-mutating",
408                        tool_name
409                    ),
410                };
411            }
412            BeforeToolCallResult::Allow
413        })
414    }
415
416    fn after_tool_call(
417        &self,
418        _ctx: AfterToolCallContext,
419    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = AfterToolCallResult> + Send>> {
420        Box::pin(async { AfterToolCallResult::Keep })
421    }
422
423    fn should_stop_after_turn(
424        &self,
425        _ctx: ShouldStopAfterTurnContext,
426    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = bool> + Send>> {
427        Box::pin(async { false })
428    }
429
430    fn prepare_next_turn(
431        &self,
432        _ctx: PrepareNextTurnContext,
433    ) -> std::pin::Pin<
434        Box<
435            dyn std::future::Future<Output = Option<opi_agent::loop_types::AgentLoopTurnUpdate>>
436                + Send,
437        >,
438    > {
439        Box::pin(async { None })
440    }
441}