1use dashmap::DashMap;
7use std::collections::{BTreeMap, HashMap};
8use std::error::Error;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::sync::OnceLock;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::time::Instant;
14use tokio::sync::broadcast;
15
16use claude_code_agent_sdk::types::config::PermissionMode as SdkPermissionMode;
17use claude_code_agent_sdk::types::mcp::McpSdkServerConfig;
18use claude_code_agent_sdk::{
19 ClaudeAgentOptions, ClaudeClient, HookEvent, HookMatcher, McpServerConfig, McpServers,
20 SystemPrompt, SystemPromptPreset,
21};
22use sacp::JrConnectionCx;
23use sacp::link::AgentToClient;
24use sacp::schema::{
25 CurrentModeUpdate, McpServer, SessionId, SessionModeId, SessionNotification, SessionUpdate,
26};
27use tokio::sync::RwLock;
28use tracing::instrument;
29
30use crate::converter::NotificationConverter;
31use crate::hooks::{HookCallbackRegistry, create_post_tool_use_hook, create_pre_tool_use_hook};
32use crate::mcp::AcpMcpServer;
33use crate::permissions::create_can_use_tool_callback;
34use crate::settings::{PermissionChecker, SettingsManager};
35use crate::terminal::TerminalClient;
36use crate::types::{AgentConfig, AgentError, NewSessionMeta, Result};
37
38use super::BackgroundProcessManager;
39use super::permission::{PermissionHandler, PermissionMode};
40use super::usage::UsageTracker;
41
42fn get_acp_replacement_tools() -> Vec<&'static str> {
50 vec![
51 "Bash",
53 "BashOutput",
54 "KillShell",
55 "Read",
57 "Write",
58 "Edit",
59 ]
60}
61
62pub struct Session {
67 pub session_id: String,
69 pub cwd: PathBuf,
71 client: RwLock<ClaudeClient>,
73 permission: Arc<RwLock<PermissionHandler>>,
75 usage_tracker: UsageTracker,
77 converter: NotificationConverter,
79 connected: AtomicBool,
81 hook_callback_registry: Arc<HookCallbackRegistry>,
83 permission_checker: Arc<RwLock<PermissionChecker>>,
85 current_model: OnceLock<String>,
87 acp_mcp_server: Arc<AcpMcpServer>,
89 background_processes: Arc<BackgroundProcessManager>,
91 external_mcp_servers: OnceLock<Vec<McpServer>>,
94 external_mcp_connected: AtomicBool,
96 connection_cx_lock: Arc<OnceLock<JrConnectionCx<AgentToClient>>>,
99 cancel_sender: broadcast::Sender<()>,
101 permission_cache: Arc<DashMap<String, bool>>,
106 tool_use_id_cache: Arc<DashMap<String, String>>,
111 cancelled: AtomicBool,
115}
116
117pub fn stable_cache_key(tool_input: &serde_json::Value) -> String {
123 fn canonicalize(value: &serde_json::Value) -> serde_json::Value {
124 match value {
125 serde_json::Value::Object(map) => {
126 let sorted: BTreeMap<_, _> = map
128 .iter()
129 .map(|(k, v)| (k.clone(), canonicalize(v)))
130 .collect();
131 serde_json::Value::Object(sorted.into_iter().collect())
132 }
133 serde_json::Value::Array(arr) => {
134 serde_json::Value::Array(arr.iter().map(canonicalize).collect())
135 }
136 other => other.clone(),
137 }
138 }
139 canonicalize(tool_input).to_string()
140}
141
142impl Session {
143 #[instrument(
155 name = "session_create",
156 skip(config, meta),
157 fields(
158 session_id = %session_id,
159 cwd = ?cwd,
160 has_meta = meta.is_some(),
161 )
162 )]
163 pub fn new(
164 session_id: String,
165 cwd: PathBuf,
166 config: &AgentConfig,
167 meta: Option<&NewSessionMeta>,
168 ) -> Result<Arc<Self>> {
169 let start_time = Instant::now();
170
171 tracing::info!(
172 session_id = %session_id,
173 cwd = ?cwd,
174 "Creating new session"
175 );
176
177 let hook_callback_registry = Arc::new(HookCallbackRegistry::new());
179
180 let settings_manager = SettingsManager::new(&cwd)
183 .unwrap_or_else(|e| {
184 tracing::warn!("Failed to load settings manager from cwd: {}. Using default settings.", e);
185 if let Some(home) = dirs::home_dir() {
187 tracing::info!("Attempting to load settings from home directory");
188 SettingsManager::new(&home).unwrap_or_else(|e2| {
189 tracing::error!("Failed to load settings from home directory: {}. Using minimal default settings.", e2);
190 SettingsManager::new_with_settings(crate::settings::Settings::default(), "/")
192 })
193 } else {
194 tracing::error!("No home directory found. Using minimal default settings.");
195 SettingsManager::new_with_settings(crate::settings::Settings::default(), "/")
196 }
197 });
198 let permission_checker = Arc::new(RwLock::new(PermissionChecker::new(
201 settings_manager.settings().clone(),
202 &cwd,
203 )));
204
205 let permission_handler = Arc::new(RwLock::new(PermissionHandler::with_checker(
209 permission_checker.clone(),
210 )));
211
212 let connection_cx_lock: Arc<OnceLock<JrConnectionCx<AgentToClient>>> =
214 Arc::new(OnceLock::new());
215
216 let permission_cache: Arc<DashMap<String, bool>> = Arc::new(DashMap::new());
219
220 let tool_use_id_cache: Arc<DashMap<String, String>> = Arc::new(DashMap::new());
224
225 let pre_tool_use_hook = create_pre_tool_use_hook(
227 connection_cx_lock.clone(),
228 session_id.clone(),
229 Some(permission_checker.clone()),
230 permission_handler.clone(),
231 permission_cache.clone(),
232 tool_use_id_cache.clone(),
233 );
234 let post_tool_use_hook = create_post_tool_use_hook(hook_callback_registry.clone());
235
236 let mut hooks_map: HashMap<HookEvent, Vec<HookMatcher>> = HashMap::new();
238 hooks_map.insert(
239 HookEvent::PreToolUse,
240 vec![
241 HookMatcher::builder()
242 .hooks(vec![pre_tool_use_hook])
243 .build(),
244 ],
245 );
246 hooks_map.insert(
247 HookEvent::PostToolUse,
248 vec![
249 HookMatcher::builder()
250 .hooks(vec![post_tool_use_hook])
251 .build(),
252 ],
253 );
254
255 tracing::info!(
256 session_id = %session_id,
257 hooks_count = 2,
258 "Hooks configured: PreToolUse, PostToolUse"
259 );
260
261 let session_lock: Arc<OnceLock<Arc<Session>>> = Arc::new(OnceLock::new());
263
264 let acp_mcp_server = Arc::new(AcpMcpServer::new("acp", env!("CARGO_PKG_VERSION")));
266
267 let background_processes = Arc::new(BackgroundProcessManager::new());
269
270 let mut mcp_servers_dict = HashMap::new();
272 mcp_servers_dict.insert(
273 "acp".to_string(),
274 McpServerConfig::Sdk(McpSdkServerConfig {
275 name: "acp".to_string(),
276 instance: acp_mcp_server.clone(),
277 }),
278 );
279
280 tracing::info!(
281 session_id = %session_id,
282 mcp_server_count = mcp_servers_dict.len(),
283 "MCP servers configured"
284 );
285
286 let can_use_tool_callback = create_can_use_tool_callback(session_lock.clone());
288
289 let mut options = ClaudeAgentOptions::builder()
297 .cwd(cwd.clone())
298 .hooks(hooks_map)
299 .mcp_servers(McpServers::Dict(mcp_servers_dict))
300 .can_use_tool(can_use_tool_callback)
301 .permission_mode(SdkPermissionMode::AcceptEdits)
302 .max_buffer_size(20 * 1024 * 1024) .build();
305
306 tracing::info!(
308 session_id = %session_id,
309 has_can_use_tool = options.can_use_tool.is_some(),
310 has_hooks = options.hooks.is_some(),
311 "Options configured after build"
312 );
313
314 match &options.mcp_servers {
316 McpServers::Dict(dict) => {
317 tracing::debug!(
318 session_id = %session_id,
319 servers = ?dict.keys().collect::<Vec<_>>(),
320 "MCP servers registered"
321 );
322 }
323 McpServers::Empty => {
324 tracing::warn!(
325 session_id = %session_id,
326 "MCP servers is Empty - this is unexpected!"
327 );
328 }
329 McpServers::Path(p) => {
330 tracing::warn!(
331 session_id = %session_id,
332 path = ?p,
333 "MCP servers is Path - this is unexpected!"
334 );
335 }
336 }
337
338 let acp_tools = get_acp_replacement_tools();
341 options.use_acp_tools(&acp_tools);
342
343 options.include_partial_messages = true;
346
347 tracing::debug!(
348 session_id = %session_id,
349 acp_tools = ?acp_tools,
350 disallowed_tools = ?options.disallowed_tools,
351 allowed_tools = ?options.allowed_tools,
352 "ACP tools configured"
353 );
354
355 config.apply_to_options(&mut options);
357
358 tracing::debug!(
359 session_id = %session_id,
360 model = ?options.model,
361 fallback_model = ?options.fallback_model,
362 max_thinking_tokens = ?options.max_thinking_tokens,
363 base_url = ?config.base_url,
364 api_key = ?config.masked_api_key(),
365 env_vars_count = options.env.len(),
366 "Agent config applied"
367 );
368
369 if let Some(meta) = meta {
371 if let Some(replace) = meta.get_system_prompt_replace() {
373 options.system_prompt = Some(SystemPrompt::Text(replace.to_string()));
375 tracing::info!(
376 session_id = %session_id,
377 prompt_len = replace.len(),
378 "Using custom system prompt from meta (replace)"
379 );
380 } else if let Some(append) = meta.get_system_prompt_append() {
381 let preset = SystemPromptPreset::with_append("claude_code", append);
383 options.system_prompt = Some(SystemPrompt::Preset(preset));
384 tracing::info!(
385 session_id = %session_id,
386 append_len = append.len(),
387 "Appending to system prompt from meta"
388 );
389 }
390
391 if let Some(resume_id) = meta.get_resume_session_id() {
393 options.resume = Some(resume_id.to_string());
394 tracing::info!(
395 session_id = %session_id,
396 resume_session_id = %resume_id,
397 "Resuming from previous session"
398 );
399 }
400
401 if let Some(tokens) = meta.get_max_thinking_tokens() {
403 options.max_thinking_tokens = Some(tokens);
404 tracing::info!(
405 session_id = %session_id,
406 max_thinking_tokens = tokens,
407 "Extended thinking mode enabled via meta"
408 );
409 }
410 }
411
412 let client = ClaudeClient::new(options);
414
415 let elapsed = start_time.elapsed();
416 tracing::info!(
417 session_id = %session_id,
418 elapsed_ms = elapsed.as_millis(),
419 "Session created successfully"
420 );
421
422 let cwd_for_converter = cwd.clone();
424
425 let session = Self {
427 session_id,
428 cwd,
429 client: RwLock::new(client),
430 permission: permission_handler,
431 usage_tracker: UsageTracker::new(),
432 converter: NotificationConverter::with_cwd(cwd_for_converter),
433 connected: AtomicBool::new(false),
434 hook_callback_registry,
435 permission_checker,
436 current_model: OnceLock::new(),
437 acp_mcp_server,
438 background_processes,
439 external_mcp_servers: OnceLock::new(),
440 external_mcp_connected: AtomicBool::new(false),
441 connection_cx_lock,
442 cancel_sender: broadcast::channel(1).0,
443 permission_cache,
444 tool_use_id_cache,
445 cancelled: AtomicBool::new(false),
446 };
447
448 let session_arc = Arc::new(session);
450
451 drop(session_lock.set(session_arc.clone()));
453
454 Ok(session_arc)
455 }
456
457 pub fn set_external_mcp_servers(&self, servers: Vec<McpServer>) {
463 if !servers.is_empty() {
464 tracing::info!(
465 session_id = %self.session_id,
466 server_count = servers.len(),
467 "Storing external MCP servers for later connection"
468 );
469
470 for server in &servers {
471 match server {
472 McpServer::Stdio(s) => {
473 tracing::debug!(
474 session_id = %self.session_id,
475 server_name = %s.name,
476 command = ?s.command,
477 args = ?s.args,
478 "External MCP server (stdio)"
479 );
480 }
481 McpServer::Http(s) => {
482 tracing::debug!(
483 session_id = %self.session_id,
484 server_name = %s.name,
485 url = %s.url,
486 "External MCP server (http)"
487 );
488 }
489 McpServer::Sse(s) => {
490 tracing::debug!(
491 session_id = %self.session_id,
492 server_name = %s.name,
493 url = %s.url,
494 "External MCP server (sse)"
495 );
496 }
497 _ => {
498 tracing::debug!(
499 session_id = %self.session_id,
500 "External MCP server (unknown type)"
501 );
502 }
503 }
504 }
505 }
506
507 if self.external_mcp_servers.get().is_none() {
509 drop(self.external_mcp_servers.set(servers));
510 }
511 }
512
513 pub fn set_connection_cx(&self, cx: JrConnectionCx<AgentToClient>) {
518 if self.connection_cx_lock.get().is_none() {
519 drop(self.connection_cx_lock.set(cx));
520 }
521 }
522
523 pub fn get_connection_cx(&self) -> Option<&JrConnectionCx<AgentToClient>> {
527 self.connection_cx_lock.get()
528 }
529
530 pub fn cache_permission(&self, tool_input: &serde_json::Value, allowed: bool) {
535 let key = stable_cache_key(tool_input);
536 tracing::debug!(
537 key_len = key.len(),
538 allowed = allowed,
539 "Caching permission result"
540 );
541 self.permission_cache.insert(key, allowed);
542 }
543
544 pub fn check_cached_permission(&self, tool_input: &serde_json::Value) -> Option<bool> {
550 let key = stable_cache_key(tool_input);
551 self.permission_cache.remove(&key).map(|(_, v)| v)
552 }
553
554 pub fn permission_cache(&self) -> Arc<DashMap<String, bool>> {
556 Arc::clone(&self.permission_cache)
557 }
558
559 pub fn cache_tool_use_id(&self, tool_input: &serde_json::Value, tool_use_id: &str) {
564 let key = stable_cache_key(tool_input);
565 tracing::debug!(
566 key_len = key.len(),
567 tool_use_id = %tool_use_id,
568 "Caching tool_use_id"
569 );
570 self.tool_use_id_cache.insert(key, tool_use_id.to_string());
571 }
572
573 pub fn get_cached_tool_use_id(&self, tool_input: &serde_json::Value) -> Option<String> {
579 let key = stable_cache_key(tool_input);
580 self.tool_use_id_cache.remove(&key).map(|(_, v)| v)
581 }
582
583 pub fn tool_use_id_cache(&self) -> Arc<DashMap<String, String>> {
585 Arc::clone(&self.tool_use_id_cache)
586 }
587
588 #[instrument(
593 name = "connect_external_mcp_servers",
594 skip(self),
595 fields(session_id = %self.session_id)
596 )]
597 pub async fn connect_external_mcp_servers(&self) -> Result<()> {
598 if self.external_mcp_connected.load(Ordering::SeqCst) {
600 tracing::debug!(
601 session_id = %self.session_id,
602 "External MCP servers already connected"
603 );
604 return Ok(());
605 }
606
607 let Some(servers) = self.external_mcp_servers.get() else {
609 tracing::debug!(
610 session_id = %self.session_id,
611 "No external MCP servers to connect"
612 );
613 self.external_mcp_connected.store(true, Ordering::SeqCst);
614 return Ok(());
615 };
616
617 let servers_vec: Vec<_> = servers.clone();
619
620 let server_count = servers_vec.len();
621 let start_time = Instant::now();
622
623 tracing::info!(
624 session_id = %self.session_id,
625 server_count = server_count,
626 "Connecting to external MCP servers"
627 );
628
629 let external_manager = self.acp_mcp_server.mcp_server().external_manager();
630
631 let mut success_count = 0;
632 let mut error_count = 0;
633
634 for server in &servers_vec {
635 match server {
636 McpServer::Stdio(s) => {
637 let server_start = Instant::now();
638
639 tracing::info!(
640 session_id = %self.session_id,
641 server_name = %s.name,
642 command = ?s.command,
643 args = ?s.args,
644 "Connecting to external MCP server (stdio)"
645 );
646
647 let env: Option<HashMap<String, String>> = if s.env.is_empty() {
649 None
650 } else {
651 Some(
652 s.env
653 .iter()
654 .map(|e| (e.name.clone(), e.value.clone()))
655 .collect(),
656 )
657 };
658
659 match external_manager
660 .connect(
661 s.name.clone(),
662 s.command.to_string_lossy().as_ref(),
663 &s.args,
664 env.as_ref(),
665 Some(self.cwd.as_path()),
666 )
667 .await
668 {
669 Ok(()) => {
670 success_count += 1;
671 let elapsed = server_start.elapsed();
672 tracing::info!(
673 session_id = %self.session_id,
674 server_name = %s.name,
675 elapsed_ms = elapsed.as_millis(),
676 "Successfully connected to external MCP server"
677 );
678 }
679 Err(e) => {
680 error_count += 1;
681 let elapsed = server_start.elapsed();
682 tracing::error!(
683 session_id = %self.session_id,
684 server_name = %s.name,
685 error = %e,
686 elapsed_ms = elapsed.as_millis(),
687 "Failed to connect to external MCP server"
688 );
689 }
690 }
691 }
692 McpServer::Http(s) => {
693 tracing::warn!(
694 session_id = %self.session_id,
695 server_name = %s.name,
696 url = %s.url,
697 "HTTP MCP servers not yet supported"
698 );
699 }
700 McpServer::Sse(s) => {
701 tracing::warn!(
702 session_id = %self.session_id,
703 server_name = %s.name,
704 url = %s.url,
705 "SSE MCP servers not yet supported"
706 );
707 }
708 _ => {
709 tracing::warn!(
710 session_id = %self.session_id,
711 "Unknown MCP server type - not supported"
712 );
713 }
714 }
715 }
716
717 let total_elapsed = start_time.elapsed();
718 tracing::info!(
719 session_id = %self.session_id,
720 total_servers = server_count,
721 success_count = success_count,
722 error_count = error_count,
723 total_elapsed_ms = total_elapsed.as_millis(),
724 "Finished connecting external MCP servers"
725 );
726
727 self.external_mcp_connected.store(true, Ordering::SeqCst);
728 Ok(())
729 }
730
731 #[instrument(
735 name = "session_connect",
736 skip(self),
737 fields(session_id = %self.session_id)
738 )]
739 pub async fn connect(&self) -> Result<()> {
740 if self.connected.load(Ordering::SeqCst) {
741 tracing::debug!(
742 session_id = %self.session_id,
743 "Already connected to Claude CLI"
744 );
745 return Ok(());
746 }
747
748 let start_time = Instant::now();
749 tracing::info!(
750 session_id = %self.session_id,
751 cwd = ?self.cwd,
752 "Connecting to Claude CLI..."
753 );
754
755 let mut client = self.client.write().await;
756 client.connect().await.map_err(|e| {
757 let agent_error = AgentError::from(e);
758 tracing::error!(
759 session_id = %self.session_id,
760 error = %agent_error,
761 error_code = ?agent_error.error_code(),
762 is_retryable = %agent_error.is_retryable(),
763 error_chain = ?agent_error.source(),
764 "Failed to connect to Claude CLI"
765 );
766 agent_error
767 })?;
768
769 self.connected.store(true, Ordering::SeqCst);
770
771 let elapsed = start_time.elapsed();
772 tracing::info!(
773 session_id = %self.session_id,
774 elapsed_ms = elapsed.as_millis(),
775 "Successfully connected to Claude CLI"
776 );
777
778 Ok(())
779 }
780
781 #[instrument(
783 name = "session_disconnect",
784 skip(self),
785 fields(session_id = %self.session_id)
786 )]
787 pub async fn disconnect(&self) -> Result<()> {
788 if !self.connected.load(Ordering::SeqCst) {
789 tracing::debug!(
790 session_id = %self.session_id,
791 "Already disconnected from Claude CLI"
792 );
793 return Ok(());
794 }
795
796 let start_time = Instant::now();
797 tracing::info!(
798 session_id = %self.session_id,
799 "Disconnecting from Claude CLI..."
800 );
801
802 let mut client = self.client.write().await;
803 client.disconnect().await.map_err(|e| {
804 let agent_error = AgentError::from(e);
805 tracing::error!(
806 session_id = %self.session_id,
807 error = %agent_error,
808 error_code = ?agent_error.error_code(),
809 is_retryable = %agent_error.is_retryable(),
810 error_chain = ?agent_error.source(),
811 "Failed to disconnect from Claude CLI"
812 );
813 agent_error
814 })?;
815
816 self.connected.store(false, Ordering::SeqCst);
817
818 let elapsed = start_time.elapsed();
819 tracing::info!(
820 session_id = %self.session_id,
821 elapsed_ms = elapsed.as_millis(),
822 "Disconnected from Claude CLI"
823 );
824
825 Ok(())
826 }
827
828 pub fn is_connected(&self) -> bool {
830 self.connected.load(Ordering::SeqCst)
831 }
832
833 pub async fn client(&self) -> tokio::sync::RwLockReadGuard<'_, ClaudeClient> {
835 self.client.read().await
836 }
837
838 pub async fn client_mut(&self) -> tokio::sync::RwLockWriteGuard<'_, ClaudeClient> {
840 self.client.write().await
841 }
842
843 pub fn cancel_receiver(&self) -> broadcast::Receiver<()> {
848 self.cancel_sender.subscribe()
849 }
850
851 #[instrument(
856 name = "session_cancel",
857 skip(self),
858 fields(session_id = %self.session_id)
859 )]
860 pub async fn cancel(&self) {
861 self.cancelled.store(true, Ordering::Release);
864
865 tracing::info!(
866 session_id = %self.session_id,
867 "Sending interrupt signal to Claude CLI (cancelled=true)"
868 );
869
870 if let Ok(client) = self.client.try_read() {
872 if let Err(e) = client.interrupt().await {
873 tracing::warn!(
874 session_id = %self.session_id,
875 error = %e,
876 "Failed to send interrupt signal to Claude CLI"
877 );
878 } else {
879 tracing::info!(
880 session_id = %self.session_id,
881 "Interrupt signal sent to Claude CLI"
882 );
883 }
884 } else {
885 tracing::warn!(
886 session_id = %self.session_id,
887 "Could not acquire client lock for interrupt"
888 );
889 }
890 }
891
892 pub fn is_user_cancelled(&self) -> bool {
897 self.cancelled.load(Ordering::Acquire)
899 }
900
901 pub fn reset_cancelled(&self) {
905 self.cancelled.store(false, Ordering::Release);
908 }
909
910 pub async fn permission(&self) -> tokio::sync::RwLockReadGuard<'_, PermissionHandler> {
912 self.permission.read().await
913 }
914
915 pub async fn permission_mode(&self) -> PermissionMode {
917 self.permission.read().await.mode()
918 }
919
920 pub async fn set_permission_mode(&self, mode: PermissionMode) {
925 self.permission.write().await.set_mode(mode);
927
928 tracing::info!(
929 session_id = %self.session_id,
930 mode = mode.as_str(),
931 "Permission mode updated"
932 );
933 }
934
935 pub fn send_mode_update(&self, mode: &str) {
941 let Some(connection_cx) = self.get_connection_cx() else {
942 tracing::warn!(
943 session_id = %self.session_id,
944 mode = %mode,
945 "Connection not ready for mode update notification"
946 );
947 return;
948 };
949
950 let mode_update = CurrentModeUpdate::new(SessionModeId::new(mode));
951 let notification = SessionNotification::new(
952 SessionId::new(self.session_id.clone()),
953 SessionUpdate::CurrentModeUpdate(mode_update),
954 );
955
956 if let Err(e) = connection_cx.send_notification(notification) {
957 tracing::warn!(
958 session_id = %self.session_id,
959 mode = %mode,
960 error = %e,
961 "Failed to send CurrentModeUpdate notification"
962 );
963 } else {
964 tracing::info!(
965 session_id = %self.session_id,
966 mode = %mode,
967 "Sent CurrentModeUpdate notification"
968 );
969 }
970 }
971
972 pub async fn add_permission_allow_rule(&self, tool_name: &str) {
976 self.permission.read().await.add_allow_rule(tool_name).await;
977 }
978
979 #[allow(dead_code)]
983 pub fn current_model(&self) -> Option<String> {
984 self.current_model.get().cloned()
985 }
986
987 #[allow(dead_code)]
991 pub fn set_model(&self, model_id: String) {
992 if self.current_model.get().is_none() {
994 drop(self.current_model.set(model_id));
995 }
996 }
997
998 pub fn usage_tracker(&self) -> &UsageTracker {
1000 &self.usage_tracker
1001 }
1002
1003 pub fn converter(&self) -> &NotificationConverter {
1005 &self.converter
1006 }
1007
1008 pub fn hook_callback_registry(&self) -> &Arc<HookCallbackRegistry> {
1010 &self.hook_callback_registry
1011 }
1012
1013 pub fn permission_checker(&self) -> &Arc<RwLock<PermissionChecker>> {
1015 &self.permission_checker
1016 }
1017
1018 pub fn register_post_tool_use_callback(
1020 &self,
1021 tool_use_id: String,
1022 callback: crate::hooks::PostToolUseCallback,
1023 ) {
1024 self.hook_callback_registry
1025 .register_post_tool_use(tool_use_id, callback);
1026 }
1027
1028 pub fn acp_mcp_server(&self) -> &Arc<AcpMcpServer> {
1030 &self.acp_mcp_server
1031 }
1032
1033 pub fn background_processes(&self) -> &Arc<BackgroundProcessManager> {
1035 &self.background_processes
1036 }
1037
1038 pub async fn configure_acp_server(
1043 &self,
1044 connection_cx: JrConnectionCx<AgentToClient>,
1045 terminal_client: Option<Arc<TerminalClient>>,
1046 ) {
1047 self.acp_mcp_server.set_session_id(&self.session_id);
1048 self.acp_mcp_server.set_connection(connection_cx);
1049 self.acp_mcp_server.set_cwd(self.cwd.clone()).await;
1050 self.acp_mcp_server
1051 .set_background_processes(self.background_processes.clone());
1052 self.acp_mcp_server
1053 .set_permission_checker(self.permission_checker.clone());
1054
1055 if let Some(client) = terminal_client {
1056 self.acp_mcp_server.set_terminal_client(client);
1057 }
1058
1059 let session_id = self.session_id.clone();
1061 let cancel_sender = self.cancel_sender.clone();
1062
1063 self.acp_mcp_server
1064 .set_cancel_callback(move || {
1065 tracing::info!(
1066 session_id = %session_id,
1067 "MCP cancel callback invoked, sending cancel signal"
1068 );
1069 let _ = cancel_sender.send(());
1072 })
1073 .await;
1074 }
1075}
1076
1077#[allow(clippy::missing_fields_in_debug)]
1078impl std::fmt::Debug for Session {
1079 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1080 f.debug_struct("Session")
1081 .field("session_id", &self.session_id)
1082 .field("cwd", &self.cwd)
1083 .field("connected", &self.connected.load(Ordering::Relaxed))
1084 .finish()
1085 }
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090 use super::*;
1091
1092 fn test_config() -> AgentConfig {
1093 AgentConfig {
1094 base_url: None,
1095 api_key: None,
1096 model: None,
1097 small_fast_model: None,
1098 max_thinking_tokens: None,
1099 }
1100 }
1101
1102 #[test]
1103 fn test_session_new() {
1104 let session = Session::new(
1105 "test-session-1".to_string(),
1106 PathBuf::from("/tmp"),
1107 &test_config(),
1108 None,
1109 )
1110 .unwrap();
1111
1112 assert_eq!(session.session_id, "test-session-1");
1113 assert_eq!(session.cwd, PathBuf::from("/tmp"));
1114 assert!(!session.is_connected());
1115 assert!(!session.is_user_cancelled());
1117 }
1118
1119 #[test]
1120 fn test_cancelled_flag_lifecycle() {
1121 let session = Session::new(
1122 "test-cancel-session".to_string(),
1123 PathBuf::from("/tmp"),
1124 &test_config(),
1125 None,
1126 )
1127 .unwrap();
1128
1129 assert!(
1131 !session.is_user_cancelled(),
1132 "Cancelled should be false initially"
1133 );
1134
1135 session.cancelled.store(true, Ordering::Release);
1137 assert!(
1138 session.is_user_cancelled(),
1139 "Cancelled should be true after setting"
1140 );
1141
1142 session.reset_cancelled();
1144 assert!(
1145 !session.is_user_cancelled(),
1146 "Cancelled should be false after reset"
1147 );
1148
1149 session.cancelled.store(true, Ordering::Release);
1151 assert!(
1152 session.is_user_cancelled(),
1153 "Cancelled should be true after setting again"
1154 );
1155 }
1156
1157 #[tokio::test]
1158 async fn test_session_cancel() {
1159 let session = Session::new(
1160 "test-session-2".to_string(),
1161 PathBuf::from("/tmp"),
1162 &test_config(),
1163 None,
1164 )
1165 .unwrap();
1166
1167 session.cancel().await;
1170 }
1171
1172 #[tokio::test]
1173 async fn test_session_permission_mode() {
1174 let session = Session::new(
1175 "test-session-3".to_string(),
1176 PathBuf::from("/tmp"),
1177 &test_config(),
1178 None,
1179 )
1180 .unwrap();
1181
1182 assert_eq!(session.permission_mode().await, PermissionMode::Default);
1184 session.set_permission_mode(PermissionMode::DontAsk).await;
1185 assert_eq!(session.permission_mode().await, PermissionMode::DontAsk);
1186 }
1187
1188 #[test]
1189 fn test_stable_cache_key_ordering() {
1190 use serde_json::json;
1191
1192 let json1 = json!({"a": 1, "b": 2, "c": 3});
1194 let json2 = json!({"c": 3, "b": 2, "a": 1});
1195 let json3 = json!({"b": 2, "a": 1, "c": 3});
1196
1197 let key1 = stable_cache_key(&json1);
1198 let key2 = stable_cache_key(&json2);
1199 let key3 = stable_cache_key(&json3);
1200
1201 assert_eq!(
1202 key1, key2,
1203 "Different key ordering should produce same cache key"
1204 );
1205 assert_eq!(
1206 key2, key3,
1207 "Different key ordering should produce same cache key"
1208 );
1209 }
1210
1211 #[test]
1212 fn test_stable_cache_key_nested_objects() {
1213 use serde_json::json;
1214
1215 let json1 = json!({
1217 "command": "cargo build",
1218 "options": {"a": 1, "b": 2}
1219 });
1220 let json2 = json!({
1221 "options": {"b": 2, "a": 1},
1222 "command": "cargo build"
1223 });
1224
1225 let key1 = stable_cache_key(&json1);
1226 let key2 = stable_cache_key(&json2);
1227
1228 assert_eq!(key1, key2, "Nested objects should also produce stable keys");
1229 }
1230
1231 #[test]
1232 fn test_stable_cache_key_arrays() {
1233 use serde_json::json;
1234
1235 let json1 = json!({
1237 "items": [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
1238 });
1239 let json2 = json!({
1240 "items": [{"b": 2, "a": 1}, {"d": 4, "c": 3}]
1241 });
1242
1243 let key1 = stable_cache_key(&json1);
1244 let key2 = stable_cache_key(&json2);
1245
1246 assert_eq!(key1, key2, "Arrays with objects should produce stable keys");
1247 }
1248
1249 #[test]
1250 fn test_stable_cache_key_different_content() {
1251 use serde_json::json;
1252
1253 let json1 = json!({"command": "cargo build"});
1255 let json2 = json!({"command": "cargo test"});
1256
1257 let key1 = stable_cache_key(&json1);
1258 let key2 = stable_cache_key(&json2);
1259
1260 assert_ne!(
1261 key1, key2,
1262 "Different content should produce different keys"
1263 );
1264 }
1265}