Skip to main content

acp_utils/
notifications.rs

1//! Typed wire-format types for Aether's custom ACP extension requests and
2//! notifications.
3use agent_client_protocol::schema::AuthMethod;
4use agent_client_protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
5pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
6pub use rmcp::model::CreateElicitationRequestParams;
7use serde::{Deserialize, Serialize};
8
9pub use mcp_utils::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
10
11/// Parameters for `_aether/context_usage` notifications.
12///
13/// Per-turn fields (`input_tokens`, `output_tokens`, `cache_read_tokens`,
14/// `cache_creation_tokens`, `reasoning_tokens`) come from the most recent
15/// API response. The `total_*` fields are cumulative across the agent's
16/// lifetime. The optional fields are `None` when the provider doesn't
17/// expose that dimension; this is semantically distinct from `Some(0)`.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcNotification)]
19#[notification(method = "_aether/context_usage")]
20pub struct ContextUsageParams {
21    pub usage_ratio: Option<f64>,
22    pub context_limit: Option<u32>,
23    pub input_tokens: u32,
24    #[serde(default)]
25    pub output_tokens: u32,
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub cache_read_tokens: Option<u32>,
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub cache_creation_tokens: Option<u32>,
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub reasoning_tokens: Option<u32>,
32    #[serde(default)]
33    pub total_input_tokens: u64,
34    #[serde(default)]
35    pub total_output_tokens: u64,
36    #[serde(default)]
37    pub total_cache_read_tokens: u64,
38    #[serde(default)]
39    pub total_cache_creation_tokens: u64,
40    #[serde(default)]
41    pub total_reasoning_tokens: u64,
42}
43
44/// Parameters for `_aether/context_cleared` notifications.
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default, JsonRpcNotification)]
46#[notification(method = "_aether/context_cleared")]
47pub struct ContextClearedParams {}
48
49/// Parameters for `_aether/auth_methods_updated` notifications.
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
51#[notification(method = "_aether/auth_methods_updated")]
52pub struct AuthMethodsUpdatedParams {
53    pub auth_methods: Vec<AuthMethod>,
54}
55
56/// Request parameters for the `_aether/elicitation` ext method.
57///
58/// Carries the full RMCP elicitation request plus the originating server name
59/// so the client can distinguish form vs URL mode and display which server is
60/// requesting.
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcRequest)]
62#[request(method = "_aether/elicitation", response = ElicitationResponse)]
63pub struct ElicitationParams {
64    pub server_name: String,
65    pub request: CreateElicitationRequestParams,
66}
67
68pub use rmcp::model::ElicitationAction;
69
70/// Response returned from the client for an elicitation request.
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcResponse)]
72pub struct ElicitationResponse {
73    pub action: ElicitationAction,
74    /// Structured form data when action is "accept".
75    pub content: Option<serde_json::Value>,
76}
77
78pub use mcp_utils::client::UrlElicitationCompleteParams;
79
80/// Server→client MCP extension notifications (relay → wisp).
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
82#[notification(method = "_aether/mcp_event")]
83pub enum McpNotification {
84    ServerStatus { servers: Vec<McpServerStatusEntry> },
85    UrlElicitationComplete(UrlElicitationCompleteParams),
86}
87
88/// Client→server MCP extension requests (wisp → relay).
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
90#[notification(method = "_aether/mcp_request")]
91pub enum McpRequest {
92    Authenticate { session_id: String, server_name: String },
93}
94
95/// Parameters for `_aether/sub_agent_progress` notifications.
96///
97/// This is the wire format sent from the ACP server (`aether-cli`) to clients like `wisp`.
98#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcNotification)]
99#[notification(method = "_aether/sub_agent_progress")]
100pub struct SubAgentProgressParams {
101    pub parent_tool_id: String,
102    pub task_id: String,
103    pub agent_name: String,
104    pub event: SubAgentEvent,
105}
106
107/// Subset of agent message variants relevant for sub-agent status display.
108///
109/// The ACP server (`aether-cli`) converts `AgentMessage` to this type before
110/// serializing, so the wire format only contains these known variants.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum SubAgentEvent {
113    ToolCall { request: SubAgentToolRequest },
114    ToolCallUpdate { update: SubAgentToolCallUpdate },
115    ToolResult { result: SubAgentToolResult },
116    ToolError { error: SubAgentToolError },
117    Done,
118    Other,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SubAgentToolRequest {
123    pub id: String,
124    pub name: String,
125    pub arguments: String,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct SubAgentToolCallUpdate {
130    pub id: String,
131    pub chunk: String,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SubAgentToolResult {
136    pub id: String,
137    pub name: String,
138    pub result_meta: Option<ToolResultMeta>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct SubAgentToolError {
143    pub id: String,
144    pub name: String,
145}
146
147#[cfg(test)]
148mod tests {
149    use agent_client_protocol::JsonRpcMessage;
150    use agent_client_protocol::schema::AuthMethodAgent;
151
152    use super::*;
153
154    #[test]
155    fn wire_method_names_are_prefixed() {
156        assert_eq!(ContextClearedParams::default().method(), "_aether/context_cleared");
157        assert!(AuthMethodsUpdatedParams { auth_methods: vec![] }.method() == "_aether/auth_methods_updated");
158        assert!(McpNotification::ServerStatus { servers: vec![] }.method() == "_aether/mcp_event");
159        assert!(
160            McpRequest::Authenticate { session_id: String::new(), server_name: String::new() }.method()
161                == "_aether/mcp_request"
162        );
163    }
164
165    #[test]
166    fn context_usage_params_roundtrip() {
167        let params = ContextUsageParams {
168            usage_ratio: Some(0.75),
169            context_limit: Some(100_000),
170            input_tokens: 75_000,
171            output_tokens: 1_200,
172            cache_read_tokens: Some(40_000),
173            cache_creation_tokens: Some(2_000),
174            reasoning_tokens: Some(500),
175            total_input_tokens: 200_000,
176            total_output_tokens: 8_000,
177            total_cache_read_tokens: 90_000,
178            total_cache_creation_tokens: 5_000,
179            total_reasoning_tokens: 1_500,
180        };
181
182        let untyped = params.to_untyped_message().expect("serializable");
183        assert_eq!(untyped.method(), "_aether/context_usage");
184        let parsed = ContextUsageParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
185        assert_eq!(parsed, params);
186    }
187
188    #[test]
189    fn context_usage_params_omits_unset_optional_token_fields() {
190        let params = ContextUsageParams {
191            usage_ratio: Some(0.1),
192            context_limit: Some(1_000),
193            input_tokens: 100,
194            output_tokens: 0,
195            cache_read_tokens: None,
196            cache_creation_tokens: None,
197            reasoning_tokens: None,
198            total_input_tokens: 0,
199            total_output_tokens: 0,
200            total_cache_read_tokens: 0,
201            total_cache_creation_tokens: 0,
202            total_reasoning_tokens: 0,
203        };
204
205        let raw = serde_json::to_string(&params).unwrap();
206        assert!(!raw.contains("\"cache_read_tokens\""));
207        assert!(!raw.contains("\"cache_creation_tokens\""));
208        assert!(!raw.contains("\"reasoning_tokens\""));
209    }
210
211    #[test]
212    fn context_cleared_params_roundtrip() {
213        let params = ContextClearedParams::default();
214        let untyped = params.to_untyped_message().expect("serializable");
215        assert_eq!(untyped.method(), "_aether/context_cleared");
216        let parsed = ContextClearedParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
217        assert_eq!(parsed, params);
218    }
219
220    #[test]
221    fn auth_methods_updated_roundtrip() {
222        let params = AuthMethodsUpdatedParams {
223            auth_methods: vec![
224                AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
225                AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
226            ],
227        };
228
229        let untyped = params.to_untyped_message().expect("serializable");
230        assert_eq!(untyped.method(), "_aether/auth_methods_updated");
231        let parsed = AuthMethodsUpdatedParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
232        assert_eq!(parsed, params);
233    }
234
235    #[test]
236    fn mcp_request_authenticate_roundtrip() {
237        let msg = McpRequest::Authenticate {
238            session_id: "session-0".to_string(),
239            server_name: "my oauth server".to_string(),
240        };
241
242        let untyped = msg.to_untyped_message().expect("serializable");
243        assert_eq!(untyped.method(), "_aether/mcp_request");
244        let parsed = McpRequest::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
245        assert_eq!(parsed, msg);
246    }
247
248    #[test]
249    fn mcp_notification_server_status_roundtrip() {
250        let msg = McpNotification::ServerStatus {
251            servers: vec![
252                McpServerStatusEntry::new("github", McpServerStatus::Connected { tool_count: 5 }),
253                McpServerStatusEntry::new("linear", McpServerStatus::NeedsOAuth)
254                    .with_auth_capability(McpServerAuthCapability::OAuth),
255                McpServerStatusEntry::new("slack", McpServerStatus::Failed { error: "connection timeout".to_string() }),
256            ],
257        };
258
259        let untyped = msg.to_untyped_message().expect("serializable");
260        assert_eq!(untyped.method(), "_aether/mcp_event");
261        let parsed = McpNotification::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
262        assert_eq!(parsed, msg);
263    }
264
265    #[test]
266    fn mcp_notification_url_elicitation_complete_roundtrip() {
267        let msg = McpNotification::UrlElicitationComplete(UrlElicitationCompleteParams {
268            server_name: "github".to_string(),
269            elicitation_id: "el-456".to_string(),
270        });
271
272        let untyped = msg.to_untyped_message().expect("serializable");
273        let parsed = McpNotification::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
274        assert_eq!(parsed, msg);
275    }
276
277    #[test]
278    fn sub_agent_progress_params_roundtrip() {
279        let params = SubAgentProgressParams {
280            parent_tool_id: "call_123".to_string(),
281            task_id: "task_abc".to_string(),
282            agent_name: "explorer".to_string(),
283            event: SubAgentEvent::Done,
284        };
285
286        let untyped = params.to_untyped_message().expect("serializable");
287        assert_eq!(untyped.method(), "_aether/sub_agent_progress");
288    }
289
290    #[test]
291    fn elicitation_params_roundtrip() {
292        use rmcp::model::{ElicitationSchema, EnumSchema};
293
294        let params = ElicitationParams {
295            server_name: "github".to_string(),
296            request: CreateElicitationRequestParams::FormElicitationParams {
297                meta: None,
298                message: "Pick a color".to_string(),
299                requested_schema: ElicitationSchema::builder()
300                    .required_enum_schema(
301                        "color",
302                        EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
303                    )
304                    .build()
305                    .unwrap(),
306            },
307        };
308
309        let untyped = params.to_untyped_message().expect("serializable");
310        assert_eq!(untyped.method(), "_aether/elicitation");
311        let parsed = ElicitationParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
312        assert_eq!(parsed, params);
313    }
314
315    #[test]
316    fn elicitation_params_url_variant_has_mode_field() {
317        let params = ElicitationParams {
318            server_name: "github".to_string(),
319            request: CreateElicitationRequestParams::UrlElicitationParams {
320                meta: None,
321                message: "Authorize GitHub".to_string(),
322                url: "https://github.com/login/oauth".to_string(),
323                elicitation_id: "el-123".to_string(),
324            },
325        };
326
327        let json = serde_json::to_string(&params).unwrap();
328        assert!(json.contains("\"mode\":\"url\""));
329        assert!(json.contains("\"server_name\":\"github\""));
330    }
331
332    #[test]
333    fn mcp_server_status_entry_serde_roundtrip() {
334        let entry = McpServerStatusEntry::new("test-server", McpServerStatus::Connected { tool_count: 3 })
335            .with_auth_capability(McpServerAuthCapability::OAuth);
336
337        let json = serde_json::to_string(&entry).unwrap();
338        assert!(json.contains("\"auth_capability\":\"OAuth\""));
339        assert!(json.contains("\"proxied\":false"));
340        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
341        assert_eq!(parsed, entry);
342        assert!(!parsed.proxied);
343        assert!(parsed.can_authenticate());
344    }
345
346    #[test]
347    fn mcp_server_status_entry_proxied_serde_roundtrip() {
348        let entry = McpServerStatusEntry::new("math", McpServerStatus::NeedsOAuth)
349            .with_auth_capability(McpServerAuthCapability::OAuth)
350            .with_proxied(true);
351
352        let json = serde_json::to_string(&entry).unwrap();
353        assert!(json.contains("\"proxied\":true"));
354        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
355        assert_eq!(parsed, entry);
356    }
357
358    #[test]
359    fn deserialize_tool_call_event() {
360        let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
361        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
362        assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
363    }
364
365    #[test]
366    fn deserialize_tool_call_update_event() {
367        let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
368        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
369        assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
370    }
371
372    #[test]
373    fn deserialize_tool_result_event() {
374        let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
375        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
376        match event {
377            SubAgentEvent::ToolResult { result } => {
378                let result_meta = result.result_meta.expect("expected result_meta");
379                assert_eq!(result_meta.display.title, "Grep");
380            }
381            other => panic!("Expected ToolResult, got {other:?}"),
382        }
383    }
384
385    #[test]
386    fn deserialize_tool_error_event() {
387        let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
388        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
389        assert!(matches!(event, SubAgentEvent::ToolError { .. }));
390    }
391
392    #[test]
393    fn deserialize_done_event() {
394        let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
395        assert!(matches!(event, SubAgentEvent::Done));
396    }
397
398    #[test]
399    fn deserialize_other_variant() {
400        let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
401        assert!(matches!(event, SubAgentEvent::Other));
402    }
403
404    #[test]
405    fn tool_result_meta_map_roundtrip() {
406        let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
407        let map = meta.clone().into_map();
408        let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
409        assert_eq!(parsed, meta);
410    }
411}