Skip to main content

acp_utils/
notifications.rs

1//! Shared wire-format types for Aether's custom ACP extension notifications.
2//!
3//! These types are used on both the agent (server) and client (UI) sides of the
4//! ACP connection.
5
6use agent_client_protocol::{AuthMethod, ExtNotification};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8pub use rmcp::model::CreateElicitationRequestParams;
9use serde::{Deserialize, Serialize, de::DeserializeOwned};
10use serde_json::value::to_raw_value;
11use std::fmt;
12use std::sync::Arc;
13
14pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
15
16/// Custom notification methods for sub-agent progress updates.
17/// Per ACP extensibility spec, custom notifications must start with underscore.
18pub const SUB_AGENT_PROGRESS_METHOD: &str = "_aether/sub_agent_progress";
19pub const CONTEXT_USAGE_METHOD: &str = "_aether/context_usage";
20pub const CONTEXT_CLEARED_METHOD: &str = "_aether/context_cleared";
21pub const MCP_MESSAGE_METHOD: &str = "_aether/mcp";
22pub const AUTH_METHODS_UPDATED_METHOD: &str = "_aether/auth_methods_updated";
23
24/// Custom `ext_method` for tunneling MCP elicitation through ACP.
25/// Note: ACP auto-prefixes `ext_method` names with `_`, so the wire method
26/// becomes `_aether/elicitation`.
27pub const ELICITATION_METHOD: &str = "aether/elicitation";
28
29/// Parameters for `_aether/context_usage` notifications.
30///
31/// Per-turn fields (`input_tokens`, `output_tokens`, `cache_read_tokens`,
32/// `cache_creation_tokens`, `reasoning_tokens`) come from the most recent
33/// API response. The `total_*` fields are cumulative across the agent's
34/// lifetime. The optional fields are `None` when the provider doesn't
35/// expose that dimension; this is semantically distinct from `Some(0)`.
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
37pub struct ContextUsageParams {
38    pub usage_ratio: Option<f64>,
39    pub context_limit: Option<u32>,
40    pub input_tokens: u32,
41    #[serde(default)]
42    pub output_tokens: u32,
43    #[serde(default, skip_serializing_if = "Option::is_none")]
44    pub cache_read_tokens: Option<u32>,
45    #[serde(default, skip_serializing_if = "Option::is_none")]
46    pub cache_creation_tokens: Option<u32>,
47    #[serde(default, skip_serializing_if = "Option::is_none")]
48    pub reasoning_tokens: Option<u32>,
49    #[serde(default)]
50    pub total_input_tokens: u64,
51    #[serde(default)]
52    pub total_output_tokens: u64,
53    #[serde(default)]
54    pub total_cache_read_tokens: u64,
55    #[serde(default)]
56    pub total_cache_creation_tokens: u64,
57    #[serde(default)]
58    pub total_reasoning_tokens: u64,
59}
60
61/// Parameters for `_aether/context_cleared` notifications.
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
63pub struct ContextClearedParams {}
64
65/// Parameters for `_aether/auth_methods_updated` notifications.
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
67pub struct AuthMethodsUpdatedParams {
68    pub auth_methods: Vec<AuthMethod>,
69}
70
71/// Parameters sent via `ext_method` for `aether/elicitation`.
72///
73/// Carries the full RMCP elicitation request plus the originating server name
74/// so the client can distinguish form vs URL mode and display which
75/// server is requesting.
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
77pub struct ElicitationParams {
78    pub server_name: String,
79    pub request: CreateElicitationRequestParams,
80}
81
82pub use rmcp::model::ElicitationAction;
83
84/// Response returned from the client for an elicitation request.
85#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
86pub struct ElicitationResponse {
87    pub action: ElicitationAction,
88    /// Structured form data when action is "accept".
89    pub content: Option<serde_json::Value>,
90}
91
92pub use mcp_utils::client::UrlElicitationCompleteParams;
93
94/// Server→client MCP extension notifications (relay → wisp).
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
96pub enum McpNotification {
97    ServerStatus { servers: Vec<McpServerStatusEntry> },
98    UrlElicitationComplete(UrlElicitationCompleteParams),
99}
100
101impl From<McpNotification> for ExtNotification {
102    fn from(msg: McpNotification) -> Self {
103        ext_notification(MCP_MESSAGE_METHOD, &msg)
104    }
105}
106
107/// Client→server MCP extension requests (wisp → relay).
108#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
109pub enum McpRequest {
110    Authenticate { session_id: String, server_name: String },
111}
112
113impl From<McpRequest> for ExtNotification {
114    fn from(msg: McpRequest) -> Self {
115        ext_notification(MCP_MESSAGE_METHOD, &msg)
116    }
117}
118
119/// Error returned when converting an `ExtNotification` into a typed MCP message.
120#[derive(Debug)]
121pub enum ExtNotificationParseError {
122    WrongMethod { expected: &'static str, actual: String },
123    InvalidJson { method: &'static str, source: serde_json::Error },
124}
125
126impl fmt::Display for ExtNotificationParseError {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        match self {
129            Self::WrongMethod { expected, actual } => {
130                write!(f, "notification method mismatch: expected {expected}, got {actual}")
131            }
132            Self::InvalidJson { method, source } => write!(f, "invalid JSON params for {method}: {source}"),
133        }
134    }
135}
136
137fn parse_ext_notification<T: DeserializeOwned>(
138    notification: &ExtNotification,
139    method: &'static str,
140) -> Result<T, ExtNotificationParseError> {
141    if notification.method.as_ref() != method {
142        return Err(ExtNotificationParseError::WrongMethod {
143            expected: method,
144            actual: notification.method.as_ref().to_string(),
145        });
146    }
147
148    serde_json::from_str(notification.params.get())
149        .map_err(|source| ExtNotificationParseError::InvalidJson { method, source })
150}
151
152impl TryFrom<&ExtNotification> for McpRequest {
153    type Error = ExtNotificationParseError;
154
155    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
156        parse_ext_notification(n, MCP_MESSAGE_METHOD)
157    }
158}
159
160impl TryFrom<&ExtNotification> for McpNotification {
161    type Error = ExtNotificationParseError;
162
163    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
164        parse_ext_notification(n, MCP_MESSAGE_METHOD)
165    }
166}
167
168impl TryFrom<&ExtNotification> for AuthMethodsUpdatedParams {
169    type Error = ExtNotificationParseError;
170
171    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
172        parse_ext_notification(n, AUTH_METHODS_UPDATED_METHOD)
173    }
174}
175
176fn ext_notification<T: Serialize>(method: &str, params: &T) -> ExtNotification {
177    let raw_value = to_raw_value(params).expect("notification params are serializable");
178    ExtNotification::new(method, Arc::from(raw_value))
179}
180
181impl From<ContextUsageParams> for ExtNotification {
182    fn from(params: ContextUsageParams) -> Self {
183        ext_notification(CONTEXT_USAGE_METHOD, &params)
184    }
185}
186
187impl From<ContextClearedParams> for ExtNotification {
188    fn from(params: ContextClearedParams) -> Self {
189        ext_notification(CONTEXT_CLEARED_METHOD, &params)
190    }
191}
192
193impl From<AuthMethodsUpdatedParams> for ExtNotification {
194    fn from(params: AuthMethodsUpdatedParams) -> Self {
195        ext_notification(AUTH_METHODS_UPDATED_METHOD, &params)
196    }
197}
198
199/// Parameters for `_aether/sub_agent_progress` notifications.
200///
201/// This is the wire format sent from the ACP server (`aether-cli`) to clients like `wisp`.
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct SubAgentProgressParams {
204    pub parent_tool_id: String,
205    pub task_id: String,
206    pub agent_name: String,
207    pub event: SubAgentEvent,
208}
209
210impl From<SubAgentProgressParams> for ExtNotification {
211    fn from(params: SubAgentProgressParams) -> Self {
212        ext_notification(SUB_AGENT_PROGRESS_METHOD, &params)
213    }
214}
215
216/// Subset of agent message variants relevant for sub-agent status display.
217///
218/// The ACP server (`aether-cli`) converts `AgentMessage` to this type before
219/// serializing, so the wire format only contains these known variants.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub enum SubAgentEvent {
222    ToolCall { request: SubAgentToolRequest },
223    ToolCallUpdate { update: SubAgentToolCallUpdate },
224    ToolResult { result: SubAgentToolResult },
225    ToolError { error: SubAgentToolError },
226    Done,
227    Other,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct SubAgentToolRequest {
232    pub id: String,
233    pub name: String,
234    pub arguments: String,
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SubAgentToolCallUpdate {
239    pub id: String,
240    pub chunk: String,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct SubAgentToolResult {
245    pub id: String,
246    pub name: String,
247    pub result_meta: Option<ToolResultMeta>,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251pub struct SubAgentToolError {
252    pub id: String,
253    pub name: String,
254}
255
256#[cfg(test)]
257mod tests {
258    use agent_client_protocol::AuthMethodAgent;
259    use serde_json::from_str;
260
261    use super::*;
262
263    #[test]
264    fn method_constants_have_underscore_prefix() {
265        assert!(SUB_AGENT_PROGRESS_METHOD.starts_with('_'));
266        assert!(CONTEXT_USAGE_METHOD.starts_with('_'));
267        assert!(CONTEXT_CLEARED_METHOD.starts_with('_'));
268        assert!(MCP_MESSAGE_METHOD.starts_with('_'));
269        assert!(AUTH_METHODS_UPDATED_METHOD.starts_with('_'));
270    }
271
272    #[test]
273    fn mcp_request_authenticate_roundtrip() {
274        let msg = McpRequest::Authenticate {
275            session_id: "session-0".to_string(),
276            server_name: "my oauth server".to_string(),
277        };
278
279        let notification: ExtNotification = msg.clone().into();
280        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
281
282        let parsed: McpRequest = serde_json::from_str(notification.params.get()).expect("valid JSON");
283        assert_eq!(parsed, msg);
284    }
285
286    #[test]
287    fn mcp_notification_server_status_roundtrip() {
288        let msg = McpNotification::ServerStatus {
289            servers: vec![
290                McpServerStatusEntry {
291                    name: "github".to_string(),
292                    status: McpServerStatus::Connected { tool_count: 5 },
293                },
294                McpServerStatusEntry { name: "linear".to_string(), status: McpServerStatus::NeedsOAuth },
295                McpServerStatusEntry {
296                    name: "slack".to_string(),
297                    status: McpServerStatus::Failed { error: "connection timeout".to_string() },
298                },
299            ],
300        };
301
302        let notification: ExtNotification = msg.clone().into();
303        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
304
305        let parsed: McpNotification = serde_json::from_str(notification.params.get()).expect("valid JSON");
306        assert_eq!(parsed, msg);
307    }
308
309    #[test]
310    fn auth_methods_updated_params_roundtrip() {
311        let params = AuthMethodsUpdatedParams {
312            auth_methods: vec![
313                AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
314                AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
315            ],
316        };
317
318        let notification: ExtNotification = params.clone().into();
319        let parsed: AuthMethodsUpdatedParams = from_str(notification.params.get()).expect("valid JSON");
320
321        assert_eq!(parsed, params);
322        assert_eq!(notification.method.as_ref(), AUTH_METHODS_UPDATED_METHOD);
323    }
324
325    #[test]
326    fn mcp_server_status_entry_serde_roundtrip() {
327        let entry = McpServerStatusEntry {
328            name: "test-server".to_string(),
329            status: McpServerStatus::Connected { tool_count: 3 },
330        };
331
332        let json = serde_json::to_string(&entry).unwrap();
333        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
334        assert_eq!(parsed, entry);
335    }
336
337    #[test]
338    fn elicitation_params_roundtrip() {
339        use rmcp::model::{ElicitationSchema, EnumSchema};
340
341        let params = ElicitationParams {
342            server_name: "github".to_string(),
343            request: CreateElicitationRequestParams::FormElicitationParams {
344                meta: None,
345                message: "Pick a color".to_string(),
346                requested_schema: ElicitationSchema::builder()
347                    .required_enum_schema(
348                        "color",
349                        EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
350                    )
351                    .build()
352                    .unwrap(),
353            },
354        };
355
356        let json = serde_json::to_string(&params).unwrap();
357        let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
358        assert_eq!(parsed, params);
359    }
360
361    #[test]
362    fn elicitation_params_url_roundtrip() {
363        let params = ElicitationParams {
364            server_name: "github".to_string(),
365            request: CreateElicitationRequestParams::UrlElicitationParams {
366                meta: None,
367                message: "Authorize GitHub".to_string(),
368                url: "https://github.com/login/oauth".to_string(),
369                elicitation_id: "el-123".to_string(),
370            },
371        };
372
373        let json = serde_json::to_string(&params).unwrap();
374        assert!(json.contains("\"mode\":\"url\""));
375        assert!(json.contains("\"server_name\":\"github\""));
376        let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
377        assert_eq!(parsed, params);
378    }
379
380    #[test]
381    fn mcp_notification_url_elicitation_complete_roundtrip() {
382        let msg = McpNotification::UrlElicitationComplete(UrlElicitationCompleteParams {
383            server_name: "github".to_string(),
384            elicitation_id: "el-456".to_string(),
385        });
386
387        let notification: ExtNotification = msg.clone().into();
388        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
389
390        let parsed: McpNotification = serde_json::from_str(notification.params.get()).expect("valid JSON");
391        assert_eq!(parsed, msg);
392    }
393
394    #[test]
395    fn context_usage_params_roundtrip() {
396        let params = ContextUsageParams {
397            usage_ratio: Some(0.75),
398            context_limit: Some(100_000),
399            input_tokens: 75_000,
400            output_tokens: 1_200,
401            cache_read_tokens: Some(40_000),
402            cache_creation_tokens: Some(2_000),
403            reasoning_tokens: Some(500),
404            total_input_tokens: 200_000,
405            total_output_tokens: 8_000,
406            total_cache_read_tokens: 90_000,
407            total_cache_creation_tokens: 5_000,
408            total_reasoning_tokens: 1_500,
409        };
410
411        let notification: ExtNotification = params.clone().into();
412        assert_eq!(notification.method.as_ref(), CONTEXT_USAGE_METHOD);
413
414        let parsed: ContextUsageParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
415        assert_eq!(parsed, params);
416    }
417
418    #[test]
419    fn context_usage_params_omits_unset_optional_token_fields() {
420        let params = ContextUsageParams {
421            usage_ratio: Some(0.1),
422            context_limit: Some(1_000),
423            input_tokens: 100,
424            output_tokens: 0,
425            cache_read_tokens: None,
426            cache_creation_tokens: None,
427            reasoning_tokens: None,
428            total_input_tokens: 0,
429            total_output_tokens: 0,
430            total_cache_read_tokens: 0,
431            total_cache_creation_tokens: 0,
432            total_reasoning_tokens: 0,
433        };
434
435        let notification: ExtNotification = params.clone().into();
436        let raw = notification.params.get();
437        assert!(!raw.contains("\"cache_read_tokens\""));
438        assert!(!raw.contains("\"cache_creation_tokens\""));
439        assert!(!raw.contains("\"reasoning_tokens\""));
440    }
441
442    #[test]
443    fn context_cleared_params_roundtrip() {
444        let params = ContextClearedParams::default();
445
446        let notification: ExtNotification = params.clone().into();
447        assert_eq!(notification.method.as_ref(), CONTEXT_CLEARED_METHOD);
448
449        let parsed: ContextClearedParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
450        assert_eq!(parsed, params);
451    }
452
453    #[test]
454    fn sub_agent_progress_params_roundtrip() {
455        let params = SubAgentProgressParams {
456            parent_tool_id: "call_123".to_string(),
457            task_id: "task_abc".to_string(),
458            agent_name: "explorer".to_string(),
459            event: SubAgentEvent::Done,
460        };
461
462        let notification: ExtNotification = params.into();
463        assert_eq!(notification.method.as_ref(), SUB_AGENT_PROGRESS_METHOD);
464
465        let parsed: SubAgentProgressParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
466        assert!(matches!(parsed.event, SubAgentEvent::Done));
467        assert_eq!(parsed.parent_tool_id, "call_123");
468    }
469
470    #[test]
471    fn deserialize_tool_call_event() {
472        let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
473        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
474        assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
475    }
476
477    #[test]
478    fn deserialize_tool_call_update_event() {
479        // "model_name" is present because the wire format comes from AgentMessage serialization;
480        // SubAgentEvent::ToolCallUpdate has no model_name field, so serde silently ignores it.
481        let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
482        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
483        assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
484    }
485
486    #[test]
487    fn deserialize_tool_result_event() {
488        let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
489        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
490        match event {
491            SubAgentEvent::ToolResult { result } => {
492                let result_meta = result.result_meta.expect("expected result_meta");
493                assert_eq!(result_meta.display.title, "Grep");
494            }
495            other => panic!("Expected ToolResult, got {other:?}"),
496        }
497    }
498
499    #[test]
500    fn deserialize_tool_error_event() {
501        let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
502        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
503        assert!(matches!(event, SubAgentEvent::ToolError { .. }));
504    }
505
506    #[test]
507    fn deserialize_done_event() {
508        let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
509        assert!(matches!(event, SubAgentEvent::Done));
510    }
511
512    #[test]
513    fn deserialize_other_variant() {
514        let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
515        assert!(matches!(event, SubAgentEvent::Other));
516    }
517
518    #[test]
519    fn tool_result_meta_map_roundtrip() {
520        let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
521        let map = meta.clone().into_map();
522        let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
523        assert_eq!(parsed, meta);
524    }
525
526    #[test]
527    fn mcp_request_try_from_roundtrip() {
528        let msg = McpRequest::Authenticate {
529            session_id: "session-0".to_string(),
530            server_name: "my oauth server".to_string(),
531        };
532
533        let notification: ExtNotification = msg.clone().into();
534        let parsed = McpRequest::try_from(&notification).expect("should parse McpRequest");
535        assert_eq!(parsed, msg);
536    }
537
538    #[test]
539    fn mcp_notification_try_from_roundtrip() {
540        let msg = McpNotification::ServerStatus {
541            servers: vec![McpServerStatusEntry {
542                name: "github".to_string(),
543                status: McpServerStatus::Connected { tool_count: 5 },
544            }],
545        };
546
547        let notification: ExtNotification = msg.clone().into();
548        let parsed = McpNotification::try_from(&notification).expect("should parse McpNotification");
549        assert_eq!(parsed, msg);
550    }
551
552    #[test]
553    fn auth_methods_updated_try_from_roundtrip() {
554        let params = AuthMethodsUpdatedParams {
555            auth_methods: vec![AuthMethod::Agent(
556                AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
557            )],
558        };
559
560        let notification: ExtNotification = params.clone().into();
561        let parsed = AuthMethodsUpdatedParams::try_from(&notification).expect("should parse auth methods");
562        assert_eq!(parsed, params);
563    }
564
565    #[test]
566    fn try_from_wrong_method_returns_error() {
567        let notification = ext_notification(
568            CONTEXT_USAGE_METHOD,
569            &ContextUsageParams {
570                usage_ratio: Some(0.5),
571                context_limit: Some(100_000),
572                input_tokens: 50_000,
573                output_tokens: 0,
574                cache_read_tokens: None,
575                cache_creation_tokens: None,
576                reasoning_tokens: None,
577                total_input_tokens: 0,
578                total_output_tokens: 0,
579                total_cache_read_tokens: 0,
580                total_cache_creation_tokens: 0,
581                total_reasoning_tokens: 0,
582            },
583        );
584
585        let result = McpRequest::try_from(&notification);
586        assert!(matches!(
587            result,
588            Err(ExtNotificationParseError::WrongMethod { expected, actual })
589                if expected == MCP_MESSAGE_METHOD && actual == CONTEXT_USAGE_METHOD
590        ));
591    }
592
593    #[test]
594    fn try_from_invalid_json_returns_error() {
595        let notification = ext_notification(MCP_MESSAGE_METHOD, &"not a valid McpRequest");
596
597        let result = McpRequest::try_from(&notification);
598        assert!(matches!(
599            result,
600            Err(ExtNotificationParseError::InvalidJson { method, .. }) if method == MCP_MESSAGE_METHOD
601        ));
602    }
603
604    #[test]
605    fn ext_notification_parse_error_display() {
606        let wrong = ExtNotificationParseError::WrongMethod {
607            expected: MCP_MESSAGE_METHOD,
608            actual: CONTEXT_USAGE_METHOD.to_string(),
609        };
610        assert!(wrong.to_string().contains(MCP_MESSAGE_METHOD));
611        assert!(wrong.to_string().contains(CONTEXT_USAGE_METHOD));
612
613        let json_err = serde_json::from_str::<McpRequest>("{}").unwrap_err();
614        let invalid = ExtNotificationParseError::InvalidJson { method: MCP_MESSAGE_METHOD, source: json_err };
615        assert!(invalid.to_string().contains("invalid JSON"));
616        assert!(invalid.to_string().contains(MCP_MESSAGE_METHOD));
617    }
618}