Skip to main content

acp_utils/
notifications.rs

1//! Typed wire-format types for Aether's custom ACP extension requests and
2//! notifications.
3use agent_client_protocol::schema::AuthMethod;
4use agent_client_protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
5pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
6pub use rmcp::model::CreateElicitationRequestParams;
7use serde::{Deserialize, Serialize};
8
9pub use mcp_utils::status::{McpServerStatus, McpServerStatusEntry};
10
11/// Parameters for `_aether/context_usage` notifications.
12///
13/// Per-turn fields (`input_tokens`, `output_tokens`, `cache_read_tokens`,
14/// `cache_creation_tokens`, `reasoning_tokens`) come from the most recent
15/// API response. The `total_*` fields are cumulative across the agent's
16/// lifetime. The optional fields are `None` when the provider doesn't
17/// expose that dimension; this is semantically distinct from `Some(0)`.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcNotification)]
19#[notification(method = "_aether/context_usage")]
20pub struct ContextUsageParams {
21    pub usage_ratio: Option<f64>,
22    pub context_limit: Option<u32>,
23    pub input_tokens: u32,
24    #[serde(default)]
25    pub output_tokens: u32,
26    #[serde(default, skip_serializing_if = "Option::is_none")]
27    pub cache_read_tokens: Option<u32>,
28    #[serde(default, skip_serializing_if = "Option::is_none")]
29    pub cache_creation_tokens: Option<u32>,
30    #[serde(default, skip_serializing_if = "Option::is_none")]
31    pub reasoning_tokens: Option<u32>,
32    #[serde(default)]
33    pub total_input_tokens: u64,
34    #[serde(default)]
35    pub total_output_tokens: u64,
36    #[serde(default)]
37    pub total_cache_read_tokens: u64,
38    #[serde(default)]
39    pub total_cache_creation_tokens: u64,
40    #[serde(default)]
41    pub total_reasoning_tokens: u64,
42}
43
44/// Parameters for `_aether/context_cleared` notifications.
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default, JsonRpcNotification)]
46#[notification(method = "_aether/context_cleared")]
47pub struct ContextClearedParams {}
48
49/// Parameters for `_aether/auth_methods_updated` notifications.
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
51#[notification(method = "_aether/auth_methods_updated")]
52pub struct AuthMethodsUpdatedParams {
53    pub auth_methods: Vec<AuthMethod>,
54}
55
56/// Request parameters for the `_aether/elicitation` ext method.
57///
58/// Carries the full RMCP elicitation request plus the originating server name
59/// so the client can distinguish form vs URL mode and display which server is
60/// requesting.
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcRequest)]
62#[request(method = "_aether/elicitation", response = ElicitationResponse)]
63pub struct ElicitationParams {
64    pub server_name: String,
65    pub request: CreateElicitationRequestParams,
66}
67
68pub use rmcp::model::ElicitationAction;
69
70/// Response returned from the client for an elicitation request.
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcResponse)]
72pub struct ElicitationResponse {
73    pub action: ElicitationAction,
74    /// Structured form data when action is "accept".
75    pub content: Option<serde_json::Value>,
76}
77
78pub use mcp_utils::client::UrlElicitationCompleteParams;
79
80/// Server→client MCP extension notifications (relay → wisp).
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
82#[notification(method = "_aether/mcp_event")]
83pub enum McpNotification {
84    ServerStatus { servers: Vec<McpServerStatusEntry> },
85    UrlElicitationComplete(UrlElicitationCompleteParams),
86}
87
88/// Client→server MCP extension requests (wisp → relay).
89#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
90#[notification(method = "_aether/mcp_request")]
91pub enum McpRequest {
92    Authenticate { session_id: String, server_name: String },
93}
94
95/// Parameters for `_aether/sub_agent_progress` notifications.
96///
97/// This is the wire format sent from the ACP server (`aether-cli`) to clients like `wisp`.
98#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcNotification)]
99#[notification(method = "_aether/sub_agent_progress")]
100pub struct SubAgentProgressParams {
101    pub parent_tool_id: String,
102    pub task_id: String,
103    pub agent_name: String,
104    pub event: SubAgentEvent,
105}
106
107/// Subset of agent message variants relevant for sub-agent status display.
108///
109/// The ACP server (`aether-cli`) converts `AgentMessage` to this type before
110/// serializing, so the wire format only contains these known variants.
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum SubAgentEvent {
113    ToolCall { request: SubAgentToolRequest },
114    ToolCallUpdate { update: SubAgentToolCallUpdate },
115    ToolResult { result: SubAgentToolResult },
116    ToolError { error: SubAgentToolError },
117    Done,
118    Other,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct SubAgentToolRequest {
123    pub id: String,
124    pub name: String,
125    pub arguments: String,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct SubAgentToolCallUpdate {
130    pub id: String,
131    pub chunk: String,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct SubAgentToolResult {
136    pub id: String,
137    pub name: String,
138    pub result_meta: Option<ToolResultMeta>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct SubAgentToolError {
143    pub id: String,
144    pub name: String,
145}
146
147#[cfg(test)]
148mod tests {
149    use agent_client_protocol::JsonRpcMessage;
150    use agent_client_protocol::schema::AuthMethodAgent;
151
152    use super::*;
153
154    #[test]
155    fn wire_method_names_are_prefixed() {
156        assert_eq!(ContextClearedParams::default().method(), "_aether/context_cleared");
157        assert!(AuthMethodsUpdatedParams { auth_methods: vec![] }.method() == "_aether/auth_methods_updated");
158        assert!(McpNotification::ServerStatus { servers: vec![] }.method() == "_aether/mcp_event");
159        assert!(
160            McpRequest::Authenticate { session_id: String::new(), server_name: String::new() }.method()
161                == "_aether/mcp_request"
162        );
163    }
164
165    #[test]
166    fn context_usage_params_roundtrip() {
167        let params = ContextUsageParams {
168            usage_ratio: Some(0.75),
169            context_limit: Some(100_000),
170            input_tokens: 75_000,
171            output_tokens: 1_200,
172            cache_read_tokens: Some(40_000),
173            cache_creation_tokens: Some(2_000),
174            reasoning_tokens: Some(500),
175            total_input_tokens: 200_000,
176            total_output_tokens: 8_000,
177            total_cache_read_tokens: 90_000,
178            total_cache_creation_tokens: 5_000,
179            total_reasoning_tokens: 1_500,
180        };
181
182        let untyped = params.to_untyped_message().expect("serializable");
183        assert_eq!(untyped.method(), "_aether/context_usage");
184        let parsed = ContextUsageParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
185        assert_eq!(parsed, params);
186    }
187
188    #[test]
189    fn context_usage_params_omits_unset_optional_token_fields() {
190        let params = ContextUsageParams {
191            usage_ratio: Some(0.1),
192            context_limit: Some(1_000),
193            input_tokens: 100,
194            output_tokens: 0,
195            cache_read_tokens: None,
196            cache_creation_tokens: None,
197            reasoning_tokens: None,
198            total_input_tokens: 0,
199            total_output_tokens: 0,
200            total_cache_read_tokens: 0,
201            total_cache_creation_tokens: 0,
202            total_reasoning_tokens: 0,
203        };
204
205        let raw = serde_json::to_string(&params).unwrap();
206        assert!(!raw.contains("\"cache_read_tokens\""));
207        assert!(!raw.contains("\"cache_creation_tokens\""));
208        assert!(!raw.contains("\"reasoning_tokens\""));
209    }
210
211    #[test]
212    fn context_cleared_params_roundtrip() {
213        let params = ContextClearedParams::default();
214        let untyped = params.to_untyped_message().expect("serializable");
215        assert_eq!(untyped.method(), "_aether/context_cleared");
216        let parsed = ContextClearedParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
217        assert_eq!(parsed, params);
218    }
219
220    #[test]
221    fn auth_methods_updated_roundtrip() {
222        let params = AuthMethodsUpdatedParams {
223            auth_methods: vec![
224                AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
225                AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
226            ],
227        };
228
229        let untyped = params.to_untyped_message().expect("serializable");
230        assert_eq!(untyped.method(), "_aether/auth_methods_updated");
231        let parsed = AuthMethodsUpdatedParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
232        assert_eq!(parsed, params);
233    }
234
235    #[test]
236    fn mcp_request_authenticate_roundtrip() {
237        let msg = McpRequest::Authenticate {
238            session_id: "session-0".to_string(),
239            server_name: "my oauth server".to_string(),
240        };
241
242        let untyped = msg.to_untyped_message().expect("serializable");
243        assert_eq!(untyped.method(), "_aether/mcp_request");
244        let parsed = McpRequest::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
245        assert_eq!(parsed, msg);
246    }
247
248    #[test]
249    fn mcp_notification_server_status_roundtrip() {
250        let msg = McpNotification::ServerStatus {
251            servers: vec![
252                McpServerStatusEntry {
253                    name: "github".to_string(),
254                    status: McpServerStatus::Connected { tool_count: 5 },
255                },
256                McpServerStatusEntry { name: "linear".to_string(), status: McpServerStatus::NeedsOAuth },
257                McpServerStatusEntry {
258                    name: "slack".to_string(),
259                    status: McpServerStatus::Failed { error: "connection timeout".to_string() },
260                },
261            ],
262        };
263
264        let untyped = msg.to_untyped_message().expect("serializable");
265        assert_eq!(untyped.method(), "_aether/mcp_event");
266        let parsed = McpNotification::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
267        assert_eq!(parsed, msg);
268    }
269
270    #[test]
271    fn mcp_notification_url_elicitation_complete_roundtrip() {
272        let msg = McpNotification::UrlElicitationComplete(UrlElicitationCompleteParams {
273            server_name: "github".to_string(),
274            elicitation_id: "el-456".to_string(),
275        });
276
277        let untyped = msg.to_untyped_message().expect("serializable");
278        let parsed = McpNotification::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
279        assert_eq!(parsed, msg);
280    }
281
282    #[test]
283    fn sub_agent_progress_params_roundtrip() {
284        let params = SubAgentProgressParams {
285            parent_tool_id: "call_123".to_string(),
286            task_id: "task_abc".to_string(),
287            agent_name: "explorer".to_string(),
288            event: SubAgentEvent::Done,
289        };
290
291        let untyped = params.to_untyped_message().expect("serializable");
292        assert_eq!(untyped.method(), "_aether/sub_agent_progress");
293    }
294
295    #[test]
296    fn elicitation_params_roundtrip() {
297        use rmcp::model::{ElicitationSchema, EnumSchema};
298
299        let params = ElicitationParams {
300            server_name: "github".to_string(),
301            request: CreateElicitationRequestParams::FormElicitationParams {
302                meta: None,
303                message: "Pick a color".to_string(),
304                requested_schema: ElicitationSchema::builder()
305                    .required_enum_schema(
306                        "color",
307                        EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
308                    )
309                    .build()
310                    .unwrap(),
311            },
312        };
313
314        let untyped = params.to_untyped_message().expect("serializable");
315        assert_eq!(untyped.method(), "_aether/elicitation");
316        let parsed = ElicitationParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
317        assert_eq!(parsed, params);
318    }
319
320    #[test]
321    fn elicitation_params_url_variant_has_mode_field() {
322        let params = ElicitationParams {
323            server_name: "github".to_string(),
324            request: CreateElicitationRequestParams::UrlElicitationParams {
325                meta: None,
326                message: "Authorize GitHub".to_string(),
327                url: "https://github.com/login/oauth".to_string(),
328                elicitation_id: "el-123".to_string(),
329            },
330        };
331
332        let json = serde_json::to_string(&params).unwrap();
333        assert!(json.contains("\"mode\":\"url\""));
334        assert!(json.contains("\"server_name\":\"github\""));
335    }
336
337    #[test]
338    fn mcp_server_status_entry_serde_roundtrip() {
339        let entry = McpServerStatusEntry {
340            name: "test-server".to_string(),
341            status: McpServerStatus::Connected { tool_count: 3 },
342        };
343
344        let json = serde_json::to_string(&entry).unwrap();
345        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
346        assert_eq!(parsed, entry);
347    }
348
349    #[test]
350    fn deserialize_tool_call_event() {
351        let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
352        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
353        assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
354    }
355
356    #[test]
357    fn deserialize_tool_call_update_event() {
358        let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
359        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
360        assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
361    }
362
363    #[test]
364    fn deserialize_tool_result_event() {
365        let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
366        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
367        match event {
368            SubAgentEvent::ToolResult { result } => {
369                let result_meta = result.result_meta.expect("expected result_meta");
370                assert_eq!(result_meta.display.title, "Grep");
371            }
372            other => panic!("Expected ToolResult, got {other:?}"),
373        }
374    }
375
376    #[test]
377    fn deserialize_tool_error_event() {
378        let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
379        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
380        assert!(matches!(event, SubAgentEvent::ToolError { .. }));
381    }
382
383    #[test]
384    fn deserialize_done_event() {
385        let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
386        assert!(matches!(event, SubAgentEvent::Done));
387    }
388
389    #[test]
390    fn deserialize_other_variant() {
391        let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
392        assert!(matches!(event, SubAgentEvent::Other));
393    }
394
395    #[test]
396    fn tool_result_meta_map_roundtrip() {
397        let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
398        let map = meta.clone().into_map();
399        let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
400        assert_eq!(parsed, meta);
401    }
402}