1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use parking_lot::Mutex as ParkingLotMutex;
7use serde_json::Value;
8use tokio::sync::oneshot;
9use tokio::task::JoinHandle;
10use tokio_util::sync::CancellationToken;
11use tracing::{Instrument, warn};
12
13use crate::canvas::CanvasHandler;
14use crate::generated::api_types::{
15 LogRequest, ModelSwitchToRequest, OpenCanvasInstance, RegisterEventInterestParams, rpc_methods,
16};
17use crate::generated::session_events::{
18 CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, McpOauthRequiredData,
19 SessionCanvasClosedData, SessionErrorData, SessionEventType,
20};
21use crate::handler::{
22 AutoModeSwitchHandler, AutoModeSwitchResponse, ElicitationHandler, ExitPlanModeHandler,
23 McpAuthHandler, McpAuthRequest, McpAuthResult, PermissionHandler, PermissionResult,
24 UserInputHandler, UserInputResponse,
25};
26use crate::hooks::SessionHooks;
27use crate::provider_token::BearerTokenProvider;
28use crate::session_fs::SessionFsProvider;
29use crate::trace_context::inject_trace_context;
30use crate::transforms::SystemMessageTransform;
31use crate::types::{
32 CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest,
33 ElicitationResult, ExitPlanModeData, GetMessagesResponse, MessageOptions,
34 PermissionRequestData, RequestId, ResumeSessionConfig, ResumeSessionResult, SectionOverride,
35 SessionCapabilities, SessionConfig, SessionEvent, SessionId, SetModelOptions,
36 SystemMessageConfig, ToolInvocation, ToolResult, ToolResultExpanded, TraceContext,
37 UiInputOptions, ensure_attachment_display_names,
38};
39use crate::{
40 Client, Error, ErrorKind, JsonRpcResponse, SessionErrorKind, SessionEventNotification,
41 error_codes,
42};
43
44#[derive(Clone)]
52pub(crate) struct SessionHandlers {
53 pub permission: Option<Arc<dyn PermissionHandler>>,
54 pub elicitation: Option<Arc<dyn ElicitationHandler>>,
55 pub mcp_auth: Option<Arc<dyn McpAuthHandler>>,
56 pub user_input: Option<Arc<dyn UserInputHandler>>,
57 pub exit_plan_mode: Option<Arc<dyn ExitPlanModeHandler>>,
58 pub auto_mode_switch: Option<Arc<dyn AutoModeSwitchHandler>>,
59 pub tools: Arc<HashMap<String, Arc<dyn crate::tool::ToolHandler>>>,
60}
61
62struct IdleWaiter {
64 tx: oneshot::Sender<Result<Option<SessionEvent>, Error>>,
65 last_assistant_message: Option<SessionEvent>,
66 started_at: Instant,
67 first_assistant_message_seen: bool,
68}
69
70struct WaiterGuard {
82 slot: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
83}
84
85impl Drop for WaiterGuard {
86 fn drop(&mut self) {
87 self.slot.lock().take();
88 }
89}
90
91struct PendingSessionRegistration {
92 client: Client,
93 session_id: SessionId,
94 shutdown: CancellationToken,
95 disarmed: bool,
96}
97
98impl PendingSessionRegistration {
99 fn new(client: Client, session_id: SessionId, shutdown: CancellationToken) -> Self {
100 Self {
101 client,
102 session_id,
103 shutdown,
104 disarmed: false,
105 }
106 }
107
108 async fn cleanup(mut self, event_loop: JoinHandle<()>) {
109 self.shutdown.cancel();
110 let _ = event_loop.await;
111 self.client.unregister_session(&self.session_id);
112 self.disarmed = true;
113 }
114
115 fn disarm(&mut self) {
116 self.disarmed = true;
117 }
118}
119
120impl Drop for PendingSessionRegistration {
121 fn drop(&mut self) {
122 if !self.disarmed {
123 self.shutdown.cancel();
124 self.client.unregister_session(&self.session_id);
125 }
126 }
127}
128
129pub struct Session {
142 id: SessionId,
143 cwd: PathBuf,
144 workspace_path: Option<PathBuf>,
145 remote_url: Option<String>,
146 client: Client,
147 event_loop: ParkingLotMutex<Option<JoinHandle<()>>>,
152 shutdown: CancellationToken,
166 idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
173 capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
175 open_canvases: Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
177 event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
179}
180
181impl Session {
182 pub fn id(&self) -> &SessionId {
184 &self.id
185 }
186
187 pub fn cwd(&self) -> &PathBuf {
189 &self.cwd
190 }
191
192 pub fn workspace_path(&self) -> Option<&Path> {
194 self.workspace_path.as_deref()
195 }
196
197 pub fn remote_url(&self) -> Option<&str> {
199 self.remote_url.as_deref()
200 }
201
202 pub fn capabilities(&self) -> SessionCapabilities {
207 self.capabilities.read().clone()
208 }
209
210 pub fn open_canvases(&self) -> Vec<OpenCanvasInstance> {
213 self.open_canvases.read().clone()
214 }
215
216 pub fn cancellation_token(&self) -> CancellationToken {
243 self.shutdown.child_token()
244 }
245
246 pub fn subscribe(&self) -> crate::subscription::EventSubscription {
285 crate::subscription::EventSubscription::new(self.event_tx.subscribe())
286 }
287
288 pub fn client(&self) -> &Client {
290 &self.client
291 }
292
293 pub fn rpc(&self) -> crate::generated::rpc::SessionRpc<'_> {
304 crate::generated::rpc::SessionRpc { session: self }
305 }
306
307 pub async fn stop_event_loop(&self) {
315 self.shutdown.cancel();
316 let handle = self.event_loop.lock().take();
317 if let Some(handle) = handle {
318 let _ = handle.await;
319 }
320 if let Some(waiter) = self.idle_waiter.lock().take() {
322 let _ = waiter.tx.send(Err(
323 ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()
324 ));
325 }
326 }
327
328 pub async fn send(&self, opts: impl Into<MessageOptions>) -> Result<String, Error> {
353 if self.idle_waiter.lock().is_some() {
354 return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into());
355 }
356 self.send_inner(opts.into()).await
357 }
358
359 async fn send_inner(&self, opts: MessageOptions) -> Result<String, Error> {
360 let mut params = serde_json::json!({
361 "sessionId": self.id,
362 "prompt": opts.prompt,
363 });
364 if let Some(m) = opts.mode {
365 params["mode"] = serde_json::to_value(m)?;
366 }
367 if let Some(am) = opts.agent_mode {
368 params["agentMode"] = serde_json::to_value(am)?;
369 }
370 if let Some(mut a) = opts.attachments {
371 ensure_attachment_display_names(&mut a);
372 params["attachments"] = serde_json::to_value(a)?;
373 }
374 if let Some(headers) = opts.request_headers
375 && !headers.is_empty()
376 {
377 params["requestHeaders"] = serde_json::to_value(headers)?;
378 }
379 if let Some(display_prompt) = opts.display_prompt {
380 params["displayPrompt"] = serde_json::to_value(display_prompt)?;
381 }
382 let trace_ctx = if opts.traceparent.is_some() || opts.tracestate.is_some() {
383 TraceContext {
384 traceparent: opts.traceparent,
385 tracestate: opts.tracestate,
386 }
387 } else {
388 self.client.resolve_trace_context().await
389 };
390 inject_trace_context(&mut params, &trace_ctx);
391 let rpc_start = Instant::now();
392 let result = self.client.call("session.send", Some(params)).await?;
393 let message_id = result
394 .get("messageId")
395 .and_then(|v| v.as_str())
396 .map(|s| s.to_string())
397 .unwrap_or_default();
398 tracing::debug!(
399 elapsed_ms = rpc_start.elapsed().as_millis(),
400 session_id = %self.id,
401 message_id = %message_id,
402 "Session::send completed successfully"
403 );
404 Ok(message_id)
405 }
406
407 pub async fn send_and_wait(
427 &self,
428 opts: impl Into<MessageOptions>,
429 ) -> Result<Option<SessionEvent>, Error> {
430 let total_start = Instant::now();
431 let opts = opts.into();
432 let timeout_duration = opts.wait_timeout.unwrap_or(Duration::from_secs(60));
433 let (tx, rx) = oneshot::channel();
434
435 {
436 let mut guard = self.idle_waiter.lock();
437 if guard.is_some() {
438 return Err(ErrorKind::Session(SessionErrorKind::SendWhileWaiting).into());
439 }
440 *guard = Some(IdleWaiter {
441 tx,
442 last_assistant_message: None,
443 started_at: total_start,
444 first_assistant_message_seen: false,
445 });
446 }
447
448 let _waiter_guard = WaiterGuard {
453 slot: self.idle_waiter.clone(),
454 };
455
456 let result = tokio::time::timeout(timeout_duration, async {
457 self.send_inner(opts).await?;
458 match rx.await {
459 Ok(result) => result,
460 Err(_) => Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()),
461 }
462 })
463 .await;
464
465 match result {
466 Ok(inner) => {
467 tracing::debug!(
468 elapsed_ms = total_start.elapsed().as_millis(),
469 session_id = %self.id,
470 completed_by = if inner.is_ok() { "idle" } else { "error" },
471 "Session::send_and_wait complete"
472 );
473 inner
474 }
475 Err(_) => {
476 tracing::warn!(
477 elapsed_ms = total_start.elapsed().as_millis(),
478 session_id = %self.id,
479 completed_by = "timeout",
480 "Session::send_and_wait failed"
481 );
482 Err(ErrorKind::Session(SessionErrorKind::Timeout(timeout_duration)).into())
483 }
484 }
485 }
486
487 pub async fn get_events(&self) -> Result<Vec<SessionEvent>, Error> {
489 let result = self
490 .client
491 .call(
492 "session.getMessages",
493 Some(serde_json::json!({ "sessionId": self.id })),
494 )
495 .await?;
496 let response: GetMessagesResponse = serde_json::from_value(result)?;
497 Ok(response.events)
498 }
499
500 #[deprecated(since = "0.1.0", note = "Use `get_events()` instead")]
502 pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, Error> {
503 self.get_events().await
504 }
505
506 pub async fn abort(&self) -> Result<(), Error> {
514 self.client
515 .call(
516 "session.abort",
517 Some(serde_json::json!({ "sessionId": self.id })),
518 )
519 .await?;
520 Ok(())
521 }
522
523 pub async fn set_model(&self, model: &str, opts: Option<SetModelOptions>) -> Result<(), Error> {
527 let opts = opts.unwrap_or_default();
528 let request = ModelSwitchToRequest {
529 model_id: model.to_string(),
530 reasoning_effort: opts.reasoning_effort,
531 reasoning_summary: opts.reasoning_summary,
532 context_tier: opts.context_tier,
533 model_capabilities: opts.model_capabilities,
534 };
535 self.rpc().model().switch_to(request).await?;
536 Ok(())
537 }
538
539 pub async fn disconnect(&self) -> Result<(), Error> {
556 self.client
557 .call(
558 "session.destroy",
559 Some(serde_json::json!({ "sessionId": self.id })),
560 )
561 .await?;
562 self.stop_event_loop().await;
563 self.client.unregister_session(&self.id);
564 Ok(())
565 }
566
567 #[deprecated(since = "0.1.0", note = "Use `disconnect()` instead")]
572 pub async fn destroy(&self) -> Result<(), Error> {
573 self.disconnect().await
574 }
575
576 pub async fn log(
580 &self,
581 message: &str,
582 opts: Option<crate::types::LogOptions>,
583 ) -> Result<(), Error> {
584 let opts = opts.unwrap_or_default();
585 let level = match opts.level {
586 Some(level) => Some(serde_json::from_value(serde_json::to_value(level)?)?),
587 None => None,
588 };
589 let request = LogRequest {
590 message: message.to_string(),
591 level,
592 ephemeral: opts.ephemeral,
593 r#type: None,
594 tip: None,
595 url: None,
596 };
597 self.rpc().log(request).await?;
598 Ok(())
599 }
600
601 pub fn ui(&self) -> SessionUi<'_> {
607 SessionUi { session: self }
608 }
609
610 fn assert_elicitation(&self) -> Result<(), Error> {
612 if self
613 .capabilities
614 .read()
615 .ui
616 .as_ref()
617 .and_then(|u| u.elicitation)
618 != Some(true)
619 {
620 return Err(ErrorKind::Session(SessionErrorKind::ElicitationNotSupported).into());
621 }
622 Ok(())
623 }
624}
625
626impl Drop for Session {
627 fn drop(&mut self) {
628 self.shutdown.cancel();
640 self.client.unregister_session(&self.id);
641 }
642}
643
644pub struct SessionUi<'a> {
651 session: &'a Session,
652}
653
654impl<'a> SessionUi<'a> {
655 pub async fn elicitation(
663 &self,
664 message: &str,
665 schema: Value,
666 ) -> Result<ElicitationResult, Error> {
667 self.session.assert_elicitation()?;
668 let result = self
669 .session
670 .client
671 .call(
672 "session.ui.elicitation",
673 Some(serde_json::json!({
674 "sessionId": self.session.id,
675 "message": message,
676 "requestedSchema": schema,
677 })),
678 )
679 .await?;
680 let elicitation: ElicitationResult = serde_json::from_value(result)?;
681 Ok(elicitation)
682 }
683
684 pub async fn confirm(&self, message: &str) -> Result<bool, Error> {
688 self.session.assert_elicitation()?;
689 let schema = serde_json::json!({
690 "type": "object",
691 "properties": {
692 "confirmed": {
693 "type": "boolean",
694 "default": true,
695 }
696 },
697 "required": ["confirmed"]
698 });
699 let result = self.elicitation(message, schema).await?;
700 Ok(result.action == "accept"
701 && result
702 .content
703 .and_then(|c| c.get("confirmed").and_then(|v| v.as_bool()))
704 == Some(true))
705 }
706
707 pub async fn select(&self, message: &str, options: &[&str]) -> Result<Option<String>, Error> {
711 self.session.assert_elicitation()?;
712 let schema = serde_json::json!({
713 "type": "object",
714 "properties": {
715 "selection": {
716 "type": "string",
717 "enum": options,
718 }
719 },
720 "required": ["selection"]
721 });
722 let result = self.elicitation(message, schema).await?;
723 if result.action != "accept" {
724 return Ok(None);
725 }
726 let selection = result.content.and_then(|c| {
727 c.get("selection")
728 .and_then(|v| v.as_str())
729 .map(String::from)
730 });
731 Ok(selection)
732 }
733
734 pub async fn input(
739 &self,
740 message: &str,
741 options: Option<&UiInputOptions<'_>>,
742 ) -> Result<Option<String>, Error> {
743 self.session.assert_elicitation()?;
744 let mut field = serde_json::json!({ "type": "string" });
745 if let Some(opts) = options {
746 if let Some(title) = opts.title {
747 field["title"] = Value::String(title.to_string());
748 }
749 if let Some(desc) = opts.description {
750 field["description"] = Value::String(desc.to_string());
751 }
752 if let Some(min) = opts.min_length {
753 field["minLength"] = Value::Number(min.into());
754 }
755 if let Some(max) = opts.max_length {
756 field["maxLength"] = Value::Number(max.into());
757 }
758 if let Some(fmt) = &opts.format {
759 field["format"] = Value::String(fmt.as_str().to_string());
760 }
761 if let Some(default) = opts.default {
762 field["default"] = Value::String(default.to_string());
763 }
764 }
765 let schema = serde_json::json!({
766 "type": "object",
767 "properties": { "value": field },
768 "required": ["value"]
769 });
770 let result = self.elicitation(message, schema).await?;
771 if result.action != "accept" {
772 return Ok(None);
773 }
774 let value = result
775 .content
776 .and_then(|c| c.get("value").and_then(|v| v.as_str()).map(String::from));
777 Ok(value)
778 }
779}
780
781impl Client {
782 pub async fn create_session(&self, mut config: SessionConfig) -> Result<Session, Error> {
804 let total_start = Instant::now();
805 let caller_session_id = config.session_id.clone();
813 let use_server_generated_id = config.cloud.is_some() && caller_session_id.is_none();
814 let local_session_id: Option<SessionId> = if use_server_generated_id {
815 None
816 } else {
817 Some(
818 caller_session_id
819 .clone()
820 .unwrap_or_else(|| SessionId::new(uuid::Uuid::new_v4().to_string())),
821 )
822 };
823 if config.hooks_handler.is_some() && config.hooks.is_none() {
824 config.hooks = Some(true);
825 }
826 if let Some(transforms) = config.system_message_transform.clone() {
827 inject_transform_sections(&mut config, transforms.as_ref());
828 }
829 let mode = self.inner.mode;
830 if mode == crate::ClientMode::Empty && config.available_tools.is_none() {
831 return Err(Error::with_message(
832 ErrorKind::InvalidConfig,
833 "ClientMode::Empty requires available_tools to be set on the session config. \
834 Use ToolSet to specify which tools the session may use (e.g. \
835 ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED)).",
836 ));
837 }
838 crate::mode::validate_tool_filter_list(
839 "available_tools",
840 config.available_tools.as_deref(),
841 )?;
842 crate::mode::validate_tool_filter_list("excluded_tools", config.excluded_tools.as_deref())?;
843 config.system_message =
844 crate::mode::system_message_for_mode(mode, config.system_message.take());
845 config.memory = crate::mode::memory_for_mode(mode, config.memory.take());
846 if mode == crate::ClientMode::Empty {
847 if config.enable_session_telemetry.is_none() {
848 config.enable_session_telemetry = Some(false);
849 }
850 if config.skip_embedding_retrieval.is_none() {
851 config.skip_embedding_retrieval = Some(true);
852 }
853 if config.enable_on_demand_instruction_discovery.is_none() {
854 config.enable_on_demand_instruction_discovery = Some(false);
855 }
856 if config.enable_file_hooks.is_none() {
857 config.enable_file_hooks = Some(false);
858 }
859 if config.enable_host_git_operations.is_none() {
860 config.enable_host_git_operations = Some(false);
861 }
862 if config.enable_session_store.is_none() {
863 config.enable_session_store = Some(false);
864 }
865 if config.enable_skills.is_none() {
866 config.enable_skills = Some(false);
867 }
868 }
869 if mode == crate::ClientMode::Empty && config.mcp_oauth_token_storage.is_none() {
870 config.mcp_oauth_token_storage = Some("in-memory".into());
871 }
872 if mode == crate::ClientMode::Empty && config.embedding_cache_storage.is_none() {
873 config.embedding_cache_storage = Some("in-memory".into());
874 }
875 let opt_skip_custom_instructions = config.skip_custom_instructions;
876 let opt_custom_agents_local_only = config.custom_agents_local_only;
877 let opt_coauthor_enabled = config.coauthor_enabled;
878 let opt_manage_schedule_enabled = config.manage_schedule_enabled;
879 let (wire, mut runtime) = config.into_wire(local_session_id.clone())?;
880
881 let permission_handler = crate::permission::resolve_handler(
882 runtime.permission_handler.take(),
883 runtime.permission_policy.take(),
884 );
885 let handlers = SessionHandlers {
886 permission: permission_handler,
887 elicitation: runtime.elicitation_handler.take(),
888 mcp_auth: runtime.mcp_auth_handler.take(),
889 user_input: runtime.user_input_handler.take(),
890 exit_plan_mode: runtime.exit_plan_mode_handler.take(),
891 auto_mode_switch: runtime.auto_mode_switch_handler.take(),
892 tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)),
893 };
894 let hooks = runtime.hooks_handler.take();
895 let transforms = runtime.system_message_transform.take();
896 let tools_count = wire.tools.as_ref().map_or(0, Vec::len);
897 let commands_count = runtime.commands.as_ref().map_or(0, Vec::len);
898 let has_hooks = hooks.is_some();
899 let command_handlers = build_command_handler_map(runtime.commands.as_deref());
900 let canvas_handler = runtime.canvas_handler.take();
901 let session_fs_provider = runtime.session_fs_provider.take();
902 let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers);
903 let has_mcp_auth_handler = handlers.mcp_auth.is_some();
904 if self.inner.session_fs_configured && session_fs_provider.is_none() {
905 return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into());
906 }
907 if self.inner.session_fs_sqlite_declared
908 && let Some(ref provider) = session_fs_provider
909 && provider.sqlite().is_none()
910 {
911 return Err(Error::with_message(
912 ErrorKind::InvalidConfig,
913 "SessionFs capabilities declare SQLite support but the provider \
914 does not implement SessionFsSqliteProvider",
915 ));
916 }
917
918 let mut params = serde_json::to_value(&wire)?;
919 let trace_ctx = self.resolve_trace_context().await;
920 inject_trace_context(&mut params, &trace_ctx);
921
922 let setup_start = Instant::now();
923 let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default()));
924 let idle_waiter = Arc::new(ParkingLotMutex::new(None));
925 let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
926 let shutdown = CancellationToken::new();
927 let (event_tx, _) = tokio::sync::broadcast::channel(512);
928
929 let inline_stash: Arc<
935 ParkingLotMutex<Option<(SessionId, crate::router::SessionChannels)>>,
936 > = Arc::new(ParkingLotMutex::new(None));
937
938 let inline_callback: Option<crate::jsonrpc::InlineResponseCallback> = if let Some(ref sid) =
939 local_session_id
940 {
941 let channels = self.register_session(sid);
942 *inline_stash.lock() = Some((sid.clone(), channels));
943 None
944 } else {
945 let client = self.clone();
946 let stash = inline_stash.clone();
947 let expected = caller_session_id.clone();
948 Some(Box::new(move |response| {
949 let result = response.result.as_ref().ok_or_else(|| {
950 Error::with_message(ErrorKind::Json, "session.create response had no result")
951 })?;
952 let parsed: CreateSessionResult =
953 serde_json::from_value(result.clone()).map_err(Error::from)?;
954 if let Some(requested) = expected.as_ref()
955 && parsed.session_id != *requested
956 {
957 return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch {
958 requested: requested.clone(),
959 returned: parsed.session_id,
960 })
961 .into());
962 }
963 let channels = client.register_session(&parsed.session_id);
964 *stash.lock() = Some((parsed.session_id, channels));
965 Ok(())
966 }))
967 };
968
969 let rpc_start = Instant::now();
970 let result = match self
971 .call_with_inline_callback("session.create", Some(params), inline_callback)
972 .await
973 {
974 Ok(result) => result,
975 Err(error) => {
976 if let Some((id, _channels)) = inline_stash.lock().take() {
977 self.unregister_session(&id);
978 }
979 return Err(error);
980 }
981 };
982 tracing::debug!(
983 elapsed_ms = rpc_start.elapsed().as_millis(),
984 "Client::create_session session creation request completed successfully"
985 );
986 let create_result: CreateSessionResult = match serde_json::from_value(result) {
987 Ok(result) => result,
988 Err(error) => {
989 if let Some((id, _channels)) = inline_stash.lock().take() {
990 self.unregister_session(&id);
991 }
992 return Err(error.into());
993 }
994 };
995
996 if let Some(ref requested) = local_session_id
997 && create_result.session_id != *requested
998 {
999 if let Some((id, _channels)) = inline_stash.lock().take() {
1000 self.unregister_session(&id);
1001 }
1002 return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch {
1003 requested: requested.clone(),
1004 returned: create_result.session_id.clone(),
1005 })
1006 .into());
1007 }
1008
1009 let (session_id, channels) = inline_stash
1010 .lock()
1011 .take()
1012 .expect("session registration must have populated stash on success");
1013 let event_loop = spawn_event_loop(
1014 session_id.clone(),
1015 self.clone(),
1016 handlers,
1017 hooks,
1018 transforms,
1019 command_handlers,
1020 canvas_handler,
1021 session_fs_provider,
1022 bearer_token_providers,
1023 channels,
1024 idle_waiter.clone(),
1025 capabilities.clone(),
1026 open_canvases.clone(),
1027 event_tx.clone(),
1028 shutdown.clone(),
1029 );
1030 tracing::debug!(
1031 elapsed_ms = setup_start.elapsed().as_millis(),
1032 session_id = %session_id,
1033 tools_count,
1034 commands_count,
1035 has_hooks,
1036 "Client::create_session local setup complete"
1037 );
1038 *capabilities.write() = create_result.capabilities.unwrap_or_default();
1039 if has_mcp_auth_handler {
1040 register_mcp_auth_interest(self, &session_id).await?;
1041 }
1042
1043 tracing::debug!(
1044 elapsed_ms = total_start.elapsed().as_millis(),
1045 session_id = %session_id,
1046 "Client::create_session complete"
1047 );
1048 let session = Session {
1049 id: session_id,
1050 cwd: self.cwd().clone(),
1051 workspace_path: create_result.workspace_path,
1052 remote_url: create_result.remote_url,
1053 client: self.clone(),
1054 event_loop: ParkingLotMutex::new(Some(event_loop)),
1055 shutdown,
1056 idle_waiter,
1057 capabilities,
1058 open_canvases,
1059 event_tx,
1060 };
1061 apply_mode_post_create_patch(
1062 &session,
1063 mode,
1064 opt_skip_custom_instructions,
1065 opt_custom_agents_local_only,
1066 opt_coauthor_enabled,
1067 opt_manage_schedule_enabled,
1068 )
1069 .await?;
1070 Ok(session)
1071 }
1072
1073 pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result<Session, Error> {
1084 let total_start = Instant::now();
1085 let session_id = config.session_id.clone();
1086 if config.hooks_handler.is_some() && config.hooks.is_none() {
1087 config.hooks = Some(true);
1088 }
1089 if let Some(transforms) = config.system_message_transform.clone() {
1090 inject_transform_sections_resume(&mut config, transforms.as_ref());
1091 }
1092 let mode = self.inner.mode;
1093 if mode == crate::ClientMode::Empty && config.available_tools.is_none() {
1094 return Err(Error::with_message(
1095 ErrorKind::InvalidConfig,
1096 "ClientMode::Empty requires available_tools to be set on the session config. \
1097 Use ToolSet to specify which tools the session may use (e.g. \
1098 ToolSet::new().add_builtin_many(BUILTIN_TOOLS_ISOLATED)).",
1099 ));
1100 }
1101 crate::mode::validate_tool_filter_list(
1102 "available_tools",
1103 config.available_tools.as_deref(),
1104 )?;
1105 crate::mode::validate_tool_filter_list("excluded_tools", config.excluded_tools.as_deref())?;
1106 config.system_message =
1107 crate::mode::system_message_for_mode(mode, config.system_message.take());
1108 config.memory = crate::mode::memory_for_mode(mode, config.memory.take());
1109 if mode == crate::ClientMode::Empty {
1110 if config.enable_session_telemetry.is_none() {
1111 config.enable_session_telemetry = Some(false);
1112 }
1113 if config.skip_embedding_retrieval.is_none() {
1114 config.skip_embedding_retrieval = Some(true);
1115 }
1116 if config.enable_on_demand_instruction_discovery.is_none() {
1117 config.enable_on_demand_instruction_discovery = Some(false);
1118 }
1119 if config.enable_file_hooks.is_none() {
1120 config.enable_file_hooks = Some(false);
1121 }
1122 if config.enable_host_git_operations.is_none() {
1123 config.enable_host_git_operations = Some(false);
1124 }
1125 if config.enable_session_store.is_none() {
1126 config.enable_session_store = Some(false);
1127 }
1128 if config.enable_skills.is_none() {
1129 config.enable_skills = Some(false);
1130 }
1131 }
1132 if mode == crate::ClientMode::Empty && config.mcp_oauth_token_storage.is_none() {
1133 config.mcp_oauth_token_storage = Some("in-memory".into());
1134 }
1135 if mode == crate::ClientMode::Empty && config.embedding_cache_storage.is_none() {
1136 config.embedding_cache_storage = Some("in-memory".into());
1137 }
1138 let opt_skip_custom_instructions = config.skip_custom_instructions;
1139 let opt_custom_agents_local_only = config.custom_agents_local_only;
1140 let opt_coauthor_enabled = config.coauthor_enabled;
1141 let opt_manage_schedule_enabled = config.manage_schedule_enabled;
1142 let (wire, mut runtime) = config.into_wire()?;
1143
1144 let permission_handler = crate::permission::resolve_handler(
1145 runtime.permission_handler.take(),
1146 runtime.permission_policy.take(),
1147 );
1148 let handlers = SessionHandlers {
1149 permission: permission_handler,
1150 elicitation: runtime.elicitation_handler.take(),
1151 mcp_auth: runtime.mcp_auth_handler.take(),
1152 user_input: runtime.user_input_handler.take(),
1153 exit_plan_mode: runtime.exit_plan_mode_handler.take(),
1154 auto_mode_switch: runtime.auto_mode_switch_handler.take(),
1155 tools: Arc::new(std::mem::take(&mut runtime.tool_handlers)),
1156 };
1157 let hooks = runtime.hooks_handler.take();
1158 let transforms = runtime.system_message_transform.take();
1159 let tools_count = wire.tools.as_ref().map_or(0, Vec::len);
1160 let commands_count = runtime.commands.as_ref().map_or(0, Vec::len);
1161 let has_hooks = hooks.is_some();
1162 let command_handlers = build_command_handler_map(runtime.commands.as_deref());
1163 let canvas_handler = runtime.canvas_handler.take();
1164 let session_fs_provider = runtime.session_fs_provider.take();
1165 let bearer_token_providers = std::mem::take(&mut runtime.bearer_token_providers);
1166 let has_mcp_auth_handler = handlers.mcp_auth.is_some();
1167 if self.inner.session_fs_configured && session_fs_provider.is_none() {
1168 return Err(ErrorKind::Session(SessionErrorKind::SessionFsProviderRequired).into());
1169 }
1170 if self.inner.session_fs_sqlite_declared
1171 && let Some(ref provider) = session_fs_provider
1172 && provider.sqlite().is_none()
1173 {
1174 return Err(Error::with_message(
1175 ErrorKind::InvalidConfig,
1176 "SessionFs capabilities declare SQLite support but the provider \
1177 does not implement SessionFsSqliteProvider",
1178 ));
1179 }
1180
1181 let mut params = serde_json::to_value(&wire)?;
1182 let trace_ctx = self.resolve_trace_context().await;
1183 inject_trace_context(&mut params, &trace_ctx);
1184
1185 let capabilities = Arc::new(parking_lot::RwLock::new(SessionCapabilities::default()));
1186 let setup_start = Instant::now();
1187 let channels = self.register_session(&session_id);
1188 let idle_waiter = Arc::new(ParkingLotMutex::new(None));
1189 let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new()));
1190 let shutdown = CancellationToken::new();
1191 let (event_tx, _) = tokio::sync::broadcast::channel(512);
1192 let event_loop = spawn_event_loop(
1193 session_id.clone(),
1194 self.clone(),
1195 handlers,
1196 hooks,
1197 transforms,
1198 command_handlers,
1199 canvas_handler,
1200 session_fs_provider,
1201 bearer_token_providers,
1202 channels,
1203 idle_waiter.clone(),
1204 capabilities.clone(),
1205 open_canvases.clone(),
1206 event_tx.clone(),
1207 shutdown.clone(),
1208 );
1209 let mut registration =
1210 PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone());
1211 tracing::debug!(
1212 elapsed_ms = setup_start.elapsed().as_millis(),
1213 session_id = %session_id,
1214 tools_count,
1215 commands_count,
1216 has_hooks,
1217 "Client::resume_session local setup complete"
1218 );
1219
1220 let rpc_start = Instant::now();
1221 let result = match self.call("session.resume", Some(params)).await {
1222 Ok(result) => result,
1223 Err(error) => {
1224 registration.cleanup(event_loop).await;
1225 return Err(error);
1226 }
1227 };
1228 tracing::debug!(
1229 elapsed_ms = rpc_start.elapsed().as_millis(),
1230 session_id = %session_id,
1231 "Client::resume_session session resume request completed successfully"
1232 );
1233
1234 let resume_result: ResumeSessionResult = match serde_json::from_value(result) {
1235 Ok(result) => result,
1236 Err(error) => {
1237 registration.cleanup(event_loop).await;
1238 return Err(error.into());
1239 }
1240 };
1241 let cli_session_id = resume_result
1242 .session_id
1243 .clone()
1244 .unwrap_or_else(|| session_id.clone());
1245 if cli_session_id != session_id {
1246 registration.cleanup(event_loop).await;
1247 return Err(ErrorKind::Session(SessionErrorKind::SessionIdMismatch {
1248 requested: session_id,
1249 returned: cli_session_id,
1250 })
1251 .into());
1252 }
1253 if has_mcp_auth_handler {
1254 register_mcp_auth_interest(self, &session_id).await?;
1255 }
1256
1257 let skills_reload_start = Instant::now();
1259 if let Err(e) = self
1260 .call(
1261 "session.skills.reload",
1262 Some(serde_json::json!({ "sessionId": session_id })),
1263 )
1264 .await
1265 {
1266 warn!(
1267 elapsed_ms = skills_reload_start.elapsed().as_millis(),
1268 session_id = %session_id,
1269 error = %e,
1270 "Client::resume_session skills reload request failed"
1271 );
1272 } else {
1273 tracing::debug!(
1274 elapsed_ms = skills_reload_start.elapsed().as_millis(),
1275 session_id = %session_id,
1276 "Client::resume_session skills reload request completed successfully"
1277 );
1278 }
1279
1280 *capabilities.write() = resume_result.capabilities.unwrap_or_default();
1281 {
1286 let mut snapshots = open_canvases.write();
1287 for snapshot in resume_result.open_canvases.unwrap_or_default() {
1288 upsert_open_canvas_snapshot(&mut snapshots, snapshot);
1289 }
1290 }
1291
1292 tracing::debug!(
1293 elapsed_ms = total_start.elapsed().as_millis(),
1294 session_id = %session_id,
1295 "Client::resume_session complete"
1296 );
1297 registration.disarm();
1298 let session = Session {
1299 id: session_id,
1300 cwd: self.cwd().clone(),
1301 workspace_path: resume_result.workspace_path,
1302 remote_url: resume_result.remote_url,
1303 client: self.clone(),
1304 event_loop: ParkingLotMutex::new(Some(event_loop)),
1305 shutdown,
1306 idle_waiter,
1307 capabilities,
1308 open_canvases,
1309 event_tx,
1310 };
1311 apply_mode_post_create_patch(
1312 &session,
1313 mode,
1314 opt_skip_custom_instructions,
1315 opt_custom_agents_local_only,
1316 opt_coauthor_enabled,
1317 opt_manage_schedule_enabled,
1318 )
1319 .await?;
1320 Ok(session)
1321 }
1322}
1323
1324type CommandHandlerMap = HashMap<String, Arc<dyn CommandHandler>>;
1325
1326async fn apply_mode_post_create_patch(
1327 session: &Session,
1328 mode: crate::ClientMode,
1329 opt_skip_custom_instructions: Option<bool>,
1330 opt_custom_agents_local_only: Option<bool>,
1331 opt_coauthor_enabled: Option<bool>,
1332 opt_manage_schedule_enabled: Option<bool>,
1333) -> Result<(), Error> {
1334 use crate::generated::api_types::SessionUpdateOptionsParams;
1335 let mut patch = SessionUpdateOptionsParams::default();
1336 let should_send = if mode == crate::ClientMode::Empty {
1337 patch.skip_custom_instructions = Some(opt_skip_custom_instructions.unwrap_or(true));
1338 patch.custom_agents_local_only = Some(opt_custom_agents_local_only.unwrap_or(true));
1339 patch.coauthor_enabled = Some(opt_coauthor_enabled.unwrap_or(false));
1340 patch.manage_schedule_enabled = Some(opt_manage_schedule_enabled.unwrap_or(false));
1341 patch.installed_plugins = Some(Vec::new());
1342 true
1343 } else {
1344 let mut any = false;
1345 if let Some(v) = opt_skip_custom_instructions {
1346 patch.skip_custom_instructions = Some(v);
1347 any = true;
1348 }
1349 if let Some(v) = opt_custom_agents_local_only {
1350 patch.custom_agents_local_only = Some(v);
1351 any = true;
1352 }
1353 if let Some(v) = opt_coauthor_enabled {
1354 patch.coauthor_enabled = Some(v);
1355 any = true;
1356 }
1357 if let Some(v) = opt_manage_schedule_enabled {
1358 patch.manage_schedule_enabled = Some(v);
1359 any = true;
1360 }
1361 any
1362 };
1363 if !should_send {
1364 return Ok(());
1365 }
1366 if let Err(error) = session.rpc().options().update(patch).await {
1367 let _ = session.disconnect().await;
1368 return Err(error);
1369 }
1370 Ok(())
1371}
1372
1373fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc<CommandHandlerMap> {
1374 let map = match commands {
1375 Some(commands) => commands
1376 .iter()
1377 .filter(|cmd| !cmd.name.is_empty())
1378 .map(|cmd| (cmd.name.clone(), cmd.handler.clone()))
1379 .collect(),
1380 None => HashMap::new(),
1381 };
1382 Arc::new(map)
1383}
1384
1385fn upsert_open_canvas_snapshot(
1386 snapshots: &mut Vec<OpenCanvasInstance>,
1387 snapshot: OpenCanvasInstance,
1388) {
1389 if let Some(existing) = snapshots
1390 .iter_mut()
1391 .find(|open| open.instance_id == snapshot.instance_id)
1392 {
1393 *existing = snapshot;
1394 } else {
1395 snapshots.push(snapshot);
1396 }
1397}
1398
1399fn remove_open_canvas_snapshot(snapshots: &mut Vec<OpenCanvasInstance>, instance_id: &str) {
1400 snapshots.retain(|open| open.instance_id != instance_id);
1401}
1402
1403#[allow(clippy::too_many_arguments)]
1404fn spawn_event_loop(
1405 session_id: SessionId,
1406 client: Client,
1407 handlers: SessionHandlers,
1408 hooks: Option<Arc<dyn SessionHooks>>,
1409 transforms: Option<Arc<dyn SystemMessageTransform>>,
1410 command_handlers: Arc<CommandHandlerMap>,
1411 canvas_handler: Option<Arc<dyn CanvasHandler>>,
1412 session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
1413 bearer_token_providers: HashMap<String, Arc<dyn BearerTokenProvider>>,
1414 channels: crate::router::SessionChannels,
1415 idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
1416 capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
1417 open_canvases: Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
1418 event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
1419 shutdown: CancellationToken,
1420) -> JoinHandle<()> {
1421 let crate::router::SessionChannels {
1422 mut notifications,
1423 mut requests,
1424 } = channels;
1425
1426 let span = tracing::error_span!("session_event_loop", session_id = %session_id);
1427 tokio::spawn(
1428 async move {
1429 loop {
1430 tokio::select! {
1441 _ = shutdown.cancelled() => break,
1442 Some(notification) = notifications.recv() => {
1443 handle_notification(
1444 &session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx,
1445 ).await;
1446 }
1447 Some(request) = requests.recv() => {
1448 let ctx = RequestDispatchContext {
1449 client: &client,
1450 handlers: &handlers,
1451 hooks: hooks.as_deref(),
1452 transforms: transforms.as_deref(),
1453 canvas_handler: canvas_handler.as_ref(),
1454 session_fs_provider: session_fs_provider.as_ref(),
1455 bearer_token_providers: &bearer_token_providers,
1456 };
1457 handle_request(&session_id, ctx, request).await;
1458 }
1459 else => break,
1460 }
1461 }
1462 if let Some(waiter) = idle_waiter.lock().take() {
1465 let _ = waiter
1466 .tx
1467 .send(Err(ErrorKind::Session(SessionErrorKind::EventLoopClosed).into()));
1468 }
1469 }
1470 .instrument(span),
1471 )
1472}
1473
1474fn extract_request_id(data: &Value) -> Option<RequestId> {
1475 data.get("requestId")
1476 .and_then(|v| v.as_str())
1477 .filter(|s| !s.is_empty())
1478 .map(RequestId::new)
1479}
1480
1481fn notification_permission_payload(result: &PermissionResult) -> Option<Value> {
1486 match result {
1487 PermissionResult::NoResult => None,
1488 PermissionResult::Decision(decision) => Some(
1489 serde_json::to_value(decision).expect("serializing permission decision should succeed"),
1490 ),
1491 }
1492}
1493
1494async fn register_mcp_auth_interest(client: &Client, session_id: &SessionId) -> Result<(), Error> {
1495 let mut params = serde_json::to_value(RegisterEventInterestParams {
1496 event_type: "mcp.oauth_required".to_string(),
1497 })?;
1498 params["sessionId"] = Value::String(session_id.to_string());
1499 client
1500 .call(rpc_methods::SESSION_EVENTLOG_REGISTERINTEREST, Some(params))
1501 .await?;
1502 Ok(())
1503}
1504
1505fn tool_failure_result(message: impl Into<String>) -> ToolResult {
1506 let message = message.into();
1507 ToolResult::Expanded(ToolResultExpanded {
1508 text_result_for_llm: message.clone(),
1509 result_type: "failure".to_string(),
1510 binary_results_for_llm: None,
1511 session_log: None,
1512 error: Some(message),
1513 tool_telemetry: None,
1514 })
1515}
1516
1517#[allow(clippy::too_many_arguments)]
1519async fn handle_notification(
1520 session_id: &SessionId,
1521 client: &Client,
1522 handlers: &SessionHandlers,
1523 command_handlers: &Arc<CommandHandlerMap>,
1524 notification: SessionEventNotification,
1525 idle_waiter: &Arc<ParkingLotMutex<Option<IdleWaiter>>>,
1526 capabilities: &Arc<parking_lot::RwLock<SessionCapabilities>>,
1527 open_canvases: &Arc<parking_lot::RwLock<Vec<OpenCanvasInstance>>>,
1528 event_tx: &tokio::sync::broadcast::Sender<SessionEvent>,
1529) {
1530 let dispatch_start = Instant::now();
1531 let event = notification.event.clone();
1532 let event_type = event.parsed_type();
1533 if event_type == SessionEventType::PermissionRequested {
1534 tracing::debug!(
1535 session_id = %session_id,
1536 event_type = %event.event_type,
1537 "Session::handle_notification permission request received"
1538 );
1539 }
1540
1541 match event_type {
1544 SessionEventType::AssistantMessage
1545 | SessionEventType::SessionIdle
1546 | SessionEventType::SessionError => {
1547 let mut guard = idle_waiter.lock();
1548 if let Some(waiter) = guard.as_mut() {
1549 match event_type {
1550 SessionEventType::AssistantMessage => {
1551 if !waiter.first_assistant_message_seen {
1552 waiter.first_assistant_message_seen = true;
1553 tracing::debug!(
1554 elapsed_ms = waiter.started_at.elapsed().as_millis(),
1555 session_id = %session_id,
1556 "Session::send_and_wait first assistant message"
1557 );
1558 }
1559 waiter.last_assistant_message = Some(event.clone());
1560 }
1561 SessionEventType::SessionIdle | SessionEventType::SessionError => {
1562 if let Some(waiter) = guard.take() {
1563 if event_type == SessionEventType::SessionIdle {
1564 tracing::debug!(
1565 elapsed_ms = waiter.started_at.elapsed().as_millis(),
1566 session_id = %session_id,
1567 "Session::send_and_wait idle received"
1568 );
1569 let _ = waiter.tx.send(Ok(waiter.last_assistant_message));
1570 } else {
1571 let error_msg = event
1572 .typed_data::<SessionErrorData>()
1573 .map(|d| d.message)
1574 .or_else(|| {
1575 event
1576 .data
1577 .get("message")
1578 .and_then(|v| v.as_str())
1579 .map(|s| s.to_string())
1580 })
1581 .unwrap_or_else(|| "session error".to_string());
1582 let _ = waiter.tx.send(Err(Error::with_message(
1583 ErrorKind::Session(SessionErrorKind::AgentError),
1584 error_msg,
1585 )));
1586 }
1587 }
1588 }
1589 _ => {}
1590 }
1591 }
1592 }
1593 _ => {}
1594 }
1595
1596 if event_type == SessionEventType::CapabilitiesChanged {
1600 match serde_json::from_value::<SessionCapabilities>(notification.event.data.clone()) {
1601 Ok(changed) => *capabilities.write() = changed,
1602 Err(e) => warn!(error = %e, "failed to deserialize capabilities.changed payload"),
1603 }
1604 }
1605 if event_type == SessionEventType::SessionCanvasOpened {
1606 match serde_json::from_value::<OpenCanvasInstance>(notification.event.data.clone()) {
1607 Ok(open_canvas) => {
1608 upsert_open_canvas_snapshot(&mut open_canvases.write(), open_canvas);
1609 }
1610 Err(e) => warn!(error = %e, "failed to deserialize session.canvas.opened payload"),
1611 }
1612 }
1613 if event_type == SessionEventType::SessionCanvasClosed {
1614 match serde_json::from_value::<SessionCanvasClosedData>(notification.event.data.clone()) {
1615 Ok(closed) => {
1616 if closed.instance_id.is_empty() {
1617 warn!("failed to deserialize session.canvas.closed payload");
1618 } else {
1619 remove_open_canvas_snapshot(&mut open_canvases.write(), &closed.instance_id);
1620 }
1621 }
1622 Err(e) => warn!(error = %e, "failed to deserialize session.canvas.closed payload"),
1623 }
1624 }
1625
1626 let _ = event_tx.send(event.clone());
1630
1631 tracing::debug!(
1632 elapsed_ms = dispatch_start.elapsed().as_millis(),
1633 session_id = %session_id,
1634 event_type = %notification.event.event_type,
1635 "Session::handle_notification dispatch"
1636 );
1637
1638 match event_type {
1641 SessionEventType::PermissionRequested => {
1642 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1643 return;
1644 };
1645 if notification
1649 .event
1650 .data
1651 .get("resolvedByHook")
1652 .and_then(|v| v.as_bool())
1653 .unwrap_or(false)
1654 {
1655 return;
1656 }
1657 let Some(permission_handler) = handlers.permission.clone() else {
1661 return;
1662 };
1663 let client = client.clone();
1664 let sid = session_id.clone();
1665 let data: PermissionRequestData =
1666 serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| {
1667 PermissionRequestData {
1668 kind: None,
1669 tool_call_id: None,
1670 extra: notification.event.data.clone(),
1671 }
1672 });
1673 let span = tracing::error_span!(
1674 "permission_request_handler",
1675 session_id = %sid,
1676 request_id = %request_id
1677 );
1678 tokio::spawn(
1679 async move {
1680 let handler_start = Instant::now();
1681 let result = permission_handler
1682 .handle(sid.clone(), request_id.clone(), data)
1683 .await;
1684 tracing::debug!(
1685 elapsed_ms = handler_start.elapsed().as_millis(),
1686 session_id = %sid,
1687 request_id = %request_id,
1688 "PermissionHandler::handle dispatch"
1689 );
1690 let Some(result_value) = notification_permission_payload(&result) else {
1691 return;
1695 };
1696 let rpc_start = Instant::now();
1697 let _ = client
1698 .call(
1699 "session.permissions.handlePendingPermissionRequest",
1700 Some(serde_json::json!({
1701 "sessionId": sid,
1702 "requestId": request_id,
1703 "result": result_value,
1704 })),
1705 )
1706 .await;
1707 tracing::debug!(
1708 elapsed_ms = rpc_start.elapsed().as_millis(),
1709 session_id = %sid,
1710 request_id = %request_id,
1711 "Session::handle_notification response sent successfully"
1712 );
1713 }
1714 .instrument(span),
1715 );
1716 }
1717 SessionEventType::ExternalToolRequested => {
1718 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1719 return;
1720 };
1721 let data: ExternalToolRequestedData =
1722 match serde_json::from_value(notification.event.data.clone()) {
1723 Ok(d) => d,
1724 Err(e) => {
1725 warn!(error = %e, "failed to deserialize external_tool.requested");
1726 let client = client.clone();
1727 let sid = session_id.clone();
1728 let span = tracing::error_span!(
1729 "external_tool_deserialize_error",
1730 session_id = %sid,
1731 request_id = %request_id
1732 );
1733 tokio::spawn(
1734 async move {
1735 let rpc_start = Instant::now();
1736 let _ = client
1737 .call(
1738 "session.tools.handlePendingToolCall",
1739 Some(serde_json::json!({
1740 "sessionId": sid,
1741 "requestId": request_id,
1742 "error": format!("Failed to deserialize tool request: {e}"),
1743 })),
1744 )
1745 .await;
1746 tracing::debug!(
1747 elapsed_ms = rpc_start.elapsed().as_millis(),
1748 session_id = %sid,
1749 request_id = %request_id,
1750 "Session::handle_notification response sent successfully"
1751 );
1752 }
1753 .instrument(span),
1754 );
1755 return;
1756 }
1757 };
1758 let tool_handler = if data.tool_name.is_empty() {
1762 None
1763 } else {
1764 handlers.tools.get(&data.tool_name).cloned()
1765 };
1766 let Some(tool_handler) = tool_handler else {
1767 return;
1768 };
1769 let client = client.clone();
1770 let sid = session_id.clone();
1771 let span = tracing::error_span!(
1772 "external_tool_handler",
1773 session_id = %sid,
1774 request_id = %request_id
1775 );
1776 tokio::spawn(
1777 async move {
1778 if data.tool_call_id.is_empty() {
1783 let error_msg = "Missing toolCallId";
1784 let rpc_start = Instant::now();
1785 let _ = client
1786 .call(
1787 "session.tools.handlePendingToolCall",
1788 Some(serde_json::json!({
1789 "sessionId": sid,
1790 "requestId": request_id,
1791 "error": error_msg,
1792 })),
1793 )
1794 .await;
1795 tracing::debug!(
1796 elapsed_ms = rpc_start.elapsed().as_millis(),
1797 session_id = %sid,
1798 request_id = %request_id,
1799 "Session::handle_notification response sent successfully"
1800 );
1801 return;
1802 }
1803 let tool_call_id = data.tool_call_id.clone();
1804 let tool_name = data.tool_name.clone();
1805 let invocation = ToolInvocation {
1806 session_id: sid.clone(),
1807 tool_call_id: data.tool_call_id,
1808 tool_name: data.tool_name,
1809 arguments: data
1810 .arguments
1811 .unwrap_or(Value::Object(serde_json::Map::new())),
1812 traceparent: data.traceparent,
1813 tracestate: data.tracestate,
1814 };
1815 let handler_start = Instant::now();
1816 let tool_result = match tool_handler.call(invocation).await {
1817 Ok(r) => r,
1818 Err(e) => tool_failure_result(e.to_string()),
1819 };
1820 tracing::debug!(
1821 elapsed_ms = handler_start.elapsed().as_millis(),
1822 session_id = %sid,
1823 request_id = %request_id,
1824 tool_call_id = %tool_call_id,
1825 tool_name = %tool_name,
1826 "ToolHandler::call dispatch"
1827 );
1828 let result_value = serde_json::to_value(tool_result).unwrap_or(Value::Null);
1829 let rpc_start = Instant::now();
1830 let _ = client
1831 .call(
1832 "session.tools.handlePendingToolCall",
1833 Some(serde_json::json!({
1834 "sessionId": sid,
1835 "requestId": request_id,
1836 "result": result_value,
1837 })),
1838 )
1839 .await;
1840 tracing::debug!(
1841 elapsed_ms = rpc_start.elapsed().as_millis(),
1842 session_id = %sid,
1843 request_id = %request_id,
1844 tool_call_id = %tool_call_id,
1845 tool_name = %tool_name,
1846 "Session::handle_notification response sent successfully"
1847 );
1848 }
1849 .instrument(span),
1850 );
1851 }
1852 SessionEventType::UserInputRequested => {
1853 }
1860 SessionEventType::ElicitationRequested => {
1861 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1862 return;
1863 };
1864 let Some(elicitation_handler) = handlers.elicitation.clone() else {
1868 return;
1869 };
1870 let elicitation_data: ElicitationRequestedData =
1871 match serde_json::from_value(notification.event.data.clone()) {
1872 Ok(d) => d,
1873 Err(e) => {
1874 warn!(error = %e, "failed to deserialize elicitation request");
1875 return;
1876 }
1877 };
1878 let request = ElicitationRequest {
1879 message: elicitation_data.message,
1880 requested_schema: elicitation_data
1881 .requested_schema
1882 .map(|s| serde_json::to_value(s).unwrap_or(Value::Null)),
1883 mode: elicitation_data.mode.map(|m| match m {
1884 crate::generated::session_events::ElicitationRequestedMode::Form => {
1885 crate::types::ElicitationMode::Form
1886 }
1887 crate::generated::session_events::ElicitationRequestedMode::Url => {
1888 crate::types::ElicitationMode::Url
1889 }
1890 _ => crate::types::ElicitationMode::Unknown,
1891 }),
1892 elicitation_source: elicitation_data.elicitation_source,
1893 url: elicitation_data.url,
1894 };
1895 let client = client.clone();
1896 let sid = session_id.clone();
1897 let span = tracing::error_span!(
1898 "elicitation_request_handler",
1899 session_id = %sid,
1900 request_id = %request_id
1901 );
1902 tokio::spawn(
1903 async move {
1904 let cancel = ElicitationResult {
1905 action: "cancel".to_string(),
1906 content: None,
1907 };
1908 let handler_task = tokio::spawn({
1910 let sid = sid.clone();
1911 let request_id = request_id.clone();
1912 let span = tracing::error_span!(
1913 "elicitation_callback",
1914 session_id = %sid,
1915 request_id = %request_id
1916 );
1917 async move {
1918 let handler_start = Instant::now();
1919 let response = elicitation_handler
1920 .handle(sid.clone(), request_id.clone(), request)
1921 .await;
1922 tracing::debug!(
1923 elapsed_ms = handler_start.elapsed().as_millis(),
1924 session_id = %sid,
1925 request_id = %request_id,
1926 "ElicitationHandler::handle dispatch"
1927 );
1928 response
1929 }
1930 .instrument(span)
1931 });
1932 let result = match handler_task.await {
1933 Ok(r) => r,
1934 Err(_) => cancel.clone(),
1935 };
1936 let rpc_start = Instant::now();
1937 if let Err(e) = client
1938 .call(
1939 "session.ui.handlePendingElicitation",
1940 Some(serde_json::json!({
1941 "sessionId": sid,
1942 "requestId": request_id,
1943 "result": result,
1944 })),
1945 )
1946 .await
1947 {
1948 warn!(error = %e, "handlePendingElicitation failed, sending cancel");
1950 let _ = client
1951 .call(
1952 "session.ui.handlePendingElicitation",
1953 Some(serde_json::json!({
1954 "sessionId": sid,
1955 "requestId": request_id,
1956 "result": cancel,
1957 })),
1958 )
1959 .await;
1960 } else {
1961 tracing::debug!(
1962 elapsed_ms = rpc_start.elapsed().as_millis(),
1963 session_id = %sid,
1964 request_id = %request_id,
1965 "Session::handle_notification response sent successfully"
1966 );
1967 }
1968 }
1969 .instrument(span),
1970 );
1971 }
1972 SessionEventType::McpOauthRequired => {
1973 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1974 return;
1975 };
1976 let Some(mcp_auth_handler) = handlers.mcp_auth.clone() else {
1977 warn!(
1978 session_id = %session_id,
1979 request_id = %request_id,
1980 "received MCP OAuth request without a registered MCP auth handler"
1981 );
1982 return;
1983 };
1984 let data: McpOauthRequiredData =
1985 match serde_json::from_value(notification.event.data.clone()) {
1986 Ok(d) => d,
1987 Err(e) => {
1988 warn!(error = %e, "failed to deserialize MCP OAuth request");
1989 return;
1990 }
1991 };
1992 let request = McpAuthRequest {
1993 request_id: request_id.clone(),
1994 server_name: data.server_name,
1995 server_url: data.server_url,
1996 reason: data.reason,
1997 www_authenticate_params: data.www_authenticate_params,
1998 resource_metadata: data.resource_metadata,
1999 static_client_config: data.static_client_config,
2000 };
2001 let client = client.clone();
2002 let sid = session_id.clone();
2003 let span = tracing::error_span!(
2004 "mcp_auth_request_handler",
2005 session_id = %sid,
2006 request_id = %request_id
2007 );
2008 tokio::spawn(
2009 async move {
2010 let cancel = McpAuthResult::Cancelled;
2011 let handler_task = tokio::spawn({
2012 let sid = sid.clone();
2013 let request_id = request_id.clone();
2014 let span = tracing::error_span!(
2015 "mcp_auth_callback",
2016 session_id = %sid,
2017 request_id = %request_id
2018 );
2019 async move {
2020 let handler_start = Instant::now();
2021 let response = mcp_auth_handler
2022 .handle(sid.clone(), request_id.clone(), request)
2023 .await;
2024 tracing::debug!(
2025 elapsed_ms = handler_start.elapsed().as_millis(),
2026 session_id = %sid,
2027 request_id = %request_id,
2028 "McpAuthHandler::handle dispatch"
2029 );
2030 response
2031 }
2032 .instrument(span)
2033 });
2034 let result = match handler_task.await {
2035 Ok(result) => result,
2036 Err(_) => cancel,
2037 };
2038 let rpc_start = Instant::now();
2039 let _ = client
2040 .call(
2041 "session.mcp.oauth.handlePendingRequest",
2042 Some(serde_json::json!({
2043 "sessionId": sid,
2044 "requestId": request_id,
2045 "result": result.into_wire(),
2046 })),
2047 )
2048 .await;
2049 tracing::debug!(
2050 elapsed_ms = rpc_start.elapsed().as_millis(),
2051 "Session::handle_notification MCP auth response sent"
2052 );
2053 }
2054 .instrument(span),
2055 );
2056 }
2057 SessionEventType::CommandExecute => {
2058 let data: CommandExecuteData =
2059 match serde_json::from_value(notification.event.data.clone()) {
2060 Ok(d) => d,
2061 Err(e) => {
2062 warn!(error = %e, "failed to deserialize command.execute");
2063 return;
2064 }
2065 };
2066 let client = client.clone();
2067 let command_handlers = command_handlers.clone();
2068 let sid = session_id.clone();
2069 let span = tracing::error_span!("command_handler", session_id = %sid);
2070 tokio::spawn(
2071 async move {
2072 let request_id = data.request_id;
2073 let ack_error = match command_handlers.get(&data.command_name).cloned() {
2074 None => Some(format!("Unknown command: {}", data.command_name)),
2075 Some(handler) => {
2076 let command_name = data.command_name.clone();
2077 let ctx = CommandContext {
2078 session_id: sid.clone(),
2079 command: data.command,
2080 command_name: data.command_name,
2081 args: data.args,
2082 };
2083 let handler_start = Instant::now();
2084 let result = handler.on_command(ctx).await;
2085 tracing::debug!(
2086 elapsed_ms = handler_start.elapsed().as_millis(),
2087 session_id = %sid,
2088 request_id = %request_id,
2089 command_name = %command_name,
2090 "CommandHandler::call dispatch"
2091 );
2092 match result {
2093 Ok(()) => None,
2094 Err(e) => Some(e.to_string()),
2095 }
2096 }
2097 };
2098 let mut params = serde_json::json!({
2099 "sessionId": sid,
2100 "requestId": request_id,
2101 });
2102 if let Some(error_msg) = ack_error {
2103 params["error"] = serde_json::Value::String(error_msg);
2104 }
2105 let rpc_start = Instant::now();
2106 let _ = client
2107 .call("session.commands.handlePendingCommand", Some(params))
2108 .await;
2109 tracing::debug!(
2110 elapsed_ms = rpc_start.elapsed().as_millis(),
2111 session_id = %sid,
2112 request_id = %request_id,
2113 "Session::handle_notification response sent successfully"
2114 );
2115 }
2116 .instrument(span),
2117 );
2118 }
2119 _ => {}
2120 }
2121}
2122
2123struct RequestDispatchContext<'a> {
2124 client: &'a Client,
2125 handlers: &'a SessionHandlers,
2126 hooks: Option<&'a dyn SessionHooks>,
2127 transforms: Option<&'a dyn SystemMessageTransform>,
2128 canvas_handler: Option<&'a Arc<dyn CanvasHandler>>,
2129 session_fs_provider: Option<&'a Arc<dyn SessionFsProvider>>,
2130 bearer_token_providers: &'a HashMap<String, Arc<dyn BearerTokenProvider>>,
2131}
2132
2133async fn handle_request(
2135 session_id: &SessionId,
2136 ctx: RequestDispatchContext<'_>,
2137 request: crate::JsonRpcRequest,
2138) {
2139 let sid = session_id.clone();
2140 let client = ctx.client;
2141 let handlers = ctx.handlers;
2142 let hooks = ctx.hooks;
2143 let transforms = ctx.transforms;
2144 let canvas_handler = ctx.canvas_handler;
2145 let session_fs_provider = ctx.session_fs_provider;
2146 let bearer_token_providers = ctx.bearer_token_providers;
2147
2148 if request.method.starts_with("sessionFs.") {
2149 crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await;
2150 return;
2151 }
2152
2153 if request.method.starts_with("canvas.") {
2154 crate::canvas_dispatch::dispatch(client, canvas_handler, request).await;
2155 return;
2156 }
2157
2158 if request.method == crate::generated::api_types::rpc_methods::PROVIDERTOKEN_GETTOKEN {
2159 crate::provider_token_dispatch::dispatch(client, bearer_token_providers, request).await;
2160 return;
2161 }
2162
2163 match request.method.as_str() {
2164 "hooks.invoke" => {
2165 let params = request.params.as_ref();
2166 let hook_type = params
2167 .and_then(|p| p.get("hookType"))
2168 .and_then(|v| v.as_str())
2169 .unwrap_or("");
2170 let input = params
2171 .and_then(|p| p.get("input"))
2172 .cloned()
2173 .unwrap_or(Value::Object(Default::default()));
2174
2175 let rpc_result = if let Some(hooks) = hooks {
2176 match crate::hooks::dispatch_hook(hooks, &sid, hook_type, input).await {
2177 Ok(output) => output,
2178 Err(e) => {
2179 warn!(error = %e, hook_type = hook_type, "hook dispatch failed");
2180 serde_json::json!({ "output": {} })
2181 }
2182 }
2183 } else {
2184 serde_json::json!({ "output": {} })
2185 };
2186
2187 let rpc_response = JsonRpcResponse {
2188 jsonrpc: "2.0".to_string(),
2189 id: request.id,
2190 result: Some(rpc_result),
2191 error: None,
2192 };
2193 let _ = client.send_response(&rpc_response).await;
2194 }
2195
2196 "userInput.request" => {
2197 let params = request.params.as_ref();
2198 let Some(question) = params
2199 .and_then(|p| p.get("question"))
2200 .and_then(|v| v.as_str())
2201 else {
2202 warn!("userInput.request missing 'question' field");
2203 let rpc_response = JsonRpcResponse {
2204 jsonrpc: "2.0".to_string(),
2205 id: request.id,
2206 result: None,
2207 error: Some(crate::JsonRpcError {
2208 code: error_codes::INVALID_PARAMS,
2209 message: "missing required field: question".to_string(),
2210 data: None,
2211 }),
2212 };
2213 let _ = client.send_response(&rpc_response).await;
2214 return;
2215 };
2216 let question = question.to_string();
2217 let choices = params
2218 .and_then(|p| p.get("choices"))
2219 .and_then(|v| v.as_array())
2220 .map(|arr| {
2221 arr.iter()
2222 .filter_map(|v| v.as_str().map(|s| s.to_string()))
2223 .collect()
2224 });
2225 let allow_freeform = params
2226 .and_then(|p| p.get("allowFreeform"))
2227 .and_then(|v| v.as_bool());
2228
2229 let handler_start = Instant::now();
2230 let response = if let Some(user_input_handler) = handlers.user_input.as_ref() {
2231 user_input_handler
2232 .handle(sid.clone(), question, choices, allow_freeform)
2233 .await
2234 } else {
2235 None
2236 };
2237 tracing::debug!(
2238 elapsed_ms = handler_start.elapsed().as_millis(),
2239 session_id = %sid,
2240 "UserInputHandler::handle dispatch"
2241 );
2242
2243 let rpc_result = match response {
2244 Some(UserInputResponse {
2245 answer,
2246 was_freeform,
2247 }) => serde_json::json!({
2248 "answer": answer,
2249 "wasFreeform": was_freeform,
2250 }),
2251 None => serde_json::json!({ "noResponse": true }),
2252 };
2253 let rpc_response = JsonRpcResponse {
2254 jsonrpc: "2.0".to_string(),
2255 id: request.id,
2256 result: Some(rpc_result),
2257 error: None,
2258 };
2259 let _ = client.send_response(&rpc_response).await;
2260 }
2261
2262 "exitPlanMode.request" => {
2263 let params = request
2264 .params
2265 .as_ref()
2266 .cloned()
2267 .unwrap_or(Value::Object(serde_json::Map::new()));
2268 let data: ExitPlanModeData = match serde_json::from_value(params) {
2269 Ok(d) => d,
2270 Err(e) => {
2271 warn!(error = %e, "failed to deserialize exitPlanMode.request params, using defaults");
2272 ExitPlanModeData::default()
2273 }
2274 };
2275
2276 let rpc_result = if let Some(exit_plan_handler) = handlers.exit_plan_mode.as_ref() {
2277 let result = exit_plan_handler.handle(sid, data).await;
2278 serde_json::to_value(result).expect("ExitPlanModeResult serialization cannot fail")
2279 } else {
2280 serde_json::json!({ "approved": true })
2281 };
2282 let rpc_response = JsonRpcResponse {
2283 jsonrpc: "2.0".to_string(),
2284 id: request.id,
2285 result: Some(rpc_result),
2286 error: None,
2287 };
2288 let _ = client.send_response(&rpc_response).await;
2289 }
2290
2291 "autoModeSwitch.request" => {
2292 let error_code = request
2293 .params
2294 .as_ref()
2295 .and_then(|p| p.get("errorCode"))
2296 .and_then(|v| v.as_str())
2297 .map(|s| s.to_string());
2298 let retry_after_seconds = request
2299 .params
2300 .as_ref()
2301 .and_then(|p| p.get("retryAfterSeconds"))
2302 .and_then(|v| v.as_f64());
2303
2304 let answer = if let Some(auto_mode_handler) = handlers.auto_mode_switch.as_ref() {
2305 auto_mode_handler
2306 .handle(sid, error_code, retry_after_seconds)
2307 .await
2308 } else {
2309 AutoModeSwitchResponse::No
2310 };
2311 let rpc_response = JsonRpcResponse {
2312 jsonrpc: "2.0".to_string(),
2313 id: request.id,
2314 result: Some(serde_json::json!({ "response": answer })),
2315 error: None,
2316 };
2317 let _ = client.send_response(&rpc_response).await;
2318 }
2319
2320 "systemMessage.transform" => {
2321 let params = request.params.as_ref();
2322 let sections: HashMap<String, crate::transforms::TransformSection> =
2323 match params.and_then(|p| p.get("sections")) {
2324 Some(v) => match serde_json::from_value(v.clone()) {
2325 Ok(s) => s,
2326 Err(e) => {
2327 let _ = send_error_response(
2328 client,
2329 request.id,
2330 error_codes::INVALID_PARAMS,
2331 &format!("invalid sections: {e}"),
2332 )
2333 .await;
2334 return;
2335 }
2336 },
2337 None => {
2338 let _ = send_error_response(
2339 client,
2340 request.id,
2341 error_codes::INVALID_PARAMS,
2342 "missing sections parameter",
2343 )
2344 .await;
2345 return;
2346 }
2347 };
2348
2349 let rpc_result = if let Some(transforms) = transforms {
2350 let transform_start = Instant::now();
2351 let response =
2352 crate::transforms::dispatch_transform(transforms, &sid, sections).await;
2353 tracing::debug!(
2354 elapsed_ms = transform_start.elapsed().as_millis(),
2355 session_id = %sid,
2356 "SystemMessageTransform::transform_section dispatch"
2357 );
2358 match serde_json::to_value(response) {
2359 Ok(v) => v,
2360 Err(e) => {
2361 warn!(error = %e, "failed to serialize transform response");
2362 serde_json::json!({ "sections": {} })
2363 }
2364 }
2365 } else {
2366 let passthrough: HashMap<String, crate::transforms::TransformSection> = sections;
2368 serde_json::json!({ "sections": passthrough })
2369 };
2370
2371 let rpc_response = JsonRpcResponse {
2372 jsonrpc: "2.0".to_string(),
2373 id: request.id,
2374 result: Some(rpc_result),
2375 error: None,
2376 };
2377 let _ = client.send_response(&rpc_response).await;
2378 }
2379
2380 method => {
2381 warn!(
2382 method = method,
2383 "unhandled request method in session event loop"
2384 );
2385 let _ = send_error_response(
2386 client,
2387 request.id,
2388 error_codes::METHOD_NOT_FOUND,
2389 &format!("unknown method: {method}"),
2390 )
2391 .await;
2392 }
2393 }
2394}
2395
2396async fn send_error_response(
2397 client: &Client,
2398 id: u64,
2399 code: i32,
2400 message: &str,
2401) -> Result<(), Error> {
2402 let response = JsonRpcResponse {
2403 jsonrpc: "2.0".to_string(),
2404 id,
2405 result: None,
2406 error: Some(crate::JsonRpcError {
2407 code,
2408 message: message.to_string(),
2409 data: None,
2410 }),
2411 };
2412 client.send_response(&response).await
2413}
2414
2415fn apply_transform_sections(
2419 sys_msg: &mut SystemMessageConfig,
2420 transforms: &dyn SystemMessageTransform,
2421) {
2422 sys_msg.mode = Some("customize".to_string());
2423 let sections = sys_msg.sections.get_or_insert_with(HashMap::new);
2424 for id in transforms.section_ids() {
2425 sections.entry(id).or_insert_with(|| SectionOverride {
2426 action: Some("transform".to_string()),
2427 content: None,
2428 });
2429 }
2430}
2431
2432fn inject_transform_sections(config: &mut SessionConfig, transforms: &dyn SystemMessageTransform) {
2433 let sys_msg = config.system_message.get_or_insert_with(Default::default);
2434 apply_transform_sections(sys_msg, transforms);
2435}
2436
2437fn inject_transform_sections_resume(
2438 config: &mut ResumeSessionConfig,
2439 transforms: &dyn SystemMessageTransform,
2440) {
2441 let sys_msg = config.system_message.get_or_insert_with(Default::default);
2442 apply_transform_sections(sys_msg, transforms);
2443}
2444
2445#[cfg(test)]
2446mod tests {
2447 use serde_json::json;
2448
2449 use super::notification_permission_payload;
2450 use crate::handler::PermissionResult;
2451
2452 #[test]
2453 fn notification_payload_suppresses_no_result() {
2454 assert!(notification_permission_payload(&PermissionResult::NoResult).is_none());
2455 }
2456
2457 #[test]
2458 fn notification_payload_serializes_decisions() {
2459 assert_eq!(
2460 notification_permission_payload(&PermissionResult::approve_once()),
2461 Some(json!({ "kind": "approve-once" }))
2462 );
2463 assert_eq!(
2464 notification_permission_payload(&PermissionResult::reject(None)),
2465 Some(json!({ "kind": "reject" }))
2466 );
2467 assert_eq!(
2468 notification_permission_payload(&PermissionResult::reject(Some("bad".to_string()))),
2469 Some(json!({ "kind": "reject", "feedback": "bad" }))
2470 );
2471 assert_eq!(
2472 notification_permission_payload(&PermissionResult::user_not_available()),
2473 Some(json!({ "kind": "user-not-available" }))
2474 );
2475 }
2476}