Skip to main content

codex_wrapper/
session.rs

1//! Stateful multi-turn session manager for the Codex CLI.
2//!
3//! [`Session`] wraps a [`Codex`] client and automatically threads
4//! conversation state across turns. The first call to [`send`](Session::send)
5//! dispatches via [`ExecCommand`]; subsequent calls use
6//! [`ExecResumeCommand`] with the captured `thread_id`.
7//!
8//! # Example
9//!
10//! ```no_run
11//! use std::sync::Arc;
12//! use codex_wrapper::{Codex, Session};
13//!
14//! # async fn example() -> codex_wrapper::Result<()> {
15//! let codex = Arc::new(Codex::builder().build()?);
16//! let mut session = Session::new(codex);
17//!
18//! let events = session.send("create a hello world program").await?;
19//! println!("turn 1: {} events", events.len());
20//!
21//! let events = session.send("now add error handling").await?;
22//! println!("turn 2: {} events, thread_id={:?}", events.len(), session.id());
23//! # Ok(())
24//! # }
25//! ```
26
27use std::sync::Arc;
28
29use crate::Codex;
30use crate::command::exec::{ExecCommand, ExecResumeCommand};
31use crate::error::{Error, Result};
32use crate::types::JsonLineEvent;
33
34/// A record of a single turn within a session.
35#[derive(Debug, Clone)]
36pub struct TurnRecord {
37    /// The parsed JSONL events returned by this turn.
38    pub events: Vec<JsonLineEvent>,
39}
40
41/// Stateful multi-turn session manager.
42///
43/// Wraps a [`Codex`] client and automatically threads conversation state
44/// across turns. On the first turn, an [`ExecCommand`] is used; on subsequent
45/// turns, an [`ExecResumeCommand`] resumes the session using the `thread_id`
46/// extracted from the JSONL event stream.
47///
48/// The `thread_id` is preserved even when a turn fails, as long as at least
49/// one event in the output carried it.
50///
51/// # Example
52///
53/// ```no_run
54/// use std::sync::Arc;
55/// use codex_wrapper::{Codex, Session};
56///
57/// # async fn example() -> codex_wrapper::Result<()> {
58/// let codex = Arc::new(Codex::builder().build()?);
59/// let mut session = Session::new(codex);
60///
61/// let events = session.send("summarize this repo").await?;
62/// assert!(session.id().is_some());
63/// assert_eq!(session.total_turns(), 1);
64///
65/// let events = session.send("now add more detail").await?;
66/// assert_eq!(session.total_turns(), 2);
67/// # Ok(())
68/// # }
69/// ```
70pub struct Session {
71    codex: Arc<Codex>,
72    thread_id: Option<String>,
73    history: Vec<TurnRecord>,
74}
75
76impl Session {
77    /// Create a new session with no prior state.
78    ///
79    /// The first call to [`send`](Session::send) will use [`ExecCommand`].
80    pub fn new(codex: Arc<Codex>) -> Self {
81        Self {
82            codex,
83            thread_id: None,
84            history: Vec::new(),
85        }
86    }
87
88    /// Resume an existing session by its `thread_id`.
89    ///
90    /// The next call to [`send`](Session::send) will use
91    /// [`ExecResumeCommand`] with the provided ID.
92    pub fn resume(codex: Arc<Codex>, thread_id: impl Into<String>) -> Self {
93        Self {
94            codex,
95            thread_id: Some(thread_id.into()),
96            history: Vec::new(),
97        }
98    }
99
100    /// Send a prompt, automatically routing to `exec` or `exec resume`.
101    ///
102    /// On the first turn (no `thread_id`), dispatches via [`ExecCommand`].
103    /// On subsequent turns, dispatches via [`ExecResumeCommand`] with the
104    /// captured `thread_id`.
105    ///
106    /// Returns the parsed JSONL events for this turn.
107    pub async fn send(&mut self, prompt: impl Into<String>) -> Result<Vec<JsonLineEvent>> {
108        let prompt = prompt.into();
109
110        match &self.thread_id {
111            None => {
112                let cmd = ExecCommand::new(&prompt);
113                self.run_exec(cmd).await
114            }
115            Some(id) => {
116                let cmd = ExecResumeCommand::new()
117                    .session_id(id.clone())
118                    .prompt(prompt);
119                self.run_resume(cmd).await
120            }
121        }
122    }
123
124    /// Execute an [`ExecCommand`] with full control over its options.
125    ///
126    /// Use this when you need to configure model, sandbox, approval policy,
127    /// or other flags beyond what [`send`](Session::send) provides.
128    /// The session still captures the `thread_id` from the output.
129    pub async fn execute(&mut self, cmd: ExecCommand) -> Result<Vec<JsonLineEvent>> {
130        self.run_exec(cmd).await
131    }
132
133    /// Execute an [`ExecResumeCommand`] with full control over its options.
134    ///
135    /// Use this when you need to configure flags on the resume command
136    /// beyond what [`send`](Session::send) provides.
137    /// The session still captures the `thread_id` from the output.
138    pub async fn execute_resume(&mut self, cmd: ExecResumeCommand) -> Result<Vec<JsonLineEvent>> {
139        self.run_resume(cmd).await
140    }
141
142    // TODO: streaming support depends on #20
143    // pub async fn stream(&mut self, prompt: impl Into<String>) -> ...
144    // pub async fn stream_execute(&mut self, cmd: ExecCommand) -> ...
145
146    /// Returns the `thread_id` captured from the most recent turn, if any.
147    #[must_use]
148    pub fn id(&self) -> Option<&str> {
149        self.thread_id.as_deref()
150    }
151
152    /// Total number of completed turns in this session.
153    #[must_use]
154    pub fn total_turns(&self) -> usize {
155        self.history.len()
156    }
157
158    /// Borrow the full turn history.
159    #[must_use]
160    pub fn history(&self) -> &[TurnRecord] {
161        &self.history
162    }
163
164    /// Run an [`ExecCommand`] and record the turn.
165    async fn run_exec(&mut self, cmd: ExecCommand) -> Result<Vec<JsonLineEvent>> {
166        match cmd.execute_json_lines(&self.codex).await {
167            Ok(events) => {
168                self.capture_thread_id(&events);
169                self.history.push(TurnRecord {
170                    events: events.clone(),
171                });
172                Ok(events)
173            }
174            Err(Error::CommandFailed {
175                stdout,
176                stderr,
177                exit_code,
178                command,
179                working_dir,
180            }) => {
181                self.try_capture_thread_id_from_stdout(&stdout);
182                Err(Error::CommandFailed {
183                    stdout,
184                    stderr,
185                    exit_code,
186                    command,
187                    working_dir,
188                })
189            }
190            Err(e) => Err(e),
191        }
192    }
193
194    /// Run an [`ExecResumeCommand`] and record the turn.
195    async fn run_resume(&mut self, cmd: ExecResumeCommand) -> Result<Vec<JsonLineEvent>> {
196        match cmd.execute_json_lines(&self.codex).await {
197            Ok(events) => {
198                self.capture_thread_id(&events);
199                self.history.push(TurnRecord {
200                    events: events.clone(),
201                });
202                Ok(events)
203            }
204            Err(Error::CommandFailed {
205                stdout,
206                stderr,
207                exit_code,
208                command,
209                working_dir,
210            }) => {
211                self.try_capture_thread_id_from_stdout(&stdout);
212                Err(Error::CommandFailed {
213                    stdout,
214                    stderr,
215                    exit_code,
216                    command,
217                    working_dir,
218                })
219            }
220            Err(e) => Err(e),
221        }
222    }
223
224    /// Extract `thread_id` from parsed events (first match wins).
225    fn capture_thread_id(&mut self, events: &[JsonLineEvent]) {
226        if let Some(id) = events.iter().find_map(|e| e.thread_id()) {
227            self.thread_id = Some(id.to_string());
228        }
229    }
230
231    /// Best-effort extraction of `thread_id` from raw stdout on error paths.
232    fn try_capture_thread_id_from_stdout(&mut self, stdout: &str) {
233        for line in stdout.lines() {
234            if let Ok(event) = serde_json::from_str::<JsonLineEvent>(line)
235                && let Some(id) = event.thread_id()
236            {
237                self.thread_id = Some(id.to_string());
238                return;
239            }
240        }
241    }
242}
243
244impl std::fmt::Debug for Session {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        f.debug_struct("Session")
247            .field("thread_id", &self.thread_id)
248            .field("total_turns", &self.history.len())
249            .finish_non_exhaustive()
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    fn test_codex() -> Arc<Codex> {
258        Arc::new(Codex::builder().binary("/usr/bin/false").build().unwrap())
259    }
260
261    #[test]
262    fn new_session_has_no_state() {
263        let session = Session::new(test_codex());
264        assert!(session.id().is_none());
265        assert_eq!(session.total_turns(), 0);
266        assert!(session.history().is_empty());
267    }
268
269    #[test]
270    fn resume_session_has_thread_id() {
271        let session = Session::resume(test_codex(), "thread_abc");
272        assert_eq!(session.id(), Some("thread_abc"));
273        assert_eq!(session.total_turns(), 0);
274    }
275
276    #[test]
277    fn capture_thread_id_from_events() {
278        let mut session = Session::new(test_codex());
279        let events: Vec<JsonLineEvent> = vec![
280            serde_json::from_str(r#"{"type":"message.created","role":"assistant"}"#).unwrap(),
281            serde_json::from_str(
282                r#"{"type":"thread.started","thread_id":"thread_xyz","session_id":"sess_1"}"#,
283            )
284            .unwrap(),
285        ];
286        session.capture_thread_id(&events);
287        assert_eq!(session.id(), Some("thread_xyz"));
288    }
289
290    #[test]
291    fn capture_thread_id_noop_when_absent() {
292        let mut session = Session::new(test_codex());
293        let events: Vec<JsonLineEvent> =
294            vec![serde_json::from_str(r#"{"type":"message.created"}"#).unwrap()];
295        session.capture_thread_id(&events);
296        assert!(session.id().is_none());
297    }
298
299    #[test]
300    fn try_capture_thread_id_from_stdout_parses_json() {
301        let mut session = Session::new(test_codex());
302        let stdout = r#"{"type":"thread.started","thread_id":"thread_err"}
303{"type":"error","message":"something went wrong"}"#;
304        session.try_capture_thread_id_from_stdout(stdout);
305        assert_eq!(session.id(), Some("thread_err"));
306    }
307
308    #[test]
309    fn try_capture_thread_id_from_stdout_ignores_garbage() {
310        let mut session = Session::new(test_codex());
311        session.try_capture_thread_id_from_stdout("not json\nalso not json");
312        assert!(session.id().is_none());
313    }
314
315    #[test]
316    fn debug_impl() {
317        let session = Session::resume(test_codex(), "thread_dbg");
318        let debug = format!("{session:?}");
319        assert!(debug.contains("thread_dbg"));
320        assert!(debug.contains("total_turns: 0"));
321    }
322}