1use agent_client_protocol::{AuthMethod, ExtNotification};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8use rmcp::model::ElicitationSchema;
9use serde::{Deserialize, Serialize};
10use serde_json::value::to_raw_value;
11use std::fmt;
12use std::sync::Arc;
13
14pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
15
16pub const SUB_AGENT_PROGRESS_METHOD: &str = "_aether/sub_agent_progress";
19pub const CONTEXT_USAGE_METHOD: &str = "_aether/context_usage";
20pub const CONTEXT_CLEARED_METHOD: &str = "_aether/context_cleared";
21pub const MCP_MESSAGE_METHOD: &str = "_aether/mcp";
22pub const AUTH_METHODS_UPDATED_METHOD: &str = "_aether/auth_methods_updated";
23
24pub const ELICITATION_METHOD: &str = "aether/elicitation";
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ContextUsageParams {
32 pub usage_ratio: Option<f64>,
33 pub tokens_used: u32,
34 pub context_limit: Option<u32>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
39pub struct ContextClearedParams {}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub struct AuthMethodsUpdatedParams {
44 pub auth_methods: Vec<AuthMethod>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ElicitationParams {
50 pub message: String,
51 pub schema: ElicitationSchema,
52}
53
54pub use rmcp::model::ElicitationAction;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ElicitationResponse {
59 pub action: ElicitationAction,
60 pub content: Option<serde_json::Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub enum McpNotification {
67 ServerStatus { servers: Vec<McpServerStatusEntry> },
68}
69
70impl From<McpNotification> for ExtNotification {
71 fn from(msg: McpNotification) -> Self {
72 ext_notification(MCP_MESSAGE_METHOD, &msg)
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
78pub enum McpRequest {
79 Authenticate {
80 session_id: String,
81 server_name: String,
82 },
83}
84
85impl From<McpRequest> for ExtNotification {
86 fn from(msg: McpRequest) -> Self {
87 ext_notification(MCP_MESSAGE_METHOD, &msg)
88 }
89}
90
91#[derive(Debug)]
93pub enum ExtNotificationParseError {
94 WrongMethod,
95 InvalidJson(serde_json::Error),
96}
97
98impl fmt::Display for ExtNotificationParseError {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Self::WrongMethod => write!(f, "notification method is not {MCP_MESSAGE_METHOD}"),
102 Self::InvalidJson(e) => write!(f, "invalid JSON params: {e}"),
103 }
104 }
105}
106
107impl TryFrom<&ExtNotification> for McpRequest {
108 type Error = ExtNotificationParseError;
109
110 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
111 if n.method.as_ref() != MCP_MESSAGE_METHOD {
112 return Err(ExtNotificationParseError::WrongMethod);
113 }
114 serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
115 }
116}
117
118impl TryFrom<&ExtNotification> for McpNotification {
119 type Error = ExtNotificationParseError;
120
121 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
122 if n.method.as_ref() != MCP_MESSAGE_METHOD {
123 return Err(ExtNotificationParseError::WrongMethod);
124 }
125 serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
126 }
127}
128
129impl TryFrom<&ExtNotification> for AuthMethodsUpdatedParams {
130 type Error = ExtNotificationParseError;
131
132 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
133 if n.method.as_ref() != AUTH_METHODS_UPDATED_METHOD {
134 return Err(ExtNotificationParseError::WrongMethod);
135 }
136 serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
137 }
138}
139
140fn ext_notification<T: Serialize>(method: &str, params: &T) -> ExtNotification {
141 let raw_value = to_raw_value(params).expect("notification params are serializable");
142 ExtNotification::new(method, Arc::from(raw_value))
143}
144
145impl From<ContextUsageParams> for ExtNotification {
146 fn from(params: ContextUsageParams) -> Self {
147 ext_notification(CONTEXT_USAGE_METHOD, ¶ms)
148 }
149}
150
151impl From<ContextClearedParams> for ExtNotification {
152 fn from(params: ContextClearedParams) -> Self {
153 ext_notification(CONTEXT_CLEARED_METHOD, ¶ms)
154 }
155}
156
157impl From<AuthMethodsUpdatedParams> for ExtNotification {
158 fn from(params: AuthMethodsUpdatedParams) -> Self {
159 ext_notification(AUTH_METHODS_UPDATED_METHOD, ¶ms)
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct SubAgentProgressParams {
168 pub parent_tool_id: String,
169 pub task_id: String,
170 pub agent_name: String,
171 pub event: SubAgentEvent,
172}
173
174impl From<SubAgentProgressParams> for ExtNotification {
175 fn from(params: SubAgentProgressParams) -> Self {
176 ext_notification(SUB_AGENT_PROGRESS_METHOD, ¶ms)
177 }
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
185pub enum SubAgentEvent {
186 ToolCall { request: SubAgentToolRequest },
187 ToolCallUpdate { update: SubAgentToolCallUpdate },
188 ToolResult { result: SubAgentToolResult },
189 ToolError { error: SubAgentToolError },
190 Done,
191 Other,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SubAgentToolRequest {
196 pub id: String,
197 pub name: String,
198 pub arguments: String,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct SubAgentToolCallUpdate {
203 pub id: String,
204 pub chunk: String,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct SubAgentToolResult {
209 pub id: String,
210 pub name: String,
211 pub result_meta: Option<ToolResultMeta>,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct SubAgentToolError {
216 pub id: String,
217 pub name: String,
218}
219
220#[cfg(test)]
221mod tests {
222 use agent_client_protocol::AuthMethodAgent;
223 use serde_json::from_str;
224
225 use super::*;
226
227 #[test]
228 fn method_constants_have_underscore_prefix() {
229 assert!(SUB_AGENT_PROGRESS_METHOD.starts_with('_'));
230 assert!(CONTEXT_USAGE_METHOD.starts_with('_'));
231 assert!(CONTEXT_CLEARED_METHOD.starts_with('_'));
232 assert!(MCP_MESSAGE_METHOD.starts_with('_'));
233 assert!(AUTH_METHODS_UPDATED_METHOD.starts_with('_'));
234 }
235
236 #[test]
237 fn mcp_request_authenticate_roundtrip() {
238 let msg = McpRequest::Authenticate {
239 session_id: "session-0".to_string(),
240 server_name: "my oauth server".to_string(),
241 };
242
243 let notification: ExtNotification = msg.clone().into();
244 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
245
246 let parsed: McpRequest =
247 serde_json::from_str(notification.params.get()).expect("valid JSON");
248 assert_eq!(parsed, msg);
249 }
250
251 #[test]
252 fn mcp_notification_server_status_roundtrip() {
253 let msg = McpNotification::ServerStatus {
254 servers: vec![
255 McpServerStatusEntry {
256 name: "github".to_string(),
257 status: McpServerStatus::Connected { tool_count: 5 },
258 },
259 McpServerStatusEntry {
260 name: "linear".to_string(),
261 status: McpServerStatus::NeedsOAuth,
262 },
263 McpServerStatusEntry {
264 name: "slack".to_string(),
265 status: McpServerStatus::Failed {
266 error: "connection timeout".to_string(),
267 },
268 },
269 ],
270 };
271
272 let notification: ExtNotification = msg.clone().into();
273 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
274
275 let parsed: McpNotification =
276 serde_json::from_str(notification.params.get()).expect("valid JSON");
277 assert_eq!(parsed, msg);
278 }
279
280 #[test]
281 fn auth_methods_updated_params_roundtrip() {
282 let params = AuthMethodsUpdatedParams {
283 auth_methods: vec![
284 AuthMethod::Agent(
285 AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
286 ),
287 AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
288 ],
289 };
290
291 let notification: ExtNotification = params.clone().into();
292 let parsed: AuthMethodsUpdatedParams =
293 from_str(notification.params.get()).expect("valid JSON");
294
295 assert_eq!(parsed, params);
296 assert_eq!(notification.method.as_ref(), AUTH_METHODS_UPDATED_METHOD);
297 }
298
299 #[test]
300 fn mcp_server_status_entry_serde_roundtrip() {
301 let entry = McpServerStatusEntry {
302 name: "test-server".to_string(),
303 status: McpServerStatus::Connected { tool_count: 3 },
304 };
305
306 let json = serde_json::to_string(&entry).unwrap();
307 let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
308 assert_eq!(parsed, entry);
309 }
310
311 #[test]
312 fn elicitation_params_roundtrip() {
313 use rmcp::model::EnumSchema;
314
315 let params = ElicitationParams {
316 message: "Pick a color".to_string(),
317 schema: ElicitationSchema::builder()
318 .required_enum_schema(
319 "color",
320 EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()])
321 .untitled()
322 .build(),
323 )
324 .build()
325 .unwrap(),
326 };
327
328 let json = serde_json::to_string(¶ms).unwrap();
329 let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
330 assert_eq!(parsed, params);
331 }
332
333 #[test]
334 fn context_usage_params_roundtrip() {
335 let params = ContextUsageParams {
336 usage_ratio: Some(0.75),
337 tokens_used: 75000,
338 context_limit: Some(100_000),
339 };
340
341 let notification: ExtNotification = params.clone().into();
342 assert_eq!(notification.method.as_ref(), CONTEXT_USAGE_METHOD);
343
344 let parsed: ContextUsageParams =
345 serde_json::from_str(notification.params.get()).expect("valid JSON");
346 assert_eq!(parsed, params);
347 }
348
349 #[test]
350 fn context_cleared_params_roundtrip() {
351 let params = ContextClearedParams::default();
352
353 let notification: ExtNotification = params.clone().into();
354 assert_eq!(notification.method.as_ref(), CONTEXT_CLEARED_METHOD);
355
356 let parsed: ContextClearedParams =
357 serde_json::from_str(notification.params.get()).expect("valid JSON");
358 assert_eq!(parsed, params);
359 }
360
361 #[test]
362 fn sub_agent_progress_params_roundtrip() {
363 let params = SubAgentProgressParams {
364 parent_tool_id: "call_123".to_string(),
365 task_id: "task_abc".to_string(),
366 agent_name: "explorer".to_string(),
367 event: SubAgentEvent::Done,
368 };
369
370 let notification: ExtNotification = params.into();
371 assert_eq!(notification.method.as_ref(), SUB_AGENT_PROGRESS_METHOD);
372
373 let parsed: SubAgentProgressParams =
374 serde_json::from_str(notification.params.get()).expect("valid JSON");
375 assert!(matches!(parsed.event, SubAgentEvent::Done));
376 assert_eq!(parsed.parent_tool_id, "call_123");
377 }
378
379 #[test]
380 fn deserialize_tool_call_event() {
381 let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
382 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
383 assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
384 }
385
386 #[test]
387 fn deserialize_tool_call_update_event() {
388 let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
391 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
392 assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
393 }
394
395 #[test]
396 fn deserialize_tool_result_event() {
397 let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
398 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
399 match event {
400 SubAgentEvent::ToolResult { result } => {
401 let result_meta = result.result_meta.expect("expected result_meta");
402 assert_eq!(result_meta.display.title, "Grep");
403 }
404 other => panic!("Expected ToolResult, got {other:?}"),
405 }
406 }
407
408 #[test]
409 fn deserialize_tool_error_event() {
410 let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
411 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
412 assert!(matches!(event, SubAgentEvent::ToolError { .. }));
413 }
414
415 #[test]
416 fn deserialize_done_event() {
417 let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
418 assert!(matches!(event, SubAgentEvent::Done));
419 }
420
421 #[test]
422 fn deserialize_other_variant() {
423 let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
424 assert!(matches!(event, SubAgentEvent::Other));
425 }
426
427 #[test]
428 fn tool_result_meta_map_roundtrip() {
429 let meta: ToolResultMeta =
430 ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
431 let map = meta.clone().into_map();
432 let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
433 assert_eq!(parsed, meta);
434 }
435
436 #[test]
437 fn mcp_request_try_from_roundtrip() {
438 let msg = McpRequest::Authenticate {
439 session_id: "session-0".to_string(),
440 server_name: "my oauth server".to_string(),
441 };
442
443 let notification: ExtNotification = msg.clone().into();
444 let parsed = McpRequest::try_from(¬ification).expect("should parse McpRequest");
445 assert_eq!(parsed, msg);
446 }
447
448 #[test]
449 fn mcp_notification_try_from_roundtrip() {
450 let msg = McpNotification::ServerStatus {
451 servers: vec![McpServerStatusEntry {
452 name: "github".to_string(),
453 status: McpServerStatus::Connected { tool_count: 5 },
454 }],
455 };
456
457 let notification: ExtNotification = msg.clone().into();
458 let parsed =
459 McpNotification::try_from(¬ification).expect("should parse McpNotification");
460 assert_eq!(parsed, msg);
461 }
462
463 #[test]
464 fn auth_methods_updated_try_from_roundtrip() {
465 let params = AuthMethodsUpdatedParams {
466 auth_methods: vec![AuthMethod::Agent(
467 AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
468 )],
469 };
470
471 let notification: ExtNotification = params.clone().into();
472 let parsed =
473 AuthMethodsUpdatedParams::try_from(¬ification).expect("should parse auth methods");
474 assert_eq!(parsed, params);
475 }
476
477 #[test]
478 fn try_from_wrong_method_returns_error() {
479 let notification = ext_notification(
480 CONTEXT_USAGE_METHOD,
481 &ContextUsageParams {
482 usage_ratio: Some(0.5),
483 tokens_used: 50000,
484 context_limit: Some(100_000),
485 },
486 );
487
488 let result = McpRequest::try_from(¬ification);
489 assert!(matches!(
490 result,
491 Err(ExtNotificationParseError::WrongMethod)
492 ));
493 }
494
495 #[test]
496 fn try_from_invalid_json_returns_error() {
497 let notification = ext_notification(MCP_MESSAGE_METHOD, &"not a valid McpRequest");
498
499 let result = McpRequest::try_from(¬ification);
500 assert!(matches!(
501 result,
502 Err(ExtNotificationParseError::InvalidJson(_))
503 ));
504 }
505
506 #[test]
507 fn ext_notification_parse_error_display() {
508 let wrong = ExtNotificationParseError::WrongMethod;
509 assert!(wrong.to_string().contains(MCP_MESSAGE_METHOD));
510
511 let json_err = serde_json::from_str::<McpRequest>("{}").unwrap_err();
512 let invalid = ExtNotificationParseError::InvalidJson(json_err);
513 assert!(invalid.to_string().contains("invalid JSON"));
514 }
515}