agent_client_protocol/
agent.rs

1//! Methods and notifications the agent handles/receives
2
3use std::{path::PathBuf, sync::Arc};
4
5use anyhow::Result;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{ClientCapabilities, ContentBlock, Error, ProtocolVersion, SessionId};
10
11pub trait Agent {
12    fn initialize(
13        &self,
14        arguments: InitializeRequest,
15    ) -> impl Future<Output = Result<InitializeResponse, Error>>;
16
17    fn authenticate(
18        &self,
19        arguments: AuthenticateRequest,
20    ) -> impl Future<Output = Result<(), Error>>;
21
22    fn new_session(
23        &self,
24        arguments: NewSessionRequest,
25    ) -> impl Future<Output = Result<NewSessionResponse, Error>>;
26
27    fn load_session(
28        &self,
29        arguments: LoadSessionRequest,
30    ) -> impl Future<Output = Result<LoadSessionResponse, Error>>;
31
32    fn prompt(&self, arguments: PromptRequest) -> impl Future<Output = Result<(), Error>>;
33
34    fn cancelled(&self, args: CancelledNotification) -> impl Future<Output = Result<(), Error>>;
35}
36
37// Initialize
38
39#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(rename_all = "camelCase")]
41pub struct InitializeRequest {
42    /// The latest protocol version supported by the client
43    pub protocol_version: ProtocolVersion,
44    /// Capabilities supported by the client
45    #[serde(default)]
46    pub client_capabilities: ClientCapabilities,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50#[serde(rename_all = "camelCase")]
51pub struct InitializeResponse {
52    /// The protocol version the client specified if supported by the agent,
53    /// or the latest protocol version supported by the agent.
54    ///
55    /// The client should disconnect, if it doesn't support this version.
56    pub protocol_version: ProtocolVersion,
57    /// Capabilities supported by the agent
58    #[serde(default)]
59    pub agent_capabilities: AgentCapabilities,
60    /// Authentication methods supported by the agent
61    #[serde(default)]
62    pub auth_methods: Vec<AuthMethod>,
63}
64
65// Authenticatication
66
67#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
68#[serde(rename_all = "camelCase")]
69pub struct AuthenticateRequest {
70    pub method_id: AuthMethodId,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
74#[serde(transparent)]
75pub struct AuthMethodId(pub Arc<str>);
76
77#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
78#[serde(rename_all = "camelCase")]
79pub struct AuthMethod {
80    pub id: AuthMethodId,
81    pub label: String,
82    pub description: Option<String>,
83}
84
85// New session
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
88#[serde(rename_all = "camelCase")]
89pub struct NewSessionRequest {
90    pub mcp_servers: Vec<McpServer>,
91    pub cwd: PathBuf,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
95#[serde(rename_all = "camelCase")]
96pub struct NewSessionResponse {
97    /// The session id if one was created, or null if authentication is required
98    // Note: It'd be nicer to use an enum here, but MCP requires the output schema
99    // to be a non-union object and adding another level seemed impractical.
100    pub session_id: Option<SessionId>,
101}
102
103// Load session
104
105#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
106#[serde(rename_all = "camelCase")]
107pub struct LoadSessionRequest {
108    pub mcp_servers: Vec<McpServer>,
109    pub cwd: PathBuf,
110    pub session_id: SessionId,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
114#[serde(rename_all = "camelCase")]
115pub struct LoadSessionResponse {
116    pub auth_required: bool,
117    #[serde(default)]
118    pub auth_methods: Vec<AuthMethod>,
119}
120
121// MCP
122
123#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
124#[serde(rename_all = "camelCase")]
125pub struct McpServer {
126    pub name: String,
127    pub command: PathBuf,
128    pub args: Vec<String>,
129    pub env: Vec<EnvVariable>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
133#[serde(rename_all = "camelCase")]
134pub struct EnvVariable {
135    pub name: String,
136    pub value: String,
137}
138
139// Prompt
140
141#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
142#[serde(rename_all = "camelCase")]
143pub struct PromptRequest {
144    pub session_id: SessionId,
145    pub prompt: Vec<ContentBlock>,
146}
147
148// Capabilities
149
150#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
151#[serde(rename_all = "camelCase")]
152pub struct AgentCapabilities {
153    /// Agent supports `session/load`
154    #[serde(default)]
155    load_session: bool,
156}
157
158// Method schema
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct AgentMethodNames {
162    pub initialize: &'static str,
163    pub authenticate: &'static str,
164    pub session_new: &'static str,
165    pub session_load: &'static str,
166    pub session_prompt: &'static str,
167    pub session_cancelled: &'static str,
168}
169
170pub const AGENT_METHOD_NAMES: AgentMethodNames = AgentMethodNames {
171    initialize: INITIALIZE_METHOD_NAME,
172    authenticate: AUTHENTICATE_METHOD_NAME,
173    session_new: SESSION_NEW_METHOD_NAME,
174    session_load: SESSION_LOAD_METHOD_NAME,
175    session_prompt: SESSION_PROMPT_METHOD_NAME,
176    session_cancelled: SESSION_CANCELLED_METHOD_NAME,
177};
178
179pub const INITIALIZE_METHOD_NAME: &str = "initialize";
180pub const AUTHENTICATE_METHOD_NAME: &str = "authenticate";
181pub const SESSION_NEW_METHOD_NAME: &str = "session/new";
182pub const SESSION_LOAD_METHOD_NAME: &str = "session/load";
183pub const SESSION_PROMPT_METHOD_NAME: &str = "session/prompt";
184pub const SESSION_CANCELLED_METHOD_NAME: &str = "session/cancelled";
185
186/// Requests the client sends to the agent
187#[derive(Debug, Serialize, Deserialize, JsonSchema)]
188#[serde(untagged)]
189pub enum ClientRequest {
190    InitializeRequest(InitializeRequest),
191    AuthenticateRequest(AuthenticateRequest),
192    NewSessionRequest(NewSessionRequest),
193    LoadSessionRequest(LoadSessionRequest),
194    PromptRequest(PromptRequest),
195}
196
197/// Responses the agent sends to the client
198#[derive(Debug, Serialize, Deserialize, JsonSchema)]
199#[serde(untagged)]
200pub enum AgentResponse {
201    InitializeResponse(InitializeResponse),
202    AuthenticateResponse,
203    NewSessionResponse(NewSessionResponse),
204    LoadSessionResponse(LoadSessionResponse),
205    PromptResponse,
206}
207
208/// Notifications the client sends to the agent
209#[derive(Debug, Serialize, Deserialize, JsonSchema)]
210#[serde(untagged)]
211pub enum ClientNotification {
212    CancelledNotification(CancelledNotification),
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(rename_all = "camelCase")]
217pub struct CancelledNotification {
218    pub session_id: SessionId,
219}