Skip to main content

codex_runtime/runtime/client/
session.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use tokio::sync::Mutex;
5
6use crate::runtime::api::{PromptRunError, PromptRunParams, PromptRunResult, PromptRunStream};
7use crate::runtime::core::Runtime;
8use crate::runtime::errors::RpcError;
9use crate::runtime::hooks::merge_hook_configs;
10
11use super::profile::{prepared_prompt_run_from_profile, session_prepared_prompt_run};
12use super::{RunProfile, SessionConfig};
13
14const SESSION_CLOSED_MESSAGE: &str = "session is closed";
15
16#[derive(Clone)]
17pub struct Session {
18    runtime: Runtime,
19    pub thread_id: String,
20    pub config: SessionConfig,
21    state: SessionState,
22}
23
24#[derive(Clone)]
25pub(super) struct SessionState {
26    closed: Arc<AtomicBool>,
27    close_result: Arc<Mutex<Option<Result<(), RpcError>>>>,
28}
29
30struct SessionClosePermit<'a> {
31    guard: tokio::sync::MutexGuard<'a, Option<Result<(), RpcError>>>,
32}
33
34#[derive(Clone, Debug, PartialEq)]
35pub(super) enum SessionCloseState {
36    ReturnCached(Result<(), RpcError>),
37    StartClosing,
38}
39
40fn ensure_session_open(closed: bool) -> Result<(), RpcError> {
41    if closed {
42        return Err(RpcError::InvalidRequest(SESSION_CLOSED_MESSAGE.to_owned()));
43    }
44    Ok(())
45}
46
47pub(super) fn next_close_state(cached: Option<&Result<(), RpcError>>) -> SessionCloseState {
48    match cached {
49        Some(result) => SessionCloseState::ReturnCached(result.clone()),
50        None => SessionCloseState::StartClosing,
51    }
52}
53
54impl SessionState {
55    pub(super) fn new() -> Self {
56        Self {
57            closed: Arc::new(AtomicBool::new(false)),
58            close_result: Arc::new(Mutex::new(None)),
59        }
60    }
61
62    fn is_closed(&self) -> bool {
63        self.closed.load(Ordering::Acquire)
64    }
65
66    pub(super) fn ensure_open_for_prompt(&self) -> Result<(), PromptRunError> {
67        ensure_session_open(self.is_closed()).map_err(PromptRunError::Rpc)
68    }
69
70    pub(super) fn ensure_open_for_rpc(&self) -> Result<(), RpcError> {
71        ensure_session_open(self.is_closed())
72    }
73
74    async fn acquire_close_permit(&self) -> SessionClosePermit<'_> {
75        SessionClosePermit {
76            guard: self.close_result.lock().await,
77        }
78    }
79
80    pub(super) fn mark_closed(&self) {
81        self.closed.store(true, Ordering::Release);
82    }
83}
84
85impl SessionClosePermit<'_> {
86    fn next_state(&self) -> SessionCloseState {
87        next_close_state(self.guard.as_ref())
88    }
89
90    fn store_result(mut self, result: Result<(), RpcError>) -> Result<(), RpcError> {
91        *self.guard = Some(result.clone());
92        result
93    }
94}
95
96impl Session {
97    pub(super) fn new(runtime: Runtime, thread_id: String, config: SessionConfig) -> Self {
98        Self {
99            runtime,
100            thread_id,
101            config,
102            state: SessionState::new(),
103        }
104    }
105
106    /// Returns true when this local session handle is closed.
107    /// Allocation: none. Complexity: O(1).
108    pub fn is_closed(&self) -> bool {
109        self.state.is_closed()
110    }
111
112    /// Continue this session with one prompt.
113    /// Side effects: sends turn/start RPC calls on one already-loaded thread.
114    /// Allocation: PromptRunParams clone payloads (cwd/model/sandbox/attachments). Complexity: O(n), n = attachment count + prompt length.
115    pub async fn ask(&self, prompt: impl Into<String>) -> Result<PromptRunResult, PromptRunError> {
116        self.state.ensure_open_for_prompt()?;
117        let prepared = session_prepared_prompt_run(&self.config, prompt);
118        self.runtime
119            .run_prompt_on_loaded_thread_with_hooks(
120                &self.thread_id,
121                prepared.params,
122                Some(prepared.hooks.as_ref()),
123            )
124            .await
125    }
126
127    /// Continue this session with one prompt and receive scoped typed turn events.
128    /// Side effects: sends turn/start RPC calls on one already-loaded thread and consumes only matching live events.
129    pub async fn ask_stream(
130        &self,
131        prompt: impl Into<String>,
132    ) -> Result<PromptRunStream, PromptRunError> {
133        self.state.ensure_open_for_prompt()?;
134        let prepared = session_prepared_prompt_run(&self.config, prompt);
135        self.runtime
136            .run_prompt_on_loaded_thread_stream_with_hooks(
137                &self.thread_id,
138                prepared.params,
139                Some(prepared.hooks.as_ref()),
140            )
141            .await
142    }
143
144    /// Continue this session with one prompt and wait for the scoped stream to finish.
145    /// Side effects: sends turn/start RPC calls on one already-loaded thread and drains the matching turn stream to completion.
146    /// Allocation: PromptRunParams clone payloads (cwd/model/sandbox/attachments). Complexity: O(n), n = attachment count + prompt length.
147    pub async fn ask_wait(
148        &self,
149        prompt: impl Into<String>,
150    ) -> Result<PromptRunResult, PromptRunError> {
151        self.ask_stream(prompt).await?.finish().await
152    }
153
154    /// Continue this session with one prompt while overriding selected turn options.
155    /// Side effects: sends turn/start RPC calls on one already-loaded thread.
156    /// Allocation: depends on caller-provided params. Complexity: O(1) wrapper.
157    pub async fn ask_with(
158        &self,
159        params: PromptRunParams,
160    ) -> Result<PromptRunResult, PromptRunError> {
161        self.state.ensure_open_for_prompt()?;
162        self.runtime
163            .run_prompt_on_loaded_thread_with_hooks(
164                &self.thread_id,
165                params,
166                Some(&self.config.hooks),
167            )
168            .await
169    }
170
171    /// Continue this session with one prompt using one explicit profile override.
172    /// Side effects: sends turn/start RPC calls on one already-loaded thread.
173    /// Allocation: moves profile-owned Strings/vectors + one prompt String. Complexity: O(n), n = attachment count + field sizes.
174    pub async fn ask_with_profile(
175        &self,
176        prompt: impl Into<String>,
177        profile: RunProfile,
178    ) -> Result<PromptRunResult, PromptRunError> {
179        self.state.ensure_open_for_prompt()?;
180        let prepared = prepared_prompt_run_from_profile(self.config.cwd.clone(), prompt, profile);
181        let merged_hooks = merge_hook_configs(&self.config.hooks, prepared.hooks.as_ref());
182        self.runtime
183            .run_prompt_on_loaded_thread_with_hooks(
184                &self.thread_id,
185                prepared.params,
186                Some(&merged_hooks),
187            )
188            .await
189    }
190
191    /// Return current session default profile snapshot.
192    /// Allocation: clones Strings/attachments. Complexity: O(n), n = attachment count + string sizes.
193    pub fn profile(&self) -> RunProfile {
194        self.config.profile()
195    }
196
197    /// Interrupt one in-flight turn in this session.
198    /// Side effects: sends turn/interrupt RPC call to app-server.
199    /// Allocation: one small JSON payload in runtime layer. Complexity: O(1).
200    pub async fn interrupt_turn(&self, turn_id: &str) -> Result<(), RpcError> {
201        self.state.ensure_open_for_rpc()?;
202        self.runtime.turn_interrupt(&self.thread_id, turn_id).await
203    }
204
205    /// Archive this session on server side.
206    /// Side effects: sends thread/archive RPC call to app-server.
207    /// Allocation: one small JSON payload in runtime layer. Complexity: O(1).
208    pub async fn close(&self) -> Result<(), RpcError> {
209        let permit = self.state.acquire_close_permit().await;
210        match permit.next_state() {
211            SessionCloseState::ReturnCached(result) => return result,
212            SessionCloseState::StartClosing => {}
213        }
214
215        self.state.mark_closed();
216        let result = self.runtime.thread_archive(&self.thread_id).await;
217        permit.store_result(result)
218    }
219}