1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use std::time::Duration;
5
6use parking_lot::Mutex as ParkingLotMutex;
7use serde_json::Value;
8use tokio::sync::oneshot;
9use tokio::task::JoinHandle;
10use tokio_util::sync::CancellationToken;
11use tracing::{Instrument, warn};
12
13use crate::generated::api_types::{
14 LogRequest, ModelSwitchToRequest, PermissionDecision, PermissionDecisionApproveOnce,
15 PermissionDecisionApproveOnceKind, PermissionDecisionReject, PermissionDecisionRejectKind,
16};
17use crate::generated::session_events::{
18 CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, SessionErrorData,
19 SessionEventType,
20};
21use crate::handler::{
22 AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, PermissionResult,
23 SessionHandler, UserInputResponse,
24};
25use crate::hooks::SessionHooks;
26use crate::session_fs::SessionFsProvider;
27use crate::trace_context::inject_trace_context;
28use crate::transforms::SystemMessageTransform;
29use crate::types::{
30 CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest,
31 ElicitationResult, ExitPlanModeData, GetMessagesResponse, InputOptions, MessageOptions,
32 PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities,
33 SessionConfig, SessionEvent, SessionId, SetModelOptions, SystemMessageConfig, ToolInvocation,
34 ToolResult, ToolResultResponse, TraceContext, ensure_attachment_display_names,
35};
36use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes};
37
38struct IdleWaiter {
40 tx: oneshot::Sender<Result<Option<SessionEvent>, Error>>,
41 last_assistant_message: Option<SessionEvent>,
42}
43
44struct WaiterGuard {
56 slot: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
57}
58
59impl Drop for WaiterGuard {
60 fn drop(&mut self) {
61 self.slot.lock().take();
62 }
63}
64
65pub struct Session {
77 id: SessionId,
78 cwd: PathBuf,
79 workspace_path: Option<PathBuf>,
80 remote_url: Option<String>,
81 client: Client,
82 event_loop: ParkingLotMutex<Option<JoinHandle<()>>>,
87 shutdown: CancellationToken,
101 idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
108 capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
110 event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
112}
113
114impl Session {
115 pub fn id(&self) -> &SessionId {
117 &self.id
118 }
119
120 pub fn cwd(&self) -> &PathBuf {
122 &self.cwd
123 }
124
125 pub fn workspace_path(&self) -> Option<&Path> {
127 self.workspace_path.as_deref()
128 }
129
130 pub fn remote_url(&self) -> Option<&str> {
132 self.remote_url.as_deref()
133 }
134
135 pub fn capabilities(&self) -> SessionCapabilities {
140 self.capabilities.read().clone()
141 }
142
143 pub fn cancellation_token(&self) -> CancellationToken {
170 self.shutdown.child_token()
171 }
172
173 pub fn subscribe(&self) -> crate::subscription::EventSubscription {
212 crate::subscription::EventSubscription::new(self.event_tx.subscribe())
213 }
214
215 pub fn client(&self) -> &Client {
217 &self.client
218 }
219
220 pub fn rpc(&self) -> crate::generated::rpc::SessionRpc<'_> {
231 crate::generated::rpc::SessionRpc { session: self }
232 }
233
234 pub async fn stop_event_loop(&self) {
242 self.shutdown.cancel();
243 let handle = self.event_loop.lock().take();
244 if let Some(handle) = handle {
245 let _ = handle.await;
246 }
247 if let Some(waiter) = self.idle_waiter.lock().take() {
249 let _ = waiter
250 .tx
251 .send(Err(Error::Session(SessionError::EventLoopClosed)));
252 }
253 }
254
255 pub async fn send(&self, opts: impl Into<MessageOptions>) -> Result<String, Error> {
280 if self.idle_waiter.lock().is_some() {
281 return Err(Error::Session(SessionError::SendWhileWaiting));
282 }
283 self.send_inner(opts.into()).await
284 }
285
286 async fn send_inner(&self, opts: MessageOptions) -> Result<String, Error> {
287 let mut params = serde_json::json!({
288 "sessionId": self.id,
289 "prompt": opts.prompt,
290 });
291 if let Some(m) = opts.mode {
292 params["mode"] = serde_json::to_value(m)?;
293 }
294 if let Some(mut a) = opts.attachments {
295 ensure_attachment_display_names(&mut a);
296 params["attachments"] = serde_json::to_value(a)?;
297 }
298 if let Some(headers) = opts.request_headers
299 && !headers.is_empty()
300 {
301 params["requestHeaders"] = serde_json::to_value(headers)?;
302 }
303 let trace_ctx = if opts.traceparent.is_some() || opts.tracestate.is_some() {
304 TraceContext {
305 traceparent: opts.traceparent,
306 tracestate: opts.tracestate,
307 }
308 } else {
309 self.client.resolve_trace_context().await
310 };
311 inject_trace_context(&mut params, &trace_ctx);
312 let result = self.client.call("session.send", Some(params)).await?;
313 let message_id = result
314 .get("messageId")
315 .and_then(|v| v.as_str())
316 .map(|s| s.to_string())
317 .unwrap_or_default();
318 Ok(message_id)
319 }
320
321 pub async fn send_and_wait(
341 &self,
342 opts: impl Into<MessageOptions>,
343 ) -> Result<Option<SessionEvent>, Error> {
344 let opts = opts.into();
345 let timeout_duration = opts.wait_timeout.unwrap_or(Duration::from_secs(60));
346 let (tx, rx) = oneshot::channel();
347
348 {
349 let mut guard = self.idle_waiter.lock();
350 if guard.is_some() {
351 return Err(Error::Session(SessionError::SendWhileWaiting));
352 }
353 *guard = Some(IdleWaiter {
354 tx,
355 last_assistant_message: None,
356 });
357 }
358
359 let _waiter_guard = WaiterGuard {
364 slot: self.idle_waiter.clone(),
365 };
366
367 let result = tokio::time::timeout(timeout_duration, async {
368 self.send_inner(opts).await?;
369 match rx.await {
370 Ok(result) => result,
371 Err(_) => Err(Error::Session(SessionError::EventLoopClosed)),
372 }
373 })
374 .await;
375
376 match result {
377 Ok(inner) => inner,
378 Err(_) => Err(Error::Session(SessionError::Timeout(timeout_duration))),
379 }
380 }
381
382 pub async fn get_messages(&self) -> Result<Vec<SessionEvent>, Error> {
384 let result = self
385 .client
386 .call(
387 "session.getMessages",
388 Some(serde_json::json!({ "sessionId": self.id })),
389 )
390 .await?;
391 let response: GetMessagesResponse = serde_json::from_value(result)?;
392 Ok(response.events)
393 }
394
395 pub async fn abort(&self) -> Result<(), Error> {
403 self.client
404 .call(
405 "session.abort",
406 Some(serde_json::json!({ "sessionId": self.id })),
407 )
408 .await?;
409 Ok(())
410 }
411
412 pub async fn set_model(&self, model: &str, opts: Option<SetModelOptions>) -> Result<(), Error> {
416 let opts = opts.unwrap_or_default();
417 let request = ModelSwitchToRequest {
418 model_id: model.to_string(),
419 reasoning_effort: opts.reasoning_effort,
420 model_capabilities: opts.model_capabilities,
421 };
422 self.rpc().model().switch_to(request).await?;
423 Ok(())
424 }
425
426 pub async fn disconnect(&self) -> Result<(), Error> {
443 self.client
444 .call(
445 "session.destroy",
446 Some(serde_json::json!({ "sessionId": self.id })),
447 )
448 .await?;
449 self.stop_event_loop().await;
450 self.client.unregister_session(&self.id);
451 Ok(())
452 }
453
454 pub async fn destroy(&self) -> Result<(), Error> {
460 self.disconnect().await
461 }
462
463 pub async fn log(
467 &self,
468 message: &str,
469 opts: Option<crate::types::LogOptions>,
470 ) -> Result<(), Error> {
471 let opts = opts.unwrap_or_default();
472 let level = match opts.level {
473 Some(level) => Some(serde_json::from_value(serde_json::to_value(level)?)?),
474 None => None,
475 };
476 let request = LogRequest {
477 message: message.to_string(),
478 level,
479 ephemeral: opts.ephemeral,
480 url: None,
481 };
482 self.rpc().log(request).await?;
483 Ok(())
484 }
485
486 pub fn ui(&self) -> SessionUi<'_> {
492 SessionUi { session: self }
493 }
494
495 fn assert_elicitation(&self) -> Result<(), Error> {
497 if self
498 .capabilities
499 .read()
500 .ui
501 .as_ref()
502 .and_then(|u| u.elicitation)
503 != Some(true)
504 {
505 return Err(Error::Session(SessionError::ElicitationNotSupported));
506 }
507 Ok(())
508 }
509}
510
511impl Drop for Session {
512 fn drop(&mut self) {
513 self.shutdown.cancel();
525 self.client.unregister_session(&self.id);
526 }
527}
528
529pub struct SessionUi<'a> {
536 session: &'a Session,
537}
538
539impl<'a> SessionUi<'a> {
540 pub async fn elicitation(
548 &self,
549 message: &str,
550 schema: Value,
551 ) -> Result<ElicitationResult, Error> {
552 self.session.assert_elicitation()?;
553 let result = self
554 .session
555 .client
556 .call(
557 "session.ui.elicitation",
558 Some(serde_json::json!({
559 "sessionId": self.session.id,
560 "message": message,
561 "requestedSchema": schema,
562 })),
563 )
564 .await?;
565 let elicitation: ElicitationResult = serde_json::from_value(result)?;
566 Ok(elicitation)
567 }
568
569 pub async fn confirm(&self, message: &str) -> Result<bool, Error> {
573 self.session.assert_elicitation()?;
574 let schema = serde_json::json!({
575 "type": "object",
576 "properties": {
577 "confirmed": {
578 "type": "boolean",
579 "default": true,
580 }
581 },
582 "required": ["confirmed"]
583 });
584 let result = self.elicitation(message, schema).await?;
585 Ok(result.action == "accept"
586 && result
587 .content
588 .and_then(|c| c.get("confirmed").and_then(|v| v.as_bool()))
589 == Some(true))
590 }
591
592 pub async fn select(&self, message: &str, options: &[&str]) -> Result<Option<String>, Error> {
596 self.session.assert_elicitation()?;
597 let schema = serde_json::json!({
598 "type": "object",
599 "properties": {
600 "selection": {
601 "type": "string",
602 "enum": options,
603 }
604 },
605 "required": ["selection"]
606 });
607 let result = self.elicitation(message, schema).await?;
608 if result.action != "accept" {
609 return Ok(None);
610 }
611 let selection = result.content.and_then(|c| {
612 c.get("selection")
613 .and_then(|v| v.as_str())
614 .map(String::from)
615 });
616 Ok(selection)
617 }
618
619 pub async fn input(
624 &self,
625 message: &str,
626 options: Option<&InputOptions<'_>>,
627 ) -> Result<Option<String>, Error> {
628 self.session.assert_elicitation()?;
629 let mut field = serde_json::json!({ "type": "string" });
630 if let Some(opts) = options {
631 if let Some(title) = opts.title {
632 field["title"] = Value::String(title.to_string());
633 }
634 if let Some(desc) = opts.description {
635 field["description"] = Value::String(desc.to_string());
636 }
637 if let Some(min) = opts.min_length {
638 field["minLength"] = Value::Number(min.into());
639 }
640 if let Some(max) = opts.max_length {
641 field["maxLength"] = Value::Number(max.into());
642 }
643 if let Some(fmt) = &opts.format {
644 field["format"] = Value::String(fmt.as_str().to_string());
645 }
646 if let Some(default) = opts.default {
647 field["default"] = Value::String(default.to_string());
648 }
649 }
650 let schema = serde_json::json!({
651 "type": "object",
652 "properties": { "value": field },
653 "required": ["value"]
654 });
655 let result = self.elicitation(message, schema).await?;
656 if result.action != "accept" {
657 return Ok(None);
658 }
659 let value = result
660 .content
661 .and_then(|c| c.get("value").and_then(|v| v.as_str()).map(String::from));
662 Ok(value)
663 }
664}
665
666impl Client {
667 pub async fn create_session(&self, mut config: SessionConfig) -> Result<Session, Error> {
689 let handler = config
690 .handler
691 .take()
692 .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler));
693 let hooks = config.hooks_handler.take();
694 let transforms = config.transform.take();
695 let command_handlers = build_command_handler_map(config.commands.as_deref());
696 let session_fs_provider = config.session_fs_provider.take();
697 if self.inner.session_fs_configured && session_fs_provider.is_none() {
698 return Err(Error::Session(SessionError::SessionFsProviderRequired));
699 }
700
701 if hooks.is_some() && config.hooks.is_none() {
702 config.hooks = Some(true);
703 }
704 if let Some(ref transforms) = transforms {
705 inject_transform_sections(&mut config, transforms.as_ref());
706 }
707 let mut params = serde_json::to_value(&config)?;
708 let trace_ctx = self.resolve_trace_context().await;
709 inject_trace_context(&mut params, &trace_ctx);
710 let result = self.call("session.create", Some(params)).await?;
711 let create_result: CreateSessionResult = serde_json::from_value(result)?;
712
713 let session_id = create_result.session_id;
714 let capabilities = Arc::new(parking_lot::RwLock::new(
715 create_result.capabilities.unwrap_or_default(),
716 ));
717 let channels = self.register_session(&session_id);
718
719 let idle_waiter = Arc::new(ParkingLotMutex::new(None));
720 let shutdown = CancellationToken::new();
721 let (event_tx, _) = tokio::sync::broadcast::channel(512);
722 let event_loop = spawn_event_loop(
723 session_id.clone(),
724 self.clone(),
725 handler,
726 hooks,
727 transforms,
728 command_handlers,
729 session_fs_provider,
730 channels,
731 idle_waiter.clone(),
732 capabilities.clone(),
733 event_tx.clone(),
734 shutdown.clone(),
735 );
736
737 Ok(Session {
738 id: session_id,
739 cwd: self.cwd().clone(),
740 workspace_path: create_result.workspace_path,
741 remote_url: create_result.remote_url,
742 client: self.clone(),
743 event_loop: ParkingLotMutex::new(Some(event_loop)),
744 shutdown,
745 idle_waiter,
746 capabilities,
747 event_tx,
748 })
749 }
750
751 pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result<Session, Error> {
762 let handler = config
763 .handler
764 .take()
765 .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler));
766 let hooks = config.hooks_handler.take();
767 let transforms = config.transform.take();
768 let command_handlers = build_command_handler_map(config.commands.as_deref());
769 let session_fs_provider = config.session_fs_provider.take();
770 if self.inner.session_fs_configured && session_fs_provider.is_none() {
771 return Err(Error::Session(SessionError::SessionFsProviderRequired));
772 }
773
774 if hooks.is_some() && config.hooks.is_none() {
775 config.hooks = Some(true);
776 }
777 if let Some(ref transforms) = transforms {
778 inject_transform_sections_resume(&mut config, transforms.as_ref());
779 }
780 let session_id = config.session_id.clone();
781 let mut params = serde_json::to_value(&config)?;
782 let trace_ctx = self.resolve_trace_context().await;
783 inject_trace_context(&mut params, &trace_ctx);
784 let result = self.call("session.resume", Some(params)).await?;
785
786 let cli_session_id: SessionId = result
788 .get("sessionId")
789 .and_then(|v| v.as_str())
790 .unwrap_or(&session_id)
791 .into();
792
793 let resume_capabilities: Option<SessionCapabilities> = result
794 .get("capabilities")
795 .and_then(|v| {
796 serde_json::from_value(v.clone())
797 .map_err(|e| warn!(error = %e, "failed to deserialize capabilities from resume response"))
798 .ok()
799 });
800 let remote_url = result
801 .get("remoteUrl")
802 .or_else(|| result.get("remote_url"))
803 .and_then(|value| value.as_str())
804 .map(ToString::to_string);
805
806 if let Err(e) = self
808 .call(
809 "session.skills.reload",
810 Some(serde_json::json!({ "sessionId": cli_session_id })),
811 )
812 .await
813 {
814 warn!(error = %e, "failed to reload skills after resume");
815 }
816
817 let capabilities = Arc::new(parking_lot::RwLock::new(
818 resume_capabilities.unwrap_or_default(),
819 ));
820 let channels = self.register_session(&cli_session_id);
821
822 let idle_waiter = Arc::new(ParkingLotMutex::new(None));
823 let shutdown = CancellationToken::new();
824 let (event_tx, _) = tokio::sync::broadcast::channel(512);
825 let event_loop = spawn_event_loop(
826 cli_session_id.clone(),
827 self.clone(),
828 handler,
829 hooks,
830 transforms,
831 command_handlers,
832 session_fs_provider,
833 channels,
834 idle_waiter.clone(),
835 capabilities.clone(),
836 event_tx.clone(),
837 shutdown.clone(),
838 );
839
840 Ok(Session {
841 id: cli_session_id,
842 cwd: self.cwd().clone(),
843 workspace_path: None,
844 remote_url,
845 client: self.clone(),
846 event_loop: ParkingLotMutex::new(Some(event_loop)),
847 shutdown,
848 idle_waiter,
849 capabilities,
850 event_tx,
851 })
852 }
853}
854
855type CommandHandlerMap = HashMap<String, Arc<dyn CommandHandler>>;
856
857fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc<CommandHandlerMap> {
858 let map = match commands {
859 Some(commands) => commands
860 .iter()
861 .filter(|cmd| !cmd.name.is_empty())
862 .map(|cmd| (cmd.name.clone(), cmd.handler.clone()))
863 .collect(),
864 None => HashMap::new(),
865 };
866 Arc::new(map)
867}
868
869#[allow(clippy::too_many_arguments)]
870fn spawn_event_loop(
871 session_id: SessionId,
872 client: Client,
873 handler: Arc<dyn SessionHandler>,
874 hooks: Option<Arc<dyn SessionHooks>>,
875 transforms: Option<Arc<dyn SystemMessageTransform>>,
876 command_handlers: Arc<CommandHandlerMap>,
877 session_fs_provider: Option<Arc<dyn SessionFsProvider>>,
878 channels: crate::router::SessionChannels,
879 idle_waiter: Arc<ParkingLotMutex<Option<IdleWaiter>>>,
880 capabilities: Arc<parking_lot::RwLock<SessionCapabilities>>,
881 event_tx: tokio::sync::broadcast::Sender<SessionEvent>,
882 shutdown: CancellationToken,
883) -> JoinHandle<()> {
884 let crate::router::SessionChannels {
885 mut notifications,
886 mut requests,
887 } = channels;
888
889 let span = tracing::error_span!("session_event_loop", session_id = %session_id);
890 tokio::spawn(
891 async move {
892 loop {
893 tokio::select! {
904 _ = shutdown.cancelled() => break,
905 Some(notification) = notifications.recv() => {
906 handle_notification(
907 &session_id, &client, &handler, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx,
908 ).await;
909 }
910 Some(request) = requests.recv() => {
911 handle_request(
912 &session_id, &client, &handler, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request,
913 ).await;
914 }
915 else => break,
916 }
917 }
918 if let Some(waiter) = idle_waiter.lock().take() {
921 let _ = waiter
922 .tx
923 .send(Err(Error::Session(SessionError::EventLoopClosed)));
924 }
925 }
926 .instrument(span),
927 )
928}
929
930fn extract_request_id(data: &Value) -> Option<RequestId> {
931 data.get("requestId")
932 .and_then(|v| v.as_str())
933 .filter(|s| !s.is_empty())
934 .map(RequestId::new)
935}
936
937fn pending_permission_result_kind(response: &HandlerResponse) -> &'static str {
938 match response {
939 HandlerResponse::Permission(PermissionResult::Approved) => "approve-once",
940 HandlerResponse::Permission(PermissionResult::Denied) => "reject",
941 HandlerResponse::Permission(PermissionResult::NoResult) => "no-result",
942 _ => "user-not-available",
946 }
947}
948
949fn permission_request_response(response: &HandlerResponse) -> PermissionDecision {
950 match response {
951 HandlerResponse::Permission(PermissionResult::Approved) => {
952 PermissionDecision::ApproveOnce(PermissionDecisionApproveOnce {
953 kind: PermissionDecisionApproveOnceKind::ApproveOnce,
954 })
955 }
956 _ => PermissionDecision::Reject(PermissionDecisionReject {
957 kind: PermissionDecisionRejectKind::Reject,
958 feedback: None,
959 }),
960 }
961}
962
963fn notification_permission_payload(response: &HandlerResponse) -> Option<Value> {
970 match response {
971 HandlerResponse::Permission(PermissionResult::Deferred) => None,
972 HandlerResponse::Permission(PermissionResult::Custom(value)) => Some(value.clone()),
973 _ => Some(serde_json::json!({
974 "kind": pending_permission_result_kind(response),
975 })),
976 }
977}
978
979fn direct_permission_payload(response: &HandlerResponse) -> Value {
986 match response {
987 HandlerResponse::Permission(PermissionResult::Custom(value)) => value.clone(),
988 HandlerResponse::Permission(PermissionResult::Deferred) => serde_json::to_value(
989 permission_request_response(&HandlerResponse::Permission(PermissionResult::Approved)),
990 )
991 .expect("serializing direct permission response should succeed"),
992 HandlerResponse::Permission(PermissionResult::NoResult)
993 | HandlerResponse::Permission(PermissionResult::UserNotAvailable) => serde_json::json!({
994 "kind": pending_permission_result_kind(response),
995 }),
996 _ => serde_json::to_value(permission_request_response(response))
997 .expect("serializing direct permission response should succeed"),
998 }
999}
1000
1001#[allow(clippy::too_many_arguments)]
1003async fn handle_notification(
1004 session_id: &SessionId,
1005 client: &Client,
1006 handler: &Arc<dyn SessionHandler>,
1007 command_handlers: &Arc<CommandHandlerMap>,
1008 notification: SessionEventNotification,
1009 idle_waiter: &Arc<ParkingLotMutex<Option<IdleWaiter>>>,
1010 capabilities: &Arc<parking_lot::RwLock<SessionCapabilities>>,
1011 event_tx: &tokio::sync::broadcast::Sender<SessionEvent>,
1012) {
1013 let event = notification.event.clone();
1014 let event_type = event.parsed_type();
1015
1016 match event_type {
1019 SessionEventType::AssistantMessage
1020 | SessionEventType::SessionIdle
1021 | SessionEventType::SessionError => {
1022 let mut guard = idle_waiter.lock();
1023 if let Some(waiter) = guard.as_mut() {
1024 match event_type {
1025 SessionEventType::AssistantMessage => {
1026 waiter.last_assistant_message = Some(event.clone());
1027 }
1028 SessionEventType::SessionIdle | SessionEventType::SessionError => {
1029 if let Some(waiter) = guard.take() {
1030 if event_type == SessionEventType::SessionIdle {
1031 let _ = waiter.tx.send(Ok(waiter.last_assistant_message));
1032 } else {
1033 let error_msg = event
1034 .typed_data::<SessionErrorData>()
1035 .map(|d| d.message)
1036 .or_else(|| {
1037 event
1038 .data
1039 .get("message")
1040 .and_then(|v| v.as_str())
1041 .map(|s| s.to_string())
1042 })
1043 .unwrap_or_else(|| "session error".to_string());
1044 let _ = waiter
1045 .tx
1046 .send(Err(Error::Session(SessionError::AgentError(error_msg))));
1047 }
1048 }
1049 }
1050 _ => {}
1051 }
1052 }
1053 }
1054 _ => {}
1055 }
1056
1057 let _ = event_tx.send(event.clone());
1061
1062 handler
1064 .on_event(HandlerEvent::SessionEvent {
1065 session_id: session_id.clone(),
1066 event,
1067 })
1068 .await;
1069
1070 if event_type == SessionEventType::CapabilitiesChanged {
1074 match serde_json::from_value::<SessionCapabilities>(notification.event.data.clone()) {
1075 Ok(changed) => *capabilities.write() = changed,
1076 Err(e) => warn!(error = %e, "failed to deserialize capabilities.changed payload"),
1077 }
1078 }
1079
1080 match event_type {
1083 SessionEventType::PermissionRequested => {
1084 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1085 return;
1086 };
1087 let client = client.clone();
1088 let handler = handler.clone();
1089 let sid = session_id.clone();
1090 let data: PermissionRequestData =
1091 serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| {
1092 PermissionRequestData {
1093 kind: None,
1094 tool_call_id: None,
1095 extra: notification.event.data.clone(),
1096 }
1097 });
1098 tokio::spawn(async move {
1099 let response = handler
1100 .on_event(HandlerEvent::PermissionRequest {
1101 session_id: sid.clone(),
1102 request_id: request_id.clone(),
1103 data,
1104 })
1105 .await;
1106 let Some(result_value) = notification_permission_payload(&response) else {
1107 return;
1110 };
1111 let _ = client
1112 .call(
1113 "session.permissions.handlePendingPermissionRequest",
1114 Some(serde_json::json!({
1115 "sessionId": sid,
1116 "requestId": request_id,
1117 "result": result_value,
1118 })),
1119 )
1120 .await;
1121 });
1122 }
1123 SessionEventType::ExternalToolRequested => {
1124 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1125 return;
1126 };
1127 let data: ExternalToolRequestedData =
1128 match serde_json::from_value(notification.event.data.clone()) {
1129 Ok(d) => d,
1130 Err(e) => {
1131 warn!(error = %e, "failed to deserialize external_tool.requested");
1132 let client = client.clone();
1133 let sid = session_id.clone();
1134 tokio::spawn(async move {
1135 let _ = client
1136 .call(
1137 "session.tools.handlePendingToolCall",
1138 Some(serde_json::json!({
1139 "sessionId": sid,
1140 "requestId": request_id,
1141 "error": format!("Failed to deserialize tool request: {e}"),
1142 })),
1143 )
1144 .await;
1145 });
1146 return;
1147 }
1148 };
1149 let client = client.clone();
1150 let handler = handler.clone();
1151 let sid = session_id.clone();
1152 tokio::spawn(async move {
1153 if data.tool_call_id.is_empty() || data.tool_name.is_empty() {
1154 let error_msg = if data.tool_call_id.is_empty() {
1155 "Missing toolCallId"
1156 } else {
1157 "Missing toolName"
1158 };
1159 let _ = client
1160 .call(
1161 "session.tools.handlePendingToolCall",
1162 Some(serde_json::json!({
1163 "sessionId": sid,
1164 "requestId": request_id,
1165 "error": error_msg,
1166 })),
1167 )
1168 .await;
1169 return;
1170 }
1171 let invocation = ToolInvocation {
1172 session_id: sid.clone(),
1173 tool_call_id: data.tool_call_id,
1174 tool_name: data.tool_name,
1175 arguments: data
1176 .arguments
1177 .unwrap_or(Value::Object(serde_json::Map::new())),
1178 traceparent: data.traceparent,
1179 tracestate: data.tracestate,
1180 };
1181 let response = handler
1182 .on_event(HandlerEvent::ExternalTool { invocation })
1183 .await;
1184 let tool_result = match response {
1185 HandlerResponse::ToolResult(r) => r,
1186 _ => ToolResult::Text("Unexpected handler response".to_string()),
1187 };
1188 let result_value = serde_json::to_value(&tool_result).unwrap_or(Value::Null);
1189 let _ = client
1190 .call(
1191 "session.tools.handlePendingToolCall",
1192 Some(serde_json::json!({
1193 "sessionId": sid,
1194 "requestId": request_id,
1195 "result": result_value,
1196 })),
1197 )
1198 .await;
1199 });
1200 }
1201 SessionEventType::UserInputRequested => {
1202 }
1209 SessionEventType::ElicitationRequested => {
1210 let Some(request_id) = extract_request_id(¬ification.event.data) else {
1211 return;
1212 };
1213 let elicitation_data: ElicitationRequestedData =
1214 match serde_json::from_value(notification.event.data.clone()) {
1215 Ok(d) => d,
1216 Err(e) => {
1217 warn!(error = %e, "failed to deserialize elicitation request");
1218 return;
1219 }
1220 };
1221 let request = ElicitationRequest {
1222 message: elicitation_data.message,
1223 requested_schema: elicitation_data
1224 .requested_schema
1225 .map(|s| serde_json::to_value(s).unwrap_or(Value::Null)),
1226 mode: elicitation_data.mode.map(|m| match m {
1227 crate::generated::session_events::ElicitationRequestedMode::Form => {
1228 crate::types::ElicitationMode::Form
1229 }
1230 crate::generated::session_events::ElicitationRequestedMode::Url => {
1231 crate::types::ElicitationMode::Url
1232 }
1233 _ => crate::types::ElicitationMode::Unknown,
1234 }),
1235 elicitation_source: elicitation_data.elicitation_source,
1236 url: elicitation_data.url,
1237 };
1238 let client = client.clone();
1239 let handler = handler.clone();
1240 let sid = session_id.clone();
1241 tokio::spawn(async move {
1242 let cancel = ElicitationResult {
1243 action: "cancel".to_string(),
1244 content: None,
1245 };
1246 let handler_task = tokio::spawn({
1249 let sid = sid.clone();
1250 let request_id = request_id.clone();
1251 async move {
1252 handler
1253 .on_event(HandlerEvent::ElicitationRequest {
1254 session_id: sid,
1255 request_id,
1256 request,
1257 })
1258 .await
1259 }
1260 });
1261 let result = match handler_task.await {
1262 Ok(HandlerResponse::Elicitation(r)) => r,
1263 _ => cancel.clone(),
1264 };
1265 if let Err(e) = client
1266 .call(
1267 "session.ui.handlePendingElicitation",
1268 Some(serde_json::json!({
1269 "sessionId": sid,
1270 "requestId": request_id,
1271 "result": result,
1272 })),
1273 )
1274 .await
1275 {
1276 warn!(error = %e, "handlePendingElicitation failed, sending cancel");
1278 let _ = client
1279 .call(
1280 "session.ui.handlePendingElicitation",
1281 Some(serde_json::json!({
1282 "sessionId": sid,
1283 "requestId": request_id,
1284 "result": cancel,
1285 })),
1286 )
1287 .await;
1288 }
1289 });
1290 }
1291 SessionEventType::CommandExecute => {
1292 let data: CommandExecuteData =
1293 match serde_json::from_value(notification.event.data.clone()) {
1294 Ok(d) => d,
1295 Err(e) => {
1296 warn!(error = %e, "failed to deserialize command.execute");
1297 return;
1298 }
1299 };
1300 let client = client.clone();
1301 let command_handlers = command_handlers.clone();
1302 let sid = session_id.clone();
1303 tokio::spawn(async move {
1304 let request_id = data.request_id;
1305 let ack_error = match command_handlers.get(&data.command_name).cloned() {
1306 None => Some(format!("Unknown command: {}", data.command_name)),
1307 Some(handler) => {
1308 let ctx = CommandContext {
1309 session_id: sid.clone(),
1310 command: data.command,
1311 command_name: data.command_name,
1312 args: data.args,
1313 };
1314 match handler.on_command(ctx).await {
1315 Ok(()) => None,
1316 Err(e) => Some(e.to_string()),
1317 }
1318 }
1319 };
1320 let mut params = serde_json::json!({
1321 "sessionId": sid,
1322 "requestId": request_id,
1323 });
1324 if let Some(error_msg) = ack_error {
1325 params["error"] = serde_json::Value::String(error_msg);
1326 }
1327 let _ = client
1328 .call("session.commands.handlePendingCommand", Some(params))
1329 .await;
1330 });
1331 }
1332 _ => {}
1333 }
1334}
1335
1336async fn handle_request(
1338 session_id: &SessionId,
1339 client: &Client,
1340 handler: &Arc<dyn SessionHandler>,
1341 hooks: Option<&dyn SessionHooks>,
1342 transforms: Option<&dyn SystemMessageTransform>,
1343 session_fs_provider: Option<&Arc<dyn SessionFsProvider>>,
1344 request: crate::JsonRpcRequest,
1345) {
1346 let sid = session_id.clone();
1347
1348 if request.method.starts_with("sessionFs.") {
1349 crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await;
1350 return;
1351 }
1352
1353 match request.method.as_str() {
1354 "hooks.invoke" => {
1355 let params = request.params.as_ref();
1356 let hook_type = params
1357 .and_then(|p| p.get("hookType"))
1358 .and_then(|v| v.as_str())
1359 .unwrap_or("");
1360 let input = params
1361 .and_then(|p| p.get("input"))
1362 .cloned()
1363 .unwrap_or(Value::Object(Default::default()));
1364
1365 let rpc_result = if let Some(hooks) = hooks {
1366 match crate::hooks::dispatch_hook(hooks, &sid, hook_type, input).await {
1367 Ok(output) => output,
1368 Err(e) => {
1369 warn!(error = %e, hook_type = hook_type, "hook dispatch failed");
1370 serde_json::json!({ "output": {} })
1371 }
1372 }
1373 } else {
1374 serde_json::json!({ "output": {} })
1375 };
1376
1377 let rpc_response = JsonRpcResponse {
1378 jsonrpc: "2.0".to_string(),
1379 id: request.id,
1380 result: Some(rpc_result),
1381 error: None,
1382 };
1383 let _ = client.send_response(&rpc_response).await;
1384 }
1385
1386 "tool.call" => {
1387 let invocation: ToolInvocation = match request
1388 .params
1389 .as_ref()
1390 .and_then(|p| serde_json::from_value::<ToolInvocation>(p.clone()).ok())
1391 {
1392 Some(inv) => inv,
1393 None => {
1394 let _ = send_error_response(
1395 client,
1396 request.id,
1397 error_codes::INVALID_PARAMS,
1398 "invalid tool.call params",
1399 )
1400 .await;
1401 return;
1402 }
1403 };
1404 let response = handler
1405 .on_event(HandlerEvent::ExternalTool { invocation })
1406 .await;
1407 let tool_result = match response {
1408 HandlerResponse::ToolResult(r) => r,
1409 _ => ToolResult::Text("Unexpected handler response".to_string()),
1410 };
1411 let rpc_response = JsonRpcResponse {
1412 jsonrpc: "2.0".to_string(),
1413 id: request.id,
1414 result: Some(serde_json::json!(ToolResultResponse {
1415 result: tool_result
1416 })),
1417 error: None,
1418 };
1419 let _ = client.send_response(&rpc_response).await;
1420 }
1421
1422 "userInput.request" => {
1423 let params = request.params.as_ref();
1424 let Some(question) = params
1425 .and_then(|p| p.get("question"))
1426 .and_then(|v| v.as_str())
1427 else {
1428 warn!("userInput.request missing 'question' field");
1429 let rpc_response = JsonRpcResponse {
1430 jsonrpc: "2.0".to_string(),
1431 id: request.id,
1432 result: None,
1433 error: Some(crate::JsonRpcError {
1434 code: error_codes::INVALID_PARAMS,
1435 message: "missing required field: question".to_string(),
1436 data: None,
1437 }),
1438 };
1439 let _ = client.send_response(&rpc_response).await;
1440 return;
1441 };
1442 let question = question.to_string();
1443 let choices = params
1444 .and_then(|p| p.get("choices"))
1445 .and_then(|v| v.as_array())
1446 .map(|arr| {
1447 arr.iter()
1448 .filter_map(|v| v.as_str().map(|s| s.to_string()))
1449 .collect()
1450 });
1451 let allow_freeform = params
1452 .and_then(|p| p.get("allowFreeform"))
1453 .and_then(|v| v.as_bool());
1454
1455 let response = handler
1456 .on_event(HandlerEvent::UserInput {
1457 session_id: sid,
1458 question,
1459 choices,
1460 allow_freeform,
1461 })
1462 .await;
1463
1464 let rpc_result = match response {
1465 HandlerResponse::UserInput(Some(UserInputResponse {
1466 answer,
1467 was_freeform,
1468 })) => serde_json::json!({
1469 "answer": answer,
1470 "wasFreeform": was_freeform,
1471 }),
1472 _ => serde_json::json!({ "noResponse": true }),
1473 };
1474 let rpc_response = JsonRpcResponse {
1475 jsonrpc: "2.0".to_string(),
1476 id: request.id,
1477 result: Some(rpc_result),
1478 error: None,
1479 };
1480 let _ = client.send_response(&rpc_response).await;
1481 }
1482
1483 "exitPlanMode.request" => {
1484 let params = request
1485 .params
1486 .as_ref()
1487 .cloned()
1488 .unwrap_or(Value::Object(serde_json::Map::new()));
1489 let data: ExitPlanModeData = match serde_json::from_value(params) {
1490 Ok(d) => d,
1491 Err(e) => {
1492 warn!(error = %e, "failed to deserialize exitPlanMode.request params, using defaults");
1493 ExitPlanModeData::default()
1494 }
1495 };
1496
1497 let response = handler
1498 .on_event(HandlerEvent::ExitPlanMode {
1499 session_id: sid,
1500 data,
1501 })
1502 .await;
1503
1504 let rpc_result = match response {
1505 HandlerResponse::ExitPlanMode(ExitPlanModeResult {
1506 approved,
1507 selected_action,
1508 feedback,
1509 }) => serde_json::json!({
1510 "approved": approved,
1511 "selectedAction": selected_action,
1512 "feedback": feedback,
1513 }),
1514 _ => serde_json::json!({ "approved": true }),
1515 };
1516 let rpc_response = JsonRpcResponse {
1517 jsonrpc: "2.0".to_string(),
1518 id: request.id,
1519 result: Some(rpc_result),
1520 error: None,
1521 };
1522 let _ = client.send_response(&rpc_response).await;
1523 }
1524
1525 "autoModeSwitch.request" => {
1526 let error_code = request
1527 .params
1528 .as_ref()
1529 .and_then(|p| p.get("errorCode"))
1530 .and_then(|v| v.as_str())
1531 .map(|s| s.to_string());
1532 let retry_after_seconds = request
1533 .params
1534 .as_ref()
1535 .and_then(|p| p.get("retryAfterSeconds"))
1536 .and_then(|v| v.as_u64());
1537
1538 let response = handler
1539 .on_event(HandlerEvent::AutoModeSwitch {
1540 session_id: sid,
1541 error_code,
1542 retry_after_seconds,
1543 })
1544 .await;
1545
1546 let answer = match response {
1547 HandlerResponse::AutoModeSwitch(answer) => answer,
1548 _ => AutoModeSwitchResponse::No,
1549 };
1550 let rpc_response = JsonRpcResponse {
1551 jsonrpc: "2.0".to_string(),
1552 id: request.id,
1553 result: Some(serde_json::json!({ "response": answer })),
1554 error: None,
1555 };
1556 let _ = client.send_response(&rpc_response).await;
1557 }
1558
1559 "permission.request" => {
1560 let Some(request_id) = request
1561 .params
1562 .as_ref()
1563 .and_then(|p| p.get("requestId"))
1564 .and_then(|v| v.as_str())
1565 .filter(|s| !s.is_empty())
1566 else {
1567 warn!("permission.request missing 'requestId' field");
1568 let rpc_response = JsonRpcResponse {
1569 jsonrpc: "2.0".to_string(),
1570 id: request.id,
1571 result: None,
1572 error: Some(crate::JsonRpcError {
1573 code: error_codes::INVALID_PARAMS,
1574 message: "missing required field: requestId".to_string(),
1575 data: None,
1576 }),
1577 };
1578 let _ = client.send_response(&rpc_response).await;
1579 return;
1580 };
1581 let request_id = RequestId::new(request_id);
1582 let raw_params = request
1583 .params
1584 .as_ref()
1585 .cloned()
1586 .unwrap_or(Value::Object(serde_json::Map::new()));
1587 let data: PermissionRequestData =
1588 serde_json::from_value(raw_params.clone()).unwrap_or(PermissionRequestData {
1589 kind: None,
1590 tool_call_id: None,
1591 extra: raw_params,
1592 });
1593
1594 let response = handler
1595 .on_event(HandlerEvent::PermissionRequest {
1596 session_id: sid,
1597 request_id,
1598 data,
1599 })
1600 .await;
1601 let rpc_response = JsonRpcResponse {
1602 jsonrpc: "2.0".to_string(),
1603 id: request.id,
1604 result: Some(direct_permission_payload(&response)),
1605 error: None,
1606 };
1607 let _ = client.send_response(&rpc_response).await;
1608 }
1609
1610 "systemMessage.transform" => {
1611 let params = request.params.as_ref();
1612 let sections: HashMap<String, crate::transforms::TransformSection> =
1613 match params.and_then(|p| p.get("sections")) {
1614 Some(v) => match serde_json::from_value(v.clone()) {
1615 Ok(s) => s,
1616 Err(e) => {
1617 let _ = send_error_response(
1618 client,
1619 request.id,
1620 error_codes::INVALID_PARAMS,
1621 &format!("invalid sections: {e}"),
1622 )
1623 .await;
1624 return;
1625 }
1626 },
1627 None => {
1628 let _ = send_error_response(
1629 client,
1630 request.id,
1631 error_codes::INVALID_PARAMS,
1632 "missing sections parameter",
1633 )
1634 .await;
1635 return;
1636 }
1637 };
1638
1639 let rpc_result = if let Some(transforms) = transforms {
1640 let response =
1641 crate::transforms::dispatch_transform(transforms, &sid, sections).await;
1642 match serde_json::to_value(response) {
1643 Ok(v) => v,
1644 Err(e) => {
1645 warn!(error = %e, "failed to serialize transform response");
1646 serde_json::json!({ "sections": {} })
1647 }
1648 }
1649 } else {
1650 let passthrough: HashMap<String, crate::transforms::TransformSection> = sections;
1652 serde_json::json!({ "sections": passthrough })
1653 };
1654
1655 let rpc_response = JsonRpcResponse {
1656 jsonrpc: "2.0".to_string(),
1657 id: request.id,
1658 result: Some(rpc_result),
1659 error: None,
1660 };
1661 let _ = client.send_response(&rpc_response).await;
1662 }
1663
1664 method => {
1665 warn!(
1666 method = method,
1667 "unhandled request method in session event loop"
1668 );
1669 let _ = send_error_response(
1670 client,
1671 request.id,
1672 error_codes::METHOD_NOT_FOUND,
1673 &format!("unknown method: {method}"),
1674 )
1675 .await;
1676 }
1677 }
1678}
1679
1680async fn send_error_response(
1681 client: &Client,
1682 id: u64,
1683 code: i32,
1684 message: &str,
1685) -> Result<(), Error> {
1686 let response = JsonRpcResponse {
1687 jsonrpc: "2.0".to_string(),
1688 id,
1689 result: None,
1690 error: Some(crate::JsonRpcError {
1691 code,
1692 message: message.to_string(),
1693 data: None,
1694 }),
1695 };
1696 client.send_response(&response).await
1697}
1698
1699fn apply_transform_sections(
1703 sys_msg: &mut SystemMessageConfig,
1704 transforms: &dyn SystemMessageTransform,
1705) {
1706 sys_msg.mode = Some("customize".to_string());
1707 let sections = sys_msg.sections.get_or_insert_with(HashMap::new);
1708 for id in transforms.section_ids() {
1709 sections.entry(id).or_insert_with(|| SectionOverride {
1710 action: Some("transform".to_string()),
1711 content: None,
1712 });
1713 }
1714}
1715
1716fn inject_transform_sections(config: &mut SessionConfig, transforms: &dyn SystemMessageTransform) {
1717 let sys_msg = config.system_message.get_or_insert_with(Default::default);
1718 apply_transform_sections(sys_msg, transforms);
1719}
1720
1721fn inject_transform_sections_resume(
1722 config: &mut ResumeSessionConfig,
1723 transforms: &dyn SystemMessageTransform,
1724) {
1725 let sys_msg = config.system_message.get_or_insert_with(Default::default);
1726 apply_transform_sections(sys_msg, transforms);
1727}
1728
1729#[cfg(test)]
1730mod tests {
1731 use serde_json::json;
1732
1733 use super::{
1734 direct_permission_payload, notification_permission_payload, pending_permission_result_kind,
1735 permission_request_response,
1736 };
1737 use crate::handler::{HandlerResponse, PermissionResult};
1738
1739 #[test]
1740 fn pending_permission_requests_use_decision_kinds() {
1741 assert_eq!(
1742 pending_permission_result_kind(&HandlerResponse::Permission(
1743 PermissionResult::Approved,
1744 )),
1745 "approve-once"
1746 );
1747 assert_eq!(
1748 pending_permission_result_kind(&HandlerResponse::Permission(PermissionResult::Denied)),
1749 "reject"
1750 );
1751 assert_eq!(
1752 pending_permission_result_kind(&HandlerResponse::Ok),
1753 "user-not-available"
1754 );
1755 }
1756
1757 #[test]
1758 fn direct_permission_requests_use_decision_response_kinds() {
1759 assert_eq!(
1760 serde_json::to_value(permission_request_response(&HandlerResponse::Permission(
1761 PermissionResult::Approved
1762 ),))
1763 .expect("serializing approved permission response should succeed"),
1764 json!({ "kind": "approve-once" })
1765 );
1766 assert_eq!(
1767 serde_json::to_value(permission_request_response(&HandlerResponse::Permission(
1768 PermissionResult::Denied
1769 ),))
1770 .expect("serializing denied permission response should succeed"),
1771 json!({ "kind": "reject" })
1772 );
1773 assert_eq!(
1774 serde_json::to_value(permission_request_response(&HandlerResponse::Ok))
1775 .expect("serializing fallback permission response should succeed"),
1776 json!({ "kind": "reject" })
1777 );
1778 }
1779
1780 #[test]
1781 fn notification_payload_handles_deferred_and_custom() {
1782 assert!(
1784 notification_permission_payload(&HandlerResponse::Permission(
1785 PermissionResult::Deferred,
1786 ))
1787 .is_none()
1788 );
1789
1790 let custom = json!({
1792 "kind": "approve-and-remember",
1793 "allowlist": ["ls", "grep"],
1794 });
1795 assert_eq!(
1796 notification_permission_payload(&HandlerResponse::Permission(
1797 PermissionResult::Custom(custom.clone()),
1798 )),
1799 Some(custom)
1800 );
1801
1802 assert_eq!(
1804 notification_permission_payload(&HandlerResponse::Permission(
1805 PermissionResult::Approved,
1806 )),
1807 Some(json!({ "kind": "approve-once" }))
1808 );
1809 assert_eq!(
1810 notification_permission_payload(
1811 &HandlerResponse::Permission(PermissionResult::Denied,)
1812 ),
1813 Some(json!({ "kind": "reject" }))
1814 );
1815 }
1816
1817 #[test]
1818 fn direct_payload_handles_deferred_and_custom() {
1819 let custom = json!({
1821 "kind": "approve-and-remember",
1822 "allowlist": ["ls", "grep"],
1823 });
1824 assert_eq!(
1825 direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Custom(
1826 custom.clone(),
1827 ))),
1828 custom
1829 );
1830
1831 assert_eq!(
1833 direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Deferred)),
1834 json!({ "kind": "approve-once" })
1835 );
1836
1837 assert_eq!(
1839 direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Approved)),
1840 json!({ "kind": "approve-once" })
1841 );
1842 assert_eq!(
1843 direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Denied)),
1844 json!({ "kind": "reject" })
1845 );
1846 }
1847}