1use agent_client_protocol::{AuthMethod, ExtNotification};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8use rmcp::model::ElicitationSchema;
9use serde::{Deserialize, Serialize};
10use serde_json::value::to_raw_value;
11use std::fmt;
12use std::sync::Arc;
13
14pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
15
16pub const SUB_AGENT_PROGRESS_METHOD: &str = "_aether/sub_agent_progress";
19pub const CONTEXT_USAGE_METHOD: &str = "_aether/context_usage";
20pub const CONTEXT_CLEARED_METHOD: &str = "_aether/context_cleared";
21pub const MCP_MESSAGE_METHOD: &str = "_aether/mcp";
22pub const AUTH_METHODS_UPDATED_METHOD: &str = "_aether/auth_methods_updated";
23
24pub const ELICITATION_METHOD: &str = "aether/elicitation";
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ContextUsageParams {
32 pub usage_ratio: Option<f64>,
33 pub tokens_used: u32,
34 pub context_limit: Option<u32>,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
39pub struct ContextClearedParams {}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub struct AuthMethodsUpdatedParams {
44 pub auth_methods: Vec<AuthMethod>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ElicitationParams {
50 pub message: String,
51 pub schema: ElicitationSchema,
52}
53
54pub use rmcp::model::ElicitationAction;
55
56#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ElicitationResponse {
59 pub action: ElicitationAction,
60 pub content: Option<serde_json::Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub enum McpNotification {
67 ServerStatus { servers: Vec<McpServerStatusEntry> },
68}
69
70impl From<McpNotification> for ExtNotification {
71 fn from(msg: McpNotification) -> Self {
72 ext_notification(MCP_MESSAGE_METHOD, &msg)
73 }
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
78pub enum McpRequest {
79 Authenticate { session_id: String, server_name: String },
80}
81
82impl From<McpRequest> for ExtNotification {
83 fn from(msg: McpRequest) -> Self {
84 ext_notification(MCP_MESSAGE_METHOD, &msg)
85 }
86}
87
88#[derive(Debug)]
90pub enum ExtNotificationParseError {
91 WrongMethod,
92 InvalidJson(serde_json::Error),
93}
94
95impl fmt::Display for ExtNotificationParseError {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 match self {
98 Self::WrongMethod => write!(f, "notification method is not {MCP_MESSAGE_METHOD}"),
99 Self::InvalidJson(e) => write!(f, "invalid JSON params: {e}"),
100 }
101 }
102}
103
104impl TryFrom<&ExtNotification> for McpRequest {
105 type Error = ExtNotificationParseError;
106
107 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
108 if n.method.as_ref() != MCP_MESSAGE_METHOD {
109 return Err(ExtNotificationParseError::WrongMethod);
110 }
111 serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
112 }
113}
114
115impl TryFrom<&ExtNotification> for McpNotification {
116 type Error = ExtNotificationParseError;
117
118 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
119 if n.method.as_ref() != MCP_MESSAGE_METHOD {
120 return Err(ExtNotificationParseError::WrongMethod);
121 }
122 serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
123 }
124}
125
126impl TryFrom<&ExtNotification> for AuthMethodsUpdatedParams {
127 type Error = ExtNotificationParseError;
128
129 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
130 if n.method.as_ref() != AUTH_METHODS_UPDATED_METHOD {
131 return Err(ExtNotificationParseError::WrongMethod);
132 }
133 serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
134 }
135}
136
137fn ext_notification<T: Serialize>(method: &str, params: &T) -> ExtNotification {
138 let raw_value = to_raw_value(params).expect("notification params are serializable");
139 ExtNotification::new(method, Arc::from(raw_value))
140}
141
142impl From<ContextUsageParams> for ExtNotification {
143 fn from(params: ContextUsageParams) -> Self {
144 ext_notification(CONTEXT_USAGE_METHOD, ¶ms)
145 }
146}
147
148impl From<ContextClearedParams> for ExtNotification {
149 fn from(params: ContextClearedParams) -> Self {
150 ext_notification(CONTEXT_CLEARED_METHOD, ¶ms)
151 }
152}
153
154impl From<AuthMethodsUpdatedParams> for ExtNotification {
155 fn from(params: AuthMethodsUpdatedParams) -> Self {
156 ext_notification(AUTH_METHODS_UPDATED_METHOD, ¶ms)
157 }
158}
159
160#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SubAgentProgressParams {
165 pub parent_tool_id: String,
166 pub task_id: String,
167 pub agent_name: String,
168 pub event: SubAgentEvent,
169}
170
171impl From<SubAgentProgressParams> for ExtNotification {
172 fn from(params: SubAgentProgressParams) -> Self {
173 ext_notification(SUB_AGENT_PROGRESS_METHOD, ¶ms)
174 }
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
182pub enum SubAgentEvent {
183 ToolCall { request: SubAgentToolRequest },
184 ToolCallUpdate { update: SubAgentToolCallUpdate },
185 ToolResult { result: SubAgentToolResult },
186 ToolError { error: SubAgentToolError },
187 Done,
188 Other,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct SubAgentToolRequest {
193 pub id: String,
194 pub name: String,
195 pub arguments: String,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct SubAgentToolCallUpdate {
200 pub id: String,
201 pub chunk: String,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SubAgentToolResult {
206 pub id: String,
207 pub name: String,
208 pub result_meta: Option<ToolResultMeta>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct SubAgentToolError {
213 pub id: String,
214 pub name: String,
215}
216
217#[cfg(test)]
218mod tests {
219 use agent_client_protocol::AuthMethodAgent;
220 use serde_json::from_str;
221
222 use super::*;
223
224 #[test]
225 fn method_constants_have_underscore_prefix() {
226 assert!(SUB_AGENT_PROGRESS_METHOD.starts_with('_'));
227 assert!(CONTEXT_USAGE_METHOD.starts_with('_'));
228 assert!(CONTEXT_CLEARED_METHOD.starts_with('_'));
229 assert!(MCP_MESSAGE_METHOD.starts_with('_'));
230 assert!(AUTH_METHODS_UPDATED_METHOD.starts_with('_'));
231 }
232
233 #[test]
234 fn mcp_request_authenticate_roundtrip() {
235 let msg = McpRequest::Authenticate {
236 session_id: "session-0".to_string(),
237 server_name: "my oauth server".to_string(),
238 };
239
240 let notification: ExtNotification = msg.clone().into();
241 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
242
243 let parsed: McpRequest = serde_json::from_str(notification.params.get()).expect("valid JSON");
244 assert_eq!(parsed, msg);
245 }
246
247 #[test]
248 fn mcp_notification_server_status_roundtrip() {
249 let msg = McpNotification::ServerStatus {
250 servers: vec![
251 McpServerStatusEntry {
252 name: "github".to_string(),
253 status: McpServerStatus::Connected { tool_count: 5 },
254 },
255 McpServerStatusEntry { name: "linear".to_string(), status: McpServerStatus::NeedsOAuth },
256 McpServerStatusEntry {
257 name: "slack".to_string(),
258 status: McpServerStatus::Failed { error: "connection timeout".to_string() },
259 },
260 ],
261 };
262
263 let notification: ExtNotification = msg.clone().into();
264 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
265
266 let parsed: McpNotification = serde_json::from_str(notification.params.get()).expect("valid JSON");
267 assert_eq!(parsed, msg);
268 }
269
270 #[test]
271 fn auth_methods_updated_params_roundtrip() {
272 let params = AuthMethodsUpdatedParams {
273 auth_methods: vec![
274 AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
275 AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
276 ],
277 };
278
279 let notification: ExtNotification = params.clone().into();
280 let parsed: AuthMethodsUpdatedParams = from_str(notification.params.get()).expect("valid JSON");
281
282 assert_eq!(parsed, params);
283 assert_eq!(notification.method.as_ref(), AUTH_METHODS_UPDATED_METHOD);
284 }
285
286 #[test]
287 fn mcp_server_status_entry_serde_roundtrip() {
288 let entry = McpServerStatusEntry {
289 name: "test-server".to_string(),
290 status: McpServerStatus::Connected { tool_count: 3 },
291 };
292
293 let json = serde_json::to_string(&entry).unwrap();
294 let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
295 assert_eq!(parsed, entry);
296 }
297
298 #[test]
299 fn elicitation_params_roundtrip() {
300 use rmcp::model::EnumSchema;
301
302 let params = ElicitationParams {
303 message: "Pick a color".to_string(),
304 schema: ElicitationSchema::builder()
305 .required_enum_schema(
306 "color",
307 EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
308 )
309 .build()
310 .unwrap(),
311 };
312
313 let json = serde_json::to_string(¶ms).unwrap();
314 let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
315 assert_eq!(parsed, params);
316 }
317
318 #[test]
319 fn context_usage_params_roundtrip() {
320 let params = ContextUsageParams { usage_ratio: Some(0.75), tokens_used: 75000, context_limit: Some(100_000) };
321
322 let notification: ExtNotification = params.clone().into();
323 assert_eq!(notification.method.as_ref(), CONTEXT_USAGE_METHOD);
324
325 let parsed: ContextUsageParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
326 assert_eq!(parsed, params);
327 }
328
329 #[test]
330 fn context_cleared_params_roundtrip() {
331 let params = ContextClearedParams::default();
332
333 let notification: ExtNotification = params.clone().into();
334 assert_eq!(notification.method.as_ref(), CONTEXT_CLEARED_METHOD);
335
336 let parsed: ContextClearedParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
337 assert_eq!(parsed, params);
338 }
339
340 #[test]
341 fn sub_agent_progress_params_roundtrip() {
342 let params = SubAgentProgressParams {
343 parent_tool_id: "call_123".to_string(),
344 task_id: "task_abc".to_string(),
345 agent_name: "explorer".to_string(),
346 event: SubAgentEvent::Done,
347 };
348
349 let notification: ExtNotification = params.into();
350 assert_eq!(notification.method.as_ref(), SUB_AGENT_PROGRESS_METHOD);
351
352 let parsed: SubAgentProgressParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
353 assert!(matches!(parsed.event, SubAgentEvent::Done));
354 assert_eq!(parsed.parent_tool_id, "call_123");
355 }
356
357 #[test]
358 fn deserialize_tool_call_event() {
359 let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
360 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
361 assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
362 }
363
364 #[test]
365 fn deserialize_tool_call_update_event() {
366 let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
369 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
370 assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
371 }
372
373 #[test]
374 fn deserialize_tool_result_event() {
375 let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
376 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
377 match event {
378 SubAgentEvent::ToolResult { result } => {
379 let result_meta = result.result_meta.expect("expected result_meta");
380 assert_eq!(result_meta.display.title, "Grep");
381 }
382 other => panic!("Expected ToolResult, got {other:?}"),
383 }
384 }
385
386 #[test]
387 fn deserialize_tool_error_event() {
388 let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
389 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
390 assert!(matches!(event, SubAgentEvent::ToolError { .. }));
391 }
392
393 #[test]
394 fn deserialize_done_event() {
395 let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
396 assert!(matches!(event, SubAgentEvent::Done));
397 }
398
399 #[test]
400 fn deserialize_other_variant() {
401 let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
402 assert!(matches!(event, SubAgentEvent::Other));
403 }
404
405 #[test]
406 fn tool_result_meta_map_roundtrip() {
407 let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
408 let map = meta.clone().into_map();
409 let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
410 assert_eq!(parsed, meta);
411 }
412
413 #[test]
414 fn mcp_request_try_from_roundtrip() {
415 let msg = McpRequest::Authenticate {
416 session_id: "session-0".to_string(),
417 server_name: "my oauth server".to_string(),
418 };
419
420 let notification: ExtNotification = msg.clone().into();
421 let parsed = McpRequest::try_from(¬ification).expect("should parse McpRequest");
422 assert_eq!(parsed, msg);
423 }
424
425 #[test]
426 fn mcp_notification_try_from_roundtrip() {
427 let msg = McpNotification::ServerStatus {
428 servers: vec![McpServerStatusEntry {
429 name: "github".to_string(),
430 status: McpServerStatus::Connected { tool_count: 5 },
431 }],
432 };
433
434 let notification: ExtNotification = msg.clone().into();
435 let parsed = McpNotification::try_from(¬ification).expect("should parse McpNotification");
436 assert_eq!(parsed, msg);
437 }
438
439 #[test]
440 fn auth_methods_updated_try_from_roundtrip() {
441 let params = AuthMethodsUpdatedParams {
442 auth_methods: vec![AuthMethod::Agent(
443 AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
444 )],
445 };
446
447 let notification: ExtNotification = params.clone().into();
448 let parsed = AuthMethodsUpdatedParams::try_from(¬ification).expect("should parse auth methods");
449 assert_eq!(parsed, params);
450 }
451
452 #[test]
453 fn try_from_wrong_method_returns_error() {
454 let notification = ext_notification(
455 CONTEXT_USAGE_METHOD,
456 &ContextUsageParams { usage_ratio: Some(0.5), tokens_used: 50000, context_limit: Some(100_000) },
457 );
458
459 let result = McpRequest::try_from(¬ification);
460 assert!(matches!(result, Err(ExtNotificationParseError::WrongMethod)));
461 }
462
463 #[test]
464 fn try_from_invalid_json_returns_error() {
465 let notification = ext_notification(MCP_MESSAGE_METHOD, &"not a valid McpRequest");
466
467 let result = McpRequest::try_from(¬ification);
468 assert!(matches!(result, Err(ExtNotificationParseError::InvalidJson(_))));
469 }
470
471 #[test]
472 fn ext_notification_parse_error_display() {
473 let wrong = ExtNotificationParseError::WrongMethod;
474 assert!(wrong.to_string().contains(MCP_MESSAGE_METHOD));
475
476 let json_err = serde_json::from_str::<McpRequest>("{}").unwrap_err();
477 let invalid = ExtNotificationParseError::InvalidJson(json_err);
478 assert!(invalid.to_string().contains("invalid JSON"));
479 }
480}