use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::types::{
ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId,
SessionEvent, SessionId, ToolInvocation, ToolResult,
};
#[non_exhaustive]
#[derive(Debug)]
pub enum HandlerEvent {
SessionEvent {
session_id: SessionId,
event: SessionEvent,
},
PermissionRequest {
session_id: SessionId,
request_id: RequestId,
data: PermissionRequestData,
},
UserInput {
session_id: SessionId,
question: String,
choices: Option<Vec<String>>,
allow_freeform: Option<bool>,
},
ExternalTool {
invocation: ToolInvocation,
},
ElicitationRequest {
session_id: SessionId,
request_id: RequestId,
request: ElicitationRequest,
},
ExitPlanMode {
session_id: SessionId,
data: ExitPlanModeData,
},
AutoModeSwitch {
session_id: SessionId,
error_code: Option<String>,
retry_after_seconds: Option<u64>,
},
}
#[non_exhaustive]
#[derive(Debug)]
pub enum HandlerResponse {
Ok,
Permission(PermissionResult),
UserInput(Option<UserInputResponse>),
ToolResult(ToolResult),
Elicitation(ElicitationResult),
ExitPlanMode(ExitPlanModeResult),
AutoModeSwitch(AutoModeSwitchResponse),
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum PermissionResult {
Approved,
Denied,
Deferred,
Custom(serde_json::Value),
UserNotAvailable,
NoResult,
}
#[derive(Debug, Clone)]
pub struct UserInputResponse {
pub answer: String,
pub was_freeform: bool,
}
#[derive(Debug, Clone)]
pub struct ExitPlanModeResult {
pub approved: bool,
pub selected_action: Option<String>,
pub feedback: Option<String>,
}
impl Default for ExitPlanModeResult {
fn default() -> Self {
Self {
approved: true,
selected_action: None,
feedback: None,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AutoModeSwitchResponse {
Yes,
YesAlways,
No,
}
#[async_trait]
pub trait SessionHandler: Send + Sync + 'static {
async fn on_event(&self, event: HandlerEvent) -> HandlerResponse {
match event {
HandlerEvent::SessionEvent { session_id, event } => {
self.on_session_event(session_id, event).await;
HandlerResponse::Ok
}
HandlerEvent::PermissionRequest {
session_id,
request_id,
data,
} => HandlerResponse::Permission(
self.on_permission_request(session_id, request_id, data)
.await,
),
HandlerEvent::UserInput {
session_id,
question,
choices,
allow_freeform,
} => HandlerResponse::UserInput(
self.on_user_input(session_id, question, choices, allow_freeform)
.await,
),
HandlerEvent::ExternalTool { invocation } => {
HandlerResponse::ToolResult(self.on_external_tool(invocation).await)
}
HandlerEvent::ElicitationRequest {
session_id,
request_id,
request,
} => HandlerResponse::Elicitation(
self.on_elicitation(session_id, request_id, request).await,
),
HandlerEvent::ExitPlanMode { session_id, data } => {
HandlerResponse::ExitPlanMode(self.on_exit_plan_mode(session_id, data).await)
}
HandlerEvent::AutoModeSwitch {
session_id,
error_code,
retry_after_seconds,
} => HandlerResponse::AutoModeSwitch(
self.on_auto_mode_switch(session_id, error_code, retry_after_seconds)
.await,
),
}
}
async fn on_session_event(&self, _session_id: SessionId, _event: SessionEvent) {}
async fn on_permission_request(
&self,
_session_id: SessionId,
_request_id: RequestId,
_data: PermissionRequestData,
) -> PermissionResult {
PermissionResult::Denied
}
async fn on_user_input(
&self,
_session_id: SessionId,
_question: String,
_choices: Option<Vec<String>>,
_allow_freeform: Option<bool>,
) -> Option<UserInputResponse> {
None
}
async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult {
let msg = format!("No handler registered for tool '{}'", invocation.tool_name);
ToolResult::Expanded(crate::types::ToolResultExpanded {
text_result_for_llm: msg.clone(),
result_type: "failure".to_string(),
session_log: None,
error: Some(msg),
})
}
async fn on_elicitation(
&self,
_session_id: SessionId,
_request_id: RequestId,
_request: ElicitationRequest,
) -> ElicitationResult {
ElicitationResult {
action: "cancel".to_string(),
content: None,
}
}
async fn on_exit_plan_mode(
&self,
_session_id: SessionId,
_data: ExitPlanModeData,
) -> ExitPlanModeResult {
ExitPlanModeResult::default()
}
async fn on_auto_mode_switch(
&self,
_session_id: SessionId,
_error_code: Option<String>,
_retry_after_seconds: Option<u64>,
) -> AutoModeSwitchResponse {
AutoModeSwitchResponse::No
}
}
#[derive(Debug, Clone)]
pub struct ApproveAllHandler;
#[async_trait]
impl SessionHandler for ApproveAllHandler {
async fn on_permission_request(
&self,
_session_id: SessionId,
_request_id: RequestId,
_data: PermissionRequestData,
) -> PermissionResult {
PermissionResult::Approved
}
}
#[derive(Debug, Clone)]
pub struct DenyAllHandler;
#[async_trait]
impl SessionHandler for DenyAllHandler {
}
#[cfg(test)]
mod tests {
use serde_json::Value;
use super::*;
use crate::types::{PermissionRequestData, RequestId, SessionId};
fn perm_data() -> PermissionRequestData {
PermissionRequestData::default()
}
struct ApproveViaPerMethod;
#[async_trait]
impl SessionHandler for ApproveViaPerMethod {
async fn on_permission_request(
&self,
_: SessionId,
_: RequestId,
_: PermissionRequestData,
) -> PermissionResult {
PermissionResult::Approved
}
}
struct ApproveViaOnEvent;
#[async_trait]
impl SessionHandler for ApproveViaOnEvent {
async fn on_event(&self, event: HandlerEvent) -> HandlerResponse {
match event {
HandlerEvent::PermissionRequest { .. } => {
HandlerResponse::Permission(PermissionResult::Approved)
}
_ => HandlerResponse::Ok,
}
}
}
#[tokio::test]
async fn per_method_override_dispatches_via_default_on_event() {
let h = ApproveViaPerMethod;
let resp = h
.on_event(HandlerEvent::PermissionRequest {
session_id: SessionId::from("s1".to_string()),
request_id: RequestId::new("r1"),
data: perm_data(),
})
.await;
assert!(matches!(
resp,
HandlerResponse::Permission(PermissionResult::Approved)
));
}
#[tokio::test]
async fn on_event_override_short_circuits_per_method_defaults() {
let h = ApproveViaOnEvent;
let resp = h
.on_event(HandlerEvent::PermissionRequest {
session_id: SessionId::from("s1".to_string()),
request_id: RequestId::new("r1"),
data: perm_data(),
})
.await;
assert!(matches!(
resp,
HandlerResponse::Permission(PermissionResult::Approved)
));
}
#[tokio::test]
async fn deny_all_handler_uses_default_permission_deny() {
let h = DenyAllHandler;
let resp = h
.on_event(HandlerEvent::PermissionRequest {
session_id: SessionId::from("s1".to_string()),
request_id: RequestId::new("r1"),
data: perm_data(),
})
.await;
assert!(matches!(
resp,
HandlerResponse::Permission(PermissionResult::Denied)
));
}
#[tokio::test]
async fn default_on_external_tool_returns_failure() {
let h = DenyAllHandler;
let resp = h
.on_event(HandlerEvent::ExternalTool {
invocation: crate::types::ToolInvocation {
session_id: SessionId::from("s1".to_string()),
tool_call_id: "tc1".to_string(),
tool_name: "missing".to_string(),
arguments: Value::Null,
traceparent: None,
tracestate: None,
},
})
.await;
match resp {
HandlerResponse::ToolResult(crate::types::ToolResult::Expanded(exp)) => {
assert_eq!(exp.result_type, "failure");
assert!(exp.text_result_for_llm.contains("missing"));
assert_eq!(exp.error.as_deref(), Some(exp.text_result_for_llm.as_str()));
}
other => panic!("unexpected response: {other:?}"),
}
}
#[tokio::test]
async fn default_on_elicitation_returns_cancel() {
let h = DenyAllHandler;
let resp = h
.on_event(HandlerEvent::ElicitationRequest {
session_id: SessionId::from("s1".to_string()),
request_id: RequestId::new("r1"),
request: crate::types::ElicitationRequest {
message: "test".to_string(),
requested_schema: None,
mode: Some(crate::types::ElicitationMode::Form),
elicitation_source: None,
url: None,
},
})
.await;
match resp {
HandlerResponse::Elicitation(r) => assert_eq!(r.action, "cancel"),
other => panic!("unexpected response: {other:?}"),
}
}
}