Skip to main content

kimi_wire/
client_ext.rs

1//! Extension traits for [`WireClient`](crate::client::WireClient).
2//!
3//! Provides convenience methods that wrap the low-level `read_raw_message`
4//! API with parsing and timeout helpers, plus request/event helpers.
5
6use std::time::Duration;
7
8use crate::client::WireClient;
9use crate::error::WireError;
10use crate::message::{parse_wire_message, WireMessage};
11use crate::protocol::{
12    ApprovalResponseKind, DisplayBlock, Event, HookAction, Request, ToolOutput, ToolReturnValue,
13};
14
15/// Convenience extensions for any [`WireClient`] implementation.
16///
17/// This trait is automatically implemented for every type that already
18/// implements [`WireClient`], so you can call `.read_message()` on
19/// `InMemoryWireClient`, `TransportWireClient`, or any custom backend.
20pub trait WireClientExt: WireClient {
21    /// Read the next incoming message and parse it into a [`WireMessage`].
22    ///
23    /// # Errors
24    ///
25    /// Returns [`WireError::JsonParse`] if the raw message cannot be
26    /// deserialized into a known request / event type.
27    fn read_message(
28        &mut self,
29    ) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
30        async move {
31            let raw = self.read_raw_message().await?;
32            parse_wire_message(raw)
33        }
34    }
35
36    /// Read the next incoming message with a timeout.
37    ///
38    /// If no message arrives within `timeout`, returns
39    /// [`WireError::Timeout`].
40    fn read_message_timeout(
41        &mut self,
42        timeout: Duration,
43    ) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
44        async move {
45            let raw = self.read_raw_message_timeout(timeout).await?;
46            parse_wire_message(raw)
47        }
48    }
49}
50
51impl<T: WireClient + ?Sized> WireClientExt for T {}
52
53/// Convenience helpers for [`Event`].
54pub trait EventExt {
55    /// Return the Pascal-case wire type name (e.g. `"TurnBegin"`).
56    fn event_type(&self) -> String;
57
58    /// Return the snake-case normalized type name (e.g. `"turn_begin"`).
59    fn normalized_event_type(&self) -> String;
60
61    /// Serialize the event back to a JSON value.
62    fn payload(&self) -> serde_json::Value;
63}
64
65impl EventExt for Event {
66    fn event_type(&self) -> String {
67        self.type_name().to_string()
68    }
69
70    fn normalized_event_type(&self) -> String {
71        let pascal = self.type_name();
72        let mut snake = String::new();
73        for (i, ch) in pascal.chars().enumerate() {
74            if ch.is_uppercase() && i > 0 {
75                snake.push('_');
76            }
77            snake.push(ch.to_ascii_lowercase());
78        }
79        snake
80    }
81
82    fn payload(&self) -> serde_json::Value {
83        serde_json::to_value(self).map_or(serde_json::Value::Null, |v| v)
84    }
85}
86
87/// Convenience helpers for [`Request`].
88pub trait RequestExt {
89    /// Return the wire type name (e.g. `"ApprovalRequest"`).
90    fn kind(&self) -> String;
91
92    /// Generate a conservative default response for this request type.
93    ///
94    /// * Approval → auto-approve for session
95    /// * Tool call → error (tool not registered)
96    /// * Question → first option for each question
97    /// * Hook → allow (no policy configured)
98    fn default_response(&self) -> serde_json::Value;
99}
100
101impl RequestExt for Request {
102    fn kind(&self) -> String {
103        match self {
104            Self::ApprovalRequest(_) => "ApprovalRequest",
105            Self::ToolCallRequest(_) => "ToolCallRequest",
106            Self::QuestionRequest(_) => "QuestionRequest",
107            Self::HookRequest(_) => "HookRequest",
108        }
109        .to_string()
110    }
111
112    fn default_response(&self) -> serde_json::Value {
113        match self {
114            Self::ApprovalRequest(req) => serde_json::json!({
115                "request_id": req.id,
116                "response": ApprovalResponseKind::ApproveForSession,
117                "feedback": "Auto-approved by non-interactive worker."
118            }),
119            Self::ToolCallRequest(req) => serde_json::json!({
120                "tool_call_id": req.id,
121                "return_value": ToolReturnValue {
122                    is_error: true,
123                    output: ToolOutput::Text(String::new()),
124                    message: format!("External tool '{}' is not registered.", req.name),
125                    display: vec![DisplayBlock::brief("External tool unavailable.")],
126                    extras: None,
127                }
128            }),
129            Self::QuestionRequest(req) => {
130                let answers: Vec<serde_json::Value> = req
131                    .questions
132                    .iter()
133                    .map(|q| {
134                        q.options.first().map_or(serde_json::Value::Null, |o| {
135                            serde_json::Value::String(o.label.clone())
136                        })
137                    })
138                    .collect();
139                serde_json::json!({
140                    "request_id": req.id,
141                    "answers": answers,
142                    "message": "Selected default answers because workers run non-interactively."
143                })
144            }
145            Self::HookRequest(req) => serde_json::json!({
146                "request_id": req.id,
147                "action": HookAction::Allow,
148                "reason": format!(
149                    "No hook policy is configured for '{}' on '{}'.",
150                    req.event, req.target
151                )
152            }),
153        }
154    }
155}