Skip to main content

acp_utils/
notifications.rs

1//! Typed wire-format types for Aether's custom ACP extension requests and
2//! notifications.
3use std::path::PathBuf;
4
5use agent_client_protocol::schema::AuthMethod;
6use agent_client_protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
7pub use mcp_utils::display_meta::{ToolDisplayMeta, ToolResultMeta};
8pub use rmcp::model::CreateElicitationRequestParams;
9use serde::{Deserialize, Serialize};
10
11pub use mcp_utils::status::{McpServerAuthCapability, McpServerStatus, McpServerStatusEntry};
12
13pub const AETHER_META_NAMESPACE: &str = "contextbridge/aether";
14pub const PROMPT_SEARCH_CAPABILITY_KEY: &str = "promptSearch";
15
16/// Parameters for `_aether/context_usage` notifications.
17///
18/// Per-turn fields (`input_tokens`, `output_tokens`, `cache_read_tokens`,
19/// `cache_creation_tokens`, `reasoning_tokens`) come from the most recent
20/// API response. The `total_*` fields are cumulative across the agent's
21/// lifetime. The optional fields are `None` when the provider doesn't
22/// expose that dimension; this is semantically distinct from `Some(0)`.
23#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcNotification)]
24#[notification(method = "_aether/context_usage")]
25pub struct ContextUsageParams {
26    pub usage_ratio: Option<f64>,
27    pub context_limit: Option<u32>,
28    pub input_tokens: u32,
29    #[serde(default)]
30    pub output_tokens: u32,
31    #[serde(default, skip_serializing_if = "Option::is_none")]
32    pub cache_read_tokens: Option<u32>,
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub cache_creation_tokens: Option<u32>,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub reasoning_tokens: Option<u32>,
37    #[serde(default)]
38    pub total_input_tokens: u64,
39    #[serde(default)]
40    pub total_output_tokens: u64,
41    #[serde(default)]
42    pub total_cache_read_tokens: u64,
43    #[serde(default)]
44    pub total_cache_creation_tokens: u64,
45    #[serde(default)]
46    pub total_reasoning_tokens: u64,
47}
48
49/// Parameters for `_aether/context_cleared` notifications.
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default, JsonRpcNotification)]
51#[notification(method = "_aether/context_cleared")]
52pub struct ContextClearedParams {}
53
54/// Parameters for `_aether/auth_methods_updated` notifications.
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
56#[notification(method = "_aether/auth_methods_updated")]
57pub struct AuthMethodsUpdatedParams {
58    pub auth_methods: Vec<AuthMethod>,
59}
60
61/// Request parameters for the `_aether/elicitation` ext method.
62///
63/// Carries the full RMCP elicitation request plus the originating server name
64/// so the client can distinguish form vs URL mode and display which server is
65/// requesting.
66#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcRequest)]
67#[request(method = "_aether/elicitation", response = ElicitationResponse)]
68pub struct ElicitationParams {
69    pub server_name: String,
70    pub request: CreateElicitationRequestParams,
71}
72
73pub use rmcp::model::ElicitationAction;
74
75/// Parameters for the `_aether/prompt_search` request.
76#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcRequest)]
77#[request(method = "_aether/prompt_search", response = PromptSearchResponse)]
78#[serde(rename_all = "camelCase")]
79pub struct PromptSearchParams {
80    pub query: String,
81    #[serde(default, skip_serializing_if = "Option::is_none")]
82    pub limit: Option<usize>,
83}
84
85/// Response for the `_aether/prompt_search` request.
86#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcResponse)]
87#[serde(rename_all = "camelCase")]
88pub struct PromptSearchResponse {
89    pub query: String,
90    pub results: Vec<PromptSearchResult>,
91    pub truncated: bool,
92}
93
94/// A single prompt-history search hit.
95///
96/// `match_start` and `match_end` are UTF-8 byte offsets into `prompt` and are
97/// guaranteed to fall on char boundaries.
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
99#[serde(rename_all = "camelCase")]
100pub struct PromptSearchResult {
101    pub session_id: String,
102    pub cwd: PathBuf,
103    pub session_created_at: String,
104    pub prompt: String,
105    pub match_start: usize,
106    pub match_end: usize,
107}
108
109/// Metadata advertised on `PromptCapabilities::_meta` when the agent supports
110/// `_aether/prompt_search`.
111pub mod prompt_search_capability {
112    use super::{AETHER_META_NAMESPACE, PROMPT_SEARCH_CAPABILITY_KEY};
113    use agent_client_protocol::schema::Meta;
114    use serde_json::{Value, json};
115
116    #[must_use]
117    pub fn to_meta() -> Meta {
118        let mut meta = Meta::new();
119        meta.insert(AETHER_META_NAMESPACE.to_string(), json!({ PROMPT_SEARCH_CAPABILITY_KEY: true }));
120        meta
121    }
122
123    #[must_use]
124    pub fn is_advertised(meta: Option<&Meta>) -> bool {
125        meta.and_then(|m| m.get(AETHER_META_NAMESPACE))
126            .and_then(Value::as_object)
127            .and_then(|aether| aether.get(PROMPT_SEARCH_CAPABILITY_KEY))
128            .and_then(Value::as_bool)
129            .unwrap_or(false)
130    }
131}
132
133/// Response returned from the client for an elicitation request.
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonRpcResponse)]
135pub struct ElicitationResponse {
136    pub action: ElicitationAction,
137    /// Structured form data when action is "accept".
138    pub content: Option<serde_json::Value>,
139}
140
141pub use mcp_utils::client::UrlElicitationCompleteParams;
142
143/// Server→client MCP extension notifications (relay → wisp).
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
145#[notification(method = "_aether/mcp_event")]
146pub enum McpNotification {
147    ServerStatus { servers: Vec<McpServerStatusEntry> },
148    UrlElicitationComplete(UrlElicitationCompleteParams),
149}
150
151/// Client→server MCP extension requests (wisp → relay).
152#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, JsonRpcNotification)]
153#[notification(method = "_aether/mcp_request")]
154pub enum McpRequest {
155    Authenticate { session_id: String, server_name: String },
156}
157
158/// Parameters for `_aether/sub_agent_progress` notifications.
159///
160/// This is the wire format sent from the ACP server (`aether-cli`) to clients like `wisp`.
161#[derive(Debug, Clone, Serialize, Deserialize, JsonRpcNotification)]
162#[notification(method = "_aether/sub_agent_progress")]
163pub struct SubAgentProgressParams {
164    pub parent_tool_id: String,
165    pub task_id: String,
166    pub agent_name: String,
167    pub event: SubAgentEvent,
168}
169
170/// Subset of agent message variants relevant for sub-agent status display.
171///
172/// The ACP server (`aether-cli`) converts `AgentMessage` to this type before
173/// serializing, so the wire format only contains these known variants.
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub enum SubAgentEvent {
176    ToolCall { request: SubAgentToolRequest },
177    ToolCallUpdate { update: SubAgentToolCallUpdate },
178    ToolResult { result: SubAgentToolResult },
179    ToolError { error: SubAgentToolError },
180    Done,
181    Other,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct SubAgentToolRequest {
186    pub id: String,
187    pub name: String,
188    pub arguments: String,
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct SubAgentToolCallUpdate {
193    pub id: String,
194    pub chunk: String,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct SubAgentToolResult {
199    pub id: String,
200    pub name: String,
201    pub result_meta: Option<ToolResultMeta>,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct SubAgentToolError {
206    pub id: String,
207    pub name: String,
208}
209
210#[cfg(test)]
211mod tests {
212    use agent_client_protocol::JsonRpcMessage;
213    use agent_client_protocol::schema::AuthMethodAgent;
214
215    use super::*;
216
217    #[test]
218    fn wire_method_names_are_prefixed() {
219        assert_eq!(ContextClearedParams::default().method(), "_aether/context_cleared");
220        assert!(AuthMethodsUpdatedParams { auth_methods: vec![] }.method() == "_aether/auth_methods_updated");
221        assert!(McpNotification::ServerStatus { servers: vec![] }.method() == "_aether/mcp_event");
222        assert!(
223            McpRequest::Authenticate { session_id: String::new(), server_name: String::new() }.method()
224                == "_aether/mcp_request"
225        );
226        assert_eq!(PromptSearchParams { query: String::new(), limit: None }.method(), "_aether/prompt_search");
227    }
228
229    #[test]
230    fn prompt_search_capability_meta_roundtrip() {
231        let meta = prompt_search_capability::to_meta();
232        assert!(prompt_search_capability::is_advertised(Some(&meta)));
233        assert!(!prompt_search_capability::is_advertised(None));
234        assert!(!prompt_search_capability::is_advertised(Some(&agent_client_protocol::schema::Meta::new())));
235    }
236
237    #[test]
238    fn context_usage_params_roundtrip() {
239        let params = ContextUsageParams {
240            usage_ratio: Some(0.75),
241            context_limit: Some(100_000),
242            input_tokens: 75_000,
243            output_tokens: 1_200,
244            cache_read_tokens: Some(40_000),
245            cache_creation_tokens: Some(2_000),
246            reasoning_tokens: Some(500),
247            total_input_tokens: 200_000,
248            total_output_tokens: 8_000,
249            total_cache_read_tokens: 90_000,
250            total_cache_creation_tokens: 5_000,
251            total_reasoning_tokens: 1_500,
252        };
253
254        let untyped = params.to_untyped_message().expect("serializable");
255        assert_eq!(untyped.method(), "_aether/context_usage");
256        let parsed = ContextUsageParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
257        assert_eq!(parsed, params);
258    }
259
260    #[test]
261    fn context_usage_params_omits_unset_optional_token_fields() {
262        let params = ContextUsageParams {
263            usage_ratio: Some(0.1),
264            context_limit: Some(1_000),
265            input_tokens: 100,
266            output_tokens: 0,
267            cache_read_tokens: None,
268            cache_creation_tokens: None,
269            reasoning_tokens: None,
270            total_input_tokens: 0,
271            total_output_tokens: 0,
272            total_cache_read_tokens: 0,
273            total_cache_creation_tokens: 0,
274            total_reasoning_tokens: 0,
275        };
276
277        let raw = serde_json::to_string(&params).unwrap();
278        assert!(!raw.contains("\"cache_read_tokens\""));
279        assert!(!raw.contains("\"cache_creation_tokens\""));
280        assert!(!raw.contains("\"reasoning_tokens\""));
281    }
282
283    #[test]
284    fn context_cleared_params_roundtrip() {
285        let params = ContextClearedParams::default();
286        let untyped = params.to_untyped_message().expect("serializable");
287        assert_eq!(untyped.method(), "_aether/context_cleared");
288        let parsed = ContextClearedParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
289        assert_eq!(parsed, params);
290    }
291
292    #[test]
293    fn auth_methods_updated_roundtrip() {
294        let params = AuthMethodsUpdatedParams {
295            auth_methods: vec![
296                AuthMethod::Agent(AuthMethodAgent::new("anthropic", "Anthropic").description("authenticated")),
297                AuthMethod::Agent(AuthMethodAgent::new("openrouter", "OpenRouter")),
298            ],
299        };
300
301        let untyped = params.to_untyped_message().expect("serializable");
302        assert_eq!(untyped.method(), "_aether/auth_methods_updated");
303        let parsed = AuthMethodsUpdatedParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
304        assert_eq!(parsed, params);
305    }
306
307    #[test]
308    fn mcp_request_authenticate_roundtrip() {
309        let msg = McpRequest::Authenticate {
310            session_id: "session-0".to_string(),
311            server_name: "my oauth server".to_string(),
312        };
313
314        let untyped = msg.to_untyped_message().expect("serializable");
315        assert_eq!(untyped.method(), "_aether/mcp_request");
316        let parsed = McpRequest::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
317        assert_eq!(parsed, msg);
318    }
319
320    #[test]
321    fn mcp_notification_server_status_roundtrip() {
322        let msg = McpNotification::ServerStatus {
323            servers: vec![
324                McpServerStatusEntry::new("github", McpServerStatus::Connected { tool_count: 5 }),
325                McpServerStatusEntry::new("linear", McpServerStatus::NeedsOAuth)
326                    .with_auth_capability(McpServerAuthCapability::OAuth),
327                McpServerStatusEntry::new("slack", McpServerStatus::Failed { error: "connection timeout".to_string() }),
328            ],
329        };
330
331        let untyped = msg.to_untyped_message().expect("serializable");
332        assert_eq!(untyped.method(), "_aether/mcp_event");
333        let parsed = McpNotification::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
334        assert_eq!(parsed, msg);
335    }
336
337    #[test]
338    fn mcp_notification_url_elicitation_complete_roundtrip() {
339        let msg = McpNotification::UrlElicitationComplete(UrlElicitationCompleteParams {
340            server_name: "github".to_string(),
341            elicitation_id: "el-456".to_string(),
342        });
343
344        let untyped = msg.to_untyped_message().expect("serializable");
345        let parsed = McpNotification::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
346        assert_eq!(parsed, msg);
347    }
348
349    #[test]
350    fn sub_agent_progress_params_roundtrip() {
351        let params = SubAgentProgressParams {
352            parent_tool_id: "call_123".to_string(),
353            task_id: "task_abc".to_string(),
354            agent_name: "explorer".to_string(),
355            event: SubAgentEvent::Done,
356        };
357
358        let untyped = params.to_untyped_message().expect("serializable");
359        assert_eq!(untyped.method(), "_aether/sub_agent_progress");
360    }
361
362    #[test]
363    fn elicitation_params_roundtrip() {
364        use rmcp::model::{ElicitationSchema, EnumSchema};
365
366        let params = ElicitationParams {
367            server_name: "github".to_string(),
368            request: CreateElicitationRequestParams::FormElicitationParams {
369                meta: None,
370                message: "Pick a color".to_string(),
371                requested_schema: ElicitationSchema::builder()
372                    .required_enum_schema(
373                        "color",
374                        EnumSchema::builder(vec!["red".into(), "green".into(), "blue".into()]).untitled().build(),
375                    )
376                    .build()
377                    .unwrap(),
378            },
379        };
380
381        let untyped = params.to_untyped_message().expect("serializable");
382        assert_eq!(untyped.method(), "_aether/elicitation");
383        let parsed = ElicitationParams::parse_message(untyped.method(), untyped.params()).expect("roundtrip");
384        assert_eq!(parsed, params);
385    }
386
387    #[test]
388    fn elicitation_params_url_variant_has_mode_field() {
389        let params = ElicitationParams {
390            server_name: "github".to_string(),
391            request: CreateElicitationRequestParams::UrlElicitationParams {
392                meta: None,
393                message: "Authorize GitHub".to_string(),
394                url: "https://github.com/login/oauth".to_string(),
395                elicitation_id: "el-123".to_string(),
396            },
397        };
398
399        let json = serde_json::to_string(&params).unwrap();
400        assert!(json.contains("\"mode\":\"url\""));
401        assert!(json.contains("\"server_name\":\"github\""));
402    }
403
404    #[test]
405    fn mcp_server_status_entry_serde_roundtrip() {
406        let entry = McpServerStatusEntry::new("test-server", McpServerStatus::Connected { tool_count: 3 })
407            .with_auth_capability(McpServerAuthCapability::OAuth);
408
409        let json = serde_json::to_string(&entry).unwrap();
410        assert!(json.contains("\"auth_capability\":\"OAuth\""));
411        assert!(json.contains("\"proxied\":false"));
412        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
413        assert_eq!(parsed, entry);
414        assert!(!parsed.proxied);
415        assert!(parsed.can_authenticate());
416    }
417
418    #[test]
419    fn mcp_server_status_entry_proxied_serde_roundtrip() {
420        let entry = McpServerStatusEntry::new("math", McpServerStatus::NeedsOAuth)
421            .with_auth_capability(McpServerAuthCapability::OAuth)
422            .with_proxied(true);
423
424        let json = serde_json::to_string(&entry).unwrap();
425        assert!(json.contains("\"proxied\":true"));
426        let parsed: McpServerStatusEntry = serde_json::from_str(&json).unwrap();
427        assert_eq!(parsed, entry);
428    }
429
430    #[test]
431    fn deserialize_tool_call_event() {
432        let json = r#"{"ToolCall":{"request":{"id":"c1","name":"grep","arguments":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
433        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
434        assert!(matches!(event, SubAgentEvent::ToolCall { .. }));
435    }
436
437    #[test]
438    fn deserialize_tool_call_update_event() {
439        let json = r#"{"ToolCallUpdate":{"update":{"id":"c1","chunk":"{\"pattern\":\"test\"}"},"model_name":"m"}}"#;
440        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
441        assert!(matches!(event, SubAgentEvent::ToolCallUpdate { .. }));
442    }
443
444    #[test]
445    fn deserialize_tool_result_event() {
446        let json = r#"{"ToolResult":{"result":{"id":"c1","name":"grep","result_meta":{"display":{"title":"Grep","value":"'test' in src (3 matches)"}}}}}"#;
447        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
448        match event {
449            SubAgentEvent::ToolResult { result } => {
450                let result_meta = result.result_meta.expect("expected result_meta");
451                assert_eq!(result_meta.display.title, "Grep");
452            }
453            other => panic!("Expected ToolResult, got {other:?}"),
454        }
455    }
456
457    #[test]
458    fn deserialize_tool_error_event() {
459        let json = r#"{"ToolError":{"error":{"id":"c1","name":"grep"}}}"#;
460        let event: SubAgentEvent = serde_json::from_str(json).unwrap();
461        assert!(matches!(event, SubAgentEvent::ToolError { .. }));
462    }
463
464    #[test]
465    fn deserialize_done_event() {
466        let event: SubAgentEvent = serde_json::from_str(r#""Done""#).unwrap();
467        assert!(matches!(event, SubAgentEvent::Done));
468    }
469
470    #[test]
471    fn deserialize_other_variant() {
472        let event: SubAgentEvent = serde_json::from_str(r#""Other""#).unwrap();
473        assert!(matches!(event, SubAgentEvent::Other));
474    }
475
476    #[test]
477    fn tool_result_meta_map_roundtrip() {
478        let meta: ToolResultMeta = ToolDisplayMeta::new("Read file", "Cargo.toml, 156 lines").into();
479        let map = meta.clone().into_map();
480        let parsed = ToolResultMeta::from_map(&map).expect("should deserialize ToolResultMeta");
481        assert_eq!(parsed, meta);
482    }
483}