1use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20
21use crate::generated::api_types::{
22 McpOauthPendingRequestResponse, McpOauthPendingRequestResponseCancelled,
23 McpOauthPendingRequestResponseCancelledKind, McpOauthPendingRequestResponseToken,
24 McpOauthPendingRequestResponseTokenKind, PermissionDecision, PermissionDecisionApproveOnce,
25 PermissionDecisionReject, PermissionDecisionUserNotAvailable,
26};
27use crate::session_events::{
28 McpOauthRequestReason, McpOauthRequiredStaticClientConfig, McpOauthWWWAuthenticateParams,
29};
30use crate::types::{
31 ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId,
32 SessionId,
33};
34
35#[derive(Debug, Clone)]
42pub enum PermissionResult {
43 Decision(PermissionDecision),
45 NoResult,
48}
49
50impl PermissionResult {
51 pub fn approve_once() -> Self {
53 Self::Decision(PermissionDecision::ApproveOnce(
54 PermissionDecisionApproveOnce::default(),
55 ))
56 }
57
58 pub fn reject(feedback: impl Into<Option<String>>) -> Self {
60 Self::Decision(PermissionDecision::Reject(PermissionDecisionReject {
61 feedback: feedback.into(),
62 ..Default::default()
63 }))
64 }
65
66 pub fn user_not_available() -> Self {
68 Self::Decision(PermissionDecision::UserNotAvailable(
69 PermissionDecisionUserNotAvailable::default(),
70 ))
71 }
72
73 pub fn no_result() -> Self {
76 Self::NoResult
77 }
78}
79
80impl From<PermissionDecision> for PermissionResult {
81 fn from(value: PermissionDecision) -> Self {
82 Self::Decision(value)
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct UserInputResponse {
89 pub answer: String,
91 pub was_freeform: bool,
93}
94
95#[derive(Debug, Clone, Serialize)]
97#[serde(rename_all = "camelCase")]
98pub struct ExitPlanModeResult {
99 pub approved: bool,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub selected_action: Option<String>,
104 #[serde(skip_serializing_if = "Option::is_none")]
106 pub feedback: Option<String>,
107}
108
109impl Default for ExitPlanModeResult {
110 fn default() -> Self {
111 Self {
112 approved: true,
113 selected_action: None,
114 feedback: None,
115 }
116 }
117}
118
119#[non_exhaustive]
121#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
122#[serde(rename_all = "snake_case")]
123pub enum AutoModeSwitchResponse {
124 Yes,
126 YesAlways,
129 No,
132}
133
134#[async_trait]
143pub trait PermissionHandler: Send + Sync + 'static {
144 async fn handle(
146 &self,
147 session_id: SessionId,
148 request_id: RequestId,
149 data: PermissionRequestData,
150 ) -> PermissionResult;
151}
152
153#[async_trait]
157pub trait ElicitationHandler: Send + Sync + 'static {
158 async fn handle(
160 &self,
161 session_id: SessionId,
162 request_id: RequestId,
163 request: ElicitationRequest,
164 ) -> ElicitationResult;
165}
166
167#[derive(Debug, Clone)]
169pub struct McpAuthRequest {
170 pub request_id: RequestId,
172 pub server_name: String,
174 pub server_url: String,
176 pub reason: McpOauthRequestReason,
178 pub www_authenticate_params: Option<McpOauthWWWAuthenticateParams>,
180 pub resource_metadata: Option<String>,
182 pub static_client_config: Option<McpOauthRequiredStaticClientConfig>,
184}
185
186#[derive(Debug, Clone)]
188pub enum McpAuthResult {
189 Token {
191 access_token: String,
193 token_type: Option<String>,
195 expires_in: Option<i64>,
197 },
198 Cancelled,
200}
201
202impl McpAuthResult {
203 pub(crate) fn into_wire(self) -> McpOauthPendingRequestResponse {
204 match self {
205 Self::Token {
206 access_token,
207 token_type,
208 expires_in,
209 } => McpOauthPendingRequestResponse::Token(McpOauthPendingRequestResponseToken {
210 access_token,
211 token_type,
212 expires_in,
213 kind: McpOauthPendingRequestResponseTokenKind::Token,
214 }),
215 Self::Cancelled => {
216 McpOauthPendingRequestResponse::Cancelled(McpOauthPendingRequestResponseCancelled {
217 kind: McpOauthPendingRequestResponseCancelledKind::Cancelled,
218 })
219 }
220 }
221 }
222}
223
224#[async_trait]
226pub trait McpAuthHandler: Send + Sync + 'static {
227 async fn handle(
229 &self,
230 session_id: SessionId,
231 request_id: RequestId,
232 request: McpAuthRequest,
233 ) -> McpAuthResult;
234}
235
236#[async_trait]
241pub trait UserInputHandler: Send + Sync + 'static {
242 async fn handle(
245 &self,
246 session_id: SessionId,
247 question: String,
248 choices: Option<Vec<String>>,
249 allow_freeform: Option<bool>,
250 ) -> Option<UserInputResponse>;
251}
252
253#[async_trait]
256pub trait ExitPlanModeHandler: Send + Sync + 'static {
257 async fn handle(&self, session_id: SessionId, data: ExitPlanModeData) -> ExitPlanModeResult;
259}
260
261#[async_trait]
264pub trait AutoModeSwitchHandler: Send + Sync + 'static {
265 async fn handle(
269 &self,
270 session_id: SessionId,
271 error_code: Option<String>,
272 retry_after_seconds: Option<f64>,
273 ) -> AutoModeSwitchResponse;
274}
275
276#[derive(Debug, Clone)]
280pub struct ApproveAllHandler;
281
282#[async_trait]
283impl PermissionHandler for ApproveAllHandler {
284 async fn handle(
285 &self,
286 _session_id: SessionId,
287 _request_id: RequestId,
288 _data: PermissionRequestData,
289 ) -> PermissionResult {
290 PermissionResult::approve_once()
291 }
292}
293
294#[derive(Debug, Clone)]
296pub struct DenyAllHandler;
297
298#[async_trait]
299impl PermissionHandler for DenyAllHandler {
300 async fn handle(
301 &self,
302 _session_id: SessionId,
303 _request_id: RequestId,
304 _data: PermissionRequestData,
305 ) -> PermissionResult {
306 PermissionResult::reject(None)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[tokio::test]
315 async fn approve_all_handler_returns_approved() {
316 let result = ApproveAllHandler
317 .handle(
318 SessionId::from("s1"),
319 RequestId::new("1"),
320 PermissionRequestData::default(),
321 )
322 .await;
323 assert!(matches!(
324 result,
325 PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
326 ));
327 }
328
329 #[tokio::test]
330 async fn deny_all_handler_returns_denied() {
331 let result = DenyAllHandler
332 .handle(
333 SessionId::from("s1"),
334 RequestId::new("1"),
335 PermissionRequestData::default(),
336 )
337 .await;
338 assert!(matches!(
339 result,
340 PermissionResult::Decision(PermissionDecision::Reject(_))
341 ));
342 }
343
344 #[test]
345 fn mcp_auth_result_token_converts_to_wire_response() {
346 let wire = McpAuthResult::Token {
347 access_token: "host-token".to_string(),
348 token_type: Some("Bearer".to_string()),
349 expires_in: Some(3600),
350 }
351 .into_wire();
352
353 match wire {
354 McpOauthPendingRequestResponse::Token(token) => {
355 assert_eq!(token.access_token, "host-token");
356 assert_eq!(token.token_type.as_deref(), Some("Bearer"));
357 assert_eq!(token.expires_in, Some(3600));
358 }
359 McpOauthPendingRequestResponse::Cancelled(_) => panic!("expected token response"),
360 }
361 }
362}