1use std::{collections::HashMap, fmt::Display, ops::Deref, path::PathBuf};
2
3use derive_more::{Deref, Display, FromStr};
4use schemars::JsonSchema;
5use semver::Version;
6use serde::{Deserialize, Serialize};
7use serde_json::value::RawValue;
8
9#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
10pub struct Error {
11 pub code: i32,
12 pub message: String,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 pub data: Option<serde_json::Value>,
15}
16
17impl Error {
18 pub fn new(code: impl Into<(i32, String)>) -> Self {
19 let (code, message) = code.into();
20 Error {
21 code,
22 message,
23 data: None,
24 }
25 }
26
27 pub fn with_data(mut self, data: impl Into<serde_json::Value>) -> Self {
28 self.data = Some(data.into());
29 self
30 }
31
32 pub fn parse_error() -> Self {
34 Error::new(ErrorCode::PARSE_ERROR)
35 }
36
37 pub fn invalid_request() -> Self {
39 Error::new(ErrorCode::INVALID_REQUEST)
40 }
41
42 pub fn method_not_found() -> Self {
44 Error::new(ErrorCode::METHOD_NOT_FOUND)
45 }
46
47 pub fn invalid_params() -> Self {
49 Error::new(ErrorCode::INVALID_PARAMS)
50 }
51
52 pub fn internal_error() -> Self {
54 Error::new(ErrorCode::INTERNAL_ERROR)
55 }
56
57 pub fn into_internal_error(err: impl std::error::Error) -> Self {
58 Error::internal_error().with_data(err.to_string())
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
63pub struct ErrorCode {
64 code: i32,
65 message: &'static str,
66}
67
68impl ErrorCode {
69 pub const PARSE_ERROR: ErrorCode = ErrorCode {
70 code: -32700,
71 message: "Parse error",
72 };
73
74 pub const INVALID_REQUEST: ErrorCode = ErrorCode {
75 code: -32600,
76 message: "Invalid Request",
77 };
78
79 pub const METHOD_NOT_FOUND: ErrorCode = ErrorCode {
80 code: -32601,
81 message: "Method not found",
82 };
83
84 pub const INVALID_PARAMS: ErrorCode = ErrorCode {
85 code: -32602,
86 message: "Invalid params",
87 };
88
89 pub const INTERNAL_ERROR: ErrorCode = ErrorCode {
90 code: -32603,
91 message: "Internal error",
92 };
93}
94
95impl From<ErrorCode> for (i32, String) {
96 fn from(error_code: ErrorCode) -> Self {
97 (error_code.code, error_code.message.to_string())
98 }
99}
100
101impl From<ErrorCode> for Error {
102 fn from(error_code: ErrorCode) -> Self {
103 Error::new(error_code)
104 }
105}
106
107impl std::error::Error for Error {}
108
109impl Display for Error {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 if self.message.is_empty() {
112 write!(f, "{}", self.code)?;
113 } else {
114 write!(f, "{}", self.message)?;
115 }
116
117 if let Some(data) = &self.data {
118 write!(f, ": {data}")?;
119 }
120
121 Ok(())
122 }
123}
124
125impl From<anyhow::Error> for Error {
126 fn from(error: anyhow::Error) -> Self {
127 Error::into_internal_error(error.deref())
128 }
129}
130
131#[derive(Serialize, JsonSchema)]
132#[serde(rename_all = "camelCase")]
133pub struct Method {
134 pub name: &'static str,
135 pub request_type: &'static str,
136 pub param_payload: bool,
137 pub response_type: &'static str,
138 pub response_payload: bool,
139}
140
141pub trait AnyRequest: Serialize + Sized + std::fmt::Debug + 'static {
142 type Response: Serialize + 'static;
143 fn from_method_and_params(method: &str, params: &RawValue) -> Result<Self, Error>;
144 fn response_from_method_and_result(
145 method: &str,
146 params: &RawValue,
147 ) -> Result<Self::Response, Error>;
148}
149
150macro_rules! acp_peer {
151 (
152 $handler_trait_name:ident,
153 $request_trait_name:ident,
154 $request_enum_name:ident,
155 $response_enum_name:ident,
156 $method_map_name:ident,
157 $(($request_method:ident, $request_method_string:expr, $request_name:ident, $param_payload: tt, $response_name:ident, $response_payload: tt)),*
158 $(,)?
159 ) => {
160 macro_rules! handler_trait_call_req {
161 ($self: ident, $method: ident, false, $resp_name: ident, false, $params: ident) => {
162 {
163 $self.$method().await?;
164 Ok($response_enum_name::$resp_name($resp_name))
165 }
166 };
167 ($self: ident, $method: ident, false, $resp_name: ident, true, $params: ident) => {
168 {
169 let resp = $self.$method().await?;
170 Ok($response_enum_name::$resp_name(resp))
171 }
172 };
173 ($self: ident, $method: ident, true, $resp_name: ident, false, $params: ident) => {
174 {
175 $self.$method($params).await?;
176 Ok($response_enum_name::$resp_name($resp_name))
177 }
178 };
179 ($self: ident, $method: ident, true, $resp_name: ident, true, $params: ident) => {
180 {
181 let resp = $self.$method($params).await?;
182 Ok($response_enum_name::$resp_name(resp))
183 }
184 }
185 }
186
187 macro_rules! handler_trait_req_method {
188 ($method: ident, $req: ident, false, $resp: tt, false) => {
189 fn $method(&self) -> impl Future<Output = Result<(), Error>>;
190 };
191 ($method: ident, $req: ident, false, $resp: tt, true) => {
192 fn $method(&self) -> impl Future<Output = Result<$resp, Error>>;
193 };
194 ($method: ident, $req: ident, true, $resp: tt, false) => {
195 fn $method(&self, request: $req) -> impl Future<Output = Result<(), Error>>;
196 };
197 ($method: ident, $req: ident, true, $resp: tt, true) => {
198 fn $method(&self, request: $req) -> impl Future<Output = Result<$resp, Error>>;
199 }
200 }
201
202 pub trait $handler_trait_name {
203 fn call(&self, params: $request_enum_name) -> impl Future<Output = Result<$response_enum_name, Error>> {
204 async move {
205 match params {
206 $(#[allow(unused_variables)]
207 $request_enum_name::$request_name(params) => {
208 handler_trait_call_req!(self, $request_method, $param_payload, $response_name, $response_payload, params)
209 }),*
210 }
211 }
212 }
213
214 $(
215 handler_trait_req_method!($request_method, $request_name, $param_payload, $response_name, $response_payload);
216 )*
217 }
218
219 pub trait $request_trait_name {
220 type Response;
221 fn into_any(self) -> $request_enum_name;
222 fn response_from_any(any: $response_enum_name) -> Result<Self::Response, Error>;
223 }
224
225 #[derive(Serialize, JsonSchema, Debug)]
226 #[serde(untagged)]
227 pub enum $request_enum_name {
228 $(
229 $request_name($request_name),
230 )*
231 }
232
233 #[derive(Serialize, Deserialize, JsonSchema)]
234 #[serde(untagged)]
235 pub enum $response_enum_name {
236 $(
237 $response_name($response_name),
238 )*
239 }
240
241 macro_rules! request_from_method_and_params {
242 ($req_name: ident, false, $params: tt) => {
243 Ok($request_enum_name::$req_name($req_name))
244 };
245 ($req_name: ident, true, $params: tt) => {
246 match serde_json::from_str($params.get()) {
247 Ok(params) => Ok($request_enum_name::$req_name(params)),
248 Err(e) => Err(Error::parse_error().with_data(e.to_string())),
249 }
250 };
251 }
252
253 macro_rules! response_from_method_and_result {
254 ($resp_name: ident, false, $result: tt) => {
255 Ok($response_enum_name::$resp_name($resp_name))
256 };
257 ($resp_name: ident, true, $result: tt) => {
258 match serde_json::from_str($result.get()) {
259 Ok(result) => Ok($response_enum_name::$resp_name(result)),
260 Err(e) => Err(Error::parse_error().with_data(e.to_string())),
261 }
262 };
263 }
264
265 impl AnyRequest for $request_enum_name {
266 type Response = $response_enum_name;
267
268 fn from_method_and_params(method: &str, params: &RawValue) -> Result<Self, Error> {
269 match method {
270 $(
271 $request_method_string => {
272 request_from_method_and_params!($request_name, $param_payload, params)
273 }
274 )*
275 _ => Err(Error::method_not_found()),
276 }
277 }
278
279 fn response_from_method_and_result(method: &str, params: &RawValue) -> Result<Self::Response, Error> {
280 match method {
281 $(
282 $request_method_string => {
283 response_from_method_and_result!($response_name, $response_payload, params)
284 }
285 )*
286 _ => Err(Error::method_not_found()),
287 }
288 }
289 }
290
291 impl $request_enum_name {
292 pub fn method_name(&self) -> &'static str {
293 match self {
294 $(
295 $request_enum_name::$request_name(_) => $request_method_string,
296 )*
297 }
298 }
299 }
300
301
302
303 pub static $method_map_name: &[Method] = &[
304 $(
305 Method {
306 name: $request_method_string,
307 request_type: stringify!($request_name),
308 param_payload: $param_payload,
309 response_type: stringify!($response_name),
310 response_payload: $response_payload,
311 },
312 )*
313 ];
314
315 macro_rules! req_into_any {
316 ($self: ident, $req_name: ident, false) => {
317 $request_enum_name::$req_name($req_name)
318 };
319 ($self: ident, $req_name: ident, true) => {
320 $request_enum_name::$req_name($self)
321 };
322 }
323
324 macro_rules! resp_type {
325 ($resp_name: ident, false) => {
326 ()
327 };
328 ($resp_name: ident, true) => {
329 $resp_name
330 };
331 }
332
333 macro_rules! resp_from_any {
334 ($any: ident, $resp_name: ident, false) => {
335 match $any {
336 $response_enum_name::$resp_name(_) => Ok(()),
337 _ => Err(Error::internal_error().with_data("Unexpected Response"))
338 }
339 };
340 ($any: ident, $resp_name: ident, true) => {
341 match $any {
342 $response_enum_name::$resp_name(this) => Ok(this),
343 _ => Err(Error::internal_error().with_data("Unexpected Response"))
344 }
345 };
346 }
347
348 $(
349 impl $request_trait_name for $request_name {
350 type Response = resp_type!($response_name, $response_payload);
351
352 fn into_any(self) -> $request_enum_name {
353 req_into_any!(self, $request_name, $param_payload)
354 }
355
356 fn response_from_any(any: $response_enum_name) -> Result<Self::Response, Error> {
357 resp_from_any!(any, $response_name, $response_payload)
358 }
359 }
360 )*
361 };
362}
363
364acp_peer!(
366 Client,
367 ClientRequest,
368 AnyClientRequest,
369 AnyClientResult,
370 CLIENT_METHODS,
371 (
372 stream_assistant_message_chunk,
373 "streamAssistantMessageChunk",
374 StreamAssistantMessageChunkParams,
375 true,
376 StreamAssistantMessageChunkResponse,
377 false
378 ),
379 (
380 request_tool_call_confirmation,
381 "requestToolCallConfirmation",
382 RequestToolCallConfirmationParams,
383 true,
384 RequestToolCallConfirmationResponse,
385 true
386 ),
387 (
388 push_tool_call,
389 "pushToolCall",
390 PushToolCallParams,
391 true,
392 PushToolCallResponse,
393 true
394 ),
395 (
396 update_tool_call,
397 "updateToolCall",
398 UpdateToolCallParams,
399 true,
400 UpdateToolCallResponse,
401 false
402 ),
403 (
404 update_plan,
405 "updatePlan",
406 UpdatePlanParams,
407 true,
408 UpdatePlanResponse,
409 false
410 ),
411 (
412 write_text_file,
413 "writeTextFile",
414 WriteTextFileParams,
415 true,
416 WriteTextFileResponse,
417 false
418 ),
419 (
420 read_text_file,
421 "readTextFile",
422 ReadTextFileParams,
423 true,
424 ReadTextFileResponse,
425 true
426 ),
427);
428
429acp_peer!(
431 Agent,
432 AgentRequest,
433 AnyAgentRequest,
434 AnyAgentResult,
435 AGENT_METHODS,
436 (
437 initialize,
438 "initialize",
439 InitializeParams,
440 true,
441 InitializeResponse,
442 true
443 ),
444 (
445 authenticate,
446 "authenticate",
447 AuthenticateParams,
448 false,
449 AuthenticateResponse,
450 false
451 ),
452 (
453 send_user_message,
454 "sendUserMessage",
455 SendUserMessageParams,
456 true,
457 SendUserMessageResponse,
458 false
459 ),
460 (
461 cancel_send_message,
462 "cancelSendMessage",
463 CancelSendMessageParams,
464 false,
465 CancelSendMessageResponse,
466 false
467 )
468);
469
470#[derive(Debug, Serialize, Deserialize, JsonSchema)]
479#[serde(rename_all = "camelCase")]
480pub struct InitializeParams {
481 pub protocol_version: ProtocolVersion,
484 #[serde(skip_serializing_if = "HashMap::is_empty")]
487 pub context_servers: HashMap<String, ContextServer>,
488}
489
490#[derive(Debug, Serialize, Deserialize, JsonSchema)]
491#[serde(rename_all = "camelCase")]
492pub enum ContextServer {
493 Stdio {
494 command: String,
495 args: Vec<String>,
496 env: HashMap<String, String>,
497 },
498 Http {
499 url: String,
500 headers: HashMap<String, String>,
501 },
502}
503
504#[derive(Clone, Debug, Deref, Display, FromStr, Serialize, Deserialize, JsonSchema)]
505#[serde(transparent)]
506pub struct ProtocolVersion(Version);
507
508impl ProtocolVersion {
509 pub fn latest() -> Self {
510 Self(env!("CARGO_PKG_VERSION").parse().expect("Invalid version"))
511 }
512}
513
514#[derive(Debug, Serialize, Deserialize, JsonSchema)]
515#[serde(rename_all = "camelCase")]
516pub struct InitializeResponse {
517 pub protocol_version: ProtocolVersion,
521 pub is_authenticated: bool,
524}
525
526#[derive(Debug, Serialize, Deserialize, JsonSchema)]
535#[serde(rename_all = "camelCase")]
536pub struct AuthenticateParams;
537
538#[derive(Debug, Serialize, Deserialize, JsonSchema)]
539#[serde(rename_all = "camelCase")]
540pub struct AuthenticateResponse;
541
542#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
547#[serde(rename_all = "camelCase")]
548pub struct SendUserMessageParams {
549 pub chunks: Vec<UserMessageChunk>,
550}
551
552#[derive(Debug, Serialize, Deserialize, JsonSchema)]
553#[serde(rename_all = "camelCase")]
554pub struct SendUserMessageResponse;
555
556#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
558#[serde(untagged, rename_all = "camelCase")]
559pub enum UserMessageChunk {
560 Text { text: String },
562 Path { path: PathBuf },
564}
565
566#[derive(Debug, Serialize, Deserialize, JsonSchema)]
569#[serde(rename_all = "camelCase")]
570pub struct CancelSendMessageParams;
571
572#[derive(Debug, Serialize, Deserialize, JsonSchema)]
573#[serde(rename_all = "camelCase")]
574pub struct CancelSendMessageResponse;
575
576#[derive(Debug, Serialize, Deserialize, JsonSchema)]
580#[serde(rename_all = "camelCase")]
581pub struct StreamAssistantMessageChunkParams {
582 pub chunk: AssistantMessageChunk,
583}
584
585#[derive(Debug, Serialize, Deserialize, JsonSchema)]
586#[serde(rename_all = "camelCase")]
587pub struct StreamAssistantMessageChunkResponse;
588
589#[derive(Debug, Serialize, Deserialize, JsonSchema)]
590#[serde(untagged, rename_all = "camelCase")]
591pub enum AssistantMessageChunk {
592 Text { text: String },
593 Thought { thought: String },
594}
595
596#[derive(Debug, Serialize, Deserialize, JsonSchema)]
601#[serde(rename_all = "camelCase")]
602pub struct RequestToolCallConfirmationParams {
603 #[serde(flatten)]
604 pub tool_call: PushToolCallParams,
605 pub confirmation: ToolCallConfirmation,
606}
607
608#[derive(Debug, Serialize, Deserialize, JsonSchema)]
609#[serde(rename_all = "camelCase")]
610pub enum Icon {
613 FileSearch,
614 Folder,
615 Globe,
616 Hammer,
617 LightBulb,
618 Pencil,
619 Regex,
620 Terminal,
621}
622
623#[derive(Debug, Serialize, Deserialize, JsonSchema)]
624#[serde(tag = "type", rename_all = "camelCase")]
625pub enum ToolCallConfirmation {
626 #[serde(rename_all = "camelCase")]
627 Edit {
628 #[serde(skip_serializing_if = "Option::is_none")]
629 description: Option<String>,
630 },
631 #[serde(rename_all = "camelCase")]
632 Execute {
633 command: String,
634 root_command: String,
635 #[serde(skip_serializing_if = "Option::is_none")]
636 description: Option<String>,
637 },
638 #[serde(rename_all = "camelCase")]
639 Mcp {
640 server_name: String,
641 tool_name: String,
642 tool_display_name: String,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 description: Option<String>,
645 },
646 #[serde(rename_all = "camelCase")]
647 Fetch {
648 urls: Vec<String>,
649 #[serde(skip_serializing_if = "Option::is_none")]
650 description: Option<String>,
651 },
652 #[serde(rename_all = "camelCase")]
653 Other { description: String },
654}
655
656#[derive(Debug, Serialize, Deserialize, JsonSchema)]
657#[serde(tag = "type", rename_all = "camelCase")]
658pub struct RequestToolCallConfirmationResponse {
659 pub id: ToolCallId,
660 pub outcome: ToolCallConfirmationOutcome,
661}
662
663#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
664#[serde(rename_all = "camelCase")]
665pub enum ToolCallConfirmationOutcome {
666 Allow,
668 AlwaysAllow,
670 AlwaysAllowMcpServer,
672 AlwaysAllowTool,
674 Reject,
676 Cancel,
678}
679
680#[derive(Debug, Serialize, Deserialize, JsonSchema)]
686#[serde(rename_all = "camelCase")]
687pub struct PushToolCallParams {
688 pub label: String,
689 pub icon: Icon,
690 #[serde(skip_serializing_if = "Option::is_none")]
691 pub content: Option<ToolCallContent>,
692 #[serde(default, skip_serializing_if = "Vec::is_empty")]
693 pub locations: Vec<ToolCallLocation>,
694}
695
696#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
697#[serde(tag = "type", rename_all = "camelCase")]
698pub struct ToolCallLocation {
699 pub path: PathBuf,
700 #[serde(skip_serializing_if = "Option::is_none")]
701 pub line: Option<u32>,
702}
703
704#[derive(Debug, Serialize, Deserialize, JsonSchema)]
705#[serde(tag = "type", rename_all = "camelCase")]
706pub struct PushToolCallResponse {
707 pub id: ToolCallId,
708}
709
710#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, Eq, PartialEq, Hash)]
711#[serde(rename_all = "camelCase")]
712pub struct ToolCallId(pub u64);
713
714#[derive(Debug, Serialize, Deserialize, JsonSchema)]
721#[serde(rename_all = "camelCase")]
722pub struct UpdateToolCallParams {
723 pub tool_call_id: ToolCallId,
724 pub status: ToolCallStatus,
725 pub content: Option<ToolCallContent>,
726}
727
728#[derive(Debug, Serialize, Deserialize, JsonSchema)]
729pub struct UpdateToolCallResponse;
730
731#[derive(Debug, Serialize, Deserialize, JsonSchema)]
736#[serde(rename_all = "camelCase")]
737pub struct UpdatePlanParams {
738 pub entries: Vec<PlanEntry>,
740}
741
742#[derive(Debug, Serialize, Deserialize, JsonSchema)]
743pub struct UpdatePlanResponse;
744
745#[derive(Debug, Serialize, Deserialize, JsonSchema)]
750#[serde(rename_all = "camelCase")]
751pub struct PlanEntry {
752 pub content: String,
754 pub priority: PlanEntryPriority,
756 pub status: PlanEntryStatus,
758}
759
760#[derive(Deserialize, Serialize, JsonSchema, Debug)]
765#[serde(rename_all = "snake_case")]
766pub enum PlanEntryPriority {
767 High,
768 Medium,
769 Low,
770}
771
772#[derive(Deserialize, Serialize, JsonSchema, Debug)]
776#[serde(rename_all = "snake_case")]
777pub enum PlanEntryStatus {
778 Pending,
779 InProgress,
780 Completed,
781}
782
783#[derive(Debug, Serialize, Deserialize, JsonSchema)]
784#[serde(rename_all = "camelCase")]
785pub enum ToolCallStatus {
786 Running,
788 Finished,
790 Error,
792}
793
794#[derive(Debug, Serialize, Deserialize, JsonSchema)]
795#[serde(tag = "type", rename_all = "camelCase")]
796pub enum ToolCallContent {
797 #[serde(rename_all = "camelCase")]
798 Markdown { markdown: String },
799 #[serde(rename_all = "camelCase")]
800 Diff {
801 #[serde(flatten)]
802 diff: Diff,
803 },
804}
805
806#[derive(Debug, Serialize, Deserialize, JsonSchema)]
807#[serde(rename_all = "camelCase")]
808pub struct Diff {
809 pub path: PathBuf,
810 pub old_text: Option<String>,
811 pub new_text: String,
812}
813
814#[derive(Debug, Serialize, Deserialize, JsonSchema)]
815#[serde(rename_all = "camelCase")]
816pub struct WriteTextFileParams {
817 pub path: PathBuf,
818 pub content: String,
819}
820
821#[derive(Debug, Serialize, Deserialize, JsonSchema)]
822pub struct WriteTextFileResponse;
823
824#[derive(Debug, Serialize, Deserialize, JsonSchema)]
825#[serde(rename_all = "camelCase")]
826pub struct ReadTextFileParams {
827 pub path: PathBuf,
828 #[serde(skip_serializing_if = "Option::is_none")]
829 pub line: Option<u32>,
830 #[serde(skip_serializing_if = "Option::is_none")]
831 pub limit: Option<u32>,
832}
833
834#[derive(Debug, Serialize, Deserialize, JsonSchema)]
835pub struct ReadTextFileResponse {
836 pub content: String,
837}