use std::{path::PathBuf, sync::Arc};
use anyhow::Result;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{ClientCapabilities, ContentBlock, Error, ProtocolVersion, SessionId};
pub trait Agent {
fn initialize(
&self,
arguments: InitializeRequest,
) -> impl Future<Output = Result<InitializeResponse, Error>>;
fn authenticate(
&self,
arguments: AuthenticateRequest,
) -> impl Future<Output = Result<(), Error>>;
fn new_session(
&self,
arguments: NewSessionRequest,
) -> impl Future<Output = Result<NewSessionResponse, Error>>;
fn load_session(
&self,
arguments: LoadSessionRequest,
) -> impl Future<Output = Result<(), Error>>;
fn prompt(
&self,
arguments: PromptRequest,
) -> impl Future<Output = Result<PromptResponse, Error>>;
fn cancel(&self, args: CancelNotification) -> impl Future<Output = Result<(), Error>>;
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "initialize"))]
#[serde(rename_all = "camelCase")]
pub struct InitializeRequest {
pub protocol_version: ProtocolVersion,
#[serde(default)]
pub client_capabilities: ClientCapabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "initialize"))]
#[serde(rename_all = "camelCase")]
pub struct InitializeResponse {
pub protocol_version: ProtocolVersion,
#[serde(default)]
pub agent_capabilities: AgentCapabilities,
#[serde(default)]
pub auth_methods: Vec<AuthMethod>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "authenticate"))]
#[serde(rename_all = "camelCase")]
pub struct AuthenticateRequest {
pub method_id: AuthMethodId,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
#[serde(transparent)]
pub struct AuthMethodId(pub Arc<str>);
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct AuthMethod {
pub id: AuthMethodId,
pub name: String,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "session/new"))]
#[serde(rename_all = "camelCase")]
pub struct NewSessionRequest {
pub cwd: PathBuf,
pub mcp_servers: Vec<McpServer>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "session/new"))]
#[serde(rename_all = "camelCase")]
pub struct NewSessionResponse {
pub session_id: SessionId,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "session/load"))]
#[serde(rename_all = "camelCase")]
pub struct LoadSessionRequest {
pub mcp_servers: Vec<McpServer>,
pub cwd: PathBuf,
pub session_id: SessionId,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct McpServer {
pub name: String,
pub command: PathBuf,
pub args: Vec<String>,
pub env: Vec<EnvVariable>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct EnvVariable {
pub name: String,
pub value: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "session/prompt"))]
#[serde(rename_all = "camelCase")]
pub struct PromptRequest {
pub session_id: SessionId,
pub prompt: Vec<ContentBlock>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "session/prompt"))]
#[serde(rename_all = "camelCase")]
pub struct PromptResponse {
pub stop_reason: StopReason,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum StopReason {
EndTurn,
MaxTokens,
MaxTurnRequests,
Refusal,
Cancelled,
}
#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct AgentCapabilities {
#[serde(default)]
pub load_session: bool,
#[serde(default)]
pub prompt_capabilities: PromptCapabilities,
}
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub struct PromptCapabilities {
#[serde(default)]
pub image: bool,
#[serde(default)]
pub audio: bool,
#[serde(default)]
pub embedded_context: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentMethodNames {
pub initialize: &'static str,
pub authenticate: &'static str,
pub session_new: &'static str,
pub session_load: &'static str,
pub session_prompt: &'static str,
pub session_cancel: &'static str,
}
pub const AGENT_METHOD_NAMES: AgentMethodNames = AgentMethodNames {
initialize: INITIALIZE_METHOD_NAME,
authenticate: AUTHENTICATE_METHOD_NAME,
session_new: SESSION_NEW_METHOD_NAME,
session_load: SESSION_LOAD_METHOD_NAME,
session_prompt: SESSION_PROMPT_METHOD_NAME,
session_cancel: SESSION_CANCEL_METHOD_NAME,
};
pub(crate) const INITIALIZE_METHOD_NAME: &str = "initialize";
pub(crate) const AUTHENTICATE_METHOD_NAME: &str = "authenticate";
pub(crate) const SESSION_NEW_METHOD_NAME: &str = "session/new";
pub(crate) const SESSION_LOAD_METHOD_NAME: &str = "session/load";
pub(crate) const SESSION_PROMPT_METHOD_NAME: &str = "session/prompt";
pub(crate) const SESSION_CANCEL_METHOD_NAME: &str = "session/cancel";
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[schemars(extend("x-docs-ignore" = true))]
pub enum ClientRequest {
InitializeRequest(InitializeRequest),
AuthenticateRequest(AuthenticateRequest),
NewSessionRequest(NewSessionRequest),
LoadSessionRequest(LoadSessionRequest),
PromptRequest(PromptRequest),
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[schemars(extend("x-docs-ignore" = true))]
pub enum AgentResponse {
InitializeResponse(InitializeResponse),
AuthenticateResponse,
NewSessionResponse(NewSessionResponse),
LoadSessionResponse,
PromptResponse(PromptResponse),
}
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(untagged)]
#[schemars(extend("x-docs-ignore" = true))]
pub enum ClientNotification {
CancelNotification(CancelNotification),
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[schemars(extend("x-side" = "agent", "x-method" = "session/cancel"))]
#[serde(rename_all = "camelCase")]
pub struct CancelNotification {
pub session_id: SessionId,
}