1use std::{path::PathBuf, sync::Arc};
4
5use anyhow::Result;
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8
9use crate::{ClientCapabilities, ContentBlock, Error, ProtocolVersion, SessionId};
10
11pub trait Agent {
12 fn initialize(
13 &self,
14 arguments: InitializeRequest,
15 ) -> impl Future<Output = Result<InitializeResponse, Error>>;
16
17 fn authenticate(
18 &self,
19 arguments: AuthenticateRequest,
20 ) -> impl Future<Output = Result<(), Error>>;
21
22 fn new_session(
23 &self,
24 arguments: NewSessionRequest,
25 ) -> impl Future<Output = Result<NewSessionResponse, Error>>;
26
27 fn load_session(
28 &self,
29 arguments: LoadSessionRequest,
30 ) -> impl Future<Output = Result<LoadSessionResponse, Error>>;
31
32 fn prompt(&self, arguments: PromptRequest) -> impl Future<Output = Result<(), Error>>;
33
34 fn cancelled(&self, args: CancelledNotification) -> impl Future<Output = Result<(), Error>>;
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
40#[serde(rename_all = "camelCase")]
41pub struct InitializeRequest {
42 pub protocol_version: ProtocolVersion,
44 #[serde(default)]
46 pub client_capabilities: ClientCapabilities,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
50#[serde(rename_all = "camelCase")]
51pub struct InitializeResponse {
52 pub protocol_version: ProtocolVersion,
57 #[serde(default)]
59 pub agent_capabilities: AgentCapabilities,
60 #[serde(default)]
62 pub auth_methods: Vec<AuthMethod>,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
68#[serde(rename_all = "camelCase")]
69pub struct AuthenticateRequest {
70 pub method_id: AuthMethodId,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
74#[serde(transparent)]
75pub struct AuthMethodId(pub Arc<str>);
76
77#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
78#[serde(rename_all = "camelCase")]
79pub struct AuthMethod {
80 pub id: AuthMethodId,
81 pub label: String,
82 pub description: Option<String>,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
88#[serde(rename_all = "camelCase")]
89pub struct NewSessionRequest {
90 pub mcp_servers: Vec<McpServer>,
91 pub cwd: PathBuf,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
95#[serde(rename_all = "camelCase")]
96pub struct NewSessionResponse {
97 pub session_id: Option<SessionId>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
106#[serde(rename_all = "camelCase")]
107pub struct LoadSessionRequest {
108 pub mcp_servers: Vec<McpServer>,
109 pub cwd: PathBuf,
110 pub session_id: SessionId,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
114#[serde(rename_all = "camelCase")]
115pub struct LoadSessionResponse {
116 pub auth_required: bool,
117 #[serde(default)]
118 pub auth_methods: Vec<AuthMethod>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
124#[serde(rename_all = "camelCase")]
125pub struct McpServer {
126 pub name: String,
127 pub command: PathBuf,
128 pub args: Vec<String>,
129 pub env: Vec<EnvVariable>,
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
133#[serde(rename_all = "camelCase")]
134pub struct EnvVariable {
135 pub name: String,
136 pub value: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
142#[serde(rename_all = "camelCase")]
143pub struct PromptRequest {
144 pub session_id: SessionId,
145 pub prompt: Vec<ContentBlock>,
146}
147
148#[derive(Default, Debug, Clone, Serialize, Deserialize, JsonSchema)]
151#[serde(rename_all = "camelCase")]
152pub struct AgentCapabilities {
153 #[serde(default)]
155 load_session: bool,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct AgentMethodNames {
162 pub initialize: &'static str,
163 pub authenticate: &'static str,
164 pub session_new: &'static str,
165 pub session_load: &'static str,
166 pub session_prompt: &'static str,
167 pub session_cancelled: &'static str,
168}
169
170pub const AGENT_METHOD_NAMES: AgentMethodNames = AgentMethodNames {
171 initialize: INITIALIZE_METHOD_NAME,
172 authenticate: AUTHENTICATE_METHOD_NAME,
173 session_new: SESSION_NEW_METHOD_NAME,
174 session_load: SESSION_LOAD_METHOD_NAME,
175 session_prompt: SESSION_PROMPT_METHOD_NAME,
176 session_cancelled: SESSION_CANCELLED_METHOD_NAME,
177};
178
179pub const INITIALIZE_METHOD_NAME: &str = "initialize";
180pub const AUTHENTICATE_METHOD_NAME: &str = "authenticate";
181pub const SESSION_NEW_METHOD_NAME: &str = "session/new";
182pub const SESSION_LOAD_METHOD_NAME: &str = "session/load";
183pub const SESSION_PROMPT_METHOD_NAME: &str = "session/prompt";
184pub const SESSION_CANCELLED_METHOD_NAME: &str = "session/cancelled";
185
186#[derive(Debug, Serialize, Deserialize, JsonSchema)]
188#[serde(untagged)]
189pub enum ClientRequest {
190 InitializeRequest(InitializeRequest),
191 AuthenticateRequest(AuthenticateRequest),
192 NewSessionRequest(NewSessionRequest),
193 LoadSessionRequest(LoadSessionRequest),
194 PromptRequest(PromptRequest),
195}
196
197#[derive(Debug, Serialize, Deserialize, JsonSchema)]
199#[serde(untagged)]
200pub enum AgentResponse {
201 InitializeResponse(InitializeResponse),
202 AuthenticateResponse,
203 NewSessionResponse(NewSessionResponse),
204 LoadSessionResponse(LoadSessionResponse),
205 PromptResponse,
206}
207
208#[derive(Debug, Serialize, Deserialize, JsonSchema)]
210#[serde(untagged)]
211pub enum ClientNotification {
212 CancelledNotification(CancelledNotification),
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
216#[serde(rename_all = "camelCase")]
217pub struct CancelledNotification {
218 pub session_id: SessionId,
219}