Skip to main content

acp_utils/
notifications.rs

1//! Shared wire-format types for Aether's custom ACP extension notifications.
2//!
3//! These types are used on both the agent (server) and client (UI) sides of the
4//! ACP connection.
5
6use agent_client_protocol::{AuthMethod, ExtNotification};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8use rmcp::model::ElicitationSchema;
9use serde::{Deserialize, Serialize};
10use serde_json::value::to_raw_value;
11use std::fmt;
12use std::sync::Arc;
13
14pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
15
16/// Custom notification methods for sub-agent progress updates.
17/// Per ACP extensibility spec, custom notifications must start with underscore.
18pub const SUB_AGENT_PROGRESS_METHOD: &str = "_aether/sub_agent_progress";
19pub const CONTEXT_USAGE_METHOD: &str = "_aether/context_usage";
20pub const CONTEXT_CLEARED_METHOD: &str = "_aether/context_cleared";
21pub const MCP_MESSAGE_METHOD: &str = "_aether/mcp";
22pub const AUTH_METHODS_UPDATED_METHOD: &str = "_aether/auth_methods_updated";
23
24/// Custom `ext_method` for tunneling MCP elicitation through ACP.
25/// Note: ACP auto-prefixes `ext_method` names with `_`, so the wire method
26/// becomes `_aether/elicitation`.
27pub const ELICITATION_METHOD: &str = "aether/elicitation";
28
29/// Parameters for `_aether/context_usage` notifications.
30#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
31pub struct ContextUsageParams {
32    pub usage_ratio: Option<f64>,
33    pub tokens_used: u32,
34    pub context_limit: Option<u32>,
35}
36
37/// Parameters for `_aether/context_cleared` notifications.
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
39pub struct ContextClearedParams {}
40
41/// Parameters for `_aether/auth_methods_updated` notifications.
42#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
43pub struct AuthMethodsUpdatedParams {
44    pub auth_methods: Vec<AuthMethod>,
45}
46
47/// Parameters sent via `ext_method` for `aether/elicitation`.
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub struct ElicitationParams {
50    pub message: String,
51    pub schema: ElicitationSchema,
52}
53
54pub use rmcp::model::ElicitationAction;
55
56/// Response returned from the client for an elicitation request.
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58pub struct ElicitationResponse {
59    pub action: ElicitationAction,
60    /// Structured form data when action is "accept".
61    pub content: Option<serde_json::Value>,
62}
63
64/// Server→client MCP extension notifications (relay → wisp).
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
66pub enum McpNotification {
67    ServerStatus { servers: Vec<McpServerStatusEntry> },
68}
69
70impl From<McpNotification> for ExtNotification {
71    fn from(msg: McpNotification) -> Self {
72        ext_notification(MCP_MESSAGE_METHOD, &msg)
73    }
74}
75
76/// Client→server MCP extension requests (wisp → relay).
77#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
78pub enum McpRequest {
79    Authenticate { session_id: String, server_name: String },
80}
81
82impl From<McpRequest> for ExtNotification {
83    fn from(msg: McpRequest) -> Self {
84        ext_notification(MCP_MESSAGE_METHOD, &msg)
85    }
86}
87
88/// Error returned when converting an `ExtNotification` into a typed MCP message.
89#[derive(Debug)]
90pub enum ExtNotificationParseError {
91    WrongMethod,
92    InvalidJson(serde_json::Error),
93}
94
95impl fmt::Display for ExtNotificationParseError {
96    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97        match self {
98            Self::WrongMethod => write!(f, "notification method is not {MCP_MESSAGE_METHOD}"),
99            Self::InvalidJson(e) => write!(f, "invalid JSON params: {e}"),
100        }
101    }
102}
103
104impl TryFrom<&ExtNotification> for McpRequest {
105    type Error = ExtNotificationParseError;
106
107    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
108        if n.method.as_ref() != MCP_MESSAGE_METHOD {
109            return Err(ExtNotificationParseError::WrongMethod);
110        }
111        serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
112    }
113}
114
115impl TryFrom<&ExtNotification> for McpNotification {
116    type Error = ExtNotificationParseError;
117
118    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
119        if n.method.as_ref() != MCP_MESSAGE_METHOD {
120            return Err(ExtNotificationParseError::WrongMethod);
121        }
122        serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
123    }
124}
125
126impl TryFrom<&ExtNotification> for AuthMethodsUpdatedParams {
127    type Error = ExtNotificationParseError;
128
129    fn try_from(n: &ExtNotification) -> Result<Self, Self::Error> {
130        if n.method.as_ref() != AUTH_METHODS_UPDATED_METHOD {
131            return Err(ExtNotificationParseError::WrongMethod);
132        }
133        serde_json::from_str(n.params.get()).map_err(ExtNotificationParseError::InvalidJson)
134    }
135}
136
137fn ext_notification<T: Serialize>(method: &str, params: &T) -> ExtNotification {
138    let raw_value = to_raw_value(params).expect("notification params are serializable");
139    ExtNotification::new(method, Arc::from(raw_value))
140}
141
142impl From<ContextUsageParams> for ExtNotification {
143    fn from(params: ContextUsageParams) -> Self {
144        ext_notification(CONTEXT_USAGE_METHOD, &params)
145    }
146}
147
148impl From<ContextClearedParams> for ExtNotification {
149    fn from(params: ContextClearedParams) -> Self {
150        ext_notification(CONTEXT_CLEARED_METHOD, &params)
151    }
152}
153
154impl From<AuthMethodsUpdatedParams> for ExtNotification {
155    fn from(params: AuthMethodsUpdatedParams) -> Self {
156        ext_notification(AUTH_METHODS_UPDATED_METHOD, &params)
157    }
158}
159
160/// Parameters for `_aether/sub_agent_progress` notifications.
161///
162/// This is the wire format sent from the ACP server (`aether-cli`) to clients like `wisp`.
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SubAgentProgressParams {
165    pub parent_tool_id: String,
166    pub task_id: String,
167    pub agent_name: String,
168    pub event: SubAgentEvent,
169}
170
171impl From<SubAgentProgressParams> for ExtNotification {
172    fn from(params: SubAgentProgressParams) -> Self {
173        ext_notification(SUB_AGENT_PROGRESS_METHOD, &params)
174    }
175}
176
177/// Subset of agent message variants relevant for sub-agent status display.
178///
179/// The ACP server (`aether-cli`) converts `AgentMessage` to this type before
180/// serializing, so the wire format only contains these known variants.
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub enum SubAgentEvent {
183    ToolCall { request: SubAgentToolRequest },
184    ToolCallUpdate { update: SubAgentToolCallUpdate },
185    ToolResult { result: SubAgentToolResult },
186    ToolError { error: SubAgentToolError },
187    Done,
188    Other,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct SubAgentToolRequest {
193    pub id: String,
194    pub name: String,
195    pub arguments: String,
196}
197
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct SubAgentToolCallUpdate {
200    pub id: String,
201    pub chunk: String,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SubAgentToolResult {
206    pub id: String,
207    pub name: String,
208    pub result_meta: Option<ToolResultMeta>,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
212pub struct SubAgentToolError {
213    pub id: String,
214    pub name: String,
215}
216
217#[cfg(test)]
218mod tests {
219    use agent_client_protocol::AuthMethodAgent;
220    use serde_json::from_str;
221
222    use super::*;
223
224    #[test]
225    fn method_constants_have_underscore_prefix() {
226        assert!(SUB_AGENT_PROGRESS_METHOD.starts_with('_'));
227        assert!(CONTEXT_USAGE_METHOD.starts_with('_'));
228        assert!(CONTEXT_CLEARED_METHOD.starts_with('_'));
229        assert!(MCP_MESSAGE_METHOD.starts_with('_'));
230        assert!(AUTH_METHODS_UPDATED_METHOD.starts_with('_'));
231    }
232
233    #[test]
234    fn mcp_request_authenticate_roundtrip() {
235        let msg = McpRequest::Authenticate {
236            session_id: "session-0".to_string(),
237            server_name: "my oauth server".to_string(),
238        };
239
240        let notification: ExtNotification = msg.clone().into();
241        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
242
243        let parsed: McpRequest = serde_json::from_str(notification.params.get()).expect("valid JSON");
244        assert_eq!(parsed, msg);
245    }
246
247    #[test]
248    fn mcp_notification_server_status_roundtrip() {
249        let msg = McpNotification::ServerStatus {
250            servers: vec![
251                McpServerStatusEntry {
252                    name: "github".to_string(),
253                    status: McpServerStatus::Connected { tool_count: 5 },
254                },
255                McpServerStatusEntry { name: "linear".to_string(), status: McpServerStatus::NeedsOAuth },
256                McpServerStatusEntry {
257                    name: "slack".to_string(),
258                    status: McpServerStatus::Failed { error: "connection timeout".to_string() },
259                },
260            ],
261        };
262
263        let notification: ExtNotification = msg.clone().into();
264        assert_eq!(notification.method.as_ref(), MCP_MESSAGE_METHOD);
265
266        let parsed: McpNotification = serde_json::from_str(notification.params.get()).expect("valid JSON");
267        assert_eq!(parsed, msg);
268    }
269
270    #[test]
271    fn auth_methods_updated_params_roundtrip() {
272        let params = AuthMethodsUpdatedParams {
273            auth_methods: vec![
274                AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
275                AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
276            ],
277        };
278
279        let notification: ExtNotification = params.clone().into();
280        let parsed: AuthMethodsUpdatedParams = from_str(notification.params.get()).expect("valid JSON");
281
282        assert_eq!(parsed, params);
283        assert_eq!(notification.method.as_ref(), AUTH_METHODS_UPDATED_METHOD);
284    }
285
286    #[test]
287    fn mcp_server_status_entry_serde_roundtrip() {
288        let entry = McpServerStatusEntry {
289            name: "test-server".to_string(),
290            status: McpServerStatus::Connected { tool_count: 3 },
291        };
292
293        let json = serde_json::to_string(&entry).unwrap();
294        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
295        assert_eq!(parsed, entry);
296    }
297
298    #[test]
299    fn elicitation_params_roundtrip() {
300        use rmcp::model::EnumSchema;
301
302        let params = ElicitationParams {
303            message: "Pick a color".to_string(),
304            schema: ElicitationSchema::builder()
305                .required_enum_schema(
306                    "color",
307                    EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
308                )
309                .build()
310                .unwrap(),
311        };
312
313        let json = serde_json::to_string(&params).unwrap();
314        let parsed: ElicitationParams = serde_json::from_str(&json).unwrap();
315        assert_eq!(parsed, params);
316    }
317
318    #[test]
319    fn context_usage_params_roundtrip() {
320        let params = ContextUsageParams { usage_ratio: Some(0.75), tokens_used: 75000, context_limit: Some(100_000) };
321
322        let notification: ExtNotification = params.clone().into();
323        assert_eq!(notification.method.as_ref(), CONTEXT_USAGE_METHOD);
324
325        let parsed: ContextUsageParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
326        assert_eq!(parsed, params);
327    }
328
329    #[test]
330    fn context_cleared_params_roundtrip() {
331        let params = ContextClearedParams::default();
332
333        let notification: ExtNotification = params.clone().into();
334        assert_eq!(notification.method.as_ref(), CONTEXT_CLEARED_METHOD);
335
336        let parsed: ContextClearedParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
337        assert_eq!(parsed, params);
338    }
339
340    #[test]
341    fn sub_agent_progress_params_roundtrip() {
342        let params = SubAgentProgressParams {
343            parent_tool_id: "call_123".to_string(),
344            task_id: "task_abc".to_string(),
345            agent_name: "explorer".to_string(),
346            event: SubAgentEvent::Done,
347        };
348
349        let notification: ExtNotification = params.into();
350        assert_eq!(notification.method.as_ref(), SUB_AGENT_PROGRESS_METHOD);
351
352        let parsed: SubAgentProgressParams = serde_json::from_str(notification.params.get()).expect("valid JSON");
353        assert!(matches!(parsed.event, SubAgentEvent::Done));
354        assert_eq!(parsed.parent_tool_id, "call_123");
355    }
356
357    #[test]
358    fn deserialize_tool_call_event() {
359        let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
360        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
361        assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
362    }
363
364    #[test]
365    fn deserialize_tool_call_update_event() {
366        // "model_name" is present because the wire format comes from AgentMessage serialization;
367        // SubAgentEvent::ToolCallUpdate has no model_name field, so serde silently ignores it.
368        let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
369        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
370        assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
371    }
372
373    #[test]
374    fn deserialize_tool_result_event() {
375        let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
376        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
377        match event {
378            SubAgentEvent::ToolResult { result } => {
379                let result_meta = result.result_meta.expect("expected result_meta");
380                assert_eq!(result_meta.display.title, "Grep");
381            }
382            other => panic!("Expected ToolResult, got {other:?}"),
383        }
384    }
385
386    #[test]
387    fn deserialize_tool_error_event() {
388        let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
389        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
390        assert!(matches!(event, SubAgentEvent::ToolError { .. }));
391    }
392
393    #[test]
394    fn deserialize_done_event() {
395        let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
396        assert!(matches!(event, SubAgentEvent::Done));
397    }
398
399    #[test]
400    fn deserialize_other_variant() {
401        let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
402        assert!(matches!(event, SubAgentEvent::Other));
403    }
404
405    #[test]
406    fn tool_result_meta_map_roundtrip() {
407        let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
408        let map = meta.clone().into_map();
409        let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
410        assert_eq!(parsed, meta);
411    }
412
413    #[test]
414    fn mcp_request_try_from_roundtrip() {
415        let msg = McpRequest::Authenticate {
416            session_id: "session-0".to_string(),
417            server_name: "my oauth server".to_string(),
418        };
419
420        let notification: ExtNotification = msg.clone().into();
421        let parsed = McpRequest::try_from(&notification).expect("should parse McpRequest");
422        assert_eq!(parsed, msg);
423    }
424
425    #[test]
426    fn mcp_notification_try_from_roundtrip() {
427        let msg = McpNotification::ServerStatus {
428            servers: vec![McpServerStatusEntry {
429                name: "github".to_string(),
430                status: McpServerStatus::Connected { tool_count: 5 },
431            }],
432        };
433
434        let notification: ExtNotification = msg.clone().into();
435        let parsed = McpNotification::try_from(&notification).expect("should parse McpNotification");
436        assert_eq!(parsed, msg);
437    }
438
439    #[test]
440    fn auth_methods_updated_try_from_roundtrip() {
441        let params = AuthMethodsUpdatedParams {
442            auth_methods: vec![AuthMethod::Agent(
443                AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated"),
444            )],
445        };
446
447        let notification: ExtNotification = params.clone().into();
448        let parsed = AuthMethodsUpdatedParams::try_from(&notification).expect("should parse auth methods");
449        assert_eq!(parsed, params);
450    }
451
452    #[test]
453    fn try_from_wrong_method_returns_error() {
454        let notification = ext_notification(
455            CONTEXT_USAGE_METHOD,
456            &ContextUsageParams { usage_ratio: Some(0.5), tokens_used: 50000, context_limit: Some(100_000) },
457        );
458
459        let result = McpRequest::try_from(&notification);
460        assert!(matches!(result, Err(ExtNotificationParseError::WrongMethod)));
461    }
462
463    #[test]
464    fn try_from_invalid_json_returns_error() {
465        let notification = ext_notification(MCP_MESSAGE_METHOD, &"not a valid McpRequest");
466
467        let result = McpRequest::try_from(&notification);
468        assert!(matches!(result, Err(ExtNotificationParseError::InvalidJson(_))));
469    }
470
471    #[test]
472    fn ext_notification_parse_error_display() {
473        let wrong = ExtNotificationParseError::WrongMethod;
474        assert!(wrong.to_string().contains(MCP_MESSAGE_METHOD));
475
476        let json_err = serde_json::from_str::<McpRequest>("{}").unwrap_err();
477        let invalid = ExtNotificationParseError::InvalidJson(json_err);
478        assert!(invalid.to_string().contains("invalid JSON"));
479    }
480}