use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::generated::api_types::{
McpOauthPendingRequestResponse, McpOauthPendingRequestResponseCancelled,
McpOauthPendingRequestResponseCancelledKind, McpOauthPendingRequestResponseToken,
McpOauthPendingRequestResponseTokenKind, PermissionDecision, PermissionDecisionApproveOnce,
PermissionDecisionReject, PermissionDecisionUserNotAvailable,
};
use crate::session_events::{
McpOauthRequestReason, McpOauthRequiredStaticClientConfig, McpOauthWWWAuthenticateParams,
};
use crate::types::{
ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId,
SessionId,
};
#[derive(Debug, Clone)]
pub enum PermissionResult {
Decision(PermissionDecision),
NoResult,
}
impl PermissionResult {
pub fn approve_once() -> Self {
Self::Decision(PermissionDecision::ApproveOnce(
PermissionDecisionApproveOnce::default(),
))
}
pub fn reject(feedback: impl Into<Option<String>>) -> Self {
Self::Decision(PermissionDecision::Reject(PermissionDecisionReject {
feedback: feedback.into(),
..Default::default()
}))
}
pub fn user_not_available() -> Self {
Self::Decision(PermissionDecision::UserNotAvailable(
PermissionDecisionUserNotAvailable::default(),
))
}
pub fn no_result() -> Self {
Self::NoResult
}
}
impl From<PermissionDecision> for PermissionResult {
fn from(value: PermissionDecision) -> Self {
Self::Decision(value)
}
}
#[derive(Debug, Clone)]
pub struct UserInputResponse {
pub answer: String,
pub was_freeform: bool,
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ExitPlanModeResult {
pub approved: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub selected_action: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub feedback: Option<String>,
}
impl Default for ExitPlanModeResult {
fn default() -> Self {
Self {
approved: true,
selected_action: None,
feedback: None,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AutoModeSwitchResponse {
Yes,
YesAlways,
No,
}
#[async_trait]
pub trait PermissionHandler: Send + Sync + 'static {
async fn handle(
&self,
session_id: SessionId,
request_id: RequestId,
data: PermissionRequestData,
) -> PermissionResult;
}
#[async_trait]
pub trait ElicitationHandler: Send + Sync + 'static {
async fn handle(
&self,
session_id: SessionId,
request_id: RequestId,
request: ElicitationRequest,
) -> ElicitationResult;
}
#[derive(Debug, Clone)]
pub struct McpAuthRequest {
pub request_id: RequestId,
pub server_name: String,
pub server_url: String,
pub reason: McpOauthRequestReason,
pub www_authenticate_params: Option<McpOauthWWWAuthenticateParams>,
pub resource_metadata: Option<String>,
pub static_client_config: Option<McpOauthRequiredStaticClientConfig>,
}
#[derive(Debug, Clone)]
pub enum McpAuthResult {
Token {
access_token: String,
token_type: Option<String>,
expires_in: Option<i64>,
},
Cancelled,
}
impl McpAuthResult {
pub(crate) fn into_wire(self) -> McpOauthPendingRequestResponse {
match self {
Self::Token {
access_token,
token_type,
expires_in,
} => McpOauthPendingRequestResponse::Token(McpOauthPendingRequestResponseToken {
access_token,
token_type,
expires_in,
kind: McpOauthPendingRequestResponseTokenKind::Token,
}),
Self::Cancelled => {
McpOauthPendingRequestResponse::Cancelled(McpOauthPendingRequestResponseCancelled {
kind: McpOauthPendingRequestResponseCancelledKind::Cancelled,
})
}
}
}
}
#[async_trait]
pub trait McpAuthHandler: Send + Sync + 'static {
async fn handle(
&self,
session_id: SessionId,
request_id: RequestId,
request: McpAuthRequest,
) -> McpAuthResult;
}
#[async_trait]
pub trait UserInputHandler: Send + Sync + 'static {
async fn handle(
&self,
session_id: SessionId,
question: String,
choices: Option<Vec<String>>,
allow_freeform: Option<bool>,
) -> Option<UserInputResponse>;
}
#[async_trait]
pub trait ExitPlanModeHandler: Send + Sync + 'static {
async fn handle(&self, session_id: SessionId, data: ExitPlanModeData) -> ExitPlanModeResult;
}
#[async_trait]
pub trait AutoModeSwitchHandler: Send + Sync + 'static {
async fn handle(
&self,
session_id: SessionId,
error_code: Option<String>,
retry_after_seconds: Option<f64>,
) -> AutoModeSwitchResponse;
}
#[derive(Debug, Clone)]
pub struct ApproveAllHandler;
#[async_trait]
impl PermissionHandler for ApproveAllHandler {
async fn handle(
&self,
_session_id: SessionId,
_request_id: RequestId,
_data: PermissionRequestData,
) -> PermissionResult {
PermissionResult::approve_once()
}
}
#[derive(Debug, Clone)]
pub struct DenyAllHandler;
#[async_trait]
impl PermissionHandler for DenyAllHandler {
async fn handle(
&self,
_session_id: SessionId,
_request_id: RequestId,
_data: PermissionRequestData,
) -> PermissionResult {
PermissionResult::reject(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn approve_all_handler_returns_approved() {
let result = ApproveAllHandler
.handle(
SessionId::from("s1"),
RequestId::new("1"),
PermissionRequestData::default(),
)
.await;
assert!(matches!(
result,
PermissionResult::Decision(PermissionDecision::ApproveOnce(_))
));
}
#[tokio::test]
async fn deny_all_handler_returns_denied() {
let result = DenyAllHandler
.handle(
SessionId::from("s1"),
RequestId::new("1"),
PermissionRequestData::default(),
)
.await;
assert!(matches!(
result,
PermissionResult::Decision(PermissionDecision::Reject(_))
));
}
#[test]
fn mcp_auth_result_token_converts_to_wire_response() {
let wire = McpAuthResult::Token {
access_token: "host-token".to_string(),
token_type: Some("Bearer".to_string()),
expires_in: Some(3600),
}
.into_wire();
match wire {
McpOauthPendingRequestResponse::Token(token) => {
assert_eq!(token.access_token, "host-token");
assert_eq!(token.token_type.as_deref(), Some("Bearer"));
assert_eq!(token.expires_in, Some(3600));
}
McpOauthPendingRequestResponse::Cancelled(_) => panic!("expected token response"),
}
}
}