Skip to main content

acp_utils/
notifications.rs

1//! Shared wire-format types for Aether's custom ACP extension notifications.
2//!
3//! These types are used on both the agent (server) and client (UI) sides of the
4//! ACP connection.
5
6use agent_client_protocol::{AuthMethod, ExtNotification};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8use rmcp::model::ElicitationSchema;
9use serde::{Deserialize, Serialize};
10use serde_json::value::to_raw_value;
11use std::fmt;
12use std::sync::Arc;
13
14pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
15
16/// Custom notification methods for sub-agent progress updates.
17/// Per ACP extensibility spec, custom notifications must start with underscore.
18pub const SUB_AGENT_PROGRESS_METHOD: &str = "_aether/sub_agent_progress";
19pub const CONTEXT_USAGE_METHOD: &str = "_aether/context_usage";
20pub const CONTEXT_CLEARED_METHOD: &str = "_aether/context_cleared";
21pub const MCP_MESSAGE_METHOD: &str = "_aether/mcp";
22pub const AUTH_METHODS_UPDATED_METHOD: &str = "_aether/auth_methods_updated";
23
24/// Custom `ext_method` for tunneling MCP elicitation through ACP.
25/// Note: ACP auto-prefixes `ext_method` names with `_`, so the wire method
26/// becomes `_aether/elicitation`.
27pub const ELICITATION_METHOD: &str = "aether/elicitation";
28
29/// Parameters for `_aether/context_usage` notifications.
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ContextUsageParams {
32    pub usage_ratio: Option<f64>,
33    pub tokens_used: u32,
34    pub context_limit: Option<u32>,
35}
36
37/// Parameters for `_aether/context_cleared` notifications.
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
39pub struct ContextClearedParams {}
40
41/// Parameters for `_aether/auth_methods_updated` notifications.
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub struct AuthMethodsUpdatedParams {
44    pub auth_methods: Vec<AuthMethod>,
45}
46
47/// Parameters sent via `ext_method` for `aether/elicitation`.
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ElicitationParams {
50    pub message: String,
51    pub schema: ElicitationSchema,
52}
53
54pub use rmcp::model::ElicitationAction;
55
56/// Response returned from the client for an elicitation request.
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ElicitationResponse {
59    pub action: ElicitationAction,
60    /// Structured form data when action is "accept".
61    pub content: Option<serde_json::Value>,
62}
63
64/// Server→client MCP extension notifications (relay → wisp).
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub enum McpNotification {
67    ServerStatus { servers: Vec<McpServerStatusEntry> },
68}
69
70impl From<McpNotification> for ExtNotification {
71    fn from(msg: McpNotification) -> Self {
72        ext_notification(MCP_MESSAGE_METHOD, &msg)
73    }
74}
75
76/// Client→server MCP extension requests (wisp → relay).
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
78pub enum McpRequest {
79    Authenticate {
80        session_id: String,
81        server_name: String,
82    },
83}
84
85impl From<McpRequest> for ExtNotification {
86    fn from(msg: McpRequest) -> Self {
87        ext_notification(MCP_MESSAGE_METHOD, &msg)
88    }
89}
90
91/// Error returned when converting an `ExtNotification` into a typed MCP message.
92#[derive(Debug)]
93pub enum ExtNotificationParseError {
94    WrongMethod,
95    InvalidJson(serde_json::Error),
96}
97
98impl fmt::Display for ExtNotificationParseError {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        match self {
101            Self::WrongMethod => write!(f, "notification method is not {MCP_MESSAGE_METHOD}"),
102            Self::InvalidJson(e) => write!(f, "invalid JSON params: {e}"),
103        }
104    }
105}
106
107impl TryFrom<&ExtNotification> for McpRequest {
108    type Error = ExtNotificationParseError;
109
110    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
111        if n.method.as_ref() != MCP_MESSAGE_METHOD {
112            return Err(ExtNotificationParseError::WrongMethod);
113        }
114        serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
115    }
116}
117
118impl TryFrom<&ExtNotification> for McpNotification {
119    type Error = ExtNotificationParseError;
120
121    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
122        if n.method.as_ref() != MCP_MESSAGE_METHOD {
123            return Err(ExtNotificationParseError::WrongMethod);
124        }
125        serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
126    }
127}
128
129impl TryFrom<&ExtNotification> for AuthMethodsUpdatedParams {
130    type Error = ExtNotificationParseError;
131
132    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
133        if n.method.as_ref() != AUTH_METHODS_UPDATED_METHOD {
134            return Err(ExtNotificationParseError::WrongMethod);
135        }
136        serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
137    }
138}
139
140fn ext_notification<T: Serialize>(method: &str, params: &T) -> ExtNotification {
141    let raw_value = to_raw_value(params).expect("notification params are serializable");
142    ExtNotification::new(method, Arc::from(raw_value))
143}
144
145impl From<ContextUsageParams> for ExtNotification {
146    fn from(params: ContextUsageParams) -> Self {
147        ext_notification(CONTEXT_USAGE_METHOD, &params)
148    }
149}
150
151impl From<ContextClearedParams> for ExtNotification {
152    fn from(params: ContextClearedParams) -> Self {
153        ext_notification(CONTEXT_CLEARED_METHOD, &params)
154    }
155}
156
157impl From<AuthMethodsUpdatedParams> for ExtNotification {
158    fn from(params: AuthMethodsUpdatedParams) -> Self {
159        ext_notification(AUTH_METHODS_UPDATED_METHOD, &params)
160    }
161}
162
163/// Parameters for `_aether/sub_agent_progress` notifications.
164///
165/// This is the wire format sent from the ACP server (`aether-cli`) to clients like `wisp`.
166#[derive(Debug, Clone, Serialize, Deserialize)]
167pub struct SubAgentProgressParams {
168    pub parent_tool_id: String,
169    pub task_id: String,
170    pub agent_name: String,
171    pub event: SubAgentEvent,
172}
173
174impl From<SubAgentProgressParams> for ExtNotification {
175    fn from(params: SubAgentProgressParams) -> Self {
176        ext_notification(SUB_AGENT_PROGRESS_METHOD, &params)
177    }
178}
179
180/// Subset of agent message variants relevant for sub-agent status display.
181///
182/// The ACP server (`aether-cli`) converts `AgentMessage` to this type before
183/// serializing, so the wire format only contains these known variants.
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub enum SubAgentEvent {
186    ToolCall { request: SubAgentToolRequest },
187    ToolCallUpdate { update: SubAgentToolCallUpdate },
188    ToolResult { result: SubAgentToolResult },
189    ToolError { error: SubAgentToolError },
190    Done,
191    Other,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct SubAgentToolRequest {
196    pub id: String,
197    pub name: String,
198    pub arguments: String,
199}
200
201#[derive(Debug, Clone, Serialize, Deserialize)]
202pub struct SubAgentToolCallUpdate {
203    pub id: String,
204    pub chunk: String,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct SubAgentToolResult {
209    pub id: String,
210    pub name: String,
211    pub result_meta: Option<ToolResultMeta>,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct SubAgentToolError {
216    pub id: String,
217    pub name: String,
218}
219
220#[cfg(test)]
221mod tests {
222    use agent_client_protocol::AuthMethodAgent;
223    use serde_json::from_str;
224
225    use super::*;
226
227    #[test]
228    fn method_constants_have_underscore_prefix() {
229        assert!(SUB_AGENT_PROGRESS_METHOD.starts_with('_'));
230        assert!(CONTEXT_USAGE_METHOD.starts_with('_'));
231        assert!(CONTEXT_CLEARED_METHOD.starts_with('_'));
232        assert!(MCP_MESSAGE_METHOD.starts_with('_'));
233        assert!(AUTH_METHODS_UPDATED_METHOD.starts_with('_'));
234    }
235
236    #[test]
237    fn mcp_request_authenticate_roundtrip() {
238        let msg = McpRequest::Authenticate {
239            session_id: "session-0".to_string(),
240            server_name: "my oauth server".to_string(),
241        };
242
243        let notification: ExtNotification = msg.clone().into();
244        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
245
246        let parsed: McpRequest =
247            serde_json::from_str(notification.params.get()).expect("valid JSON");
248        assert_eq!(parsed, msg);
249    }
250
251    #[test]
252    fn mcp_notification_server_status_roundtrip() {
253        let msg = McpNotification::ServerStatus {
254            servers: vec![
255                McpServerStatusEntry {
256                    name: "github".to_string(),
257                    status: McpServerStatus::Connected { tool_count: 5 },
258                },
259                McpServerStatusEntry {
260                    name: "linear".to_string(),
261                    status: McpServerStatus::NeedsOAuth,
262                },
263                McpServerStatusEntry {
264                    name: "slack".to_string(),
265                    status: McpServerStatus::Failed {
266                        error: "connection timeout".to_string(),
267                    },
268                },
269            ],
270        };
271
272        let notification: ExtNotification = msg.clone().into();
273        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
274
275        let parsed: McpNotification =
276            serde_json::from_str(notification.params.get()).expect("valid JSON");
277        assert_eq!(parsed, msg);
278    }
279
280    #[test]
281    fn auth_methods_updated_params_roundtrip() {
282        let params = AuthMethodsUpdatedParams {
283            auth_methods: vec![
284                AuthMethod::Agent(
285                    AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
286                ),
287                AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
288            ],
289        };
290
291        let notification: ExtNotification = params.clone().into();
292        let parsed: AuthMethodsUpdatedParams =
293            from_str(notification.params.get()).expect("valid JSON");
294
295        assert_eq!(parsed, params);
296        assert_eq!(notification.method.as_ref(), AUTH_METHODS_UPDATED_METHOD);
297    }
298
299    #[test]
300    fn mcp_server_status_entry_serde_roundtrip() {
301        let entry = McpServerStatusEntry {
302            name: "test-server".to_string(),
303            status: McpServerStatus::Connected { tool_count: 3 },
304        };
305
306        let json = serde_json::to_string(&entry).unwrap();
307        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
308        assert_eq!(parsed, entry);
309    }
310
311    #[test]
312    fn elicitation_params_roundtrip() {
313        use rmcp::model::EnumSchema;
314
315        let params = ElicitationParams {
316            message: "Pick a color".to_string(),
317            schema: ElicitationSchema::builder()
318                .required_enum_schema(
319                    "color",
320                    EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()])
321                        .untitled()
322                        .build(),
323                )
324                .build()
325                .unwrap(),
326        };
327
328        let json = serde_json::to_string(&params).unwrap();
329        let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
330        assert_eq!(parsed, params);
331    }
332
333    #[test]
334    fn context_usage_params_roundtrip() {
335        let params = ContextUsageParams {
336            usage_ratio: Some(0.75),
337            tokens_used: 75000,
338            context_limit: Some(100_000),
339        };
340
341        let notification: ExtNotification = params.clone().into();
342        assert_eq!(notification.method.as_ref(), CONTEXT_USAGE_METHOD);
343
344        let parsed: ContextUsageParams =
345            serde_json::from_str(notification.params.get()).expect("valid JSON");
346        assert_eq!(parsed, params);
347    }
348
349    #[test]
350    fn context_cleared_params_roundtrip() {
351        let params = ContextClearedParams::default();
352
353        let notification: ExtNotification = params.clone().into();
354        assert_eq!(notification.method.as_ref(), CONTEXT_CLEARED_METHOD);
355
356        let parsed: ContextClearedParams =
357            serde_json::from_str(notification.params.get()).expect("valid JSON");
358        assert_eq!(parsed, params);
359    }
360
361    #[test]
362    fn sub_agent_progress_params_roundtrip() {
363        let params = SubAgentProgressParams {
364            parent_tool_id: "call_123".to_string(),
365            task_id: "task_abc".to_string(),
366            agent_name: "explorer".to_string(),
367            event: SubAgentEvent::Done,
368        };
369
370        let notification: ExtNotification = params.into();
371        assert_eq!(notification.method.as_ref(), SUB_AGENT_PROGRESS_METHOD);
372
373        let parsed: SubAgentProgressParams =
374            serde_json::from_str(notification.params.get()).expect("valid JSON");
375        assert!(matches!(parsed.event, SubAgentEvent::Done));
376        assert_eq!(parsed.parent_tool_id, "call_123");
377    }
378
379    #[test]
380    fn deserialize_tool_call_event() {
381        let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
382        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
383        assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
384    }
385
386    #[test]
387    fn deserialize_tool_call_update_event() {
388        // "model_name" is present because the wire format comes from AgentMessage serialization;
389        // SubAgentEvent::ToolCallUpdate has no model_name field, so serde silently ignores it.
390        let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
391        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
392        assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
393    }
394
395    #[test]
396    fn deserialize_tool_result_event() {
397        let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
398        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
399        match event {
400            SubAgentEvent::ToolResult { result } => {
401                let result_meta = result.result_meta.expect("expected result_meta");
402                assert_eq!(result_meta.display.title, "Grep");
403            }
404            other => panic!("Expected ToolResult, got {other:?}"),
405        }
406    }
407
408    #[test]
409    fn deserialize_tool_error_event() {
410        let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
411        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
412        assert!(matches!(event, SubAgentEvent::ToolError { .. }));
413    }
414
415    #[test]
416    fn deserialize_done_event() {
417        let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
418        assert!(matches!(event, SubAgentEvent::Done));
419    }
420
421    #[test]
422    fn deserialize_other_variant() {
423        let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
424        assert!(matches!(event, SubAgentEvent::Other));
425    }
426
427    #[test]
428    fn tool_result_meta_map_roundtrip() {
429        let meta: ToolResultMeta =
430            ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
431        let map = meta.clone().into_map();
432        let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
433        assert_eq!(parsed, meta);
434    }
435
436    #[test]
437    fn mcp_request_try_from_roundtrip() {
438        let msg = McpRequest::Authenticate {
439            session_id: "session-0".to_string(),
440            server_name: "my oauth server".to_string(),
441        };
442
443        let notification: ExtNotification = msg.clone().into();
444        let parsed = McpRequest::try_from(&notification).expect("should parse McpRequest");
445        assert_eq!(parsed, msg);
446    }
447
448    #[test]
449    fn mcp_notification_try_from_roundtrip() {
450        let msg = McpNotification::ServerStatus {
451            servers: vec![McpServerStatusEntry {
452                name: "github".to_string(),
453                status: McpServerStatus::Connected { tool_count: 5 },
454            }],
455        };
456
457        let notification: ExtNotification = msg.clone().into();
458        let parsed =
459            McpNotification::try_from(&notification).expect("should parse McpNotification");
460        assert_eq!(parsed, msg);
461    }
462
463    #[test]
464    fn auth_methods_updated_try_from_roundtrip() {
465        let params = AuthMethodsUpdatedParams {
466            auth_methods: vec![AuthMethod::Agent(
467                AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
468            )],
469        };
470
471        let notification: ExtNotification = params.clone().into();
472        let parsed =
473            AuthMethodsUpdatedParams::try_from(&notification).expect("should parse auth methods");
474        assert_eq!(parsed, params);
475    }
476
477    #[test]
478    fn try_from_wrong_method_returns_error() {
479        let notification = ext_notification(
480            CONTEXT_USAGE_METHOD,
481            &ContextUsageParams {
482                usage_ratio: Some(0.5),
483                tokens_used: 50000,
484                context_limit: Some(100_000),
485            },
486        );
487
488        let result = McpRequest::try_from(&notification);
489        assert!(matches!(
490            result,
491            Err(ExtNotificationParseError::WrongMethod)
492        ));
493    }
494
495    #[test]
496    fn try_from_invalid_json_returns_error() {
497        let notification = ext_notification(MCP_MESSAGE_METHOD, &"not a valid McpRequest");
498
499        let result = McpRequest::try_from(&notification);
500        assert!(matches!(
501            result,
502            Err(ExtNotificationParseError::InvalidJson(_))
503        ));
504    }
505
506    #[test]
507    fn ext_notification_parse_error_display() {
508        let wrong = ExtNotificationParseError::WrongMethod;
509        assert!(wrong.to_string().contains(MCP_MESSAGE_METHOD));
510
511        let json_err = serde_json::from_str::<McpRequest>("{}").unwrap_err();
512        let invalid = ExtNotificationParseError::InvalidJson(json_err);
513        assert!(invalid.to_string().contains("invalid JSON"));
514    }
515}