1use std::path::PathBuf;
57
58use claude_agent_sdk_rs::{
59 ClaudeAgentOptions, ClaudeClient, ContentBlock, Message, PermissionMode, ResultMessage,
60 SystemPrompt, ToolResultContent,
61};
62use futures::StreamExt;
63use tracing::{debug, info, instrument, trace};
64use uuid::Uuid;
65
66use crate::config::TaskConfig;
67use crate::error::{EngineError, Result};
68use crate::event::EventHandler;
69use crate::task::TaskStats;
70
71#[derive(Debug, Clone)]
76pub enum ConversationMessage {
77 User(String),
79 Assistant(String),
81}
82
83impl ConversationMessage {
84 #[must_use]
86 pub fn content(&self) -> &str {
87 match self {
88 Self::User(content) | Self::Assistant(content) => content,
89 }
90 }
91
92 #[must_use]
94 pub fn is_user(&self) -> bool {
95 matches!(self, Self::User(_))
96 }
97
98 #[must_use]
100 pub fn is_assistant(&self) -> bool {
101 matches!(self, Self::Assistant(_))
102 }
103}
104
105pub struct Session {
112 client: ClaudeClient,
114 session_id: String,
116 history: Vec<ConversationMessage>,
118 stats: TaskStats,
120 connected: bool,
122}
123
124impl std::fmt::Debug for Session {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("Session")
127 .field("session_id", &self.session_id)
128 .field("history_len", &self.history.len())
129 .field("stats", &self.stats)
130 .field("connected", &self.connected)
131 .finish()
132 }
133}
134
135impl Session {
136 pub(crate) fn new(options: ClaudeAgentOptions, session_id: Option<String>) -> Result<Self> {
147 let session_id = session_id.unwrap_or_else(|| Uuid::new_v4().to_string());
148 debug!(session_id = %session_id, "creating new session");
149
150 let client = ClaudeClient::new(options);
151
152 Ok(Self {
153 client,
154 session_id,
155 history: Vec::new(),
156 stats: TaskStats::default(),
157 connected: false,
158 })
159 }
160
161 pub async fn connect(&mut self) -> Result<()> {
169 if self.connected {
170 return Ok(());
171 }
172
173 debug!(session_id = %self.session_id, "connecting session");
174 self.client.connect().await?;
175 self.connected = true;
176 info!(session_id = %self.session_id, "session connected");
177
178 Ok(())
179 }
180
181 #[instrument(skip(self, message), fields(session_id = %self.session_id))]
208 pub async fn send(&mut self, message: &str) -> Result<String> {
209 self.ensure_connected().await?;
210
211 info!("sending message");
212 self.history
213 .push(ConversationMessage::User(message.to_string()));
214
215 self.client
217 .query_with_session(message, &self.session_id)
218 .await?;
219
220 let mut messages = Vec::new();
222 {
223 let mut stream = self.client.receive_response();
224 while let Some(result) = stream.next().await {
225 messages.push(result?);
226 }
227 }
228
229 let mut response_text = String::new();
231 for msg in &messages {
232 self.process_message_no_handler(msg, &mut response_text);
233 }
234
235 self.history
237 .push(ConversationMessage::Assistant(response_text.clone()));
238 debug!(
239 response_len = response_text.len(),
240 "message sent and response received"
241 );
242
243 Ok(response_text)
244 }
245
246 #[instrument(skip(self, message, handler), fields(session_id = %self.session_id))]
276 pub async fn send_stream(
277 &mut self,
278 message: &str,
279 handler: &mut impl EventHandler,
280 ) -> Result<String> {
281 self.ensure_connected().await?;
282
283 info!("sending message with streaming");
284 self.history
285 .push(ConversationMessage::User(message.to_string()));
286
287 self.client
289 .query_with_session(message, &self.session_id)
290 .await?;
291
292 let mut messages = Vec::new();
294 {
295 let mut stream = self.client.receive_response();
296 while let Some(result) = stream.next().await {
297 match result {
298 Ok(msg) => messages.push(msg),
299 Err(e) => {
300 let error_msg = e.to_string();
301 handler.on_error(&error_msg);
302 return Err(e.into());
303 }
304 }
305 }
306 }
307
308 let mut response_text = String::new();
310 for msg in &messages {
311 self.process_message_with_handler(msg, &mut response_text, handler);
312 }
313
314 handler.on_complete();
315
316 self.history
318 .push(ConversationMessage::Assistant(response_text.clone()));
319 debug!(
320 response_len = response_text.len(),
321 "streaming message sent and response received"
322 );
323
324 Ok(response_text)
325 }
326
327 #[must_use]
331 pub fn history(&self) -> &[ConversationMessage] {
332 &self.history
333 }
334
335 pub fn clear(&mut self) {
340 self.history.clear();
341 debug!(session_id = %self.session_id, "conversation history cleared");
342 }
343
344 #[must_use]
348 pub fn stats(&self) -> &TaskStats {
349 &self.stats
350 }
351
352 #[must_use]
354 pub fn session_id(&self) -> &str {
355 &self.session_id
356 }
357
358 #[must_use]
360 pub fn is_connected(&self) -> bool {
361 self.connected
362 }
363
364 pub async fn interrupt(&self) -> Result<()> {
372 if !self.connected {
373 return Err(EngineError::config_error("Session not connected"));
374 }
375
376 self.client.interrupt().await?;
377 debug!(session_id = %self.session_id, "interrupt sent");
378
379 Ok(())
380 }
381
382 pub async fn disconnect(&mut self) -> Result<()> {
391 if !self.connected {
392 return Ok(());
393 }
394
395 debug!(session_id = %self.session_id, "disconnecting session");
396 self.client.disconnect().await?;
397 self.connected = false;
398 info!(session_id = %self.session_id, "session disconnected");
399
400 Ok(())
401 }
402
403 async fn ensure_connected(&mut self) -> Result<()> {
405 if !self.connected {
406 self.connect().await?;
407 }
408 Ok(())
409 }
410
411 fn process_message_no_handler(&mut self, msg: &Message, response_text: &mut String) {
413 match msg {
414 Message::Assistant(assistant_msg) => {
415 for block in &assistant_msg.message.content {
416 if let ContentBlock::Text(text) = block {
417 response_text.push_str(&text.text);
418 }
419 }
420 }
421 Message::Result(result_msg) => {
422 self.update_stats_from_result(result_msg);
423 }
424 Message::User(_)
425 | Message::System(_)
426 | Message::StreamEvent(_)
427 | Message::ControlCancelRequest(_) => {
428 }
430 }
431 }
432
433 fn process_message_with_handler(
435 &mut self,
436 msg: &Message,
437 response_text: &mut String,
438 handler: &mut impl EventHandler,
439 ) {
440 match msg {
441 Message::Assistant(assistant_msg) => {
442 for block in &assistant_msg.message.content {
443 match block {
444 ContentBlock::Text(text) => {
445 response_text.push_str(&text.text);
446 handler.on_text(&text.text);
447 }
448 ContentBlock::ToolUse(tool_use) => {
449 handler.on_tool_use(&tool_use.name, &tool_use.input);
450 }
451 _ => {}
452 }
453 }
454 }
455 Message::User(user_msg) => {
456 if let Some(ref content) = user_msg.content {
458 for block in content {
459 if let ContentBlock::ToolResult(tool_result) = block {
460 let result_str = match &tool_result.content {
461 Some(ToolResultContent::Text(s)) => s.as_str(),
462 Some(ToolResultContent::Blocks(_)) => "[structured content]",
463 None => "",
464 };
465 handler.on_tool_result(result_str);
466 }
467 }
468 }
469 }
470 Message::Result(result_msg) => {
471 self.update_stats_from_result(result_msg);
472
473 if result_msg.is_error {
474 handler.on_error("Claude reported an error");
475 }
476 }
477 Message::System(_) | Message::StreamEvent(_) | Message::ControlCancelRequest(_) => {
478 }
480 }
481 }
482
483 fn update_stats_from_result(&mut self, result_msg: &ResultMessage) {
485 self.stats.turns += result_msg.num_turns;
486 self.stats.cost_usd += result_msg.total_cost_usd.unwrap_or(0.0);
487
488 if let Some(usage) = &result_msg.usage {
489 if let Some(input) = usage.get("input_tokens").and_then(|v| v.as_u64()) {
490 self.stats.input_tokens += input;
491 }
492 if let Some(output) = usage.get("output_tokens").and_then(|v| v.as_u64()) {
493 self.stats.output_tokens += output;
494 }
495 }
496
497 trace!(
498 turns = result_msg.num_turns,
499 cost = result_msg.total_cost_usd,
500 "result message processed"
501 );
502 }
503}
504
505pub struct SessionBuilder {
509 workdir: PathBuf,
510 base_options: Option<ClaudeAgentOptions>,
511 task_config: Option<TaskConfig>,
512 system_prompt: Option<SystemPrompt>,
513 session_id: Option<String>,
514}
515
516impl std::fmt::Debug for SessionBuilder {
517 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518 f.debug_struct("SessionBuilder")
519 .field("workdir", &self.workdir)
520 .field("task_config", &self.task_config)
521 .field("session_id", &self.session_id)
522 .finish()
523 }
524}
525
526impl SessionBuilder {
527 pub(crate) fn new(workdir: PathBuf) -> Self {
529 Self {
530 workdir,
531 base_options: None,
532 task_config: None,
533 system_prompt: None,
534 session_id: None,
535 }
536 }
537
538 pub(crate) fn with_base_options(mut self, options: ClaudeAgentOptions) -> Self {
540 self.base_options = Some(options);
541 self
542 }
543
544 pub(crate) fn with_task_config(mut self, config: TaskConfig) -> Self {
546 self.task_config = Some(config);
547 self
548 }
549
550 pub(crate) fn with_system_prompt(mut self, prompt: SystemPrompt) -> Self {
552 self.system_prompt = Some(prompt);
553 self
554 }
555
556 pub(crate) fn with_session_id(mut self, id: String) -> Self {
558 self.session_id = Some(id);
559 self
560 }
561
562 pub(crate) fn build(self) -> Result<Session> {
568 let mut options = ClaudeAgentOptions::default();
569
570 if let Some(base) = self.base_options {
572 if base.model.is_some() {
573 options.model = base.model;
574 }
575 if base.permission_mode.is_some() {
576 options.permission_mode = base.permission_mode;
577 }
578 if base.max_turns.is_some() {
579 options.max_turns = base.max_turns;
580 }
581 if base.cwd.is_some() {
582 options.cwd = base.cwd;
583 }
584 }
585
586 if options.cwd.is_none() {
588 options.cwd = Some(self.workdir);
589 }
590
591 if let Some(config) = self.task_config {
593 if !config.tools.is_empty() {
594 options.allowed_tools = config.tools;
595 }
596 if !config.disallowed_tools.is_empty() {
597 options.disallowed_tools = config.disallowed_tools;
598 }
599 }
600
601 if let Some(prompt) = self.system_prompt {
603 options.system_prompt = Some(prompt);
604 }
605
606 if options.permission_mode.is_none() {
608 options.permission_mode = Some(PermissionMode::BypassPermissions);
609 }
610
611 options.skip_version_check = true;
613
614 Session::new(options, self.session_id)
615 }
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 #[test]
623 fn test_should_create_conversation_message() {
624 let user_msg = ConversationMessage::User("Hello".to_string());
625 let assistant_msg = ConversationMessage::Assistant("Hi there".to_string());
626
627 assert!(user_msg.is_user());
628 assert!(!user_msg.is_assistant());
629 assert_eq!(user_msg.content(), "Hello");
630
631 assert!(assistant_msg.is_assistant());
632 assert!(!assistant_msg.is_user());
633 assert_eq!(assistant_msg.content(), "Hi there");
634 }
635
636 #[test]
637 fn test_should_build_session_with_defaults() {
638 let builder = SessionBuilder::new(PathBuf::from("/tmp/test"));
639 let session = builder.build().unwrap();
640
641 assert!(!session.session_id().is_empty());
642 assert!(session.history().is_empty());
643 assert_eq!(session.stats().turns, 0);
644 }
645
646 #[test]
647 fn test_should_build_session_with_custom_id() {
648 let builder = SessionBuilder::new(PathBuf::from("/tmp/test"))
649 .with_session_id("custom-session".to_string());
650 let session = builder.build().unwrap();
651
652 assert_eq!(session.session_id(), "custom-session");
653 }
654
655 #[test]
656 fn test_should_clear_history() {
657 let builder = SessionBuilder::new(PathBuf::from("/tmp/test"));
658 let mut session = builder.build().unwrap();
659
660 session
662 .history
663 .push(ConversationMessage::User("test".to_string()));
664 session
665 .history
666 .push(ConversationMessage::Assistant("response".to_string()));
667
668 assert_eq!(session.history().len(), 2);
669
670 session.clear();
671
672 assert!(session.history().is_empty());
673 }
674
675 #[test]
676 fn test_task_stats_accumulation() {
677 let mut stats = TaskStats::default();
678
679 stats.turns += 5;
680 stats.input_tokens += 1000;
681 stats.output_tokens += 500;
682 stats.cost_usd += 0.05;
683
684 stats.turns += 3;
685 stats.input_tokens += 800;
686 stats.output_tokens += 400;
687 stats.cost_usd += 0.03;
688
689 assert_eq!(stats.turns, 8);
690 assert_eq!(stats.input_tokens, 1800);
691 assert_eq!(stats.output_tokens, 900);
692 assert!((stats.cost_usd - 0.08).abs() < f64::EPSILON);
693 }
694}