Skip to main content

gba_core/
session.rs

1//! Multi-turn interactive session management.
2//!
3//! This module provides the [`Session`] struct for managing multi-turn
4//! conversations with Claude, maintaining conversation history and
5//! accumulating statistics across turns.
6//!
7//! # Overview
8//!
9//! A session wraps a [`ClaudeClient`] to provide:
10//!
11//! - Multi-turn conversation support with history tracking
12//! - Streaming responses with event handlers
13//! - Cumulative statistics across all conversation turns
14//! - Session isolation via unique session IDs
15//!
16//! # Example
17//!
18//! ```no_run
19//! use gba_core::{Engine, EngineConfig, TaskStats};
20//! use gba_core::event::PrintEventHandler;
21//! use gba_pm::PromptManager;
22//!
23//! # async fn example() -> gba_core::Result<()> {
24//! // Create engine
25//! let mut prompts = PromptManager::new();
26//! prompts.load_dir("./tasks")?;
27//!
28//! let config = EngineConfig::builder()
29//!     .workdir(".")
30//!     .prompts(prompts)
31//!     .build();
32//! let engine = Engine::new(config)?;
33//!
34//! // Create a session
35//! let mut session = engine.session(None)?;
36//!
37//! // Send messages (non-streaming)
38//! let response = session.send("What is Rust?").await?;
39//! println!("Response: {}", response);
40//!
41//! // Send with streaming
42//! let mut handler = PrintEventHandler::new().with_auto_flush();
43//! let response = session.send_stream("Tell me more about ownership", &mut handler).await?;
44//!
45//! // Check accumulated stats
46//! let stats = session.stats();
47//! println!("Total turns: {}", stats.turns);
48//! println!("Total cost: ${:.4}", stats.cost_usd);
49//!
50//! // Disconnect when done
51//! session.disconnect().await?;
52//! # Ok(())
53//! # }
54//! ```
55
56use std::path::PathBuf;
57
58use claude_agent_sdk_rs::{
59    ClaudeAgentOptions, ClaudeClient, ContentBlock, Message, PermissionMode, ResultMessage,
60    SystemPrompt, ToolResultContent,
61};
62use futures::StreamExt;
63use tracing::{debug, info, instrument, trace};
64use uuid::Uuid;
65
66use crate::config::TaskConfig;
67use crate::error::{EngineError, Result};
68use crate::event::EventHandler;
69use crate::task::TaskStats;
70
71/// A message in a conversation.
72///
73/// Represents either a user message or an assistant response
74/// in the conversation history.
75#[derive(Debug, Clone)]
76pub enum ConversationMessage {
77    /// A message sent by the user.
78    User(String),
79    /// A response from the assistant.
80    Assistant(String),
81}
82
83impl ConversationMessage {
84    /// Get the content of the message.
85    #[must_use]
86    pub fn content(&self) -> &str {
87        match self {
88            Self::User(content) | Self::Assistant(content) => content,
89        }
90    }
91
92    /// Check if this is a user message.
93    #[must_use]
94    pub fn is_user(&self) -> bool {
95        matches!(self, Self::User(_))
96    }
97
98    /// Check if this is an assistant message.
99    #[must_use]
100    pub fn is_assistant(&self) -> bool {
101        matches!(self, Self::Assistant(_))
102    }
103}
104
105/// Multi-turn interactive session with Claude.
106///
107/// A session maintains a persistent connection to Claude, enabling
108/// multi-turn conversations while tracking history and statistics.
109///
110/// Sessions are created via [`Engine::session()`](crate::Engine::session).
111pub struct Session {
112    /// The Claude client for bidirectional streaming.
113    client: ClaudeClient,
114    /// Unique session identifier.
115    session_id: String,
116    /// Conversation history.
117    history: Vec<ConversationMessage>,
118    /// Accumulated statistics across all turns.
119    stats: TaskStats,
120    /// Whether the session is connected.
121    connected: bool,
122}
123
124impl std::fmt::Debug for Session {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("Session")
127            .field("session_id", &self.session_id)
128            .field("history_len", &self.history.len())
129            .field("stats", &self.stats)
130            .field("connected", &self.connected)
131            .finish()
132    }
133}
134
135impl Session {
136    /// Create a new session with the given options.
137    ///
138    /// # Arguments
139    ///
140    /// * `options` - Claude agent options for the session
141    /// * `session_id` - Optional session ID; if None, a UUID is generated
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if the client cannot be created.
146    pub(crate) fn new(options: ClaudeAgentOptions, session_id: Option<String>) -> Result<Self> {
147        let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
148        debug!(session_id = %session_id, "creating new session");
149
150        let client = ClaudeClient::new(options);
151
152        Ok(Self {
153            client,
154            session_id,
155            history: Vec::new(),
156            stats: TaskStats::default(),
157            connected: false,
158        })
159    }
160
161    /// Connect the session.
162    ///
163    /// This must be called before sending messages.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if the connection fails.
168    pub async fn connect(&mut self) -> Result<()> {
169        if self.connected {
170            return Ok(());
171        }
172
173        debug!(session_id = %self.session_id, "connecting session");
174        self.client.connect().await?;
175        self.connected = true;
176        info!(session_id = %self.session_id, "session connected");
177
178        Ok(())
179    }
180
181    /// Send a message and get the complete response.
182    ///
183    /// This method sends a message to Claude and waits for the complete
184    /// response. For streaming responses, use [`send_stream`](Self::send_stream).
185    ///
186    /// # Arguments
187    ///
188    /// * `message` - The message to send
189    ///
190    /// # Errors
191    ///
192    /// Returns an error if:
193    /// - The session is not connected
194    /// - Sending the message fails
195    /// - Receiving the response fails
196    ///
197    /// # Example
198    ///
199    /// ```no_run
200    /// # use gba_core::session::Session;
201    /// # async fn example(session: &mut Session) -> gba_core::Result<()> {
202    /// let response = session.send("Hello Claude!").await?;
203    /// println!("Claude says: {}", response);
204    /// # Ok(())
205    /// # }
206    /// ```
207    #[instrument(skip(self, message), fields(session_id = %self.session_id))]
208    pub async fn send(&mut self, message: &str) -> Result<String> {
209        self.ensure_connected().await?;
210
211        info!("sending message");
212        self.history
213            .push(ConversationMessage::User(message.to_string()));
214
215        // Send the query
216        self.client
217            .query_with_session(message, &self.session_id)
218            .await?;
219
220        // Collect all messages first to avoid borrow conflicts
221        let mut messages = Vec::new();
222        {
223            let mut stream = self.client.receive_response();
224            while let Some(result) = stream.next().await {
225                messages.push(result?);
226            }
227        }
228
229        // Process collected messages
230        let mut response_text = String::new();
231        for msg in &messages {
232            self.process_message_no_handler(msg, &mut response_text);
233        }
234
235        // Store assistant response
236        self.history
237            .push(ConversationMessage::Assistant(response_text.clone()));
238        debug!(
239            response_len = response_text.len(),
240            "message sent and response received"
241        );
242
243        Ok(response_text)
244    }
245
246    /// Send a message with streaming events.
247    ///
248    /// This method sends a message to Claude and streams the response
249    /// through an event handler, allowing real-time processing of the
250    /// response.
251    ///
252    /// # Arguments
253    ///
254    /// * `message` - The message to send
255    /// * `handler` - Event handler for streaming events
256    ///
257    /// # Errors
258    ///
259    /// Returns an error if:
260    /// - The session is not connected
261    /// - Sending the message fails
262    /// - Receiving the response fails
263    ///
264    /// # Example
265    ///
266    /// ```no_run
267    /// # use gba_core::session::Session;
268    /// # use gba_core::event::PrintEventHandler;
269    /// # async fn example(session: &mut Session) -> gba_core::Result<()> {
270    /// let mut handler = PrintEventHandler::new().with_auto_flush();
271    /// let response = session.send_stream("Explain async/await", &mut handler).await?;
272    /// # Ok(())
273    /// # }
274    /// ```
275    #[instrument(skip(self, message, handler), fields(session_id = %self.session_id))]
276    pub async fn send_stream(
277        &mut self,
278        message: &str,
279        handler: &mut impl EventHandler,
280    ) -> Result<String> {
281        self.ensure_connected().await?;
282
283        info!("sending message with streaming");
284        self.history
285            .push(ConversationMessage::User(message.to_string()));
286
287        // Send the query
288        self.client
289            .query_with_session(message, &self.session_id)
290            .await?;
291
292        // Collect messages first, then process them
293        let mut messages = Vec::new();
294        {
295            let mut stream = self.client.receive_response();
296            while let Some(result) = stream.next().await {
297                match result {
298                    Ok(msg) => messages.push(msg),
299                    Err(e) => {
300                        let error_msg = e.to_string();
301                        handler.on_error(&error_msg);
302                        return Err(e.into());
303                    }
304                }
305            }
306        }
307
308        // Process collected messages with handler
309        let mut response_text = String::new();
310        for msg in &messages {
311            self.process_message_with_handler(msg, &mut response_text, handler);
312        }
313
314        handler.on_complete();
315
316        // Store assistant response
317        self.history
318            .push(ConversationMessage::Assistant(response_text.clone()));
319        debug!(
320            response_len = response_text.len(),
321            "streaming message sent and response received"
322        );
323
324        Ok(response_text)
325    }
326
327    /// Get the conversation history.
328    ///
329    /// Returns all messages exchanged in this session, in chronological order.
330    #[must_use]
331    pub fn history(&self) -> &[ConversationMessage] {
332        &self.history
333    }
334
335    /// Clear the conversation history.
336    ///
337    /// This clears the local history but does not affect the Claude session's
338    /// memory. To start a completely fresh conversation, create a new session.
339    pub fn clear(&mut self) {
340        self.history.clear();
341        debug!(session_id = %self.session_id, "conversation history cleared");
342    }
343
344    /// Get the accumulated statistics for this session.
345    ///
346    /// Statistics are accumulated across all turns in the session.
347    #[must_use]
348    pub fn stats(&self) -> &TaskStats {
349        &self.stats
350    }
351
352    /// Get the session ID.
353    #[must_use]
354    pub fn session_id(&self) -> &str {
355        &self.session_id
356    }
357
358    /// Check if the session is connected.
359    #[must_use]
360    pub fn is_connected(&self) -> bool {
361        self.connected
362    }
363
364    /// Interrupt the current operation.
365    ///
366    /// This sends an interrupt signal to stop any ongoing Claude operation.
367    ///
368    /// # Errors
369    ///
370    /// Returns an error if the session is not connected or interruption fails.
371    pub async fn interrupt(&self) -> Result<()> {
372        if !self.connected {
373            return Err(EngineError::config_error("Session not connected"));
374        }
375
376        self.client.interrupt().await?;
377        debug!(session_id = %self.session_id, "interrupt sent");
378
379        Ok(())
380    }
381
382    /// Disconnect the session.
383    ///
384    /// This cleanly disconnects from Claude. The session cannot be used
385    /// after disconnection.
386    ///
387    /// # Errors
388    ///
389    /// Returns an error if disconnection fails.
390    pub async fn disconnect(&mut self) -> Result<()> {
391        if !self.connected {
392            return Ok(());
393        }
394
395        debug!(session_id = %self.session_id, "disconnecting session");
396        self.client.disconnect().await?;
397        self.connected = false;
398        info!(session_id = %self.session_id, "session disconnected");
399
400        Ok(())
401    }
402
403    /// Ensure the session is connected, connecting if necessary.
404    async fn ensure_connected(&mut self) -> Result<()> {
405        if !self.connected {
406            self.connect().await?;
407        }
408        Ok(())
409    }
410
411    /// Process a message from the stream without a handler.
412    fn process_message_no_handler(&mut self, msg: &Message, response_text: &mut String) {
413        match msg {
414            Message::Assistant(assistant_msg) => {
415                for block in &assistant_msg.message.content {
416                    if let ContentBlock::Text(text) = block {
417                        response_text.push_str(&text.text);
418                    }
419                }
420            }
421            Message::Result(result_msg) => {
422                self.update_stats_from_result(result_msg);
423            }
424            Message::User(_)
425            | Message::System(_)
426            | Message::StreamEvent(_)
427            | Message::ControlCancelRequest(_) => {
428                // Ignore these message types
429            }
430        }
431    }
432
433    /// Process a message from the stream with an event handler.
434    fn process_message_with_handler(
435        &mut self,
436        msg: &Message,
437        response_text: &mut String,
438        handler: &mut impl EventHandler,
439    ) {
440        match msg {
441            Message::Assistant(assistant_msg) => {
442                for block in &assistant_msg.message.content {
443                    match block {
444                        ContentBlock::Text(text) => {
445                            response_text.push_str(&text.text);
446                            handler.on_text(&text.text);
447                        }
448                        ContentBlock::ToolUse(tool_use) => {
449                            handler.on_tool_use(&tool_use.name, &tool_use.input);
450                        }
451                        _ => {}
452                    }
453                }
454            }
455            Message::User(user_msg) => {
456                // Handle tool results from user messages
457                if let Some(ref content) = user_msg.content {
458                    for block in content {
459                        if let ContentBlock::ToolResult(tool_result) = block {
460                            let result_str = match &tool_result.content {
461                                Some(ToolResultContent::Text(s)) => s.as_str(),
462                                Some(ToolResultContent::Blocks(_)) => "[structured content]",
463                                None => "",
464                            };
465                            handler.on_tool_result(result_str);
466                        }
467                    }
468                }
469            }
470            Message::Result(result_msg) => {
471                self.update_stats_from_result(result_msg);
472
473                if result_msg.is_error {
474                    handler.on_error("Claude reported an error");
475                }
476            }
477            Message::System(_) | Message::StreamEvent(_) | Message::ControlCancelRequest(_) => {
478                // Ignore these message types
479            }
480        }
481    }
482
483    /// Update session stats from a result message.
484    fn update_stats_from_result(&mut self, result_msg: &ResultMessage) {
485        self.stats.turns += result_msg.num_turns;
486        self.stats.cost_usd += result_msg.total_cost_usd.unwrap_or(0.0);
487
488        if let Some(usage) = &result_msg.usage {
489            if let Some(input) = usage.get("input_tokens").and_then(|v| v.as_u64()) {
490                self.stats.input_tokens += input;
491            }
492            if let Some(output) = usage.get("output_tokens").and_then(|v| v.as_u64()) {
493                self.stats.output_tokens += output;
494            }
495        }
496
497        trace!(
498            turns = result_msg.num_turns,
499            cost = result_msg.total_cost_usd,
500            "result message processed"
501        );
502    }
503}
504
505/// Builder for creating sessions with custom options.
506///
507/// This is used internally by [`Engine::session()`](crate::Engine::session).
508pub struct SessionBuilder {
509    workdir: PathBuf,
510    base_options: Option<ClaudeAgentOptions>,
511    task_config: Option<TaskConfig>,
512    system_prompt: Option<SystemPrompt>,
513    session_id: Option<String>,
514}
515
516impl std::fmt::Debug for SessionBuilder {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        f.debug_struct("SessionBuilder")
519            .field("workdir", &self.workdir)
520            .field("task_config", &self.task_config)
521            .field("session_id", &self.session_id)
522            .finish()
523    }
524}
525
526impl SessionBuilder {
527    /// Create a new session builder.
528    pub(crate) fn new(workdir: PathBuf) -> Self {
529        Self {
530            workdir,
531            base_options: None,
532            task_config: None,
533            system_prompt: None,
534            session_id: None,
535        }
536    }
537
538    /// Set base agent options.
539    pub(crate) fn with_base_options(mut self, options: ClaudeAgentOptions) -> Self {
540        self.base_options = Some(options);
541        self
542    }
543
544    /// Set task configuration.
545    pub(crate) fn with_task_config(mut self, config: TaskConfig) -> Self {
546        self.task_config = Some(config);
547        self
548    }
549
550    /// Set system prompt.
551    pub(crate) fn with_system_prompt(mut self, prompt: SystemPrompt) -> Self {
552        self.system_prompt = Some(prompt);
553        self
554    }
555
556    /// Set session ID.
557    pub(crate) fn with_session_id(mut self, id: String) -> Self {
558        self.session_id = Some(id);
559        self
560    }
561
562    /// Build the session.
563    ///
564    /// # Errors
565    ///
566    /// Returns an error if the session cannot be created.
567    pub(crate) fn build(self) -> Result<Session> {
568        let mut options = ClaudeAgentOptions::default();
569
570        // Apply base options
571        if let Some(base) = self.base_options {
572            if base.model.is_some() {
573                options.model = base.model;
574            }
575            if base.permission_mode.is_some() {
576                options.permission_mode = base.permission_mode;
577            }
578            if base.max_turns.is_some() {
579                options.max_turns = base.max_turns;
580            }
581            if base.cwd.is_some() {
582                options.cwd = base.cwd;
583            }
584        }
585
586        // Set working directory
587        if options.cwd.is_none() {
588            options.cwd = Some(self.workdir);
589        }
590
591        // Apply task config
592        if let Some(config) = self.task_config {
593            if !config.tools.is_empty() {
594                options.allowed_tools = config.tools;
595            }
596            if !config.disallowed_tools.is_empty() {
597                options.disallowed_tools = config.disallowed_tools;
598            }
599        }
600
601        // Set system prompt
602        if let Some(prompt) = self.system_prompt {
603            options.system_prompt = Some(prompt);
604        }
605
606        // Default to bypass permissions - no approval needed for any operation
607        if options.permission_mode.is_none() {
608            options.permission_mode = Some(PermissionMode::BypassPermissions);
609        }
610
611        // Skip version check for faster startup
612        options.skip_version_check = true;
613
614        Session::new(options, self.session_id)
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    #[test]
623    fn test_should_create_conversation_message() {
624        let user_msg = ConversationMessage::User("Hello".to_string());
625        let assistant_msg = ConversationMessage::Assistant("Hi there".to_string());
626
627        assert!(user_msg.is_user());
628        assert!(!user_msg.is_assistant());
629        assert_eq!(user_msg.content(), "Hello");
630
631        assert!(assistant_msg.is_assistant());
632        assert!(!assistant_msg.is_user());
633        assert_eq!(assistant_msg.content(), "Hi there");
634    }
635
636    #[test]
637    fn test_should_build_session_with_defaults() {
638        let builder = SessionBuilder::new(PathBuf::from("/tmp/test"));
639        let session = builder.build().unwrap();
640
641        assert!(!session.session_id().is_empty());
642        assert!(session.history().is_empty());
643        assert_eq!(session.stats().turns, 0);
644    }
645
646    #[test]
647    fn test_should_build_session_with_custom_id() {
648        let builder = SessionBuilder::new(PathBuf::from("/tmp/test"))
649            .with_session_id("custom-session".to_string());
650        let session = builder.build().unwrap();
651
652        assert_eq!(session.session_id(), "custom-session");
653    }
654
655    #[test]
656    fn test_should_clear_history() {
657        let builder = SessionBuilder::new(PathBuf::from("/tmp/test"));
658        let mut session = builder.build().unwrap();
659
660        // Manually add history for testing
661        session
662            .history
663            .push(ConversationMessage::User("test".to_string()));
664        session
665            .history
666            .push(ConversationMessage::Assistant("response".to_string()));
667
668        assert_eq!(session.history().len(), 2);
669
670        session.clear();
671
672        assert!(session.history().is_empty());
673    }
674
675    #[test]
676    fn test_task_stats_accumulation() {
677        let mut stats = TaskStats::default();
678
679        stats.turns += 5;
680        stats.input_tokens += 1000;
681        stats.output_tokens += 500;
682        stats.cost_usd += 0.05;
683
684        stats.turns += 3;
685        stats.input_tokens += 800;
686        stats.output_tokens += 400;
687        stats.cost_usd += 0.03;
688
689        assert_eq!(stats.turns, 8);
690        assert_eq!(stats.input_tokens, 1800);
691        assert_eq!(stats.output_tokens, 900);
692        assert!((stats.cost_usd - 0.08).abs() < f64::EPSILON);
693    }
694}