1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use parking_lot::Mutex as ParkingLotMutex;
7use serde_json::Value;
8use tokio::sync::oneshot;
9use tokio::task::JoinHandle;
10use tokio_util::sync::CancellationToken;
11use tracing::{Instrument, warn};
12
13use crate::canvas::CanvasHandler;
14use crate::generated::api_types::{
15 LogRequest, ModelSwitchToRequest, OpenCanvasInstance, RegisterEventInterestParams, rpc_methods,
16};
17use crate::generated::session_events::{
18 CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, McpOauthRequiredData,
19 SessionCanvasClosedData, SessionErrorData, SessionEventType,
20};
21use crate::handler::{
22 AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler,
23 McpAuthHandler, McpAuthRequest, McpAuthResult, PermissionHandler, PermissionResult,
24 UserInputHandler, UserInputResponse,
25};
26use crate::hooks::SessionHooks;
27use crate::provider_token::BearerTokenProvider;
28use crate::session_fs::SessionFsProvider;
29use crate::trace_context::inject_trace_context;
30use crate::transforms::SystemMessageTransform;
31use crate::types::{
32 CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest,
33 ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions,
34 PermissionRequestData, RequestId, ResumeSessionConfig, ResumeSessionResult, SectionOverride,
35 SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions,
36 SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, TraceContext,
37 UiInputOptions, ensure_attachment_display_names,
38};
39use crate::{
40 Client, Error, ErrorKind, JsonRpcResponse, SessionErrorKind, SessionEventNotification,
41 error_codes,
42};
43
44#[derive(Clone)]
52pub(crate) struct SessionHandlers {
53 pub permission: Option<Arc<dyn PermissionHandler>>,
54 pub elicitation: Option<Arc<dyn ElicitationHandler>>,
55 pub mcp_auth: Option<Arc<dyn McpAuthHandler>>,
56 pub user_input: Option<Arc<dyn UserInputHandler>>,
57 pub exit_plan_mode: Option<Arc<dyn ExitPlanModeHandler>>,
58 pub auto_mode_switch: Option<Arc<dyn AutoModeSwitchHandler>>,
59 pub tools: Arc<HashMap<String, Arc<dyn crate::tool::ToolHandler>>>,
60}
61
62struct IdleWaiter {
64 tx: oneshot::Sender<Result<Option<SessionEvent>, Error>>,
65 last_assistant_message: Option<SessionEvent>,
66 started_at: Instant,
67 first_assistant_message_seen: bool,
68}
69
70struct WaiterGuard {
82 slot: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
83}
84
85impl Drop for WaiterGuard {
86 fn drop(&mut self) {
87 self.slot.lock().take();
88 }
89}
90
91struct PendingSessionRegistration {
92 client: Client,
93 session_id: SessionId,
94 shutdown: CancellationToken,
95 disarmed: bool,
96}
97
98impl PendingSessionRegistration {
99 fn new(client: Client, session_id: SessionId, shutdown: CancellationToken) -> Self {
100 Self {
101 client,
102 session_id,
103 shutdown,
104 disarmed: false,
105 }
106 }
107
108 async fn cleanup(mut self, event_loop: JoinHandle<()>) {
109 self.shutdown.cancel();
110 let _ = event_loop.await;
111 self.client.unregister_session(&self.session_id);
112 self.disarmed = true;
113 }
114
115 fn disarm(&mut self) {
116 self.disarmed = true;
117 }
118}
119
120impl Drop for PendingSessionRegistration {
121 fn drop(&mut self) {
122 if !self.disarmed {
123 self.shutdown.cancel();
124 self.client.unregister_session(&self.session_id);
125 }
126 }
127}
128
129pub struct Session {
142 id: SessionId,
143 cwd: PathBuf,
144 workspace_path: Option<PathBuf>,
145 remote_url: Option<String>,
146 client: Client,
147 event_loop: ParkingLotMutex<Option<JoinHandle<()>>>,
152 shutdown: CancellationToken,
166 idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
173 capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
175 open_canvases: Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
177 event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
179}
180
181impl Session {
182 pub fn id(&self) -> &SessionId {
184 &self.id
185 }
186
187 pub fn cwd(&self) -> &PathBuf {
189 &self.cwd
190 }
191
192 pub fn workspace_path(&self) -> Option<&Path> {
194 self.workspace_path.as_deref()
195 }
196
197 pub fn remote_url(&self) -> Option<&str> {
199 self.remote_url.as_deref()
200 }
201
202 pub fn capabilities(&self) -> SessionCapabilities {
207 self.capabilities.read().clone()
208 }
209
210 pub fn open_canvases(&self) -> Vec<OpenCanvasInstance> {
213 self.open_canvases.read().clone()
214 }
215
216 pub fn cancellation_token(&self) -> CancellationToken {
243 self.shutdown.child_token()
244 }
245
246 pub fn subscribe(&self) -> crate::subscription::EventSubscription {
285 crate::subscription::EventSubscription::new(self.event_tx.subscribe())
286 }
287
288 pub fn client(&self) -> &Client {
290 &self.client
291 }
292
293 pub fn rpc(&self) -> crate::generated::rpc::SessionRpc<'_> {
304 crate::generated::rpc::SessionRpc { session: self }
305 }
306
307 pub async fn stop_event_loop(&self) {
315 self.shutdown.cancel();
316 let handle = self.event_loop.lock().take();
317 if let Some(handle) = handle {
318 let _ = handle.await;
319 }
320 if let Some(waiter) = self.idle_waiter.lock().take() {
322 let _ = waiter.tx.send(Err(
323 ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()
324 ));
325 }
326 }
327
328 pub async fn send(&self, opts: impl Into<MessageOptions>) -> Result<String, Error> {
353 if self.idle_waiter.lock().is_some() {
354 return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into());
355 }
356 self.send_inner(opts.into()).await
357 }
358
359 async fn send_inner(&self, opts: MessageOptions) -> Result<String, Error> {
360 let mut params = serde_json::json!({
361 "sessionId": self.id,
362 "prompt": opts.prompt,
363 });
364 if let Some(m) = opts.mode {
365 params["mode"] = serde_json::to_value(m)?;
366 }
367 if let Some(am) = opts.agent_mode {
368 params["agentMode"] = serde_json::to_value(am)?;
369 }
370 if let Some(mut a) = opts.attachments {
371 ensure_attachment_display_names(&mut a);
372 params["attachments"] = serde_json::to_value(a)?;
373 }
374 if let Some(headers) = opts.request_headers
375 && !headers.is_empty()
376 {
377 params["requestHeaders"] = serde_json::to_value(headers)?;
378 }
379 if let Some(display_prompt) = opts.display_prompt {
380 params["displayPrompt"] = serde_json::to_value(display_prompt)?;
381 }
382 let trace_ctx = if opts.traceparent.is_some() || opts.tracestate.is_some() {
383 TraceContext {
384 traceparent: opts.traceparent,
385 tracestate: opts.tracestate,
386 }
387 } else {
388 self.client.resolve_trace_context().await
389 };
390 inject_trace_context(&mut params, &trace_ctx);
391 let rpc_start = Instant::now();
392 let result = self.client.call("session.send", Some(params)).await?;
393 let message_id = result
394 .get("messageId")
395 .and_then(|v| v.as_str())
396 .map(|s| s.to_string())
397 .unwrap_or_default();
398 tracing::debug!(
399 elapsed_ms = rpc_start.elapsed().as_millis(),
400 session_id = %self.id,
401 message_id = %message_id,
402 "Session::send completed successfully"
403 );
404 Ok(message_id)
405 }
406
407 pub async fn send_and_wait(
427 &self,
428 opts: impl Into<MessageOptions>,
429 ) -> Result<Option<SessionEvent>, Error> {
430 let total_start = Instant::now();
431 let opts = opts.into();
432 let timeout_duration = opts.wait_timeout.unwrap_or(Duration::from_secs(60));
433 let (tx, rx) = oneshot::channel();
434
435 {
436 let mut guard = self.idle_waiter.lock();
437 if guard.is_some() {
438 return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into());
439 }
440 *guard = Some(IdleWaiter {
441 tx,
442 last_assistant_message: None,
443 started_at: total_start,
444 first_assistant_message_seen: false,
445 });
446 }
447
448 let _waiter_guard = WaiterGuard {
453 slot: self.idle_waiter.clone(),
454 };
455
456 let result = tokio::time::timeout(timeout_duration, async {
457 self.send_inner(opts).await?;
458 match rx.await {
459 Ok(result) => result,
460 Err(_) => Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()),
461 }
462 })
463 .await;
464
465 match result {
466 Ok(inner) => {
467 tracing::debug!(
468 elapsed_ms = total_start.elapsed().as_millis(),
469 session_id = %self.id,
470 completed_by = if inner.is_ok() { "idle" } else { "error" },
471 "Session::send_and_wait complete"
472 );
473 inner
474 }
475 Err(_) => {
476 tracing::warn!(
477 elapsed_ms = total_start.elapsed().as_millis(),
478 session_id = %self.id,
479 completed_by = "timeout",
480 "Session::send_and_wait failed"
481 );
482 Err(ErrorKind::Session(SessionErrorKind::Timeout(timeout_duration)).into())
483 }
484 }
485 }
486
487 pub async fn get_events(&self) -> Result<Vec<SessionEvent>, Error> {
489 let result = self
490 .client
491 .call(
492 "session.getMessages",
493 Some(serde_json::json!({ "sessionId": self.id })),
494 )
495 .await?;
496 let response: GetMessagesResponse = serde_json::from_value(result)?;
497 Ok(response.events)
498 }
499
500 #[deprecated(since = "0.1.0", note = "Use `get_events()` instead")]
502 pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, Error> {
503 self.get_events().await
504 }
505
506 pub async fn abort(&self) -> Result<(), Error> {
514 self.client
515 .call(
516 "session.abort",
517 Some(serde_json::json!({ "sessionId": self.id })),
518 )
519 .await?;
520 Ok(())
521 }
522
523 pub async fn set_model(&self, model: &str, opts: Option<SetModelOptions>) -> Result<(), Error> {
527 let opts = opts.unwrap_or_default();
528 let request = ModelSwitchToRequest {
529 model_id: model.to_string(),
530 reasoning_effort: opts.reasoning_effort,
531 reasoning_summary: opts.reasoning_summary,
532 context_tier: opts.context_tier,
533 model_capabilities: opts.model_capabilities,
534 };
535 self.rpc().model().switch_to(request).await?;
536 Ok(())
537 }
538
539 pub async fn disconnect(&self) -> Result<(), Error> {
556 self.client
557 .call(
558 "session.destroy",
559 Some(serde_json::json!({ "sessionId": self.id })),
560 )
561 .await?;
562 self.stop_event_loop().await;
563 self.client.unregister_session(&self.id);
564 Ok(())
565 }
566
567 #[deprecated(since = "0.1.0", note = "Use `disconnect()` instead")]
572 pub async fn destroy(&self) -> Result<(), Error> {
573 self.disconnect().await
574 }
575
576 pub async fn log(
580 &self,
581 message: &str,
582 opts: Option<crate::types::LogOptions>,
583 ) -> Result<(), Error> {
584 let opts = opts.unwrap_or_default();
585 let level = match opts.level {
586 Some(level) => Some(serde_json::from_value(serde_json::to_value(level)?)?),
587 None => None,
588 };
589 let request = LogRequest {
590 message: message.to_string(),
591 level,
592 ephemeral: opts.ephemeral,
593 r#type: None,
594 tip: None,
595 url: None,
596 };
597 self.rpc().log(request).await?;
598 Ok(())
599 }
600
601 pub fn ui(&self) -> SessionUi<'_> {
607 SessionUi { session: self }
608 }
609
610 fn assert_elicitation(&self) -> Result<(), Error> {
612 if self
613 .capabilities
614 .read()
615 .ui
616 .as_ref()
617 .and_then(|u| u.elicitation)
618 != Some(true)
619 {
620 return Err(ErrorKind::Session(SessionErrorKind::ElicitationNotSupported).into());
621 }
622 Ok(())
623 }
624}
625
626impl Drop for Session {
627 fn drop(&mut self) {
628 self.shutdown.cancel();
640 self.client.unregister_session(&self.id);
641 }
642}
643
644pub struct SessionUi<'a> {
651 session: &'a Session,
652}
653
654impl<'a> SessionUi<'a> {
655 pub async fn elicitation(
663 &self,
664 message: &str,
665 schema: Value,
666 ) -> Result<ElicitationResult, Error> {
667 self.session.assert_elicitation()?;
668 let result = self
669 .session
670 .client
671 .call(
672 "session.ui.elicitation",
673 Some(serde_json::json!({
674 "sessionId": self.session.id,
675 "message": message,
676 "requestedSchema": schema,
677 })),
678 )
679 .await?;
680 let elicitation: ElicitationResult = serde_json::from_value(result)?;
681 Ok(elicitation)
682 }
683
684 pub async fn confirm(&self, message: &str) -> Result<bool, Error> {
688 self.session.assert_elicitation()?;
689 let schema = serde_json::json!({
690 "type": "object",
691 "properties": {
692 "confirmed": {
693 "type": "boolean",
694 "default": true,
695 }
696 },
697 "required": ["confirmed"]
698 });
699 let result = self.elicitation(message, schema).await?;
700 Ok(result.action == "accept"
701 && result
702 .content
703 .and_then(|c| c.get("confirmed").and_then(|v| v.as_bool()))
704 == Some(true))
705 }
706
707 pub async fn select(&self, message: &str, options: &[&str]) -> Result<Option<String>, Error> {
711 self.session.assert_elicitation()?;
712 let schema = serde_json::json!({
713 "type": "object",
714 "properties": {
715 "selection": {
716 "type": "string",
717 "enum": options,
718 }
719 },
720 "required": ["selection"]
721 });
722 let result = self.elicitation(message, schema).await?;
723 if result.action != "accept" {
724 return Ok(None);
725 }
726 let selection = result.content.and_then(|c| {
727 c.get("selection")
728 .and_then(|v| v.as_str())
729 .map(String::from)
730 });
731 Ok(selection)
732 }
733
734 pub async fn input(
739 &self,
740 message: &str,
741 options: Option<&UiInputOptions<'_>>,
742 ) -> Result<Option<String>, Error> {
743 self.session.assert_elicitation()?;
744 let mut field = serde_json::json!({ "type": "string" });
745 if let Some(opts) = options {
746 if let Some(title) = opts.title {
747 field["title"] = Value::String(title.to_string());
748 }
749 if let Some(desc) = opts.description {
750 field["description"] = Value::String(desc.to_string());
751 }
752 if let Some(min) = opts.min_length {
753 field["minLength"] = Value::Number(min.into());
754 }
755 if let Some(max) = opts.max_length {
756 field["maxLength"] = Value::Number(max.into());
757 }
758 if let Some(fmt) = &opts.format {
759 field["format"] = Value::String(fmt.as_str().to_string());
760 }
761 if let Some(default) = opts.default {
762 field["default"] = Value::String(default.to_string());
763 }
764 }
765 let schema = serde_json::json!({
766 "type": "object",
767 "properties": { "value": field },
768 "required": ["value"]
769 });
770 let result = self.elicitation(message, schema).await?;
771 if result.action != "accept" {
772 return Ok(None);
773 }
774 let value = result
775 .content
776 .and_then(|c| c.get("value").and_then(|v| v.as_str()).map(String::from));
777 Ok(value)
778 }
779}
780
781impl Client {
782 pub async fn create_session(&self, mut config: SessionConfig) -> Result<Session, Error> {
804 let total_start = Instant::now();
805 let caller_session_id = config.session_id.clone();
813 let use_server_generated_id = config.cloud.is_some() && caller_session_id.is_none();
814 let local_session_id: Option<SessionId> = if use_server_generated_id {
815 None
816 } else {
817 Some(
818 caller_session_id
819 .clone()
820 .unwrap_or_else(|| SessionId::new(uuid::Uuid::new_v4().to_string())),
821 )
822 };
823 if config.hooks_handler.is_some() && config.hooks.is_none() {
824 config.hooks = Some(true);
825 }
826 if let Some(transforms) = config.system_message_transform.clone() {
827 inject_transform_sections(&mut config, transforms.as_ref());
828 }
829 let mode = self.inner.mode;
830 if mode == crate::ClientMode::Empty && config.available_tools.is_none() {
831 return Err(Error::with_message(
832 ErrorKind::InvalidConfig,
833 "ClientMode::Empty requires available_tools to be set on the session config. \
834 Use ToolSet to specify which tools the session may use (e.g. \
835 ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED)).",
836 ));
837 }
838 crate::mode::validate_tool_filter_list(
839 "available_tools",
840 config.available_tools.as_deref(),
841 )?;
842 crate::mode::validate_tool_filter_list("excluded_tools", config.excluded_tools.as_deref())?;
843 config.system_message =
844 crate::mode::system_message_for_mode(mode, config.system_message.take());
845 config.memory = crate::mode::memory_for_mode(mode, config.memory.take());
846 if mode == crate::ClientMode::Empty {
847 if config.enable_session_telemetry.is_none() {
848 config.enable_session_telemetry = Some(false);
849 }
850 if config.skip_embedding_retrieval.is_none() {
851 config.skip_embedding_retrieval = Some(true);
852 }
853 if config.enable_on_demand_instruction_discovery.is_none() {
854 config.enable_on_demand_instruction_discovery = Some(false);
855 }
856 if config.enable_file_hooks.is_none() {
857 config.enable_file_hooks = Some(false);
858 }
859 if config.enable_host_git_operations.is_none() {
860 config.enable_host_git_operations = Some(false);
861 }
862 if config.enable_session_store.is_none() {
863 config.enable_session_store = Some(false);
864 }
865 if config.enable_skills.is_none() {
866 config.enable_skills = Some(false);
867 }
868 }
869 if mode == crate::ClientMode::Empty && config.mcp_oauth_token_storage.is_none() {
870 config.mcp_oauth_token_storage = Some("in-memory".into());
871 }
872 if mode == crate::ClientMode::Empty && config.embedding_cache_storage.is_none() {
873 config.embedding_cache_storage = Some("in-memory".into());
874 }
875 let opt_skip_custom_instructions = config.skip_custom_instructions;
876 let opt_custom_agents_local_only = config.custom_agents_local_only;
877 let opt_coauthor_enabled = config.coauthor_enabled;
878 let opt_manage_schedule_enabled = config.manage_schedule_enabled;
879 let (mut wire, mut runtime) = config.into_wire(local_session_id.clone())?;
880 wire.enable_github_telemetry_forwarding =
881 self.inner.on_github_telemetry.is_some().then_some(true);
882
883 let permission_handler = crate::permission::resolve_handler(
884 runtime.permission_handler.take(),
885 runtime.permission_policy.take(),
886 );
887 let handlers = SessionHandlers {
888 permission: permission_handler,
889 elicitation: runtime.elicitation_handler.take(),
890 mcp_auth: runtime.mcp_auth_handler.take(),
891 user_input: runtime.user_input_handler.take(),
892 exit_plan_mode: runtime.exit_plan_mode_handler.take(),
893 auto_mode_switch: runtime.auto_mode_switch_handler.take(),
894 tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)),
895 };
896 let hooks = runtime.hooks_handler.take();
897 let transforms = runtime.system_message_transform.take();
898 let tools_count = wire.tools.as_ref().map_or(0, Vec::len);
899 let commands_count = runtime.commands.as_ref().map_or(0, Vec::len);
900 let has_hooks = hooks.is_some();
901 let command_handlers = build_command_handler_map(runtime.commands.as_deref());
902 let canvas_handler = runtime.canvas_handler.take();
903 let session_fs_provider = runtime.session_fs_provider.take();
904 let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers);
905 let has_mcp_auth_handler = handlers.mcp_auth.is_some();
906 if self.inner.session_fs_configured && session_fs_provider.is_none() {
907 return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into());
908 }
909 if self.inner.session_fs_sqlite_declared
910 && let Some(ref provider) = session_fs_provider
911 && provider.sqlite().is_none()
912 {
913 return Err(Error::with_message(
914 ErrorKind::InvalidConfig,
915 "SessionFs capabilities declare SQLite support but the provider \
916 does not implement SessionFsSqliteProvider",
917 ));
918 }
919
920 let mut params = serde_json::to_value(&wire)?;
921 let trace_ctx = self.resolve_trace_context().await;
922 inject_trace_context(&mut params, &trace_ctx);
923
924 let setup_start = Instant::now();
925 let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default()));
926 let idle_waiter = Arc::new(ParkingLotMutex::new(None));
927 let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
928 let shutdown = CancellationToken::new();
929 let (event_tx, _) = tokio::sync::broadcast::channel(512);
930
931 let inline_stash: Arc<
937 ParkingLotMutex<Option<(SessionId, crate::router::SessionChannels)>>,
938 > = Arc::new(ParkingLotMutex::new(None));
939
940 let inline_callback: Option<crate::jsonrpc::InlineResponseCallback> = if let Some(ref sid) =
941 local_session_id
942 {
943 let channels = self.register_session(sid);
944 *inline_stash.lock() = Some((sid.clone(), channels));
945 None
946 } else {
947 let client = self.clone();
948 let stash = inline_stash.clone();
949 let expected = caller_session_id.clone();
950 Some(Box::new(move |response| {
951 let result = response.result.as_ref().ok_or_else(|| {
952 Error::with_message(ErrorKind::Json, "session.create response had no result")
953 })?;
954 let parsed: CreateSessionResult =
955 serde_json::from_value(result.clone()).map_err(Error::from)?;
956 if let Some(requested) = expected.as_ref()
957 && parsed.session_id != *requested
958 {
959 return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch {
960 requested: requested.clone(),
961 returned: parsed.session_id,
962 })
963 .into());
964 }
965 let channels = client.register_session(&parsed.session_id);
966 *stash.lock() = Some((parsed.session_id, channels));
967 Ok(())
968 }))
969 };
970
971 let rpc_start = Instant::now();
972 let result = match self
973 .call_with_inline_callback("session.create", Some(params), inline_callback)
974 .await
975 {
976 Ok(result) => result,
977 Err(error) => {
978 if let Some((id, _channels)) = inline_stash.lock().take() {
979 self.unregister_session(&id);
980 }
981 return Err(error);
982 }
983 };
984 tracing::debug!(
985 elapsed_ms = rpc_start.elapsed().as_millis(),
986 "Client::create_session session creation request completed successfully"
987 );
988 let create_result: CreateSessionResult = match serde_json::from_value(result) {
989 Ok(result) => result,
990 Err(error) => {
991 if let Some((id, _channels)) = inline_stash.lock().take() {
992 self.unregister_session(&id);
993 }
994 return Err(error.into());
995 }
996 };
997
998 if let Some(ref requested) = local_session_id
999 && create_result.session_id != *requested
1000 {
1001 if let Some((id, _channels)) = inline_stash.lock().take() {
1002 self.unregister_session(&id);
1003 }
1004 return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch {
1005 requested: requested.clone(),
1006 returned: create_result.session_id.clone(),
1007 })
1008 .into());
1009 }
1010
1011 let (session_id, channels) = inline_stash
1012 .lock()
1013 .take()
1014 .expect("session registration must have populated stash on success");
1015 let event_loop = spawn_event_loop(
1016 session_id.clone(),
1017 self.clone(),
1018 handlers,
1019 hooks,
1020 transforms,
1021 command_handlers,
1022 canvas_handler,
1023 session_fs_provider,
1024 bearer_token_providers,
1025 channels,
1026 idle_waiter.clone(),
1027 capabilities.clone(),
1028 open_canvases.clone(),
1029 event_tx.clone(),
1030 shutdown.clone(),
1031 );
1032 tracing::debug!(
1033 elapsed_ms = setup_start.elapsed().as_millis(),
1034 session_id = %session_id,
1035 tools_count,
1036 commands_count,
1037 has_hooks,
1038 "Client::create_session local setup complete"
1039 );
1040 *capabilities.write() = create_result.capabilities.unwrap_or_default();
1041 if has_mcp_auth_handler {
1042 register_mcp_auth_interest(self, &session_id).await?;
1043 }
1044
1045 tracing::debug!(
1046 elapsed_ms = total_start.elapsed().as_millis(),
1047 session_id = %session_id,
1048 "Client::create_session complete"
1049 );
1050 let session = Session {
1051 id: session_id,
1052 cwd: self.cwd().clone(),
1053 workspace_path: create_result.workspace_path,
1054 remote_url: create_result.remote_url,
1055 client: self.clone(),
1056 event_loop: ParkingLotMutex::new(Some(event_loop)),
1057 shutdown,
1058 idle_waiter,
1059 capabilities,
1060 open_canvases,
1061 event_tx,
1062 };
1063 apply_mode_post_create_patch(
1064 &session,
1065 mode,
1066 opt_skip_custom_instructions,
1067 opt_custom_agents_local_only,
1068 opt_coauthor_enabled,
1069 opt_manage_schedule_enabled,
1070 )
1071 .await?;
1072 Ok(session)
1073 }
1074
1075 pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result<Session, Error> {
1086 let total_start = Instant::now();
1087 let session_id = config.session_id.clone();
1088 if config.hooks_handler.is_some() && config.hooks.is_none() {
1089 config.hooks = Some(true);
1090 }
1091 if let Some(transforms) = config.system_message_transform.clone() {
1092 inject_transform_sections_resume(&mut config, transforms.as_ref());
1093 }
1094 let mode = self.inner.mode;
1095 if mode == crate::ClientMode::Empty && config.available_tools.is_none() {
1096 return Err(Error::with_message(
1097 ErrorKind::InvalidConfig,
1098 "ClientMode::Empty requires available_tools to be set on the session config. \
1099 Use ToolSet to specify which tools the session may use (e.g. \
1100 ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED)).",
1101 ));
1102 }
1103 crate::mode::validate_tool_filter_list(
1104 "available_tools",
1105 config.available_tools.as_deref(),
1106 )?;
1107 crate::mode::validate_tool_filter_list("excluded_tools", config.excluded_tools.as_deref())?;
1108 config.system_message =
1109 crate::mode::system_message_for_mode(mode, config.system_message.take());
1110 config.memory = crate::mode::memory_for_mode(mode, config.memory.take());
1111 if mode == crate::ClientMode::Empty {
1112 if config.enable_session_telemetry.is_none() {
1113 config.enable_session_telemetry = Some(false);
1114 }
1115 if config.skip_embedding_retrieval.is_none() {
1116 config.skip_embedding_retrieval = Some(true);
1117 }
1118 if config.enable_on_demand_instruction_discovery.is_none() {
1119 config.enable_on_demand_instruction_discovery = Some(false);
1120 }
1121 if config.enable_file_hooks.is_none() {
1122 config.enable_file_hooks = Some(false);
1123 }
1124 if config.enable_host_git_operations.is_none() {
1125 config.enable_host_git_operations = Some(false);
1126 }
1127 if config.enable_session_store.is_none() {
1128 config.enable_session_store = Some(false);
1129 }
1130 if config.enable_skills.is_none() {
1131 config.enable_skills = Some(false);
1132 }
1133 }
1134 if mode == crate::ClientMode::Empty && config.mcp_oauth_token_storage.is_none() {
1135 config.mcp_oauth_token_storage = Some("in-memory".into());
1136 }
1137 if mode == crate::ClientMode::Empty && config.embedding_cache_storage.is_none() {
1138 config.embedding_cache_storage = Some("in-memory".into());
1139 }
1140 let opt_skip_custom_instructions = config.skip_custom_instructions;
1141 let opt_custom_agents_local_only = config.custom_agents_local_only;
1142 let opt_coauthor_enabled = config.coauthor_enabled;
1143 let opt_manage_schedule_enabled = config.manage_schedule_enabled;
1144 let (mut wire, mut runtime) = config.into_wire()?;
1145 wire.enable_github_telemetry_forwarding =
1146 self.inner.on_github_telemetry.is_some().then_some(true);
1147
1148 let permission_handler = crate::permission::resolve_handler(
1149 runtime.permission_handler.take(),
1150 runtime.permission_policy.take(),
1151 );
1152 let handlers = SessionHandlers {
1153 permission: permission_handler,
1154 elicitation: runtime.elicitation_handler.take(),
1155 mcp_auth: runtime.mcp_auth_handler.take(),
1156 user_input: runtime.user_input_handler.take(),
1157 exit_plan_mode: runtime.exit_plan_mode_handler.take(),
1158 auto_mode_switch: runtime.auto_mode_switch_handler.take(),
1159 tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)),
1160 };
1161 let hooks = runtime.hooks_handler.take();
1162 let transforms = runtime.system_message_transform.take();
1163 let tools_count = wire.tools.as_ref().map_or(0, Vec::len);
1164 let commands_count = runtime.commands.as_ref().map_or(0, Vec::len);
1165 let has_hooks = hooks.is_some();
1166 let command_handlers = build_command_handler_map(runtime.commands.as_deref());
1167 let canvas_handler = runtime.canvas_handler.take();
1168 let session_fs_provider = runtime.session_fs_provider.take();
1169 let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers);
1170 let has_mcp_auth_handler = handlers.mcp_auth.is_some();
1171 if self.inner.session_fs_configured && session_fs_provider.is_none() {
1172 return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into());
1173 }
1174 if self.inner.session_fs_sqlite_declared
1175 && let Some(ref provider) = session_fs_provider
1176 && provider.sqlite().is_none()
1177 {
1178 return Err(Error::with_message(
1179 ErrorKind::InvalidConfig,
1180 "SessionFs capabilities declare SQLite support but the provider \
1181 does not implement SessionFsSqliteProvider",
1182 ));
1183 }
1184
1185 let mut params = serde_json::to_value(&wire)?;
1186 let trace_ctx = self.resolve_trace_context().await;
1187 inject_trace_context(&mut params, &trace_ctx);
1188
1189 let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default()));
1190 let setup_start = Instant::now();
1191 let channels = self.register_session(&session_id);
1192 let idle_waiter = Arc::new(ParkingLotMutex::new(None));
1193 let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
1194 let shutdown = CancellationToken::new();
1195 let (event_tx, _) = tokio::sync::broadcast::channel(512);
1196 let event_loop = spawn_event_loop(
1197 session_id.clone(),
1198 self.clone(),
1199 handlers,
1200 hooks,
1201 transforms,
1202 command_handlers,
1203 canvas_handler,
1204 session_fs_provider,
1205 bearer_token_providers,
1206 channels,
1207 idle_waiter.clone(),
1208 capabilities.clone(),
1209 open_canvases.clone(),
1210 event_tx.clone(),
1211 shutdown.clone(),
1212 );
1213 let mut registration =
1214 PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone());
1215 tracing::debug!(
1216 elapsed_ms = setup_start.elapsed().as_millis(),
1217 session_id = %session_id,
1218 tools_count,
1219 commands_count,
1220 has_hooks,
1221 "Client::resume_session local setup complete"
1222 );
1223
1224 let rpc_start = Instant::now();
1225 let result = match self.call("session.resume", Some(params)).await {
1226 Ok(result) => result,
1227 Err(error) => {
1228 registration.cleanup(event_loop).await;
1229 return Err(error);
1230 }
1231 };
1232 tracing::debug!(
1233 elapsed_ms = rpc_start.elapsed().as_millis(),
1234 session_id = %session_id,
1235 "Client::resume_session session resume request completed successfully"
1236 );
1237
1238 let resume_result: ResumeSessionResult = match serde_json::from_value(result) {
1239 Ok(result) => result,
1240 Err(error) => {
1241 registration.cleanup(event_loop).await;
1242 return Err(error.into());
1243 }
1244 };
1245 let cli_session_id = resume_result
1246 .session_id
1247 .clone()
1248 .unwrap_or_else(|| session_id.clone());
1249 if cli_session_id != session_id {
1250 registration.cleanup(event_loop).await;
1251 return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch {
1252 requested: session_id,
1253 returned: cli_session_id,
1254 })
1255 .into());
1256 }
1257 if has_mcp_auth_handler {
1258 register_mcp_auth_interest(self, &session_id).await?;
1259 }
1260
1261 let skills_reload_start = Instant::now();
1263 if let Err(e) = self
1264 .call(
1265 "session.skills.reload",
1266 Some(serde_json::json!({ "sessionId": session_id })),
1267 )
1268 .await
1269 {
1270 warn!(
1271 elapsed_ms = skills_reload_start.elapsed().as_millis(),
1272 session_id = %session_id,
1273 error = %e,
1274 "Client::resume_session skills reload request failed"
1275 );
1276 } else {
1277 tracing::debug!(
1278 elapsed_ms = skills_reload_start.elapsed().as_millis(),
1279 session_id = %session_id,
1280 "Client::resume_session skills reload request completed successfully"
1281 );
1282 }
1283
1284 *capabilities.write() = resume_result.capabilities.unwrap_or_default();
1285 {
1290 let mut snapshots = open_canvases.write();
1291 for snapshot in resume_result.open_canvases.unwrap_or_default() {
1292 upsert_open_canvas_snapshot(&mut snapshots, snapshot);
1293 }
1294 }
1295
1296 tracing::debug!(
1297 elapsed_ms = total_start.elapsed().as_millis(),
1298 session_id = %session_id,
1299 "Client::resume_session complete"
1300 );
1301 registration.disarm();
1302 let session = Session {
1303 id: session_id,
1304 cwd: self.cwd().clone(),
1305 workspace_path: resume_result.workspace_path,
1306 remote_url: resume_result.remote_url,
1307 client: self.clone(),
1308 event_loop: ParkingLotMutex::new(Some(event_loop)),
1309 shutdown,
1310 idle_waiter,
1311 capabilities,
1312 open_canvases,
1313 event_tx,
1314 };
1315 apply_mode_post_create_patch(
1316 &session,
1317 mode,
1318 opt_skip_custom_instructions,
1319 opt_custom_agents_local_only,
1320 opt_coauthor_enabled,
1321 opt_manage_schedule_enabled,
1322 )
1323 .await?;
1324 Ok(session)
1325 }
1326}
1327
1328type CommandHandlerMap = HashMap<String, Arc<dyn CommandHandler>>;
1329
1330async fn apply_mode_post_create_patch(
1331 session: &Session,
1332 mode: crate::ClientMode,
1333 opt_skip_custom_instructions: Option<bool>,
1334 opt_custom_agents_local_only: Option<bool>,
1335 opt_coauthor_enabled: Option<bool>,
1336 opt_manage_schedule_enabled: Option<bool>,
1337) -> Result<(), Error> {
1338 use crate::generated::api_types::SessionUpdateOptionsParams;
1339 let mut patch = SessionUpdateOptionsParams::default();
1340 let should_send = if mode == crate::ClientMode::Empty {
1341 patch.skip_custom_instructions = Some(opt_skip_custom_instructions.unwrap_or(true));
1342 patch.custom_agents_local_only = Some(opt_custom_agents_local_only.unwrap_or(true));
1343 patch.coauthor_enabled = Some(opt_coauthor_enabled.unwrap_or(false));
1344 patch.manage_schedule_enabled = Some(opt_manage_schedule_enabled.unwrap_or(false));
1345 patch.installed_plugins = Some(Vec::new());
1346 true
1347 } else {
1348 let mut any = false;
1349 if let Some(v) = opt_skip_custom_instructions {
1350 patch.skip_custom_instructions = Some(v);
1351 any = true;
1352 }
1353 if let Some(v) = opt_custom_agents_local_only {
1354 patch.custom_agents_local_only = Some(v);
1355 any = true;
1356 }
1357 if let Some(v) = opt_coauthor_enabled {
1358 patch.coauthor_enabled = Some(v);
1359 any = true;
1360 }
1361 if let Some(v) = opt_manage_schedule_enabled {
1362 patch.manage_schedule_enabled = Some(v);
1363 any = true;
1364 }
1365 any
1366 };
1367 if !should_send {
1368 return Ok(());
1369 }
1370 if let Err(error) = session.rpc().options().update(patch).await {
1371 let _ = session.disconnect().await;
1372 return Err(error);
1373 }
1374 Ok(())
1375}
1376
1377fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc<CommandHandlerMap> {
1378 let map = match commands {
1379 Some(commands) => commands
1380 .iter()
1381 .filter(|cmd| !cmd.name.is_empty())
1382 .map(|cmd| (cmd.name.clone(), cmd.handler.clone()))
1383 .collect(),
1384 None => HashMap::new(),
1385 };
1386 Arc::new(map)
1387}
1388
1389fn upsert_open_canvas_snapshot(
1390 snapshots: &mut Vec<OpenCanvasInstance>,
1391 snapshot: OpenCanvasInstance,
1392) {
1393 if let Some(existing) = snapshots
1394 .iter_mut()
1395 .find(|open| open.instance_id == snapshot.instance_id)
1396 {
1397 *existing = snapshot;
1398 } else {
1399 snapshots.push(snapshot);
1400 }
1401}
1402
1403fn remove_open_canvas_snapshot(snapshots: &mut Vec<OpenCanvasInstance>, instance_id: &str) {
1404 snapshots.retain(|open| open.instance_id != instance_id);
1405}
1406
1407#[allow(clippy::too_many_arguments)]
1408fn spawn_event_loop(
1409 session_id: SessionId,
1410 client: Client,
1411 handlers: SessionHandlers,
1412 hooks: Option<Arc<dyn SessionHooks>>,
1413 transforms: Option<Arc<dyn SystemMessageTransform>>,
1414 command_handlers: Arc<CommandHandlerMap>,
1415 canvas_handler: Option<Arc<dyn CanvasHandler>>,
1416 session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
1417 bearer_token_providers: HashMap<String, Arc<dyn BearerTokenProvider>>,
1418 channels: crate::router::SessionChannels,
1419 idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
1420 capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
1421 open_canvases: Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
1422 event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
1423 shutdown: CancellationToken,
1424) -> JoinHandle<()> {
1425 let crate::router::SessionChannels {
1426 mut notifications,
1427 mut requests,
1428 } = channels;
1429
1430 let span = tracing::error_span!("session_event_loop", session_id = %session_id);
1431 tokio::spawn(
1432 async move {
1433 loop {
1434 tokio::select! {
1445 _ = shutdown.cancelled() => break,
1446 Some(notification) = notifications.recv() => {
1447 handle_notification(
1448 &session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx,
1449 ).await;
1450 }
1451 Some(request) = requests.recv() => {
1452 let ctx = RequestDispatchContext {
1453 client: &client,
1454 handlers: &handlers,
1455 hooks: hooks.as_deref(),
1456 transforms: transforms.as_deref(),
1457 canvas_handler: canvas_handler.as_ref(),
1458 session_fs_provider: session_fs_provider.as_ref(),
1459 bearer_token_providers: &bearer_token_providers,
1460 };
1461 handle_request(&session_id, ctx, request).await;
1462 }
1463 else => break,
1464 }
1465 }
1466 if let Some(waiter) = idle_waiter.lock().take() {
1469 let _ = waiter
1470 .tx
1471 .send(Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()));
1472 }
1473 }
1474 .instrument(span),
1475 )
1476}
1477
1478fn extract_request_id(data: &Value) -> Option<RequestId> {
1479 data.get("requestId")
1480 .and_then(|v| v.as_str())
1481 .filter(|s| !s.is_empty())
1482 .map(RequestId::new)
1483}
1484
1485fn notification_permission_payload(result: &PermissionResult) -> Option<Value> {
1490 match result {
1491 PermissionResult::NoResult => None,
1492 PermissionResult::Decision(decision) => Some(
1493 serde_json::to_value(decision).expect("serializing permission decision should succeed"),
1494 ),
1495 }
1496}
1497
1498async fn register_mcp_auth_interest(client: &Client, session_id: &SessionId) -> Result<(), Error> {
1499 let mut params = serde_json::to_value(RegisterEventInterestParams {
1500 event_type: "mcp.oauth_required".to_string(),
1501 })?;
1502 params["sessionId"] = Value::String(session_id.to_string());
1503 client
1504 .call(rpc_methods::SESSION_EVENTLOG_REGISTERINTEREST, Some(params))
1505 .await?;
1506 Ok(())
1507}
1508
1509fn tool_failure_result(message: impl Into<String>) -> ToolResult {
1510 let message = message.into();
1511 ToolResult::Expanded(ToolResultExpanded {
1512 text_result_for_llm: message.clone(),
1513 result_type: "failure".to_string(),
1514 binary_results_for_llm: None,
1515 session_log: None,
1516 error: Some(message),
1517 tool_telemetry: None,
1518 })
1519}
1520
1521#[allow(clippy::too_many_arguments)]
1523async fn handle_notification(
1524 session_id: &SessionId,
1525 client: &Client,
1526 handlers: &SessionHandlers,
1527 command_handlers: &Arc<CommandHandlerMap>,
1528 notification: SessionEventNotification,
1529 idle_waiter: &Arc<ParkingLotMutex<Option<IdleWaiter>>>,
1530 capabilities: &Arc<parking_lot::RwLock<SessionCapabilities>>,
1531 open_canvases: &Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
1532 event_tx: &tokio::sync::broadcast::Sender<SessionEvent>,
1533) {
1534 let dispatch_start = Instant::now();
1535 let event = notification.event.clone();
1536 let event_type = event.parsed_type();
1537 if event_type == SessionEventType::PermissionRequested {
1538 tracing::debug!(
1539 session_id = %session_id,
1540 event_type = %event.event_type,
1541 "Session::handle_notification permission request received"
1542 );
1543 }
1544
1545 match event_type {
1548 SessionEventType::AssistantMessage
1549 | SessionEventType::SessionIdle
1550 | SessionEventType::SessionError => {
1551 let mut guard = idle_waiter.lock();
1552 if let Some(waiter) = guard.as_mut() {
1553 match event_type {
1554 SessionEventType::AssistantMessage => {
1555 if !waiter.first_assistant_message_seen {
1556 waiter.first_assistant_message_seen = true;
1557 tracing::debug!(
1558 elapsed_ms = waiter.started_at.elapsed().as_millis(),
1559 session_id = %session_id,
1560 "Session::send_and_wait first assistant message"
1561 );
1562 }
1563 waiter.last_assistant_message = Some(event.clone());
1564 }
1565 SessionEventType::SessionIdle | SessionEventType::SessionError => {
1566 if let Some(waiter) = guard.take() {
1567 if event_type == SessionEventType::SessionIdle {
1568 tracing::debug!(
1569 elapsed_ms = waiter.started_at.elapsed().as_millis(),
1570 session_id = %session_id,
1571 "Session::send_and_wait idle received"
1572 );
1573 let _ = waiter.tx.send(Ok(waiter.last_assistant_message));
1574 } else {
1575 let error_msg = event
1576 .typed_data::<SessionErrorData>()
1577 .map(|d| d.message)
1578 .or_else(|| {
1579 event
1580 .data
1581 .get("message")
1582 .and_then(|v| v.as_str())
1583 .map(|s| s.to_string())
1584 })
1585 .unwrap_or_else(|| "session error".to_string());
1586 let _ = waiter.tx.send(Err(Error::with_message(
1587 ErrorKind::Session(SessionErrorKind::AgentError),
1588 error_msg,
1589 )));
1590 }
1591 }
1592 }
1593 _ => {}
1594 }
1595 }
1596 }
1597 _ => {}
1598 }
1599
1600 if event_type == SessionEventType::CapabilitiesChanged {
1604 match serde_json::from_value::<SessionCapabilities>(notification.event.data.clone()) {
1605 Ok(changed) => *capabilities.write() = changed,
1606 Err(e) => warn!(error = %e, "failed to deserialize capabilities.changed payload"),
1607 }
1608 }
1609 if event_type == SessionEventType::SessionCanvasOpened {
1610 match serde_json::from_value::<OpenCanvasInstance>(notification.event.data.clone()) {
1611 Ok(open_canvas) => {
1612 upsert_open_canvas_snapshot(&mut open_canvases.write(), open_canvas);
1613 }
1614 Err(e) => warn!(error = %e, "failed to deserialize session.canvas.opened payload"),
1615 }
1616 }
1617 if event_type == SessionEventType::SessionCanvasClosed {
1618 match serde_json::from_value::<SessionCanvasClosedData>(notification.event.data.clone()) {
1619 Ok(closed) => {
1620 if closed.instance_id.is_empty() {
1621 warn!("failed to deserialize session.canvas.closed payload");
1622 } else {
1623 remove_open_canvas_snapshot(&mut open_canvases.write(), &closed.instance_id);
1624 }
1625 }
1626 Err(e) => warn!(error = %e, "failed to deserialize session.canvas.closed payload"),
1627 }
1628 }
1629
1630 let _ = event_tx.send(event.clone());
1634
1635 tracing::debug!(
1636 elapsed_ms = dispatch_start.elapsed().as_millis(),
1637 session_id = %session_id,
1638 event_type = %notification.event.event_type,
1639 "Session::handle_notification dispatch"
1640 );
1641
1642 match event_type {
1645 SessionEventType::PermissionRequested => {
1646 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1647 return;
1648 };
1649 if notification
1653 .event
1654 .data
1655 .get("resolvedByHook")
1656 .and_then(|v| v.as_bool())
1657 .unwrap_or(false)
1658 {
1659 return;
1660 }
1661 let Some(permission_handler) = handlers.permission.clone() else {
1665 return;
1666 };
1667 let client = client.clone();
1668 let sid = session_id.clone();
1669 let data: PermissionRequestData =
1670 serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| {
1671 PermissionRequestData {
1672 kind: None,
1673 tool_call_id: None,
1674 extra: notification.event.data.clone(),
1675 }
1676 });
1677 let span = tracing::error_span!(
1678 "permission_request_handler",
1679 session_id = %sid,
1680 request_id = %request_id
1681 );
1682 tokio::spawn(
1683 async move {
1684 let handler_start = Instant::now();
1685 let result = permission_handler
1686 .handle(sid.clone(), request_id.clone(), data)
1687 .await;
1688 tracing::debug!(
1689 elapsed_ms = handler_start.elapsed().as_millis(),
1690 session_id = %sid,
1691 request_id = %request_id,
1692 "PermissionHandler::handle dispatch"
1693 );
1694 let Some(result_value) = notification_permission_payload(&result) else {
1695 return;
1699 };
1700 let rpc_start = Instant::now();
1701 let _ = client
1702 .call(
1703 "session.permissions.handlePendingPermissionRequest",
1704 Some(serde_json::json!({
1705 "sessionId": sid,
1706 "requestId": request_id,
1707 "result": result_value,
1708 })),
1709 )
1710 .await;
1711 tracing::debug!(
1712 elapsed_ms = rpc_start.elapsed().as_millis(),
1713 session_id = %sid,
1714 request_id = %request_id,
1715 "Session::handle_notification response sent successfully"
1716 );
1717 }
1718 .instrument(span),
1719 );
1720 }
1721 SessionEventType::ExternalToolRequested => {
1722 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1723 return;
1724 };
1725 let data: ExternalToolRequestedData =
1726 match serde_json::from_value(notification.event.data.clone()) {
1727 Ok(d) => d,
1728 Err(e) => {
1729 warn!(error = %e, "failed to deserialize external_tool.requested");
1730 let client = client.clone();
1731 let sid = session_id.clone();
1732 let span = tracing::error_span!(
1733 "external_tool_deserialize_error",
1734 session_id = %sid,
1735 request_id = %request_id
1736 );
1737 tokio::spawn(
1738 async move {
1739 let rpc_start = Instant::now();
1740 let _ = client
1741 .call(
1742 "session.tools.handlePendingToolCall",
1743 Some(serde_json::json!({
1744 "sessionId": sid,
1745 "requestId": request_id,
1746 "error": format!("Failed to deserialize tool request: {e}"),
1747 })),
1748 )
1749 .await;
1750 tracing::debug!(
1751 elapsed_ms = rpc_start.elapsed().as_millis(),
1752 session_id = %sid,
1753 request_id = %request_id,
1754 "Session::handle_notification response sent successfully"
1755 );
1756 }
1757 .instrument(span),
1758 );
1759 return;
1760 }
1761 };
1762 let tool_handler = if data.tool_name.is_empty() {
1766 None
1767 } else {
1768 handlers.tools.get(&data.tool_name).cloned()
1769 };
1770 let Some(tool_handler) = tool_handler else {
1771 return;
1772 };
1773 let client = client.clone();
1774 let sid = session_id.clone();
1775 let span = tracing::error_span!(
1776 "external_tool_handler",
1777 session_id = %sid,
1778 request_id = %request_id
1779 );
1780 tokio::spawn(
1781 async move {
1782 if data.tool_call_id.is_empty() {
1787 let error_msg = "Missing toolCallId";
1788 let rpc_start = Instant::now();
1789 let _ = client
1790 .call(
1791 "session.tools.handlePendingToolCall",
1792 Some(serde_json::json!({
1793 "sessionId": sid,
1794 "requestId": request_id,
1795 "error": error_msg,
1796 })),
1797 )
1798 .await;
1799 tracing::debug!(
1800 elapsed_ms = rpc_start.elapsed().as_millis(),
1801 session_id = %sid,
1802 request_id = %request_id,
1803 "Session::handle_notification response sent successfully"
1804 );
1805 return;
1806 }
1807 let tool_call_id = data.tool_call_id.clone();
1808 let tool_name = data.tool_name.clone();
1809 let invocation = ToolInvocation {
1810 session_id: sid.clone(),
1811 tool_call_id: data.tool_call_id,
1812 tool_name: data.tool_name,
1813 arguments: data
1814 .arguments
1815 .unwrap_or(Value::Object(serde_json::Map::new())),
1816 traceparent: data.traceparent,
1817 tracestate: data.tracestate,
1818 };
1819 let handler_start = Instant::now();
1820 let tool_result = match tool_handler.call(invocation).await {
1821 Ok(r) => r,
1822 Err(e) => tool_failure_result(e.to_string()),
1823 };
1824 tracing::debug!(
1825 elapsed_ms = handler_start.elapsed().as_millis(),
1826 session_id = %sid,
1827 request_id = %request_id,
1828 tool_call_id = %tool_call_id,
1829 tool_name = %tool_name,
1830 "ToolHandler::call dispatch"
1831 );
1832 let result_value = serde_json::to_value(tool_result).unwrap_or(Value::Null);
1833 let rpc_start = Instant::now();
1834 let _ = client
1835 .call(
1836 "session.tools.handlePendingToolCall",
1837 Some(serde_json::json!({
1838 "sessionId": sid,
1839 "requestId": request_id,
1840 "result": result_value,
1841 })),
1842 )
1843 .await;
1844 tracing::debug!(
1845 elapsed_ms = rpc_start.elapsed().as_millis(),
1846 session_id = %sid,
1847 request_id = %request_id,
1848 tool_call_id = %tool_call_id,
1849 tool_name = %tool_name,
1850 "Session::handle_notification response sent successfully"
1851 );
1852 }
1853 .instrument(span),
1854 );
1855 }
1856 SessionEventType::UserInputRequested => {
1857 }
1864 SessionEventType::ElicitationRequested => {
1865 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1866 return;
1867 };
1868 let Some(elicitation_handler) = handlers.elicitation.clone() else {
1872 return;
1873 };
1874 let elicitation_data: ElicitationRequestedData =
1875 match serde_json::from_value(notification.event.data.clone()) {
1876 Ok(d) => d,
1877 Err(e) => {
1878 warn!(error = %e, "failed to deserialize elicitation request");
1879 return;
1880 }
1881 };
1882 let request = ElicitationRequest {
1883 message: elicitation_data.message,
1884 requested_schema: elicitation_data
1885 .requested_schema
1886 .map(|s| serde_json::to_value(s).unwrap_or(Value::Null)),
1887 mode: elicitation_data.mode.map(|m| match m {
1888 crate::generated::session_events::ElicitationRequestedMode::Form => {
1889 crate::types::ElicitationMode::Form
1890 }
1891 crate::generated::session_events::ElicitationRequestedMode::Url => {
1892 crate::types::ElicitationMode::Url
1893 }
1894 _ => crate::types::ElicitationMode::Unknown,
1895 }),
1896 elicitation_source: elicitation_data.elicitation_source,
1897 url: elicitation_data.url,
1898 };
1899 let client = client.clone();
1900 let sid = session_id.clone();
1901 let span = tracing::error_span!(
1902 "elicitation_request_handler",
1903 session_id = %sid,
1904 request_id = %request_id
1905 );
1906 tokio::spawn(
1907 async move {
1908 let cancel = ElicitationResult {
1909 action: "cancel".to_string(),
1910 content: None,
1911 };
1912 let handler_task = tokio::spawn({
1914 let sid = sid.clone();
1915 let request_id = request_id.clone();
1916 let span = tracing::error_span!(
1917 "elicitation_callback",
1918 session_id = %sid,
1919 request_id = %request_id
1920 );
1921 async move {
1922 let handler_start = Instant::now();
1923 let response = elicitation_handler
1924 .handle(sid.clone(), request_id.clone(), request)
1925 .await;
1926 tracing::debug!(
1927 elapsed_ms = handler_start.elapsed().as_millis(),
1928 session_id = %sid,
1929 request_id = %request_id,
1930 "ElicitationHandler::handle dispatch"
1931 );
1932 response
1933 }
1934 .instrument(span)
1935 });
1936 let result = match handler_task.await {
1937 Ok(r) => r,
1938 Err(_) => cancel.clone(),
1939 };
1940 let rpc_start = Instant::now();
1941 if let Err(e) = client
1942 .call(
1943 "session.ui.handlePendingElicitation",
1944 Some(serde_json::json!({
1945 "sessionId": sid,
1946 "requestId": request_id,
1947 "result": result,
1948 })),
1949 )
1950 .await
1951 {
1952 warn!(error = %e, "handlePendingElicitation failed, sending cancel");
1954 let _ = client
1955 .call(
1956 "session.ui.handlePendingElicitation",
1957 Some(serde_json::json!({
1958 "sessionId": sid,
1959 "requestId": request_id,
1960 "result": cancel,
1961 })),
1962 )
1963 .await;
1964 } else {
1965 tracing::debug!(
1966 elapsed_ms = rpc_start.elapsed().as_millis(),
1967 session_id = %sid,
1968 request_id = %request_id,
1969 "Session::handle_notification response sent successfully"
1970 );
1971 }
1972 }
1973 .instrument(span),
1974 );
1975 }
1976 SessionEventType::McpOauthRequired => {
1977 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1978 return;
1979 };
1980 let Some(mcp_auth_handler) = handlers.mcp_auth.clone() else {
1981 warn!(
1982 session_id = %session_id,
1983 request_id = %request_id,
1984 "received MCP OAuth request without a registered MCP auth handler"
1985 );
1986 return;
1987 };
1988 let data: McpOauthRequiredData =
1989 match serde_json::from_value(notification.event.data.clone()) {
1990 Ok(d) => d,
1991 Err(e) => {
1992 warn!(error = %e, "failed to deserialize MCP OAuth request");
1993 return;
1994 }
1995 };
1996 let request = McpAuthRequest {
1997 request_id: request_id.clone(),
1998 server_name: data.server_name,
1999 server_url: data.server_url,
2000 reason: data.reason,
2001 www_authenticate_params: data.www_authenticate_params,
2002 resource_metadata: data.resource_metadata,
2003 static_client_config: data.static_client_config,
2004 };
2005 let client = client.clone();
2006 let sid = session_id.clone();
2007 let span = tracing::error_span!(
2008 "mcp_auth_request_handler",
2009 session_id = %sid,
2010 request_id = %request_id
2011 );
2012 tokio::spawn(
2013 async move {
2014 let cancel = McpAuthResult::Cancelled;
2015 let handler_task = tokio::spawn({
2016 let sid = sid.clone();
2017 let request_id = request_id.clone();
2018 let span = tracing::error_span!(
2019 "mcp_auth_callback",
2020 session_id = %sid,
2021 request_id = %request_id
2022 );
2023 async move {
2024 let handler_start = Instant::now();
2025 let response = mcp_auth_handler
2026 .handle(sid.clone(), request_id.clone(), request)
2027 .await;
2028 tracing::debug!(
2029 elapsed_ms = handler_start.elapsed().as_millis(),
2030 session_id = %sid,
2031 request_id = %request_id,
2032 "McpAuthHandler::handle dispatch"
2033 );
2034 response
2035 }
2036 .instrument(span)
2037 });
2038 let result = match handler_task.await {
2039 Ok(result) => result,
2040 Err(_) => cancel,
2041 };
2042 let rpc_start = Instant::now();
2043 let _ = client
2044 .call(
2045 "session.mcp.oauth.handlePendingRequest",
2046 Some(serde_json::json!({
2047 "sessionId": sid,
2048 "requestId": request_id,
2049 "result": result.into_wire(),
2050 })),
2051 )
2052 .await;
2053 tracing::debug!(
2054 elapsed_ms = rpc_start.elapsed().as_millis(),
2055 "Session::handle_notification MCP auth response sent"
2056 );
2057 }
2058 .instrument(span),
2059 );
2060 }
2061 SessionEventType::CommandExecute => {
2062 let data: CommandExecuteData =
2063 match serde_json::from_value(notification.event.data.clone()) {
2064 Ok(d) => d,
2065 Err(e) => {
2066 warn!(error = %e, "failed to deserialize command.execute");
2067 return;
2068 }
2069 };
2070 let client = client.clone();
2071 let command_handlers = command_handlers.clone();
2072 let sid = session_id.clone();
2073 let span = tracing::error_span!("command_handler", session_id = %sid);
2074 tokio::spawn(
2075 async move {
2076 let request_id = data.request_id;
2077 let ack_error = match command_handlers.get(&data.command_name).cloned() {
2078 None => Some(format!("Unknown command: {}", data.command_name)),
2079 Some(handler) => {
2080 let command_name = data.command_name.clone();
2081 let ctx = CommandContext {
2082 session_id: sid.clone(),
2083 command: data.command,
2084 command_name: data.command_name,
2085 args: data.args,
2086 };
2087 let handler_start = Instant::now();
2088 let result = handler.on_command(ctx).await;
2089 tracing::debug!(
2090 elapsed_ms = handler_start.elapsed().as_millis(),
2091 session_id = %sid,
2092 request_id = %request_id,
2093 command_name = %command_name,
2094 "CommandHandler::call dispatch"
2095 );
2096 match result {
2097 Ok(()) => None,
2098 Err(e) => Some(e.to_string()),
2099 }
2100 }
2101 };
2102 let mut params = serde_json::json!({
2103 "sessionId": sid,
2104 "requestId": request_id,
2105 });
2106 if let Some(error_msg) = ack_error {
2107 params["error"] = serde_json::Value::String(error_msg);
2108 }
2109 let rpc_start = Instant::now();
2110 let _ = client
2111 .call("session.commands.handlePendingCommand", Some(params))
2112 .await;
2113 tracing::debug!(
2114 elapsed_ms = rpc_start.elapsed().as_millis(),
2115 session_id = %sid,
2116 request_id = %request_id,
2117 "Session::handle_notification response sent successfully"
2118 );
2119 }
2120 .instrument(span),
2121 );
2122 }
2123 _ => {}
2124 }
2125}
2126
2127struct RequestDispatchContext<'a> {
2128 client: &'a Client,
2129 handlers: &'a SessionHandlers,
2130 hooks: Option<&'a dyn SessionHooks>,
2131 transforms: Option<&'a dyn SystemMessageTransform>,
2132 canvas_handler: Option<&'a Arc<dyn CanvasHandler>>,
2133 session_fs_provider: Option<&'a Arc<dyn SessionFsProvider>>,
2134 bearer_token_providers: &'a HashMap<String, Arc<dyn BearerTokenProvider>>,
2135}
2136
2137async fn handle_request(
2139 session_id: &SessionId,
2140 ctx: RequestDispatchContext<'_>,
2141 request: crate::JsonRpcRequest,
2142) {
2143 let sid = session_id.clone();
2144 let client = ctx.client;
2145 let handlers = ctx.handlers;
2146 let hooks = ctx.hooks;
2147 let transforms = ctx.transforms;
2148 let canvas_handler = ctx.canvas_handler;
2149 let session_fs_provider = ctx.session_fs_provider;
2150 let bearer_token_providers = ctx.bearer_token_providers;
2151
2152 if request.method.starts_with("sessionFs.") {
2153 crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await;
2154 return;
2155 }
2156
2157 if request.method.starts_with("canvas.") {
2158 crate::canvas_dispatch::dispatch(client, canvas_handler, request).await;
2159 return;
2160 }
2161
2162 if request.method == crate::generated::api_types::rpc_methods::PROVIDERTOKEN_GETTOKEN {
2163 crate::provider_token_dispatch::dispatch(client, bearer_token_providers, request).await;
2164 return;
2165 }
2166
2167 match request.method.as_str() {
2168 "hooks.invoke" => {
2169 let params = request.params.as_ref();
2170 let hook_type = params
2171 .and_then(|p| p.get("hookType"))
2172 .and_then(|v| v.as_str())
2173 .unwrap_or("");
2174 let input = params
2175 .and_then(|p| p.get("input"))
2176 .cloned()
2177 .unwrap_or(Value::Object(Default::default()));
2178
2179 let rpc_result = if let Some(hooks) = hooks {
2180 match crate::hooks::dispatch_hook(hooks, &sid, hook_type, input).await {
2181 Ok(output) => output,
2182 Err(e) => {
2183 warn!(error = %e, hook_type = hook_type, "hook dispatch failed");
2184 serde_json::json!({ "output": {} })
2185 }
2186 }
2187 } else {
2188 serde_json::json!({ "output": {} })
2189 };
2190
2191 let rpc_response = JsonRpcResponse {
2192 jsonrpc: "2.0".to_string(),
2193 id: request.id,
2194 result: Some(rpc_result),
2195 error: None,
2196 };
2197 let _ = client.send_response(&rpc_response).await;
2198 }
2199
2200 "userInput.request" => {
2201 let params = request.params.as_ref();
2202 let Some(question) = params
2203 .and_then(|p| p.get("question"))
2204 .and_then(|v| v.as_str())
2205 else {
2206 warn!("userInput.request missing 'question' field");
2207 let rpc_response = JsonRpcResponse {
2208 jsonrpc: "2.0".to_string(),
2209 id: request.id,
2210 result: None,
2211 error: Some(crate::JsonRpcError {
2212 code: error_codes::INVALID_PARAMS,
2213 message: "missing required field: question".to_string(),
2214 data: None,
2215 }),
2216 };
2217 let _ = client.send_response(&rpc_response).await;
2218 return;
2219 };
2220 let question = question.to_string();
2221 let choices = params
2222 .and_then(|p| p.get("choices"))
2223 .and_then(|v| v.as_array())
2224 .map(|arr| {
2225 arr.iter()
2226 .filter_map(|v| v.as_str().map(|s| s.to_string()))
2227 .collect()
2228 });
2229 let allow_freeform = params
2230 .and_then(|p| p.get("allowFreeform"))
2231 .and_then(|v| v.as_bool());
2232
2233 let handler_start = Instant::now();
2234 let response = if let Some(user_input_handler) = handlers.user_input.as_ref() {
2235 user_input_handler
2236 .handle(sid.clone(), question, choices, allow_freeform)
2237 .await
2238 } else {
2239 None
2240 };
2241 tracing::debug!(
2242 elapsed_ms = handler_start.elapsed().as_millis(),
2243 session_id = %sid,
2244 "UserInputHandler::handle dispatch"
2245 );
2246
2247 let rpc_result = match response {
2248 Some(UserInputResponse {
2249 answer,
2250 was_freeform,
2251 }) => serde_json::json!({
2252 "answer": answer,
2253 "wasFreeform": was_freeform,
2254 }),
2255 None => serde_json::json!({ "noResponse": true }),
2256 };
2257 let rpc_response = JsonRpcResponse {
2258 jsonrpc: "2.0".to_string(),
2259 id: request.id,
2260 result: Some(rpc_result),
2261 error: None,
2262 };
2263 let _ = client.send_response(&rpc_response).await;
2264 }
2265
2266 "exitPlanMode.request" => {
2267 let params = request
2268 .params
2269 .as_ref()
2270 .cloned()
2271 .unwrap_or(Value::Object(serde_json::Map::new()));
2272 let data: ExitPlanModeData = match serde_json::from_value(params) {
2273 Ok(d) => d,
2274 Err(e) => {
2275 warn!(error = %e, "failed to deserialize exitPlanMode.request params, using defaults");
2276 ExitPlanModeData::default()
2277 }
2278 };
2279
2280 let rpc_result = if let Some(exit_plan_handler) = handlers.exit_plan_mode.as_ref() {
2281 let result = exit_plan_handler.handle(sid, data).await;
2282 serde_json::to_value(result).expect("ExitPlanModeResult serialization cannot fail")
2283 } else {
2284 serde_json::json!({ "approved": true })
2285 };
2286 let rpc_response = JsonRpcResponse {
2287 jsonrpc: "2.0".to_string(),
2288 id: request.id,
2289 result: Some(rpc_result),
2290 error: None,
2291 };
2292 let _ = client.send_response(&rpc_response).await;
2293 }
2294
2295 "autoModeSwitch.request" => {
2296 let error_code = request
2297 .params
2298 .as_ref()
2299 .and_then(|p| p.get("errorCode"))
2300 .and_then(|v| v.as_str())
2301 .map(|s| s.to_string());
2302 let retry_after_seconds = request
2303 .params
2304 .as_ref()
2305 .and_then(|p| p.get("retryAfterSeconds"))
2306 .and_then(|v| v.as_f64());
2307
2308 let answer = if let Some(auto_mode_handler) = handlers.auto_mode_switch.as_ref() {
2309 auto_mode_handler
2310 .handle(sid, error_code, retry_after_seconds)
2311 .await
2312 } else {
2313 AutoModeSwitchResponse::No
2314 };
2315 let rpc_response = JsonRpcResponse {
2316 jsonrpc: "2.0".to_string(),
2317 id: request.id,
2318 result: Some(serde_json::json!({ "response": answer })),
2319 error: None,
2320 };
2321 let _ = client.send_response(&rpc_response).await;
2322 }
2323
2324 "systemMessage.transform" => {
2325 let params = request.params.as_ref();
2326 let sections: HashMap<String, crate::transforms::TransformSection> =
2327 match params.and_then(|p| p.get("sections")) {
2328 Some(v) => match serde_json::from_value(v.clone()) {
2329 Ok(s) => s,
2330 Err(e) => {
2331 let _ = send_error_response(
2332 client,
2333 request.id,
2334 error_codes::INVALID_PARAMS,
2335 &format!("invalid sections: {e}"),
2336 )
2337 .await;
2338 return;
2339 }
2340 },
2341 None => {
2342 let _ = send_error_response(
2343 client,
2344 request.id,
2345 error_codes::INVALID_PARAMS,
2346 "missing sections parameter",
2347 )
2348 .await;
2349 return;
2350 }
2351 };
2352
2353 let rpc_result = if let Some(transforms) = transforms {
2354 let transform_start = Instant::now();
2355 let response =
2356 crate::transforms::dispatch_transform(transforms, &sid, sections).await;
2357 tracing::debug!(
2358 elapsed_ms = transform_start.elapsed().as_millis(),
2359 session_id = %sid,
2360 "SystemMessageTransform::transform_section dispatch"
2361 );
2362 match serde_json::to_value(response) {
2363 Ok(v) => v,
2364 Err(e) => {
2365 warn!(error = %e, "failed to serialize transform response");
2366 serde_json::json!({ "sections": {} })
2367 }
2368 }
2369 } else {
2370 let passthrough: HashMap<String, crate::transforms::TransformSection> = sections;
2372 serde_json::json!({ "sections": passthrough })
2373 };
2374
2375 let rpc_response = JsonRpcResponse {
2376 jsonrpc: "2.0".to_string(),
2377 id: request.id,
2378 result: Some(rpc_result),
2379 error: None,
2380 };
2381 let _ = client.send_response(&rpc_response).await;
2382 }
2383
2384 method => {
2385 warn!(
2386 method = method,
2387 "unhandled request method in session event loop"
2388 );
2389 let _ = send_error_response(
2390 client,
2391 request.id,
2392 error_codes::METHOD_NOT_FOUND,
2393 &format!("unknown method: {method}"),
2394 )
2395 .await;
2396 }
2397 }
2398}
2399
2400async fn send_error_response(
2401 client: &Client,
2402 id: u64,
2403 code: i32,
2404 message: &str,
2405) -> Result<(), Error> {
2406 let response = JsonRpcResponse {
2407 jsonrpc: "2.0".to_string(),
2408 id,
2409 result: None,
2410 error: Some(crate::JsonRpcError {
2411 code,
2412 message: message.to_string(),
2413 data: None,
2414 }),
2415 };
2416 client.send_response(&response).await
2417}
2418
2419fn apply_transform_sections(
2423 sys_msg: &mut SystemMessageConfig,
2424 transforms: &dyn SystemMessageTransform,
2425) {
2426 sys_msg.mode = Some("customize".to_string());
2427 let sections = sys_msg.sections.get_or_insert_with(HashMap::new);
2428 for id in transforms.section_ids() {
2429 sections.entry(id).or_insert_with(|| SectionOverride {
2430 action: Some("transform".to_string()),
2431 content: None,
2432 });
2433 }
2434}
2435
2436fn inject_transform_sections(config: &mut SessionConfig, transforms: &dyn SystemMessageTransform) {
2437 let sys_msg = config.system_message.get_or_insert_with(Default::default);
2438 apply_transform_sections(sys_msg, transforms);
2439}
2440
2441fn inject_transform_sections_resume(
2442 config: &mut ResumeSessionConfig,
2443 transforms: &dyn SystemMessageTransform,
2444) {
2445 let sys_msg = config.system_message.get_or_insert_with(Default::default);
2446 apply_transform_sections(sys_msg, transforms);
2447}
2448
2449#[cfg(test)]
2450mod tests {
2451 use serde_json::json;
2452
2453 use super::notification_permission_payload;
2454 use crate::handler::PermissionResult;
2455
2456 #[test]
2457 fn notification_payload_suppresses_no_result() {
2458 assert!(notification_permission_payload(&PermissionResult::NoResult).is_none());
2459 }
2460
2461 #[test]
2462 fn notification_payload_serializes_decisions() {
2463 assert_eq!(
2464 notification_permission_payload(&PermissionResult::approve_once()),
2465 Some(json!({ "kind": "approve-once" }))
2466 );
2467 assert_eq!(
2468 notification_permission_payload(&PermissionResult::reject(None)),
2469 Some(json!({ "kind": "reject" }))
2470 );
2471 assert_eq!(
2472 notification_permission_payload(&PermissionResult::reject(Some("bad".to_string()))),
2473 Some(json!({ "kind": "reject", "feedback": "bad" }))
2474 );
2475 assert_eq!(
2476 notification_permission_payload(&PermissionResult::user_not_available()),
2477 Some(json!({ "kind": "user-not-available" }))
2478 );
2479 }
2480}