Skip to main content

codex_runtime/runtime/api/
prompt_run.rs

1use std::time::Duration;
2
3use crate::plugin::{BlockReason, HookPhase};
4use tokio::sync::broadcast::error::RecvError;
5
6use crate::runtime::core::Runtime;
7use crate::runtime::errors::{RpcError, RuntimeError};
8use crate::runtime::hooks::{PreHookDecision, RuntimeHookConfig};
9use crate::runtime::rpc_contract::{methods, RpcValidationMode};
10use crate::runtime::turn_lifecycle::{
11    collect_turn_terminal_with_limits, interrupt_turn_best_effort_detached, LaggedTurnTerminal,
12    TurnCollectError,
13};
14use crate::runtime::turn_output::{TurnStreamCollector, TurnTerminalEvent};
15
16use super::attachment_validation::validate_prompt_attachments;
17use super::flow::{
18    apply_pre_hook_actions_to_prompt, build_hook_context, extract_assistant_text_from_turn,
19    result_status, HookContextInput, HookExecutionState, PromptMutationState,
20};
21use super::turn_error::{extract_turn_error_signal, PromptTurnErrorSignal};
22use super::wire::{
23    deserialize_result, serialize_params, thread_start_params_from_prompt,
24    turn_start_params_from_prompt,
25};
26use super::*;
27
28#[derive(Clone, Copy)]
29enum PromptRunTarget<'a> {
30    OpenOrResume(Option<&'a str>),
31    Loaded(&'a str),
32}
33
34impl<'a> PromptRunTarget<'a> {
35    fn hook_thread_id(self) -> Option<&'a str> {
36        match self {
37            Self::OpenOrResume(thread_id) => thread_id,
38            Self::Loaded(thread_id) => Some(thread_id),
39        }
40    }
41}
42
43impl Runtime {
44    /// Run one prompt with safe default policies using only cwd + prompt.
45    /// Side effects: same as `run_prompt`. Allocation: params object + two Strings.
46    /// Complexity: O(n), n = input string lengths + streamed turn output size.
47    pub async fn run_prompt_simple(
48        &self,
49        cwd: impl Into<String>,
50        prompt: impl Into<String>,
51    ) -> Result<PromptRunResult, PromptRunError> {
52        self.run_prompt(PromptRunParams::new(cwd, prompt)).await
53    }
54
55    /// Run one prompt end-to-end and return the final assistant text.
56    /// Side effects: sends thread/turn RPC calls and consumes live event stream.
57    /// Allocation: O(n), n = prompt length + attachment count + streamed text.
58    pub async fn run_prompt(&self, p: PromptRunParams) -> Result<PromptRunResult, PromptRunError> {
59        self.run_prompt_with_hooks(p, None).await
60    }
61
62    pub(crate) async fn run_prompt_with_hooks(
63        &self,
64        p: PromptRunParams,
65        scoped_hooks: Option<&RuntimeHookConfig>,
66    ) -> Result<PromptRunResult, PromptRunError> {
67        self.run_prompt_target_with_hooks(None, p, scoped_hooks)
68            .await
69    }
70
71    /// Continue an existing thread with one additional prompt turn.
72    /// Side effects: sends thread/resume + turn/start RPC calls and consumes live event stream.
73    /// Allocation: O(n), n = prompt length + attachment count + streamed text.
74    pub async fn run_prompt_in_thread(
75        &self,
76        thread_id: &str,
77        p: PromptRunParams,
78    ) -> Result<PromptRunResult, PromptRunError> {
79        self.run_prompt_in_thread_with_hooks(thread_id, p, None)
80            .await
81    }
82
83    pub(crate) async fn run_prompt_in_thread_with_hooks(
84        &self,
85        thread_id: &str,
86        p: PromptRunParams,
87        scoped_hooks: Option<&RuntimeHookConfig>,
88    ) -> Result<PromptRunResult, PromptRunError> {
89        self.run_prompt_target_with_hooks(Some(thread_id), p, scoped_hooks)
90            .await
91    }
92
93    pub(crate) async fn run_prompt_on_loaded_thread_with_hooks(
94        &self,
95        thread_id: &str,
96        p: PromptRunParams,
97        scoped_hooks: Option<&RuntimeHookConfig>,
98    ) -> Result<PromptRunResult, PromptRunError> {
99        self.run_prompt_with_hook_scaffold(PromptRunTarget::Loaded(thread_id), p, scoped_hooks)
100            .await
101    }
102
103    async fn run_prompt_target_with_hooks(
104        &self,
105        thread_id: Option<&str>,
106        p: PromptRunParams,
107        scoped_hooks: Option<&RuntimeHookConfig>,
108    ) -> Result<PromptRunResult, PromptRunError> {
109        self.run_prompt_with_hook_scaffold(
110            PromptRunTarget::OpenOrResume(thread_id),
111            p,
112            scoped_hooks,
113        )
114        .await
115    }
116
117    async fn run_prompt_with_hook_scaffold(
118        &self,
119        target: PromptRunTarget<'_>,
120        p: PromptRunParams,
121        scoped_hooks: Option<&RuntimeHookConfig>,
122    ) -> Result<PromptRunResult, PromptRunError> {
123        if !self.hooks_enabled_with(scoped_hooks) {
124            return self
125                .run_prompt_target_entry_dispatch(target, p, None, scoped_hooks)
126                .await;
127        }
128
129        let fallback_thread_id = target.hook_thread_id();
130        let (p, mut hook_state, run_cwd, run_model) = self
131            .prepare_prompt_pre_run_hooks(p, fallback_thread_id, scoped_hooks)
132            .await?;
133        let result = self
134            .run_prompt_target_entry_dispatch(target, p, Some(&mut hook_state), scoped_hooks)
135            .await;
136        self.finalize_prompt_run_hooks(
137            &mut hook_state,
138            run_cwd.as_str(),
139            run_model.as_deref(),
140            fallback_thread_id,
141            &result,
142            scoped_hooks,
143        )
144        .await;
145        self.publish_hook_report(hook_state.report);
146        result
147    }
148
149    async fn run_prompt_target_entry_dispatch(
150        &self,
151        target: PromptRunTarget<'_>,
152        p: PromptRunParams,
153        hook_state: Option<&mut HookExecutionState>,
154        scoped_hooks: Option<&RuntimeHookConfig>,
155    ) -> Result<PromptRunResult, PromptRunError> {
156        match target {
157            PromptRunTarget::OpenOrResume(thread_id) => {
158                self.run_prompt_entry(thread_id, p, hook_state, scoped_hooks)
159                    .await
160            }
161            PromptRunTarget::Loaded(thread_id) => {
162                self.run_prompt_on_loaded_thread_entry(thread_id, p, hook_state, scoped_hooks)
163                    .await
164            }
165        }
166    }
167
168    async fn open_prompt_thread(
169        &self,
170        thread_id: Option<&str>,
171        p: &PromptRunParams,
172        scoped_hooks: Option<&RuntimeHookConfig>,
173    ) -> Result<ThreadHandle, RpcError> {
174        let mut start = thread_start_params_from_prompt(p);
175        if self.has_pre_tool_use_hooks_with(scoped_hooks)
176            && matches!(start.approval_policy, None | Some(ApprovalPolicy::Never))
177        {
178            start.approval_policy = Some(ApprovalPolicy::Untrusted);
179        }
180        match thread_id {
181            Some(existing_thread_id) => self.thread_resume_raw(existing_thread_id, start).await,
182            None => self.thread_start_raw(start).await,
183        }
184    }
185
186    async fn run_prompt_entry(
187        &self,
188        thread_id: Option<&str>,
189        p: PromptRunParams,
190        hook_state: Option<&mut HookExecutionState>,
191        scoped_hooks: Option<&RuntimeHookConfig>,
192    ) -> Result<PromptRunResult, PromptRunError> {
193        validate_prompt_attachments(&p.cwd, &p.attachments).await?;
194        let effort = p.effort.unwrap_or(DEFAULT_REASONING_EFFORT);
195        let thread = self.open_prompt_thread(thread_id, &p, scoped_hooks).await?;
196        self.run_prompt_on_thread(thread, p, effort, hook_state, scoped_hooks)
197            .await
198    }
199
200    async fn run_prompt_on_loaded_thread_entry(
201        &self,
202        thread_id: &str,
203        p: PromptRunParams,
204        hook_state: Option<&mut HookExecutionState>,
205        scoped_hooks: Option<&RuntimeHookConfig>,
206    ) -> Result<PromptRunResult, PromptRunError> {
207        validate_prompt_attachments(&p.cwd, &p.attachments).await?;
208        let effort = p.effort.unwrap_or(DEFAULT_REASONING_EFFORT);
209        let thread = self.loaded_thread_handle(thread_id);
210        self.run_prompt_on_thread(thread, p, effort, hook_state, scoped_hooks)
211            .await
212    }
213
214    async fn prepare_prompt_pre_run_hooks(
215        &self,
216        mut p: PromptRunParams,
217        thread_id: Option<&str>,
218        scoped_hooks: Option<&RuntimeHookConfig>,
219    ) -> Result<(PromptRunParams, HookExecutionState, String, Option<String>), PromptRunError> {
220        let mut hook_state = HookExecutionState::new(self.next_hook_correlation_id());
221        let mut prompt_state = PromptMutationState::from_params(&p, hook_state.metadata.clone());
222        let decisions = self
223            .execute_pre_hook_phase(
224                &mut hook_state,
225                HookPhase::PreRun,
226                Some(p.cwd.as_str()),
227                prompt_state.model.as_deref(),
228                thread_id,
229                None,
230                scoped_hooks,
231            )
232            .await
233            .map_err(PromptRunError::from_block)?;
234        apply_pre_hook_actions_to_prompt(
235            &mut prompt_state,
236            p.cwd.as_str(),
237            HookPhase::PreRun,
238            decisions,
239            &mut hook_state.report,
240        )
241        .await;
242        hook_state.metadata = prompt_state.metadata.clone();
243        p.prompt = prompt_state.prompt;
244        p.model = prompt_state.model;
245        p.attachments = prompt_state.attachments;
246        let run_cwd = p.cwd.clone();
247        let run_model = p.model.clone();
248        Ok((p, hook_state, run_cwd, run_model))
249    }
250
251    async fn finalize_prompt_run_hooks(
252        &self,
253        hook_state: &mut HookExecutionState,
254        run_cwd: &str,
255        run_model: Option<&str>,
256        fallback_thread_id: Option<&str>,
257        result: &Result<PromptRunResult, PromptRunError>,
258        scoped_hooks: Option<&RuntimeHookConfig>,
259    ) {
260        let post_thread_id = result
261            .as_ref()
262            .ok()
263            .map(|value| value.thread_id.as_str())
264            .or(fallback_thread_id);
265        self.execute_post_hook_phase(
266            hook_state,
267            HookContextInput {
268                phase: HookPhase::PostRun,
269                cwd: Some(run_cwd),
270                model: run_model,
271                thread_id: post_thread_id,
272                turn_id: None,
273                main_status: Some(result_status(result)),
274            },
275            scoped_hooks,
276        )
277        .await;
278    }
279
280    async fn run_prompt_on_thread(
281        &self,
282        thread: ThreadHandle,
283        p: PromptRunParams,
284        effort: ReasoningEffort,
285        hook_state: Option<&mut HookExecutionState>,
286        scoped_hooks: Option<&RuntimeHookConfig>,
287    ) -> Result<PromptRunResult, PromptRunError> {
288        let mut hook_state = hook_state;
289        let mut p = p;
290        if let Some(state) = hook_state.as_deref_mut() {
291            let mut prompt_state = PromptMutationState::from_params(&p, state.metadata.clone());
292            let decisions = self
293                .execute_pre_hook_phase(
294                    state,
295                    HookPhase::PreTurn,
296                    Some(p.cwd.as_str()),
297                    prompt_state.model.as_deref(),
298                    Some(thread.thread_id.as_str()),
299                    None,
300                    scoped_hooks,
301                )
302                .await
303                .map_err(PromptRunError::from_block)?;
304            apply_pre_hook_actions_to_prompt(
305                &mut prompt_state,
306                p.cwd.as_str(),
307                HookPhase::PreTurn,
308                decisions,
309                &mut state.report,
310            )
311            .await;
312            state.metadata = prompt_state.metadata;
313            p.prompt = prompt_state.prompt;
314            p.model = prompt_state.model;
315            p.attachments = prompt_state.attachments;
316        }
317
318        self.register_thread_scoped_pre_tool_use_hooks(&thread.thread_id, scoped_hooks);
319        let live_rx = self.subscribe_live();
320        let mut post_turn_id: Option<String> = None;
321        let run_result = match thread
322            .turn_start(turn_start_params_from_prompt(&p, effort))
323            .await
324            .map_err(PromptRunError::Rpc)
325        {
326            Ok(turn) => {
327                post_turn_id = Some(turn.turn_id.clone());
328                self.collect_prompt_turn_assistant_text(live_rx, &thread, &turn.turn_id, p.timeout)
329                    .await
330                    .map(|assistant_text| PromptRunResult {
331                        thread_id: thread.thread_id.clone(),
332                        turn_id: turn.turn_id,
333                        assistant_text,
334                    })
335            }
336            Err(err) => Err(err),
337        };
338
339        if let Some(state) = hook_state {
340            self.execute_post_hook_phase(
341                state,
342                HookContextInput {
343                    phase: HookPhase::PostTurn,
344                    cwd: Some(p.cwd.as_str()),
345                    model: p.model.as_deref(),
346                    thread_id: Some(thread.thread_id.as_str()),
347                    turn_id: post_turn_id.as_deref(),
348                    main_status: Some(result_status(&run_result)),
349                },
350                scoped_hooks,
351            )
352            .await;
353        }
354
355        self.clear_thread_scoped_pre_tool_use_hooks(&thread.thread_id);
356        run_result
357    }
358
359    async fn collect_prompt_turn_assistant_text(
360        &self,
361        mut live_rx: tokio::sync::broadcast::Receiver<crate::runtime::events::Envelope>,
362        thread: &ThreadHandle,
363        turn_id: &str,
364        timeout_duration: Duration,
365    ) -> Result<String, PromptRunError> {
366        const INTERRUPT_RPC_TIMEOUT: Duration = Duration::from_millis(500);
367
368        let mut stream = TurnStreamCollector::new(&thread.thread_id, turn_id);
369        let mut last_turn_error: Option<PromptTurnErrorSignal> = None;
370        let collected = collect_turn_terminal_with_limits(
371            &mut live_rx,
372            &mut stream,
373            usize::MAX,
374            timeout_duration,
375            |envelope| {
376                if let Some(err) = extract_turn_error_signal(envelope) {
377                    last_turn_error = Some(err);
378                }
379                Ok::<(), RpcError>(())
380            },
381            |lag_probe_budget| async move {
382                self.read_turn_terminal_after_lag(&thread.thread_id, turn_id, lag_probe_budget)
383                    .await
384            },
385        )
386        .await;
387
388        let (terminal, lagged_terminal) = match collected {
389            Ok(result) => result,
390            Err(TurnCollectError::Timeout) => {
391                interrupt_turn_best_effort_detached(
392                    thread.runtime().clone(),
393                    thread.thread_id.clone(),
394                    turn_id.to_owned(),
395                    INTERRUPT_RPC_TIMEOUT,
396                );
397                return Err(PromptRunError::Timeout(timeout_duration));
398            }
399            Err(TurnCollectError::StreamClosed) => {
400                return Err(PromptRunError::Runtime(RuntimeError::Internal(format!(
401                    "live stream closed: {}",
402                    RecvError::Closed
403                ))));
404            }
405            Err(TurnCollectError::EventBudgetExceeded) => {
406                return Err(PromptRunError::Runtime(RuntimeError::Internal(
407                    "turn event budget exhausted while collecting assistant output".to_owned(),
408                )));
409            }
410            Err(TurnCollectError::TargetEnvelope(err)) => return Err(PromptRunError::Rpc(err)),
411            Err(TurnCollectError::LagProbe(RpcError::Timeout)) => {
412                interrupt_turn_best_effort_detached(
413                    thread.runtime().clone(),
414                    thread.thread_id.clone(),
415                    turn_id.to_owned(),
416                    INTERRUPT_RPC_TIMEOUT,
417                );
418                return Err(PromptRunError::Timeout(timeout_duration));
419            }
420            Err(TurnCollectError::LagProbe(err)) => return Err(PromptRunError::Rpc(err)),
421        };
422
423        let lagged_completed_text = match lagged_terminal.as_ref() {
424            Some(LaggedTurnTerminal::Completed { assistant_text }) => assistant_text.clone(),
425            _ => None,
426        };
427
428        match terminal {
429            TurnTerminalEvent::Completed => Self::finalize_prompt_turn_assistant_text(
430                stream.into_assistant_text(),
431                lagged_completed_text,
432                last_turn_error,
433            ),
434            TurnTerminalEvent::Failed => {
435                if let Some(err) = last_turn_error {
436                    Err(PromptRunError::TurnFailedWithContext(
437                        err.into_failure(PromptTurnTerminalState::Failed),
438                    ))
439                } else if let Some(LaggedTurnTerminal::Failed { message }) =
440                    lagged_terminal.as_ref()
441                {
442                    if let Some(message) = message.clone() {
443                        Err(PromptRunError::TurnFailedWithContext(PromptTurnFailure {
444                            terminal_state: PromptTurnTerminalState::Failed,
445                            source_method: "thread/read".to_owned(),
446                            code: None,
447                            message,
448                        }))
449                    } else {
450                        Err(PromptRunError::TurnFailed)
451                    }
452                } else {
453                    Err(PromptRunError::TurnFailed)
454                }
455            }
456            TurnTerminalEvent::Interrupted | TurnTerminalEvent::Cancelled => {
457                Err(PromptRunError::TurnInterrupted)
458            }
459        }
460    }
461
462    fn finalize_prompt_turn_assistant_text(
463        collected_assistant_text: String,
464        lagged_completed_text: Option<String>,
465        last_turn_error: Option<PromptTurnErrorSignal>,
466    ) -> Result<String, PromptRunError> {
467        let assistant_text = if let Some(snapshot_text) = lagged_completed_text {
468            if snapshot_text.trim().is_empty() {
469                collected_assistant_text
470            } else {
471                snapshot_text
472            }
473        } else {
474            collected_assistant_text
475        };
476        let assistant_text = assistant_text.trim().to_owned();
477        if assistant_text.is_empty() {
478            if let Some(err) = last_turn_error {
479                Err(PromptRunError::TurnCompletedWithoutAssistantText(
480                    err.into_failure(PromptTurnTerminalState::CompletedWithoutAssistantText),
481                ))
482            } else {
483                Err(PromptRunError::EmptyAssistantText)
484            }
485        } else {
486            Ok(assistant_text)
487        }
488    }
489
490    async fn read_turn_terminal_after_lag(
491        &self,
492        thread_id: &str,
493        turn_id: &str,
494        timeout_duration: Duration,
495    ) -> Result<Option<LaggedTurnTerminal>, RpcError> {
496        let params = serialize_params(
497            methods::THREAD_READ,
498            &ThreadReadParams {
499                thread_id: thread_id.to_owned(),
500                include_turns: Some(true),
501            },
502        )?;
503        let response = self
504            .call_validated_with_mode_and_timeout(
505                methods::THREAD_READ,
506                params,
507                RpcValidationMode::KnownMethods,
508                timeout_duration,
509            )
510            .await?;
511        let response: ThreadReadResponse = deserialize_result(methods::THREAD_READ, response)?;
512
513        let Some(turn) = response.thread.turns.iter().find(|turn| turn.id == turn_id) else {
514            return Ok(None);
515        };
516
517        let terminal = match turn.status {
518            ThreadTurnStatus::Completed => Some(LaggedTurnTerminal::Completed {
519                assistant_text: extract_assistant_text_from_turn(turn),
520            }),
521            ThreadTurnStatus::Failed => Some(LaggedTurnTerminal::Failed {
522                message: turn.error.as_ref().map(|error| error.message.clone()),
523            }),
524            ThreadTurnStatus::Interrupted => Some(LaggedTurnTerminal::Interrupted),
525            ThreadTurnStatus::InProgress => None,
526        };
527        Ok(terminal)
528    }
529
530    #[allow(clippy::too_many_arguments)]
531    pub(super) async fn execute_pre_hook_phase(
532        &self,
533        hook_state: &mut HookExecutionState,
534        phase: HookPhase,
535        cwd: Option<&str>,
536        model: Option<&str>,
537        thread_id: Option<&str>,
538        turn_id: Option<&str>,
539        scoped_hooks: Option<&RuntimeHookConfig>,
540    ) -> Result<Vec<PreHookDecision>, BlockReason> {
541        let ctx = build_hook_context(
542            hook_state.correlation_id.as_str(),
543            &hook_state.metadata,
544            HookContextInput {
545                phase,
546                cwd,
547                model,
548                thread_id,
549                turn_id,
550                main_status: None,
551            },
552        );
553        self.run_pre_hooks_with(&ctx, &mut hook_state.report, scoped_hooks)
554            .await
555    }
556
557    pub(super) async fn execute_post_hook_phase(
558        &self,
559        hook_state: &mut HookExecutionState,
560        input: HookContextInput<'_>,
561        scoped_hooks: Option<&RuntimeHookConfig>,
562    ) {
563        let ctx = build_hook_context(
564            hook_state.correlation_id.as_str(),
565            &hook_state.metadata,
566            input,
567        );
568        self.run_post_hooks_with(&ctx, &mut hook_state.report, scoped_hooks)
569            .await;
570    }
571}