1use std::{path::PathBuf, sync::Arc};
4
5use anyhow::Result;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{ClientCapabilities, ContentBlock, Error, ProtocolVersion, SessionId};
10
11pub trait Agent {
12 fn initialize(
13 &self,
14 arguments: InitializeRequest,
15 ) -> impl Future<Output = Result<InitializeResponse, Error>>;
16
17 fn authenticate(
18 &self,
19 arguments: AuthenticateRequest,
20 ) -> impl Future<Output = Result<(), Error>>;
21
22 fn new_session(
23 &self,
24 arguments: NewSessionRequest,
25 ) -> impl Future<Output = Result<NewSessionResponse, Error>>;
26
27 fn load_session(
28 &self,
29 arguments: LoadSessionRequest,
30 ) -> impl Future<Output = Result<(), Error>>;
31
32 fn prompt(
33 &self,
34 arguments: PromptRequest,
35 ) -> impl Future<Output = Result<PromptResponse, Error>>;
36
37 fn cancel(&self, args: CancelNotification) -> impl Future<Output = Result<(), Error>>;
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
43#[serde(rename_all = "camelCase")]
44pub struct InitializeRequest {
45 pub protocol_version: ProtocolVersion,
47 #[serde(default)]
49 pub client_capabilities: ClientCapabilities,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
53#[serde(rename_all = "camelCase")]
54pub struct InitializeResponse {
55 pub protocol_version: ProtocolVersion,
60 #[serde(default)]
62 pub agent_capabilities: AgentCapabilities,
63 #[serde(default)]
65 pub auth_methods: Vec<AuthMethod>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
71#[serde(rename_all = "camelCase")]
72pub struct AuthenticateRequest {
73 pub method_id: AuthMethodId,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
77#[serde(transparent)]
78pub struct AuthMethodId(pub Arc<str>);
79
80#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
81#[serde(rename_all = "camelCase")]
82pub struct AuthMethod {
83 pub id: AuthMethodId,
84 pub name: String,
85 pub description: Option<String>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
91#[serde(rename_all = "camelCase")]
92pub struct NewSessionRequest {
93 pub mcp_servers: Vec<McpServer>,
94 pub cwd: PathBuf,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
98#[serde(rename_all = "camelCase")]
99pub struct NewSessionResponse {
100 pub session_id: SessionId,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
106#[serde(rename_all = "camelCase")]
107pub struct LoadSessionRequest {
108 pub mcp_servers: Vec<McpServer>,
109 pub cwd: PathBuf,
110 pub session_id: SessionId,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
116#[serde(rename_all = "camelCase")]
117pub struct McpServer {
118 pub name: String,
119 pub command: PathBuf,
120 pub args: Vec<String>,
121 pub env: Vec<EnvVariable>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
125#[serde(rename_all = "camelCase")]
126pub struct EnvVariable {
127 pub name: String,
128 pub value: String,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
134#[serde(rename_all = "camelCase")]
135pub struct PromptRequest {
136 pub session_id: SessionId,
138 pub prompt: Vec<ContentBlock>,
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
157#[serde(rename_all = "camelCase")]
158pub struct PromptResponse {
159 pub stop_reason: StopReason,
160}
161
162#[derive(Debug, Copy, Clone, Eq, PartialEq, Serialize, Deserialize, JsonSchema)]
163#[serde(rename_all = "snake_case")]
164pub enum StopReason {
165 EndTurn,
167 MaxTokens,
169 MaxTurnRequests,
172 Refusal,
176 Cancelled,
178}
179
180#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
183#[serde(rename_all = "camelCase")]
184pub struct AgentCapabilities {
185 #[serde(default)]
187 pub load_session: bool,
188 #[serde(default)]
190 pub prompt_capabilities: PromptCapabilities,
191}
192
193#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
200#[serde(rename_all = "camelCase")]
201pub struct PromptCapabilities {
202 #[serde(default)]
204 pub image: bool,
205 #[serde(default)]
207 pub audio: bool,
208 #[serde(default)]
213 pub embedded_context: bool,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct AgentMethodNames {
220 pub initialize: &'static str,
221 pub authenticate: &'static str,
222 pub session_new: &'static str,
223 pub session_load: &'static str,
224 pub session_prompt: &'static str,
225 pub session_cancel: &'static str,
226}
227
228pub const AGENT_METHOD_NAMES: AgentMethodNames = AgentMethodNames {
229 initialize: INITIALIZE_METHOD_NAME,
230 authenticate: AUTHENTICATE_METHOD_NAME,
231 session_new: SESSION_NEW_METHOD_NAME,
232 session_load: SESSION_LOAD_METHOD_NAME,
233 session_prompt: SESSION_PROMPT_METHOD_NAME,
234 session_cancel: SESSION_CANCEL_METHOD_NAME,
235};
236
237pub const INITIALIZE_METHOD_NAME: &str = "initialize";
238pub const AUTHENTICATE_METHOD_NAME: &str = "authenticate";
239pub const SESSION_NEW_METHOD_NAME: &str = "session/new";
240pub const SESSION_LOAD_METHOD_NAME: &str = "session/load";
241pub const SESSION_PROMPT_METHOD_NAME: &str = "session/prompt";
242pub const SESSION_CANCEL_METHOD_NAME: &str = "session/cancel";
243
244#[derive(Debug, Serialize, Deserialize, JsonSchema)]
246#[serde(untagged)]
247pub enum ClientRequest {
248 InitializeRequest(InitializeRequest),
249 AuthenticateRequest(AuthenticateRequest),
250 NewSessionRequest(NewSessionRequest),
251 LoadSessionRequest(LoadSessionRequest),
252 PromptRequest(PromptRequest),
253}
254
255#[derive(Debug, Serialize, Deserialize, JsonSchema)]
257#[serde(untagged)]
258pub enum AgentResponse {
259 InitializeResponse(InitializeResponse),
260 AuthenticateResponse,
261 NewSessionResponse(NewSessionResponse),
262 LoadSessionResponse,
263 PromptResponse(PromptResponse),
264}
265
266#[derive(Debug, Serialize, Deserialize, JsonSchema)]
268#[serde(untagged)]
269pub enum ClientNotification {
270 CancelNotification(CancelNotification),
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
274#[serde(rename_all = "camelCase")]
275pub struct CancelNotification {
276 pub session_id: SessionId,
277}