Skip to main content

codex_runtime/runtime/api/
prompt_run.rs

1use std::time::Duration;
2
3use crate::plugin::{BlockReason, HookPhase};
4use tokio::sync::broadcast::error::RecvError;
5use tokio::time::{timeout, Instant};
6
7use crate::runtime::core::Runtime;
8use crate::runtime::detached_task::{current_detached_task_plan, spawn_detached_task};
9use crate::runtime::errors::{RpcError, RuntimeError};
10use crate::runtime::events::{
11    extract_agent_message_delta, extract_turn_cancelled, extract_turn_completed,
12    extract_turn_failed, extract_turn_interrupted, Envelope,
13};
14use crate::runtime::hooks::{PreHookDecision, RuntimeHookConfig};
15use crate::runtime::rpc_contract::{methods, RpcValidationMode};
16use crate::runtime::turn_lifecycle::{
17    collect_turn_terminal_with_limits, interrupt_turn_best_effort_detached,
18    interrupt_turn_best_effort_with_timeout, LaggedTurnTerminal, TurnCollectError,
19};
20use crate::runtime::turn_output::{TurnStreamCollector, TurnTerminalEvent};
21
22use super::attachment_validation::validate_prompt_attachments;
23use super::flow::{
24    apply_pre_hook_actions_to_prompt, build_hook_context, extract_assistant_text_from_turn,
25    result_status, HookContextInput, HookExecutionState, PromptMutationState,
26};
27use super::models::{PromptRunStreamState, PromptStreamCleanupState};
28use super::turn_error::{extract_turn_error_signal, PromptTurnErrorSignal};
29use super::wire::{
30    deserialize_result, serialize_params, thread_start_params_from_prompt,
31    turn_start_params_from_prompt,
32};
33use super::*;
34
35const INTERRUPT_RPC_TIMEOUT: Duration = Duration::from_millis(500);
36
37#[derive(Clone, Copy)]
38enum PromptRunTarget<'a> {
39    OpenOrResume(Option<&'a str>),
40    Loaded(&'a str),
41}
42
43impl<'a> PromptRunTarget<'a> {
44    fn hook_thread_id(self) -> Option<&'a str> {
45        match self {
46            Self::OpenOrResume(thread_id) => thread_id,
47            Self::Loaded(thread_id) => Some(thread_id),
48        }
49    }
50}
51
52impl Runtime {
53    /// Run one prompt with safe default policies using only cwd + prompt.
54    /// Side effects: same as `run_prompt`. Allocation: params object + two Strings.
55    /// Complexity: O(n), n = input string lengths + streamed turn output size.
56    pub async fn run_prompt_simple(
57        &self,
58        cwd: impl Into<String>,
59        prompt: impl Into<String>,
60    ) -> Result<PromptRunResult, PromptRunError> {
61        self.run_prompt(PromptRunParams::new(cwd, prompt)).await
62    }
63
64    /// Run one prompt end-to-end and return the final assistant text.
65    /// Side effects: sends thread/turn RPC calls and consumes live event stream.
66    /// Allocation: O(n), n = prompt length + attachment count + streamed text.
67    pub async fn run_prompt(&self, p: PromptRunParams) -> Result<PromptRunResult, PromptRunError> {
68        self.run_prompt_with_hooks(p, None).await
69    }
70
71    pub(crate) async fn run_prompt_with_hooks(
72        &self,
73        p: PromptRunParams,
74        scoped_hooks: Option<&RuntimeHookConfig>,
75    ) -> Result<PromptRunResult, PromptRunError> {
76        self.run_prompt_target_with_hooks(None, p, scoped_hooks)
77            .await
78    }
79
80    /// Continue an existing thread with one additional prompt turn.
81    /// Side effects: sends thread/resume + turn/start RPC calls and consumes live event stream.
82    /// Allocation: O(n), n = prompt length + attachment count + streamed text.
83    pub async fn run_prompt_in_thread(
84        &self,
85        thread_id: &str,
86        p: PromptRunParams,
87    ) -> Result<PromptRunResult, PromptRunError> {
88        self.run_prompt_in_thread_with_hooks(thread_id, p, None)
89            .await
90    }
91
92    pub(crate) async fn run_prompt_in_thread_with_hooks(
93        &self,
94        thread_id: &str,
95        p: PromptRunParams,
96        scoped_hooks: Option<&RuntimeHookConfig>,
97    ) -> Result<PromptRunResult, PromptRunError> {
98        self.run_prompt_target_with_hooks(Some(thread_id), p, scoped_hooks)
99            .await
100    }
101
102    pub(crate) async fn run_prompt_on_loaded_thread_with_hooks(
103        &self,
104        thread_id: &str,
105        p: PromptRunParams,
106        scoped_hooks: Option<&RuntimeHookConfig>,
107    ) -> Result<PromptRunResult, PromptRunError> {
108        self.run_prompt_with_hook_scaffold(PromptRunTarget::Loaded(thread_id), p, scoped_hooks)
109            .await
110    }
111
112    pub(crate) async fn run_prompt_on_loaded_thread_stream_with_hooks(
113        &self,
114        thread_id: &str,
115        p: PromptRunParams,
116        scoped_hooks: Option<&RuntimeHookConfig>,
117    ) -> Result<PromptRunStream, PromptRunError> {
118        validate_prompt_attachments(&p.cwd, &p.attachments).await?;
119        let effort = p.effort.unwrap_or(DEFAULT_REASONING_EFFORT);
120        let thread = self.loaded_thread_handle(thread_id);
121        self.run_prompt_on_thread_stream(thread, p, effort, scoped_hooks)
122            .await
123    }
124
125    async fn run_prompt_target_with_hooks(
126        &self,
127        thread_id: Option<&str>,
128        p: PromptRunParams,
129        scoped_hooks: Option<&RuntimeHookConfig>,
130    ) -> Result<PromptRunResult, PromptRunError> {
131        self.run_prompt_with_hook_scaffold(
132            PromptRunTarget::OpenOrResume(thread_id),
133            p,
134            scoped_hooks,
135        )
136        .await
137    }
138
139    async fn run_prompt_with_hook_scaffold(
140        &self,
141        target: PromptRunTarget<'_>,
142        p: PromptRunParams,
143        scoped_hooks: Option<&RuntimeHookConfig>,
144    ) -> Result<PromptRunResult, PromptRunError> {
145        if !self.hooks_enabled_with(scoped_hooks) {
146            return self
147                .run_prompt_target_entry_dispatch(target, p, None, scoped_hooks)
148                .await;
149        }
150
151        let fallback_thread_id = target.hook_thread_id();
152        let (p, mut hook_state, run_cwd, run_model) = self
153            .prepare_prompt_pre_run_hooks(p, fallback_thread_id, scoped_hooks)
154            .await?;
155        let result = self
156            .run_prompt_target_entry_dispatch(target, p, Some(&mut hook_state), scoped_hooks)
157            .await;
158        self.finalize_prompt_run_hooks(
159            &mut hook_state,
160            run_cwd.as_str(),
161            run_model.as_deref(),
162            fallback_thread_id,
163            &result,
164            scoped_hooks,
165        )
166        .await;
167        self.publish_hook_report(hook_state.report);
168        result
169    }
170
171    async fn run_prompt_target_entry_dispatch(
172        &self,
173        target: PromptRunTarget<'_>,
174        p: PromptRunParams,
175        hook_state: Option<&mut HookExecutionState>,
176        scoped_hooks: Option<&RuntimeHookConfig>,
177    ) -> Result<PromptRunResult, PromptRunError> {
178        match target {
179            PromptRunTarget::OpenOrResume(thread_id) => {
180                self.run_prompt_entry(thread_id, p, hook_state, scoped_hooks)
181                    .await
182            }
183            PromptRunTarget::Loaded(thread_id) => {
184                self.run_prompt_on_loaded_thread_entry(thread_id, p, hook_state, scoped_hooks)
185                    .await
186            }
187        }
188    }
189
190    async fn open_prompt_thread(
191        &self,
192        thread_id: Option<&str>,
193        p: &PromptRunParams,
194        scoped_hooks: Option<&RuntimeHookConfig>,
195    ) -> Result<ThreadHandle, RpcError> {
196        let mut start = thread_start_params_from_prompt(p);
197        if self.has_pre_tool_use_hooks_with(scoped_hooks)
198            && matches!(start.approval_policy, None | Some(ApprovalPolicy::Never))
199        {
200            start.approval_policy = Some(ApprovalPolicy::Untrusted);
201        }
202        match thread_id {
203            Some(existing_thread_id) => self.thread_resume_raw(existing_thread_id, start).await,
204            None => self.thread_start_raw(start).await,
205        }
206    }
207
208    async fn run_prompt_entry(
209        &self,
210        thread_id: Option<&str>,
211        p: PromptRunParams,
212        hook_state: Option<&mut HookExecutionState>,
213        scoped_hooks: Option<&RuntimeHookConfig>,
214    ) -> Result<PromptRunResult, PromptRunError> {
215        validate_prompt_attachments(&p.cwd, &p.attachments).await?;
216        let effort = p.effort.unwrap_or(DEFAULT_REASONING_EFFORT);
217        let thread = self.open_prompt_thread(thread_id, &p, scoped_hooks).await?;
218        self.run_prompt_on_thread(thread, p, effort, hook_state, scoped_hooks)
219            .await
220    }
221
222    async fn run_prompt_on_loaded_thread_entry(
223        &self,
224        thread_id: &str,
225        p: PromptRunParams,
226        hook_state: Option<&mut HookExecutionState>,
227        scoped_hooks: Option<&RuntimeHookConfig>,
228    ) -> Result<PromptRunResult, PromptRunError> {
229        validate_prompt_attachments(&p.cwd, &p.attachments).await?;
230        let effort = p.effort.unwrap_or(DEFAULT_REASONING_EFFORT);
231        let thread = self.loaded_thread_handle(thread_id);
232        self.run_prompt_on_thread(thread, p, effort, hook_state, scoped_hooks)
233            .await
234    }
235
236    async fn prepare_prompt_pre_run_hooks(
237        &self,
238        mut p: PromptRunParams,
239        thread_id: Option<&str>,
240        scoped_hooks: Option<&RuntimeHookConfig>,
241    ) -> Result<(PromptRunParams, HookExecutionState, String, Option<String>), PromptRunError> {
242        let mut hook_state = HookExecutionState::new(self.next_hook_correlation_id());
243        let mut prompt_state = PromptMutationState::from_params(&p, hook_state.metadata.clone());
244        let decisions = self
245            .execute_pre_hook_phase(
246                &mut hook_state,
247                HookPhase::PreRun,
248                Some(p.cwd.as_str()),
249                prompt_state.model.as_deref(),
250                thread_id,
251                None,
252                scoped_hooks,
253            )
254            .await
255            .map_err(PromptRunError::from_block)?;
256        apply_pre_hook_actions_to_prompt(
257            &mut prompt_state,
258            p.cwd.as_str(),
259            HookPhase::PreRun,
260            decisions,
261            &mut hook_state.report,
262        )
263        .await;
264        hook_state.metadata = prompt_state.metadata.clone();
265        p.prompt = prompt_state.prompt;
266        p.model = prompt_state.model;
267        p.attachments = prompt_state.attachments;
268        let run_cwd = p.cwd.clone();
269        let run_model = p.model.clone();
270        Ok((p, hook_state, run_cwd, run_model))
271    }
272
273    async fn finalize_prompt_run_hooks(
274        &self,
275        hook_state: &mut HookExecutionState,
276        run_cwd: &str,
277        run_model: Option<&str>,
278        fallback_thread_id: Option<&str>,
279        result: &Result<PromptRunResult, PromptRunError>,
280        scoped_hooks: Option<&RuntimeHookConfig>,
281    ) {
282        let post_thread_id = result
283            .as_ref()
284            .ok()
285            .map(|value| value.thread_id.as_str())
286            .or(fallback_thread_id);
287        self.execute_post_hook_phase(
288            hook_state,
289            HookContextInput {
290                phase: HookPhase::PostRun,
291                cwd: Some(run_cwd),
292                model: run_model,
293                thread_id: post_thread_id,
294                turn_id: None,
295                main_status: Some(result_status(result)),
296            },
297            scoped_hooks,
298        )
299        .await;
300    }
301
302    async fn prepare_prompt_pre_turn_hooks(
303        &self,
304        thread_id: &str,
305        mut p: PromptRunParams,
306        hook_state: Option<&mut HookExecutionState>,
307        scoped_hooks: Option<&RuntimeHookConfig>,
308    ) -> Result<PromptRunParams, PromptRunError> {
309        let Some(state) = hook_state else {
310            return Ok(p);
311        };
312
313        let mut prompt_state = PromptMutationState::from_params(&p, state.metadata.clone());
314        let decisions = self
315            .execute_pre_hook_phase(
316                state,
317                HookPhase::PreTurn,
318                Some(p.cwd.as_str()),
319                prompt_state.model.as_deref(),
320                Some(thread_id),
321                None,
322                scoped_hooks,
323            )
324            .await
325            .map_err(PromptRunError::from_block)?;
326        apply_pre_hook_actions_to_prompt(
327            &mut prompt_state,
328            p.cwd.as_str(),
329            HookPhase::PreTurn,
330            decisions,
331            &mut state.report,
332        )
333        .await;
334        state.metadata = prompt_state.metadata;
335        p.prompt = prompt_state.prompt;
336        p.model = prompt_state.model;
337        p.attachments = prompt_state.attachments;
338        Ok(p)
339    }
340
341    async fn run_prompt_on_thread(
342        &self,
343        thread: ThreadHandle,
344        p: PromptRunParams,
345        effort: ReasoningEffort,
346        hook_state: Option<&mut HookExecutionState>,
347        scoped_hooks: Option<&RuntimeHookConfig>,
348    ) -> Result<PromptRunResult, PromptRunError> {
349        let mut hook_state = hook_state;
350        let p = self
351            .prepare_prompt_pre_turn_hooks(
352                thread.thread_id.as_str(),
353                p,
354                hook_state.as_deref_mut(),
355                scoped_hooks,
356            )
357            .await?;
358
359        self.register_thread_scoped_pre_tool_use_hooks(&thread.thread_id, scoped_hooks);
360        let live_rx = self.subscribe_live();
361        let mut post_turn_id: Option<String> = None;
362        let run_result = match thread
363            .turn_start(turn_start_params_from_prompt(&p, effort))
364            .await
365            .map_err(PromptRunError::Rpc)
366        {
367            Ok(turn) => {
368                post_turn_id = Some(turn.turn_id.clone());
369                self.collect_prompt_turn_assistant_text(live_rx, &thread, &turn.turn_id, p.timeout)
370                    .await
371                    .map(|assistant_text| PromptRunResult {
372                        thread_id: thread.thread_id.clone(),
373                        turn_id: turn.turn_id,
374                        assistant_text,
375                    })
376            }
377            Err(err) => Err(err),
378        };
379
380        if let Some(state) = hook_state {
381            self.execute_post_hook_phase(
382                state,
383                HookContextInput {
384                    phase: HookPhase::PostTurn,
385                    cwd: Some(p.cwd.as_str()),
386                    model: p.model.as_deref(),
387                    thread_id: Some(thread.thread_id.as_str()),
388                    turn_id: post_turn_id.as_deref(),
389                    main_status: Some(result_status(&run_result)),
390                },
391                scoped_hooks,
392            )
393            .await;
394        }
395
396        self.clear_thread_scoped_pre_tool_use_hooks(&thread.thread_id);
397        run_result
398    }
399
400    async fn run_prompt_on_thread_stream(
401        &self,
402        thread: ThreadHandle,
403        p: PromptRunParams,
404        effort: ReasoningEffort,
405        scoped_hooks: Option<&RuntimeHookConfig>,
406    ) -> Result<PromptRunStream, PromptRunError> {
407        let mut hook_state = if self.hooks_enabled_with(scoped_hooks) {
408            Some(HookExecutionState::new(self.next_hook_correlation_id()))
409        } else {
410            None
411        };
412        let p = self
413            .prepare_prompt_pre_turn_hooks(
414                thread.thread_id.as_str(),
415                p,
416                hook_state.as_mut(),
417                scoped_hooks,
418            )
419            .await?;
420
421        self.register_thread_scoped_pre_tool_use_hooks(&thread.thread_id, scoped_hooks);
422        let live_rx = self.subscribe_live();
423        let timeout_duration = p.timeout;
424        let run_cwd = p.cwd.clone();
425        let run_model = p.model.clone();
426
427        let turn = match thread
428            .turn_start(turn_start_params_from_prompt(&p, effort))
429            .await
430            .map_err(PromptRunError::Rpc)
431        {
432            Ok(turn) => turn,
433            Err(err) => {
434                if let Some(state) = hook_state.as_mut() {
435                    self.execute_post_hook_phase(
436                        state,
437                        HookContextInput {
438                            phase: HookPhase::PostTurn,
439                            cwd: Some(run_cwd.as_str()),
440                            model: run_model.as_deref(),
441                            thread_id: Some(thread.thread_id.as_str()),
442                            turn_id: None,
443                            main_status: Some("error"),
444                        },
445                        scoped_hooks,
446                    )
447                    .await;
448                    self.publish_hook_report(state.report.clone());
449                }
450                self.clear_thread_scoped_pre_tool_use_hooks(&thread.thread_id);
451                return Err(err);
452            }
453        };
454        let cleanup = PromptStreamCleanupState {
455            run_cwd,
456            run_model,
457            scoped_hooks: scoped_hooks.cloned(),
458            hook_state,
459            cleaned_up: false,
460        };
461
462        Ok(PromptRunStream {
463            runtime: self.clone(),
464            thread_id: thread.thread_id.clone(),
465            turn_id: turn.turn_id.clone(),
466            live_rx,
467            stream: TurnStreamCollector::new(&thread.thread_id, &turn.turn_id),
468            state: PromptRunStreamState {
469                last_turn_error: None,
470                lagged_terminal: None,
471                final_result: None,
472            },
473            deadline: Instant::now() + timeout_duration,
474            timeout: timeout_duration,
475            cleanup,
476        })
477    }
478
479    async fn collect_prompt_turn_assistant_text(
480        &self,
481        mut live_rx: tokio::sync::broadcast::Receiver<crate::runtime::events::Envelope>,
482        thread: &ThreadHandle,
483        turn_id: &str,
484        timeout_duration: Duration,
485    ) -> Result<String, PromptRunError> {
486        let mut stream = TurnStreamCollector::new(&thread.thread_id, turn_id);
487        let mut last_turn_error: Option<PromptTurnErrorSignal> = None;
488        let collected = collect_turn_terminal_with_limits(
489            &mut live_rx,
490            &mut stream,
491            usize::MAX,
492            timeout_duration,
493            |envelope| {
494                if let Some(err) = extract_turn_error_signal(envelope) {
495                    last_turn_error = Some(err);
496                }
497                Ok::<(), RpcError>(())
498            },
499            |lag_probe_budget| async move {
500                self.read_turn_terminal_after_lag(&thread.thread_id, turn_id, lag_probe_budget)
501                    .await
502            },
503        )
504        .await;
505
506        let (terminal, lagged_terminal) = match collected {
507            Ok(result) => result,
508            Err(TurnCollectError::Timeout) => {
509                interrupt_turn_best_effort_detached(
510                    thread.runtime().clone(),
511                    thread.thread_id.clone(),
512                    turn_id.to_owned(),
513                    INTERRUPT_RPC_TIMEOUT,
514                );
515                return Err(PromptRunError::Timeout(timeout_duration));
516            }
517            Err(TurnCollectError::StreamClosed) => {
518                return Err(PromptRunError::Runtime(RuntimeError::Internal(format!(
519                    "live stream closed: {}",
520                    RecvError::Closed
521                ))));
522            }
523            Err(TurnCollectError::EventBudgetExceeded) => {
524                return Err(PromptRunError::Runtime(RuntimeError::Internal(
525                    "turn event budget exhausted while collecting assistant output".to_owned(),
526                )));
527            }
528            Err(TurnCollectError::TargetEnvelope(err)) => return Err(PromptRunError::Rpc(err)),
529            Err(TurnCollectError::LagProbe(RpcError::Timeout)) => {
530                interrupt_turn_best_effort_detached(
531                    thread.runtime().clone(),
532                    thread.thread_id.clone(),
533                    turn_id.to_owned(),
534                    INTERRUPT_RPC_TIMEOUT,
535                );
536                return Err(PromptRunError::Timeout(timeout_duration));
537            }
538            Err(TurnCollectError::LagProbe(err)) => return Err(PromptRunError::Rpc(err)),
539        };
540
541        Self::resolve_prompt_turn_assistant_text(
542            terminal,
543            stream.into_assistant_text(),
544            lagged_terminal.as_ref(),
545            last_turn_error,
546        )
547    }
548
549    fn resolve_prompt_turn_assistant_text(
550        terminal: TurnTerminalEvent,
551        collected_assistant_text: String,
552        lagged_terminal: Option<&LaggedTurnTerminal>,
553        last_turn_error: Option<PromptTurnErrorSignal>,
554    ) -> Result<String, PromptRunError> {
555        match terminal {
556            TurnTerminalEvent::Completed => Self::finalize_prompt_turn_assistant_text(
557                collected_assistant_text,
558                lagged_completed_text(lagged_terminal),
559                last_turn_error,
560            ),
561            TurnTerminalEvent::Failed => prompt_turn_failed_error(last_turn_error, lagged_terminal),
562            TurnTerminalEvent::Interrupted | TurnTerminalEvent::Cancelled => {
563                Err(PromptRunError::TurnInterrupted)
564            }
565        }
566    }
567
568    fn finalize_prompt_turn_assistant_text(
569        collected_assistant_text: String,
570        lagged_completed_text: Option<String>,
571        last_turn_error: Option<PromptTurnErrorSignal>,
572    ) -> Result<String, PromptRunError> {
573        let assistant_text = if let Some(snapshot_text) = lagged_completed_text {
574            if snapshot_text.trim().is_empty() {
575                collected_assistant_text
576            } else {
577                snapshot_text
578            }
579        } else {
580            collected_assistant_text
581        };
582        let assistant_text = assistant_text.trim().to_owned();
583        if assistant_text.is_empty() {
584            if let Some(err) = last_turn_error {
585                Err(PromptRunError::TurnCompletedWithoutAssistantText(
586                    err.into_failure(PromptTurnTerminalState::CompletedWithoutAssistantText),
587                ))
588            } else {
589                Err(PromptRunError::EmptyAssistantText)
590            }
591        } else {
592            Ok(assistant_text)
593        }
594    }
595
596    async fn read_turn_terminal_after_lag(
597        &self,
598        thread_id: &str,
599        turn_id: &str,
600        timeout_duration: Duration,
601    ) -> Result<Option<LaggedTurnTerminal>, RpcError> {
602        let params = serialize_params(
603            methods::THREAD_READ,
604            &ThreadReadParams {
605                thread_id: thread_id.to_owned(),
606                include_turns: Some(true),
607            },
608        )?;
609        let response = self
610            .call_validated_with_mode_and_timeout(
611                methods::THREAD_READ,
612                params,
613                RpcValidationMode::KnownMethods,
614                timeout_duration,
615            )
616            .await?;
617        let response: ThreadReadResponse = deserialize_result(methods::THREAD_READ, response)?;
618
619        let Some(turn) = response.thread.turns.iter().find(|turn| turn.id == turn_id) else {
620            return Ok(None);
621        };
622
623        Ok(lagged_terminal_from_turn(turn))
624    }
625
626    #[allow(clippy::too_many_arguments)]
627    pub(super) async fn execute_pre_hook_phase(
628        &self,
629        hook_state: &mut HookExecutionState,
630        phase: HookPhase,
631        cwd: Option<&str>,
632        model: Option<&str>,
633        thread_id: Option<&str>,
634        turn_id: Option<&str>,
635        scoped_hooks: Option<&RuntimeHookConfig>,
636    ) -> Result<Vec<PreHookDecision>, BlockReason> {
637        let ctx = build_hook_context(
638            hook_state.correlation_id.as_str(),
639            &hook_state.metadata,
640            HookContextInput {
641                phase,
642                cwd,
643                model,
644                thread_id,
645                turn_id,
646                main_status: None,
647            },
648        );
649        self.run_pre_hooks_with(&ctx, &mut hook_state.report, scoped_hooks)
650            .await
651    }
652
653    pub(super) async fn execute_post_hook_phase(
654        &self,
655        hook_state: &mut HookExecutionState,
656        input: HookContextInput<'_>,
657        scoped_hooks: Option<&RuntimeHookConfig>,
658    ) {
659        let ctx = build_hook_context(
660            hook_state.correlation_id.as_str(),
661            &hook_state.metadata,
662            input,
663        );
664        self.run_post_hooks_with(&ctx, &mut hook_state.report, scoped_hooks)
665            .await;
666    }
667}
668
669impl PromptRunStream {
670    /// Borrow target thread id for this scoped stream.
671    pub fn thread_id(&self) -> &str {
672        self.thread_id.as_str()
673    }
674
675    /// Borrow target turn id for this scoped stream.
676    pub fn turn_id(&self) -> &str {
677        self.turn_id.as_str()
678    }
679
680    /// Receive the next typed event for the target turn.
681    pub async fn recv(&mut self) -> Result<Option<PromptRunStreamEvent>, PromptRunError> {
682        if self.state.final_result.is_some() {
683            return Ok(None);
684        }
685
686        loop {
687            let now = Instant::now();
688            if now >= self.deadline {
689                return Err(self.timeout_with_interrupt().await);
690            }
691            let remaining = self.deadline.saturating_duration_since(now);
692
693            let envelope = match timeout(remaining, self.live_rx.recv()).await {
694                Ok(Ok(envelope)) => envelope,
695                Ok(Err(RecvError::Lagged(_))) => {
696                    let lag_probe_budget = self.deadline.saturating_duration_since(Instant::now());
697                    if lag_probe_budget.is_zero() {
698                        return Err(self.timeout_with_interrupt().await);
699                    }
700
701                    match self
702                        .runtime
703                        .read_turn_terminal_after_lag(
704                            &self.thread_id,
705                            &self.turn_id,
706                            lag_probe_budget,
707                        )
708                        .await
709                    {
710                        Ok(Some(snapshot)) => {
711                            self.state.lagged_terminal = Some(snapshot.clone());
712                            let observation =
713                                observe_lagged_terminal(&self.thread_id, &self.turn_id, &snapshot);
714                            return Ok(self.apply_observation(observation).await);
715                        }
716                        Ok(None) => continue,
717                        Err(RpcError::Timeout) => return Err(self.timeout_with_interrupt().await),
718                        Err(err) => return Err(self.fail(PromptRunError::Rpc(err)).await),
719                    }
720                }
721                Ok(Err(RecvError::Closed)) => {
722                    return Err(self
723                        .fail(PromptRunError::Runtime(RuntimeError::Internal(format!(
724                            "live stream closed: {}",
725                            RecvError::Closed
726                        ))))
727                        .await);
728                }
729                Err(_) => return Err(self.timeout_with_interrupt().await),
730            };
731
732            if !self.stream.is_target_envelope(&envelope) {
733                continue;
734            }
735
736            let terminal = self.stream.push_envelope(&envelope);
737            let observation = observe_target_envelope(&envelope, terminal);
738            let next = self.apply_observation(observation).await;
739            if next.is_some() {
740                return Ok(next);
741            }
742            if self.state.final_result.is_some() {
743                return Ok(None);
744            }
745        }
746    }
747
748    /// Drain the stream to its terminal result.
749    pub async fn finish(mut self) -> Result<PromptRunResult, PromptRunError> {
750        while self.state.final_result.is_none() {
751            if self.recv().await?.is_none() {
752                break;
753            }
754        }
755
756        match self.state.final_result.take() {
757            Some(result) => result,
758            None => Err(PromptRunError::Runtime(RuntimeError::Internal(
759                "prompt stream finished without terminal result".to_owned(),
760            ))),
761        }
762    }
763
764    async fn complete(&mut self, result: Result<PromptRunResult, PromptRunError>) {
765        self.cleanup(stream_result_status(&result)).await;
766        self.state.final_result = Some(result);
767    }
768
769    async fn timeout_with_interrupt(&mut self) -> PromptRunError {
770        self.interrupt_best_effort();
771        self.fail(PromptRunError::Timeout(self.timeout)).await
772    }
773
774    async fn fail(&mut self, err: PromptRunError) -> PromptRunError {
775        self.cleanup("error").await;
776        self.state.final_result = Some(Err(err.clone()));
777        err
778    }
779
780    async fn cleanup(&mut self, main_status: &'static str) {
781        let Some(plan) = self.take_cleanup_plan(main_status, false) else {
782            return;
783        };
784        run_cleanup_plan(&self.runtime, plan).await;
785    }
786
787    fn detach_cleanup(&mut self, main_status: &'static str) {
788        let Some(plan) = self.take_cleanup_plan(main_status, true) else {
789            return;
790        };
791
792        let runtime = self.runtime.clone();
793        let fallback_runtime = runtime.clone();
794        let thread_id = plan.thread_id.clone();
795        spawn_detached_task(
796            async move {
797                run_cleanup_plan(&runtime, plan).await;
798            },
799            current_detached_task_plan("prompt_stream_cleanup"),
800            move || {
801                fallback_runtime.record_detached_task_init_failed();
802                fallback_runtime.clear_thread_scoped_pre_tool_use_hooks(&thread_id);
803            },
804        );
805    }
806
807    fn interrupt_best_effort(&self) {
808        interrupt_turn_best_effort_detached(
809            self.runtime.clone(),
810            self.thread_id.clone(),
811            self.turn_id.clone(),
812            INTERRUPT_RPC_TIMEOUT,
813        );
814    }
815
816    fn take_cleanup_plan(
817        &mut self,
818        main_status: &'static str,
819        send_interrupt: bool,
820    ) -> Option<PromptStreamCleanupPlan> {
821        self.cleanup.take_plan(
822            self.thread_id.clone(),
823            self.turn_id.clone(),
824            main_status,
825            send_interrupt,
826        )
827    }
828
829    async fn apply_observation(
830        &mut self,
831        observation: PromptStreamObservation,
832    ) -> Option<PromptRunStreamEvent> {
833        let transition = reduce_prompt_stream_observation(
834            &mut self.state,
835            self.thread_id.as_str(),
836            self.turn_id.as_str(),
837            self.stream.clone().into_assistant_text(),
838            observation,
839        );
840        if let Some(result) = transition.terminal_result {
841            self.complete(result).await;
842        }
843        transition.event
844    }
845}
846
847fn lagged_completed_text(lagged_terminal: Option<&LaggedTurnTerminal>) -> Option<String> {
848    match lagged_terminal {
849        Some(LaggedTurnTerminal::Completed { assistant_text }) => assistant_text.clone(),
850        _ => None,
851    }
852}
853
854fn prompt_turn_failed_error(
855    last_turn_error: Option<PromptTurnErrorSignal>,
856    lagged_terminal: Option<&LaggedTurnTerminal>,
857) -> Result<String, PromptRunError> {
858    if let Some(err) = last_turn_error {
859        Err(PromptRunError::TurnFailedWithContext(
860            err.into_failure(PromptTurnTerminalState::Failed),
861        ))
862    } else if let Some(LaggedTurnTerminal::Failed { message }) = lagged_terminal {
863        if let Some(message) = message.clone() {
864            Err(PromptRunError::TurnFailedWithContext(PromptTurnFailure {
865                terminal_state: PromptTurnTerminalState::Failed,
866                source_method: "thread/read".to_owned(),
867                code: None,
868                message,
869            }))
870        } else {
871            Err(PromptRunError::TurnFailed)
872        }
873    } else {
874        Err(PromptRunError::TurnFailed)
875    }
876}
877
878impl Drop for PromptRunStream {
879    fn drop(&mut self) {
880        if !self.cleanup.cleaned_up {
881            self.detach_cleanup("error");
882        }
883    }
884}
885
886fn envelope_to_stream_event(envelope: &Envelope) -> Option<PromptRunStreamEvent> {
887    extract_agent_message_delta(envelope)
888        .map(PromptRunStreamEvent::AgentMessageDelta)
889        .or_else(|| extract_turn_completed(envelope).map(PromptRunStreamEvent::TurnCompleted))
890        .or_else(|| extract_turn_failed(envelope).map(PromptRunStreamEvent::TurnFailed))
891        .or_else(|| extract_turn_interrupted(envelope).map(PromptRunStreamEvent::TurnInterrupted))
892        .or_else(|| extract_turn_cancelled(envelope).map(PromptRunStreamEvent::TurnCancelled))
893}
894
895fn lagged_terminal_to_stream_event(
896    thread_id: &str,
897    turn_id: &str,
898    terminal: &LaggedTurnTerminal,
899) -> Option<PromptRunStreamEvent> {
900    match terminal {
901        LaggedTurnTerminal::Completed { assistant_text } => {
902            Some(PromptRunStreamEvent::TurnCompleted(
903                crate::runtime::events::TurnCompletedNotification {
904                    thread_id: thread_id.to_owned(),
905                    turn_id: turn_id.to_owned(),
906                    text: assistant_text.clone(),
907                },
908            ))
909        }
910        LaggedTurnTerminal::Failed { message } => Some(PromptRunStreamEvent::TurnFailed(
911            crate::runtime::events::TurnFailedNotification {
912                thread_id: thread_id.to_owned(),
913                turn_id: turn_id.to_owned(),
914                code: None,
915                message: message.clone(),
916            },
917        )),
918        LaggedTurnTerminal::Cancelled => Some(PromptRunStreamEvent::TurnCancelled(
919            crate::runtime::events::TurnCancelledNotification {
920                thread_id: thread_id.to_owned(),
921                turn_id: turn_id.to_owned(),
922            },
923        )),
924        LaggedTurnTerminal::Interrupted => Some(PromptRunStreamEvent::TurnInterrupted(
925            crate::runtime::events::TurnInterruptedNotification {
926                thread_id: thread_id.to_owned(),
927                turn_id: turn_id.to_owned(),
928            },
929        )),
930    }
931}
932
933fn observe_target_envelope(
934    envelope: &Envelope,
935    terminal: Option<TurnTerminalEvent>,
936) -> PromptStreamObservation {
937    PromptStreamObservation {
938        event: envelope_to_stream_event(envelope),
939        terminal,
940        turn_error: extract_turn_error_signal(envelope),
941    }
942}
943
944fn observe_lagged_terminal(
945    thread_id: &str,
946    turn_id: &str,
947    terminal: &LaggedTurnTerminal,
948) -> PromptStreamObservation {
949    PromptStreamObservation {
950        event: lagged_terminal_to_stream_event(thread_id, turn_id, terminal),
951        terminal: Some(terminal.as_terminal_event()),
952        turn_error: None,
953    }
954}
955
956fn lagged_terminal_from_turn(turn: &ThreadTurnView) -> Option<LaggedTurnTerminal> {
957    match turn.status {
958        ThreadTurnStatus::Completed => Some(LaggedTurnTerminal::Completed {
959            assistant_text: extract_assistant_text_from_turn(turn),
960        }),
961        ThreadTurnStatus::Failed => Some(LaggedTurnTerminal::Failed {
962            message: turn.error.as_ref().map(|error| error.message.clone()),
963        }),
964        ThreadTurnStatus::Cancelled => Some(LaggedTurnTerminal::Cancelled),
965        ThreadTurnStatus::Interrupted => Some(LaggedTurnTerminal::Interrupted),
966        ThreadTurnStatus::InProgress => None,
967    }
968}
969
970fn stream_result_status(result: &Result<PromptRunResult, PromptRunError>) -> &'static str {
971    if result.is_ok() {
972        "ok"
973    } else {
974        "error"
975    }
976}
977
978struct PromptStreamCleanupPlan {
979    thread_id: String,
980    turn_id: String,
981    run_cwd: String,
982    run_model: Option<String>,
983    scoped_hooks: Option<RuntimeHookConfig>,
984    hook_state: Option<HookExecutionState>,
985    main_status: &'static str,
986    send_interrupt: bool,
987}
988
989struct PromptStreamObservation {
990    event: Option<PromptRunStreamEvent>,
991    terminal: Option<TurnTerminalEvent>,
992    turn_error: Option<PromptTurnErrorSignal>,
993}
994
995struct PromptStreamTransition {
996    event: Option<PromptRunStreamEvent>,
997    terminal_result: Option<Result<PromptRunResult, PromptRunError>>,
998}
999
1000fn reduce_prompt_stream_observation(
1001    state: &mut PromptRunStreamState,
1002    thread_id: &str,
1003    turn_id: &str,
1004    collected_assistant_text: String,
1005    observation: PromptStreamObservation,
1006) -> PromptStreamTransition {
1007    if let Some(err) = observation.turn_error {
1008        state.last_turn_error = Some(err);
1009    }
1010
1011    let terminal_result = observation.terminal.map(|terminal| {
1012        build_prompt_run_result(
1013            thread_id,
1014            turn_id,
1015            collected_assistant_text,
1016            state.lagged_terminal.as_ref(),
1017            state.last_turn_error.clone(),
1018            terminal,
1019        )
1020    });
1021
1022    PromptStreamTransition {
1023        event: observation.event,
1024        terminal_result,
1025    }
1026}
1027
1028fn build_prompt_run_result(
1029    thread_id: &str,
1030    turn_id: &str,
1031    collected_assistant_text: String,
1032    lagged_terminal: Option<&LaggedTurnTerminal>,
1033    last_turn_error: Option<PromptTurnErrorSignal>,
1034    terminal: TurnTerminalEvent,
1035) -> Result<PromptRunResult, PromptRunError> {
1036    Runtime::resolve_prompt_turn_assistant_text(
1037        terminal,
1038        collected_assistant_text,
1039        lagged_terminal,
1040        last_turn_error,
1041    )
1042    .map(|assistant_text| PromptRunResult {
1043        thread_id: thread_id.to_owned(),
1044        turn_id: turn_id.to_owned(),
1045        assistant_text,
1046    })
1047}
1048
1049impl PromptStreamCleanupState {
1050    fn take_plan(
1051        &mut self,
1052        thread_id: String,
1053        turn_id: String,
1054        main_status: &'static str,
1055        send_interrupt: bool,
1056    ) -> Option<PromptStreamCleanupPlan> {
1057        if self.cleaned_up {
1058            return None;
1059        }
1060
1061        self.cleaned_up = true;
1062        Some(PromptStreamCleanupPlan {
1063            thread_id,
1064            turn_id,
1065            run_cwd: self.run_cwd.clone(),
1066            run_model: self.run_model.clone(),
1067            scoped_hooks: self.scoped_hooks.clone(),
1068            hook_state: self.hook_state.take(),
1069            main_status,
1070            send_interrupt,
1071        })
1072    }
1073}
1074
1075async fn run_cleanup_plan(runtime: &Runtime, mut plan: PromptStreamCleanupPlan) {
1076    if let Some(state) = plan.hook_state.as_mut() {
1077        runtime
1078            .execute_post_hook_phase(
1079                state,
1080                HookContextInput {
1081                    phase: HookPhase::PostTurn,
1082                    cwd: Some(plan.run_cwd.as_str()),
1083                    model: plan.run_model.as_deref(),
1084                    thread_id: Some(plan.thread_id.as_str()),
1085                    turn_id: Some(plan.turn_id.as_str()),
1086                    main_status: Some(plan.main_status),
1087                },
1088                plan.scoped_hooks.as_ref(),
1089            )
1090            .await;
1091        runtime.publish_hook_report(state.report.clone());
1092    }
1093
1094    if plan.send_interrupt {
1095        interrupt_turn_best_effort_with_timeout(
1096            runtime,
1097            plan.thread_id.as_str(),
1098            plan.turn_id.as_str(),
1099            INTERRUPT_RPC_TIMEOUT,
1100        )
1101        .await;
1102    }
1103
1104    runtime.clear_thread_scoped_pre_tool_use_hooks(&plan.thread_id);
1105}