codex_runtime/runtime/client/
session.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::Arc;
3
4use tokio::sync::Mutex;
5
6use crate::runtime::api::{PromptRunError, PromptRunParams, PromptRunResult, PromptRunStream};
7use crate::runtime::core::Runtime;
8use crate::runtime::errors::RpcError;
9use crate::runtime::hooks::merge_hook_configs;
10
11use super::profile::{prepared_prompt_run_from_profile, session_prepared_prompt_run};
12use super::{RunProfile, SessionConfig};
13
14const SESSION_CLOSED_MESSAGE: &str = "session is closed";
15
16#[derive(Clone)]
17pub struct Session {
18 runtime: Runtime,
19 pub thread_id: String,
20 pub config: SessionConfig,
21 state: SessionState,
22}
23
24#[derive(Clone)]
25pub(super) struct SessionState {
26 closed: Arc<AtomicBool>,
27 close_result: Arc<Mutex<Option<Result<(), RpcError>>>>,
28}
29
30struct SessionClosePermit<'a> {
31 guard: tokio::sync::MutexGuard<'a, Option<Result<(), RpcError>>>,
32}
33
34#[derive(Clone, Debug, PartialEq)]
35pub(super) enum SessionCloseState {
36 ReturnCached(Result<(), RpcError>),
37 StartClosing,
38}
39
40fn ensure_session_open(closed: bool) -> Result<(), RpcError> {
41 if closed {
42 return Err(RpcError::InvalidRequest(SESSION_CLOSED_MESSAGE.to_owned()));
43 }
44 Ok(())
45}
46
47pub(super) fn next_close_state(cached: Option<&Result<(), RpcError>>) -> SessionCloseState {
48 match cached {
49 Some(result) => SessionCloseState::ReturnCached(result.clone()),
50 None => SessionCloseState::StartClosing,
51 }
52}
53
54impl SessionState {
55 pub(super) fn new() -> Self {
56 Self {
57 closed: Arc::new(AtomicBool::new(false)),
58 close_result: Arc::new(Mutex::new(None)),
59 }
60 }
61
62 fn is_closed(&self) -> bool {
63 self.closed.load(Ordering::Acquire)
64 }
65
66 pub(super) fn ensure_open_for_prompt(&self) -> Result<(), PromptRunError> {
67 ensure_session_open(self.is_closed()).map_err(PromptRunError::Rpc)
68 }
69
70 pub(super) fn ensure_open_for_rpc(&self) -> Result<(), RpcError> {
71 ensure_session_open(self.is_closed())
72 }
73
74 async fn acquire_close_permit(&self) -> SessionClosePermit<'_> {
75 SessionClosePermit {
76 guard: self.close_result.lock().await,
77 }
78 }
79
80 pub(super) fn mark_closed(&self) {
81 self.closed.store(true, Ordering::Release);
82 }
83}
84
85impl SessionClosePermit<'_> {
86 fn next_state(&self) -> SessionCloseState {
87 next_close_state(self.guard.as_ref())
88 }
89
90 fn store_result(mut self, result: Result<(), RpcError>) -> Result<(), RpcError> {
91 *self.guard = Some(result.clone());
92 result
93 }
94}
95
96impl Session {
97 pub(super) fn new(runtime: Runtime, thread_id: String, config: SessionConfig) -> Self {
98 Self {
99 runtime,
100 thread_id,
101 config,
102 state: SessionState::new(),
103 }
104 }
105
106 pub fn is_closed(&self) -> bool {
109 self.state.is_closed()
110 }
111
112 pub async fn ask(&self, prompt: impl Into<String>) -> Result<PromptRunResult, PromptRunError> {
116 self.state.ensure_open_for_prompt()?;
117 let prepared = session_prepared_prompt_run(&self.config, prompt);
118 self.runtime
119 .run_prompt_on_loaded_thread_with_hooks(
120 &self.thread_id,
121 prepared.params,
122 Some(prepared.hooks.as_ref()),
123 )
124 .await
125 }
126
127 pub async fn ask_stream(
130 &self,
131 prompt: impl Into<String>,
132 ) -> Result<PromptRunStream, PromptRunError> {
133 self.state.ensure_open_for_prompt()?;
134 let prepared = session_prepared_prompt_run(&self.config, prompt);
135 self.runtime
136 .run_prompt_on_loaded_thread_stream_with_hooks(
137 &self.thread_id,
138 prepared.params,
139 Some(prepared.hooks.as_ref()),
140 )
141 .await
142 }
143
144 pub async fn ask_wait(
148 &self,
149 prompt: impl Into<String>,
150 ) -> Result<PromptRunResult, PromptRunError> {
151 self.ask_stream(prompt).await?.finish().await
152 }
153
154 pub async fn ask_with(
158 &self,
159 params: PromptRunParams,
160 ) -> Result<PromptRunResult, PromptRunError> {
161 self.state.ensure_open_for_prompt()?;
162 self.runtime
163 .run_prompt_on_loaded_thread_with_hooks(
164 &self.thread_id,
165 params,
166 Some(&self.config.hooks),
167 )
168 .await
169 }
170
171 pub async fn ask_with_profile(
175 &self,
176 prompt: impl Into<String>,
177 profile: RunProfile,
178 ) -> Result<PromptRunResult, PromptRunError> {
179 self.state.ensure_open_for_prompt()?;
180 let prepared = prepared_prompt_run_from_profile(self.config.cwd.clone(), prompt, profile);
181 let merged_hooks = merge_hook_configs(&self.config.hooks, prepared.hooks.as_ref());
182 self.runtime
183 .run_prompt_on_loaded_thread_with_hooks(
184 &self.thread_id,
185 prepared.params,
186 Some(&merged_hooks),
187 )
188 .await
189 }
190
191 pub fn profile(&self) -> RunProfile {
194 self.config.profile()
195 }
196
197 pub async fn interrupt_turn(&self, turn_id: &str) -> Result<(), RpcError> {
201 self.state.ensure_open_for_rpc()?;
202 self.runtime.turn_interrupt(&self.thread_id, turn_id).await
203 }
204
205 pub async fn close(&self) -> Result<(), RpcError> {
209 let permit = self.state.acquire_close_permit().await;
210 match permit.next_state() {
211 SessionCloseState::ReturnCached(result) => return result,
212 SessionCloseState::StartClosing => {}
213 }
214
215 self.state.mark_closed();
216 let result = self.runtime.thread_archive(&self.thread_id).await;
217 permit.store_result(result)
218 }
219}