1use agent_client_protocol::{AuthMethod, ExtNotification};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8pub use rmcp::model::CreateElicitationRequestParams;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::value::to_raw_value;
11use std::fmt;
12use std::sync::Arc;
13
14pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
15
16pub const SUB_AGENT_PROGRESS_METHOD: &str = "_aether/sub_agent_progress";
19pub const CONTEXT_USAGE_METHOD: &str = "_aether/context_usage";
20pub const CONTEXT_CLEARED_METHOD: &str = "_aether/context_cleared";
21pub const MCP_MESSAGE_METHOD: &str = "_aether/mcp";
22pub const AUTH_METHODS_UPDATED_METHOD: &str = "_aether/auth_methods_updated";
23
24pub const ELICITATION_METHOD: &str = "aether/elicitation";
28
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub struct ContextUsageParams {
38 pub usage_ratio: Option<f64>,
39 pub context_limit: Option<u32>,
40 pub input_tokens: u32,
41 #[serde(default)]
42 pub output_tokens: u32,
43 #[serde(default, skip_serializing_if = "Option::is_none")]
44 pub cache_read_tokens: Option<u32>,
45 #[serde(default, skip_serializing_if = "Option::is_none")]
46 pub cache_creation_tokens: Option<u32>,
47 #[serde(default, skip_serializing_if = "Option::is_none")]
48 pub reasoning_tokens: Option<u32>,
49 #[serde(default)]
50 pub total_input_tokens: u64,
51 #[serde(default)]
52 pub total_output_tokens: u64,
53 #[serde(default)]
54 pub total_cache_read_tokens: u64,
55 #[serde(default)]
56 pub total_cache_creation_tokens: u64,
57 #[serde(default)]
58 pub total_reasoning_tokens: u64,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
63pub struct ContextClearedParams {}
64
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
67pub struct AuthMethodsUpdatedParams {
68 pub auth_methods: Vec<AuthMethod>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct ElicitationParams {
78 pub server_name: String,
79 pub request: CreateElicitationRequestParams,
80}
81
82pub use rmcp::model::ElicitationAction;
83
84#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ElicitationResponse {
87 pub action: ElicitationAction,
88 pub content: Option<serde_json::Value>,
90}
91
92pub use mcp_utils::client::UrlElicitationCompleteParams;
93
94#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96pub enum McpNotification {
97 ServerStatus { servers: Vec<McpServerStatusEntry> },
98 UrlElicitationComplete(UrlElicitationCompleteParams),
99}
100
101impl From<McpNotification> for ExtNotification {
102 fn from(msg: McpNotification) -> Self {
103 ext_notification(MCP_MESSAGE_METHOD, &msg)
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
109pub enum McpRequest {
110 Authenticate { session_id: String, server_name: String },
111}
112
113impl From<McpRequest> for ExtNotification {
114 fn from(msg: McpRequest) -> Self {
115 ext_notification(MCP_MESSAGE_METHOD, &msg)
116 }
117}
118
119#[derive(Debug)]
121pub enum ExtNotificationParseError {
122 WrongMethod { expected: &'static str, actual: String },
123 InvalidJson { method: &'static str, source: serde_json::Error },
124}
125
126impl fmt::Display for ExtNotificationParseError {
127 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128 match self {
129 Self::WrongMethod { expected, actual } => {
130 write!(f, "notification method mismatch: expected {expected}, got {actual}")
131 }
132 Self::InvalidJson { method, source } => write!(f, "invalid JSON params for {method}: {source}"),
133 }
134 }
135}
136
137fn parse_ext_notification<T: DeserializeOwned>(
138 notification: &ExtNotification,
139 method: &'static str,
140) -> Result<T, ExtNotificationParseError> {
141 if notification.method.as_ref() != method {
142 return Err(ExtNotificationParseError::WrongMethod {
143 expected: method,
144 actual: notification.method.as_ref().to_string(),
145 });
146 }
147
148 serde_json::from_str(notification.params.get())
149 .map_err(|source| ExtNotificationParseError::InvalidJson { method, source })
150}
151
152impl TryFrom<&ExtNotification> for McpRequest {
153 type Error = ExtNotificationParseError;
154
155 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
156 parse_ext_notification(n, MCP_MESSAGE_METHOD)
157 }
158}
159
160impl TryFrom<&ExtNotification> for McpNotification {
161 type Error = ExtNotificationParseError;
162
163 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
164 parse_ext_notification(n, MCP_MESSAGE_METHOD)
165 }
166}
167
168impl TryFrom<&ExtNotification> for AuthMethodsUpdatedParams {
169 type Error = ExtNotificationParseError;
170
171 fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
172 parse_ext_notification(n, AUTH_METHODS_UPDATED_METHOD)
173 }
174}
175
176fn ext_notification<T: Serialize>(method: &str, params: &T) -> ExtNotification {
177 let raw_value = to_raw_value(params).expect("notification params are serializable");
178 ExtNotification::new(method, Arc::from(raw_value))
179}
180
181impl From<ContextUsageParams> for ExtNotification {
182 fn from(params: ContextUsageParams) -> Self {
183 ext_notification(CONTEXT_USAGE_METHOD, ¶ms)
184 }
185}
186
187impl From<ContextClearedParams> for ExtNotification {
188 fn from(params: ContextClearedParams) -> Self {
189 ext_notification(CONTEXT_CLEARED_METHOD, ¶ms)
190 }
191}
192
193impl From<AuthMethodsUpdatedParams> for ExtNotification {
194 fn from(params: AuthMethodsUpdatedParams) -> Self {
195 ext_notification(AUTH_METHODS_UPDATED_METHOD, ¶ms)
196 }
197}
198
199#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct SubAgentProgressParams {
204 pub parent_tool_id: String,
205 pub task_id: String,
206 pub agent_name: String,
207 pub event: SubAgentEvent,
208}
209
210impl From<SubAgentProgressParams> for ExtNotification {
211 fn from(params: SubAgentProgressParams) -> Self {
212 ext_notification(SUB_AGENT_PROGRESS_METHOD, ¶ms)
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
221pub enum SubAgentEvent {
222 ToolCall { request: SubAgentToolRequest },
223 ToolCallUpdate { update: SubAgentToolCallUpdate },
224 ToolResult { result: SubAgentToolResult },
225 ToolError { error: SubAgentToolError },
226 Done,
227 Other,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct SubAgentToolRequest {
232 pub id: String,
233 pub name: String,
234 pub arguments: String,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SubAgentToolCallUpdate {
239 pub id: String,
240 pub chunk: String,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct SubAgentToolResult {
245 pub id: String,
246 pub name: String,
247 pub result_meta: Option<ToolResultMeta>,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct SubAgentToolError {
252 pub id: String,
253 pub name: String,
254}
255
256#[cfg(test)]
257mod tests {
258 use agent_client_protocol::AuthMethodAgent;
259 use serde_json::from_str;
260
261 use super::*;
262
263 #[test]
264 fn method_constants_have_underscore_prefix() {
265 assert!(SUB_AGENT_PROGRESS_METHOD.starts_with('_'));
266 assert!(CONTEXT_USAGE_METHOD.starts_with('_'));
267 assert!(CONTEXT_CLEARED_METHOD.starts_with('_'));
268 assert!(MCP_MESSAGE_METHOD.starts_with('_'));
269 assert!(AUTH_METHODS_UPDATED_METHOD.starts_with('_'));
270 }
271
272 #[test]
273 fn mcp_request_authenticate_roundtrip() {
274 let msg = McpRequest::Authenticate {
275 session_id: "session-0".to_string(),
276 server_name: "my oauth server".to_string(),
277 };
278
279 let notification: ExtNotification = msg.clone().into();
280 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
281
282 let parsed: McpRequest = serde_json::from_str(notification.params.get()).expect("valid JSON");
283 assert_eq!(parsed, msg);
284 }
285
286 #[test]
287 fn mcp_notification_server_status_roundtrip() {
288 let msg = McpNotification::ServerStatus {
289 servers: vec![
290 McpServerStatusEntry {
291 name: "github".to_string(),
292 status: McpServerStatus::Connected { tool_count: 5 },
293 },
294 McpServerStatusEntry { name: "linear".to_string(), status: McpServerStatus::NeedsOAuth },
295 McpServerStatusEntry {
296 name: "slack".to_string(),
297 status: McpServerStatus::Failed { error: "connection timeout".to_string() },
298 },
299 ],
300 };
301
302 let notification: ExtNotification = msg.clone().into();
303 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
304
305 let parsed: McpNotification = serde_json::from_str(notification.params.get()).expect("valid JSON");
306 assert_eq!(parsed, msg);
307 }
308
309 #[test]
310 fn auth_methods_updated_params_roundtrip() {
311 let params = AuthMethodsUpdatedParams {
312 auth_methods: vec![
313 AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
314 AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
315 ],
316 };
317
318 let notification: ExtNotification = params.clone().into();
319 let parsed: AuthMethodsUpdatedParams = from_str(notification.params.get()).expect("valid JSON");
320
321 assert_eq!(parsed, params);
322 assert_eq!(notification.method.as_ref(), AUTH_METHODS_UPDATED_METHOD);
323 }
324
325 #[test]
326 fn mcp_server_status_entry_serde_roundtrip() {
327 let entry = McpServerStatusEntry {
328 name: "test-server".to_string(),
329 status: McpServerStatus::Connected { tool_count: 3 },
330 };
331
332 let json = serde_json::to_string(&entry).unwrap();
333 let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
334 assert_eq!(parsed, entry);
335 }
336
337 #[test]
338 fn elicitation_params_roundtrip() {
339 use rmcp::model::{ElicitationSchema, EnumSchema};
340
341 let params = ElicitationParams {
342 server_name: "github".to_string(),
343 request: CreateElicitationRequestParams::FormElicitationParams {
344 meta: None,
345 message: "Pick a color".to_string(),
346 requested_schema: ElicitationSchema::builder()
347 .required_enum_schema(
348 "color",
349 EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
350 )
351 .build()
352 .unwrap(),
353 },
354 };
355
356 let json = serde_json::to_string(¶ms).unwrap();
357 let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
358 assert_eq!(parsed, params);
359 }
360
361 #[test]
362 fn elicitation_params_url_roundtrip() {
363 let params = ElicitationParams {
364 server_name: "github".to_string(),
365 request: CreateElicitationRequestParams::UrlElicitationParams {
366 meta: None,
367 message: "Authorize GitHub".to_string(),
368 url: "https://github.com/login/oauth".to_string(),
369 elicitation_id: "el-123".to_string(),
370 },
371 };
372
373 let json = serde_json::to_string(¶ms).unwrap();
374 assert!(json.contains("\"mode\":\"url\""));
375 assert!(json.contains("\"server_name\":\"github\""));
376 let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
377 assert_eq!(parsed, params);
378 }
379
380 #[test]
381 fn mcp_notification_url_elicitation_complete_roundtrip() {
382 let msg = McpNotification::UrlElicitationComplete(UrlElicitationCompleteParams {
383 server_name: "github".to_string(),
384 elicitation_id: "el-456".to_string(),
385 });
386
387 let notification: ExtNotification = msg.clone().into();
388 assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
389
390 let parsed: McpNotification = serde_json::from_str(notification.params.get()).expect("valid JSON");
391 assert_eq!(parsed, msg);
392 }
393
394 #[test]
395 fn context_usage_params_roundtrip() {
396 let params = ContextUsageParams {
397 usage_ratio: Some(0.75),
398 context_limit: Some(100_000),
399 input_tokens: 75_000,
400 output_tokens: 1_200,
401 cache_read_tokens: Some(40_000),
402 cache_creation_tokens: Some(2_000),
403 reasoning_tokens: Some(500),
404 total_input_tokens: 200_000,
405 total_output_tokens: 8_000,
406 total_cache_read_tokens: 90_000,
407 total_cache_creation_tokens: 5_000,
408 total_reasoning_tokens: 1_500,
409 };
410
411 let notification: ExtNotification = params.clone().into();
412 assert_eq!(notification.method.as_ref(), CONTEXT_USAGE_METHOD);
413
414 let parsed: ContextUsageParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
415 assert_eq!(parsed, params);
416 }
417
418 #[test]
419 fn context_usage_params_omits_unset_optional_token_fields() {
420 let params = ContextUsageParams {
421 usage_ratio: Some(0.1),
422 context_limit: Some(1_000),
423 input_tokens: 100,
424 output_tokens: 0,
425 cache_read_tokens: None,
426 cache_creation_tokens: None,
427 reasoning_tokens: None,
428 total_input_tokens: 0,
429 total_output_tokens: 0,
430 total_cache_read_tokens: 0,
431 total_cache_creation_tokens: 0,
432 total_reasoning_tokens: 0,
433 };
434
435 let notification: ExtNotification = params.clone().into();
436 let raw = notification.params.get();
437 assert!(!raw.contains("\"cache_read_tokens\""));
438 assert!(!raw.contains("\"cache_creation_tokens\""));
439 assert!(!raw.contains("\"reasoning_tokens\""));
440 }
441
442 #[test]
443 fn context_cleared_params_roundtrip() {
444 let params = ContextClearedParams::default();
445
446 let notification: ExtNotification = params.clone().into();
447 assert_eq!(notification.method.as_ref(), CONTEXT_CLEARED_METHOD);
448
449 let parsed: ContextClearedParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
450 assert_eq!(parsed, params);
451 }
452
453 #[test]
454 fn sub_agent_progress_params_roundtrip() {
455 let params = SubAgentProgressParams {
456 parent_tool_id: "call_123".to_string(),
457 task_id: "task_abc".to_string(),
458 agent_name: "explorer".to_string(),
459 event: SubAgentEvent::Done,
460 };
461
462 let notification: ExtNotification = params.into();
463 assert_eq!(notification.method.as_ref(), SUB_AGENT_PROGRESS_METHOD);
464
465 let parsed: SubAgentProgressParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
466 assert!(matches!(parsed.event, SubAgentEvent::Done));
467 assert_eq!(parsed.parent_tool_id, "call_123");
468 }
469
470 #[test]
471 fn deserialize_tool_call_event() {
472 let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
473 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
474 assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
475 }
476
477 #[test]
478 fn deserialize_tool_call_update_event() {
479 let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
482 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
483 assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
484 }
485
486 #[test]
487 fn deserialize_tool_result_event() {
488 let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
489 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
490 match event {
491 SubAgentEvent::ToolResult { result } => {
492 let result_meta = result.result_meta.expect("expected result_meta");
493 assert_eq!(result_meta.display.title, "Grep");
494 }
495 other => panic!("Expected ToolResult, got {other:?}"),
496 }
497 }
498
499 #[test]
500 fn deserialize_tool_error_event() {
501 let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
502 let event: SubAgentEvent = serde_json::from_str(json).unwrap();
503 assert!(matches!(event, SubAgentEvent::ToolError { .. }));
504 }
505
506 #[test]
507 fn deserialize_done_event() {
508 let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
509 assert!(matches!(event, SubAgentEvent::Done));
510 }
511
512 #[test]
513 fn deserialize_other_variant() {
514 let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
515 assert!(matches!(event, SubAgentEvent::Other));
516 }
517
518 #[test]
519 fn tool_result_meta_map_roundtrip() {
520 let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
521 let map = meta.clone().into_map();
522 let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
523 assert_eq!(parsed, meta);
524 }
525
526 #[test]
527 fn mcp_request_try_from_roundtrip() {
528 let msg = McpRequest::Authenticate {
529 session_id: "session-0".to_string(),
530 server_name: "my oauth server".to_string(),
531 };
532
533 let notification: ExtNotification = msg.clone().into();
534 let parsed = McpRequest::try_from(¬ification).expect("should parse McpRequest");
535 assert_eq!(parsed, msg);
536 }
537
538 #[test]
539 fn mcp_notification_try_from_roundtrip() {
540 let msg = McpNotification::ServerStatus {
541 servers: vec![McpServerStatusEntry {
542 name: "github".to_string(),
543 status: McpServerStatus::Connected { tool_count: 5 },
544 }],
545 };
546
547 let notification: ExtNotification = msg.clone().into();
548 let parsed = McpNotification::try_from(¬ification).expect("should parse McpNotification");
549 assert_eq!(parsed, msg);
550 }
551
552 #[test]
553 fn auth_methods_updated_try_from_roundtrip() {
554 let params = AuthMethodsUpdatedParams {
555 auth_methods: vec![AuthMethod::Agent(
556 AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
557 )],
558 };
559
560 let notification: ExtNotification = params.clone().into();
561 let parsed = AuthMethodsUpdatedParams::try_from(¬ification).expect("should parse auth methods");
562 assert_eq!(parsed, params);
563 }
564
565 #[test]
566 fn try_from_wrong_method_returns_error() {
567 let notification = ext_notification(
568 CONTEXT_USAGE_METHOD,
569 &ContextUsageParams {
570 usage_ratio: Some(0.5),
571 context_limit: Some(100_000),
572 input_tokens: 50_000,
573 output_tokens: 0,
574 cache_read_tokens: None,
575 cache_creation_tokens: None,
576 reasoning_tokens: None,
577 total_input_tokens: 0,
578 total_output_tokens: 0,
579 total_cache_read_tokens: 0,
580 total_cache_creation_tokens: 0,
581 total_reasoning_tokens: 0,
582 },
583 );
584
585 let result = McpRequest::try_from(¬ification);
586 assert!(matches!(
587 result,
588 Err(ExtNotificationParseError::WrongMethod { expected, actual })
589 if expected == MCP_MESSAGE_METHOD && actual == CONTEXT_USAGE_METHOD
590 ));
591 }
592
593 #[test]
594 fn try_from_invalid_json_returns_error() {
595 let notification = ext_notification(MCP_MESSAGE_METHOD, &"not a valid McpRequest");
596
597 let result = McpRequest::try_from(¬ification);
598 assert!(matches!(
599 result,
600 Err(ExtNotificationParseError::InvalidJson { method, .. }) if method == MCP_MESSAGE_METHOD
601 ));
602 }
603
604 #[test]
605 fn ext_notification_parse_error_display() {
606 let wrong = ExtNotificationParseError::WrongMethod {
607 expected: MCP_MESSAGE_METHOD,
608 actual: CONTEXT_USAGE_METHOD.to_string(),
609 };
610 assert!(wrong.to_string().contains(MCP_MESSAGE_METHOD));
611 assert!(wrong.to_string().contains(CONTEXT_USAGE_METHOD));
612
613 let json_err = serde_json::from_str::<McpRequest>("{}").unwrap_err();
614 let invalid = ExtNotificationParseError::InvalidJson { method: MCP_MESSAGE_METHOD, source: json_err };
615 assert!(invalid.to_string().contains("invalid JSON"));
616 assert!(invalid.to_string().contains(MCP_MESSAGE_METHOD));
617 }
618}