Skip to main content

codex_runtime/runtime/api/
thread_api.rs

1use std::time::Duration;
2
3use crate::plugin::{BlockReason, HookPhase};
4use serde_json::{Map, Value};
5
6use crate::runtime::core::Runtime;
7use crate::runtime::errors::RpcError;
8use crate::runtime::hooks::RuntimeHookConfig;
9use crate::runtime::rpc_contract::{methods, RpcValidationMode};
10use crate::runtime::turn_output::{parse_thread_id, parse_turn_id};
11
12use super::flow::{
13    apply_pre_hook_actions_to_session, result_status, HookContextInput, HookExecutionState,
14    SessionMutationState,
15};
16use super::wire::{
17    deserialize_result, input_item_to_wire, serialize_params, thread_overrides_to_wire,
18    turn_start_params_to_wire, validate_turn_start_security,
19};
20use super::*;
21
22impl ThreadHandle {
23    pub fn runtime(&self) -> &crate::runtime::core::Runtime {
24        &self.runtime
25    }
26
27    pub async fn turn_start(&self, p: TurnStartParams) -> Result<TurnHandle, RpcError> {
28        ensure_turn_input_not_empty(&p.input)?;
29        validate_turn_start_security(&p)?;
30
31        let response = self
32            .runtime
33            .call_validated(
34                methods::TURN_START,
35                turn_start_params_to_wire(&self.thread_id, &p),
36            )
37            .await?;
38
39        let turn_id = parse_turn_id(&response).ok_or_else(|| {
40            RpcError::InvalidRequest(format!("turn/start missing turn id in result: {response}"))
41        })?;
42
43        Ok(TurnHandle {
44            turn_id,
45            thread_id: self.thread_id.clone(),
46        })
47    }
48
49    /// Start a follow-up turn anchored to an expected previous turn id.
50    /// Allocation: JSON params + input item wire objects.
51    /// Complexity: O(n), n = input item count.
52    pub async fn turn_steer(
53        &self,
54        expected_turn_id: &str,
55        input: Vec<InputItem>,
56    ) -> Result<super::TurnId, RpcError> {
57        ensure_turn_input_not_empty(&input)?;
58
59        let mut params = Map::<String, Value>::new();
60        params.insert("threadId".to_owned(), Value::String(self.thread_id.clone()));
61        params.insert(
62            "expectedTurnId".to_owned(),
63            Value::String(expected_turn_id.to_owned()),
64        );
65        params.insert(
66            "input".to_owned(),
67            Value::Array(input.iter().map(input_item_to_wire).collect()),
68        );
69        let response = self
70            .runtime
71            .call_validated(methods::TURN_START, Value::Object(params))
72            .await?;
73        parse_turn_id(&response).ok_or_else(|| {
74            RpcError::InvalidRequest(format!(
75                "turn/start(steer) missing turn id in result: {response}"
76            ))
77        })
78    }
79
80    pub async fn turn_interrupt(&self, turn_id: &str) -> Result<(), RpcError> {
81        self.runtime.turn_interrupt(&self.thread_id, turn_id).await
82    }
83}
84
85impl Runtime {
86    pub(crate) fn loaded_thread_handle(&self, thread_id: &str) -> ThreadHandle {
87        ThreadHandle {
88            thread_id: thread_id.to_owned(),
89            runtime: self.clone(),
90        }
91    }
92
93    pub async fn thread_start(&self, p: ThreadStartParams) -> Result<ThreadHandle, RpcError> {
94        self.thread_start_with_hooks(p, None).await
95    }
96
97    pub(crate) async fn thread_start_with_hooks(
98        &self,
99        p: ThreadStartParams,
100        scoped_hooks: Option<&RuntimeHookConfig>,
101    ) -> Result<ThreadHandle, RpcError> {
102        if !self.hooks_enabled_with(scoped_hooks) {
103            return self.thread_start_raw(p).await;
104        }
105
106        let (p, mut hook_state, start_cwd, start_model) = self
107            .prepare_session_start_hooks(p, None, scoped_hooks)
108            .await?;
109        let result = self.thread_start_raw(p).await;
110        self.finalize_session_start_hooks(
111            &mut hook_state,
112            start_cwd.as_deref(),
113            start_model.as_deref(),
114            None,
115            &result,
116            scoped_hooks,
117        )
118        .await;
119        self.publish_hook_report(hook_state.report);
120        result
121    }
122
123    pub async fn thread_resume(
124        &self,
125        thread_id: &str,
126        p: ThreadStartParams,
127    ) -> Result<ThreadHandle, RpcError> {
128        self.thread_resume_with_hooks(thread_id, p, None).await
129    }
130
131    pub(crate) async fn thread_resume_with_hooks(
132        &self,
133        thread_id: &str,
134        p: ThreadStartParams,
135        scoped_hooks: Option<&RuntimeHookConfig>,
136    ) -> Result<ThreadHandle, RpcError> {
137        if !self.hooks_enabled_with(scoped_hooks) {
138            return self.thread_resume_raw(thread_id, p).await;
139        }
140
141        let (p, mut hook_state, resume_cwd, resume_model) = self
142            .prepare_session_start_hooks(p, Some(thread_id), scoped_hooks)
143            .await?;
144        let result = self.thread_resume_raw(thread_id, p).await;
145        self.finalize_session_start_hooks(
146            &mut hook_state,
147            resume_cwd.as_deref(),
148            resume_model.as_deref(),
149            Some(thread_id),
150            &result,
151            scoped_hooks,
152        )
153        .await;
154        self.publish_hook_report(hook_state.report);
155        result
156    }
157
158    async fn prepare_session_start_hooks(
159        &self,
160        p: ThreadStartParams,
161        thread_id: Option<&str>,
162        scoped_hooks: Option<&RuntimeHookConfig>,
163    ) -> Result<
164        (
165            ThreadStartParams,
166            HookExecutionState,
167            Option<String>,
168            Option<String>,
169        ),
170        RpcError,
171    > {
172        let mut hook_state = HookExecutionState::new(self.next_hook_correlation_id());
173        let mut session_state =
174            SessionMutationState::from_thread_start(&p, hook_state.metadata.clone());
175        let decisions = self
176            .execute_pre_hook_phase(
177                &mut hook_state,
178                HookPhase::PreSessionStart,
179                p.cwd.as_deref(),
180                p.model.as_deref(),
181                thread_id,
182                None,
183                scoped_hooks,
184            )
185            .await
186            .map_err(block_reason_to_rpc_error)?;
187        apply_pre_hook_actions_to_session(
188            &mut session_state,
189            HookPhase::PreSessionStart,
190            decisions,
191            &mut hook_state.report,
192        );
193        hook_state.metadata = session_state.metadata.clone();
194
195        let mut p = p;
196        p.model = session_state.model;
197        let cwd = p.cwd.clone();
198        let model = p.model.clone();
199
200        Ok((p, hook_state, cwd, model))
201    }
202
203    async fn finalize_session_start_hooks(
204        &self,
205        hook_state: &mut HookExecutionState,
206        cwd: Option<&str>,
207        model: Option<&str>,
208        fallback_thread_id: Option<&str>,
209        result: &Result<ThreadHandle, RpcError>,
210        scoped_hooks: Option<&RuntimeHookConfig>,
211    ) {
212        let post_thread_id = result
213            .as_ref()
214            .ok()
215            .map(|thread| thread.thread_id.as_str())
216            .or(fallback_thread_id);
217        self.execute_post_hook_phase(
218            hook_state,
219            HookContextInput {
220                phase: HookPhase::PostSessionStart,
221                cwd,
222                model,
223                thread_id: post_thread_id,
224                turn_id: None,
225                main_status: Some(result_status(result)),
226            },
227            scoped_hooks,
228        )
229        .await;
230    }
231
232    pub(crate) async fn thread_resume_raw(
233        &self,
234        thread_id: &str,
235        p: ThreadStartParams,
236    ) -> Result<ThreadHandle, RpcError> {
237        let p = super::escalate_approval_if_tool_hooks(self, p);
238        super::wire::validate_thread_start_security(&p)?;
239        let mut params = Map::<String, Value>::new();
240        params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
241        let overrides = thread_overrides_to_wire(&p);
242        if !overrides.is_empty() {
243            params.insert("overrides".to_owned(), Value::Object(overrides));
244        }
245
246        let response = self
247            .call_validated(methods::THREAD_RESUME, Value::Object(params))
248            .await?;
249        let resumed = parse_thread_id(&response).ok_or_else(|| {
250            RpcError::InvalidRequest(format!(
251                "thread/resume missing thread id in result: {response}"
252            ))
253        })?;
254        if resumed != thread_id {
255            return Err(RpcError::InvalidRequest(format!(
256                "thread/resume returned mismatched thread id: requested={thread_id} actual={resumed}"
257            )));
258        }
259        Ok(ThreadHandle {
260            thread_id: resumed,
261            runtime: self.clone(),
262        })
263    }
264
265    pub async fn thread_fork(&self, thread_id: &str) -> Result<ThreadHandle, RpcError> {
266        let mut params = Map::<String, Value>::new();
267        params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
268        let response = self
269            .call_validated(methods::THREAD_FORK, Value::Object(params))
270            .await?;
271        let forked = parse_thread_id(&response).ok_or_else(|| {
272            RpcError::InvalidRequest(format!(
273                "thread/fork missing thread id in result: {response}"
274            ))
275        })?;
276        Ok(ThreadHandle {
277            thread_id: forked,
278            runtime: self.clone(),
279        })
280    }
281
282    /// Archive a thread (logical close on server side).
283    /// Allocation: one JSON object with thread id.
284    /// Complexity: O(1).
285    pub async fn thread_archive(&self, thread_id: &str) -> Result<(), RpcError> {
286        let mut params = Map::<String, Value>::new();
287        params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
288        let _ = self
289            .call_validated(methods::THREAD_ARCHIVE, Value::Object(params))
290            .await?;
291        Ok(())
292    }
293
294    /// Read one thread by id.
295    /// Allocation: serialized params + decoded response object.
296    /// Complexity: O(n), n = thread payload size.
297    pub async fn thread_read(&self, p: ThreadReadParams) -> Result<ThreadReadResponse, RpcError> {
298        let params = serialize_params(methods::THREAD_READ, &p)?;
299        let response = self.call_validated(methods::THREAD_READ, params).await?;
300        deserialize_result(methods::THREAD_READ, response)
301    }
302
303    /// List persisted threads with optional filters and pagination.
304    /// Allocation: serialized params + decoded list payload.
305    /// Complexity: O(n), n = number of returned threads.
306    pub async fn thread_list(&self, p: ThreadListParams) -> Result<ThreadListResponse, RpcError> {
307        let params = serialize_params(methods::THREAD_LIST, &p)?;
308        let response = self.call_validated(methods::THREAD_LIST, params).await?;
309        deserialize_result(methods::THREAD_LIST, response)
310    }
311
312    /// List currently loaded thread ids from in-memory sessions.
313    /// Allocation: serialized params + decoded list payload.
314    /// Complexity: O(n), n = number of returned ids.
315    pub async fn thread_loaded_list(
316        &self,
317        p: ThreadLoadedListParams,
318    ) -> Result<ThreadLoadedListResponse, RpcError> {
319        let params = serialize_params(methods::THREAD_LOADED_LIST, &p)?;
320        let response = self
321            .call_validated(methods::THREAD_LOADED_LIST, params)
322            .await?;
323        deserialize_result(methods::THREAD_LOADED_LIST, response)
324    }
325
326    /// List skills for one or more working directories.
327    /// Allocation: serialized params + decoded inventory payload.
328    /// Complexity: O(n), n = number of returned cwd entries + skill metadata size.
329    pub async fn skills_list(&self, p: SkillsListParams) -> Result<SkillsListResponse, RpcError> {
330        let params = serialize_params(methods::SKILLS_LIST, &p)?;
331        let response = self.call_validated(methods::SKILLS_LIST, params).await?;
332        deserialize_result(methods::SKILLS_LIST, response)
333    }
334
335    /// Roll back the last `num_turns` turns from a thread.
336    /// Allocation: serialized params + decoded response payload.
337    /// Complexity: O(n), n = rolled thread payload size.
338    pub async fn thread_rollback(
339        &self,
340        p: ThreadRollbackParams,
341    ) -> Result<ThreadRollbackResponse, RpcError> {
342        let params = serialize_params(methods::THREAD_ROLLBACK, &p)?;
343        let response = self
344            .call_validated(methods::THREAD_ROLLBACK, params)
345            .await?;
346        deserialize_result(methods::THREAD_ROLLBACK, response)
347    }
348
349    /// Interrupt one in-flight turn for a thread.
350    /// Allocation: one JSON object with thread + turn id.
351    /// Complexity: O(1).
352    pub async fn turn_interrupt(&self, thread_id: &str, turn_id: &str) -> Result<(), RpcError> {
353        let _ = self
354            .call_validated(
355                methods::TURN_INTERRUPT,
356                interrupt_params(thread_id, turn_id),
357            )
358            .await?;
359        Ok(())
360    }
361
362    /// Interrupt one in-flight turn with explicit RPC timeout.
363    /// Allocation: one JSON object with thread + turn id.
364    /// Complexity: O(1).
365    pub async fn turn_interrupt_with_timeout(
366        &self,
367        thread_id: &str,
368        turn_id: &str,
369        timeout_duration: Duration,
370    ) -> Result<(), RpcError> {
371        let _ = self
372            .call_validated_with_mode_and_timeout(
373                methods::TURN_INTERRUPT,
374                interrupt_params(thread_id, turn_id),
375                RpcValidationMode::KnownMethods,
376                timeout_duration,
377            )
378            .await?;
379        Ok(())
380    }
381}
382
383fn interrupt_params(thread_id: &str, turn_id: &str) -> Value {
384    let mut params = Map::<String, Value>::new();
385    params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
386    params.insert("turnId".to_owned(), Value::String(turn_id.to_owned()));
387    Value::Object(params)
388}
389
390fn ensure_turn_input_not_empty(input: &[InputItem]) -> Result<(), RpcError> {
391    if input.is_empty() {
392        return Err(RpcError::InvalidRequest(
393            "turn input must not be empty".to_owned(),
394        ));
395    }
396    Ok(())
397}
398
399/// Convert a `BlockReason` to `RpcError` for session-start callers.
400/// Allocation: one formatted String.
401fn block_reason_to_rpc_error(r: BlockReason) -> RpcError {
402    RpcError::InvalidRequest(format!(
403        "blocked by hook '{}' at {:?}: {}",
404        r.hook_name, r.phase, r.message
405    ))
406}