Skip to main content

kimi_wire/
client_ext.rs

1//! Extension traits for [`WireClient`](crate::client::WireClient).
2//!
3//! Provides convenience methods that wrap the low-level `read_raw_message`
4//! API with parsing and timeout helpers, plus request/event helpers.
5
6use std::time::Duration;
7
8use crate::client::WireClient;
9use crate::error::WireError;
10use crate::message::{parse_wire_message, WireMessage};
11use crate::protocol::{
12    ApprovalResponseKind, DisplayBlock, Event, HookAction, Request, ToolOutput, ToolReturnValue,
13};
14
15/// Convenience extensions for any [`WireClient`] implementation.
16///
17/// This trait is automatically implemented for every type that already
18/// implements [`WireClient`], so you can call `.read_message()` on
19/// `InMemoryWireClient`, `TransportWireClient`, or any custom backend.
20pub trait WireClientExt: WireClient {
21    /// Read the next incoming message and parse it into a [`WireMessage`].
22    ///
23    /// # Errors
24    ///
25    /// Returns [`WireError::JsonParse`] if the raw message cannot be
26    /// deserialized into a known request / event type.
27    fn read_message(
28        &mut self,
29    ) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
30        async move {
31            let raw = self.read_raw_message().await?;
32            parse_wire_message(raw)
33        }
34    }
35
36    /// Read the next incoming message with a timeout.
37    ///
38    /// If no message arrives within `timeout`, returns
39    /// [`WireError::Timeout`].
40    fn read_message_timeout(
41        &mut self,
42        timeout: Duration,
43    ) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
44        async move {
45            let raw = self.read_raw_message_timeout(timeout).await?;
46            parse_wire_message(raw)
47        }
48    }
49}
50
51impl<T: WireClient + ?Sized> WireClientExt for T {}
52
53/// Convenience helpers for [`Event`].
54pub trait EventExt {
55    /// Return the Pascal-case wire type name (e.g. `"TurnBegin"`).
56    fn event_type(&self) -> String;
57
58    /// Return the snake-case normalized type name (e.g. `"turn_begin"`).
59    fn normalized_event_type(&self) -> String;
60
61    /// Serialize the event back to a JSON value.
62    fn payload(&self) -> serde_json::Value;
63}
64
65impl EventExt for Event {
66    fn event_type(&self) -> String {
67        self.type_name().to_string()
68    }
69
70    fn normalized_event_type(&self) -> String {
71        let pascal = self.type_name();
72        let mut snake = String::new();
73        for (i, ch) in pascal.chars().enumerate() {
74            if ch.is_uppercase() && i > 0 {
75                snake.push('_');
76            }
77            snake.push(ch.to_ascii_lowercase());
78        }
79        snake
80    }
81
82    fn payload(&self) -> serde_json::Value {
83        match serde_json::to_value(self) {
84            Ok(v) => v,
85            Err(_) => serde_json::Value::Null,
86        }
87    }
88}
89
90/// Convenience helpers for [`Request`].
91pub trait RequestExt {
92    /// Return the wire type name (e.g. `"ApprovalRequest"`).
93    fn kind(&self) -> String;
94
95    /// Generate a conservative default response for this request type.
96    ///
97    /// * Approval → auto-approve for session
98    /// * Tool call → error (tool not registered)
99    /// * Question → first option for each question
100    /// * Hook → allow (no policy configured)
101    fn default_response(&self) -> serde_json::Value;
102}
103
104impl RequestExt for Request {
105    fn kind(&self) -> String {
106        match self {
107            Request::ApprovalRequest(_) => "ApprovalRequest",
108            Request::ToolCallRequest(_) => "ToolCallRequest",
109            Request::QuestionRequest(_) => "QuestionRequest",
110            Request::HookRequest(_) => "HookRequest",
111        }
112        .to_string()
113    }
114
115    fn default_response(&self) -> serde_json::Value {
116        match self {
117            Request::ApprovalRequest(req) => serde_json::json!({
118                "request_id": req.id,
119                "response": ApprovalResponseKind::ApproveForSession,
120                "feedback": "Auto-approved by non-interactive worker."
121            }),
122            Request::ToolCallRequest(req) => serde_json::json!({
123                "tool_call_id": req.id,
124                "return_value": ToolReturnValue {
125                    is_error: true,
126                    output: ToolOutput::Text(String::new()),
127                    message: format!("External tool '{}' is not registered.", req.name),
128                    display: vec![DisplayBlock::brief("External tool unavailable.")],
129                    extras: None,
130                }
131            }),
132            Request::QuestionRequest(req) => {
133                let answers: Vec<serde_json::Value> = req
134                    .questions
135                    .iter()
136                    .map(|q| {
137                        q.options
138                            .first()
139                            .map(|o| serde_json::Value::String(o.label.clone()))
140                            .unwrap_or(serde_json::Value::Null)
141                    })
142                    .collect();
143                serde_json::json!({
144                    "request_id": req.id,
145                    "answers": answers,
146                    "message": "Selected default answers because workers run non-interactively."
147                })
148            }
149            Request::HookRequest(req) => serde_json::json!({
150                "request_id": req.id,
151                "action": HookAction::Allow,
152                "reason": format!(
153                    "No hook policy is configured for '{}' on '{}'.",
154                    req.event, req.target
155                )
156            }),
157        }
158    }
159}