Skip to main content

cyril_core/protocol/
client.rs

1use std::cell::RefCell;
2use std::collections::HashMap;
3use std::rc::Rc;
4use std::str::FromStr;
5
6use agent_client_protocol as acp;
7use async_trait::async_trait;
8use tokio::sync::mpsc;
9use tokio::sync::oneshot;
10
11use crate::capabilities;
12use crate::event::{AppEvent, ExtensionEvent, InteractionRequest, InternalEvent, ProtocolEvent};
13use crate::hooks::{HookContext, HookRegistry, HookResult, HookTarget, HookTiming};
14use crate::kiro_ext::KiroCommandsPayload;
15use crate::platform::path;
16use crate::platform::terminal::{TerminalId, TerminalManager};
17
18/// Construct an ACP internal error with a message.
19fn internal_err(msg: impl Into<String>) -> acp::Error {
20    acp::Error::new(-32603, msg)
21}
22
23/// The central ACP Client implementation.
24/// Handles all agent callbacks: fs, terminal, permissions, notifications.
25///
26/// Uses `Rc<RefCell<_>>` for interior mutability since everything is `!Send`
27/// (required by `#[async_trait(?Send)]` on the ACP `Client` trait).
28pub struct KiroClient {
29    event_tx: mpsc::UnboundedSender<AppEvent>,
30    terminal_manager: RefCell<TerminalManager>,
31    hooks: RefCell<HookRegistry>,
32    /// Cache of `raw_input` from ToolCall/ToolCallUpdate notifications, keyed by tool call ID.
33    /// Permission requests arrive without `raw_input`, so we look it up here to enrich them.
34    tool_call_inputs: RefCell<HashMap<acp::ToolCallId, serde_json::Value>>,
35}
36
37impl KiroClient {
38    pub fn new(
39        event_tx: mpsc::UnboundedSender<AppEvent>,
40        hooks: HookRegistry,
41    ) -> Rc<Self> {
42        Rc::new(Self {
43            event_tx,
44            terminal_manager: RefCell::new(TerminalManager::new()),
45            hooks: RefCell::new(hooks),
46            tool_call_inputs: RefCell::new(HashMap::new()),
47        })
48    }
49
50    /// Send an event to the TUI, logging if the receiver has been dropped.
51    fn emit(&self, event: AppEvent) {
52        if self.event_tx.send(event).is_err() {
53            tracing::error!("Event channel closed — TUI receiver is gone, events are being dropped");
54        }
55    }
56}
57
58#[async_trait(?Send)]
59impl acp::Client for KiroClient {
60    async fn request_permission(
61        &self,
62        mut args: acp::RequestPermissionRequest,
63    ) -> acp::Result<acp::RequestPermissionResponse> {
64        // Permission requests arrive without raw_input — enrich from our cache
65        // so the approval UI can display details like URLs and commands.
66        if args.tool_call.fields.raw_input.is_none() {
67            if let Some(cached) = self.tool_call_inputs.borrow().get(&args.tool_call.tool_call_id) {
68                args.tool_call.fields.raw_input = Some(cached.clone());
69            }
70        }
71
72        let (tx, rx) = oneshot::channel();
73        self.emit(AppEvent::Interaction(InteractionRequest::Permission {
74            request: args,
75            responder: tx,
76        }));
77
78        rx.await.map_err(|_| internal_err("Permission request channel closed"))
79    }
80
81    async fn session_notification(
82        &self,
83        args: acp::SessionNotification,
84    ) -> acp::Result<()> {
85        match args.update {
86            acp::SessionUpdate::AgentMessageChunk(chunk) => {
87                self.emit(AppEvent::Protocol(ProtocolEvent::AgentMessage {
88                    session_id: args.session_id,
89                    chunk,
90                }));
91            }
92            acp::SessionUpdate::AgentThoughtChunk(chunk) => {
93                self.emit(AppEvent::Protocol(ProtocolEvent::AgentThought {
94                    session_id: args.session_id,
95                    chunk,
96                }));
97            }
98            acp::SessionUpdate::ToolCall(tool_call) => {
99                if let Some(ref raw_input) = tool_call.raw_input {
100                    self.tool_call_inputs
101                        .borrow_mut()
102                        .insert(tool_call.tool_call_id.clone(), raw_input.clone());
103                }
104                self.emit(AppEvent::Protocol(ProtocolEvent::ToolCallStarted {
105                    session_id: args.session_id,
106                    tool_call,
107                }));
108            }
109            acp::SessionUpdate::ToolCallUpdate(update) => {
110                if let Some(ref raw_input) = update.fields.raw_input {
111                    self.tool_call_inputs
112                        .borrow_mut()
113                        .insert(update.tool_call_id.clone(), raw_input.clone());
114                }
115                self.emit(AppEvent::Protocol(ProtocolEvent::ToolCallUpdated {
116                    session_id: args.session_id,
117                    update,
118                }));
119            }
120            acp::SessionUpdate::Plan(plan) => {
121                self.emit(AppEvent::Protocol(ProtocolEvent::PlanUpdated {
122                    session_id: args.session_id,
123                    plan,
124                }));
125            }
126            acp::SessionUpdate::AvailableCommandsUpdate(commands) => {
127                self.emit(AppEvent::Protocol(ProtocolEvent::CommandsUpdated {
128                    session_id: args.session_id,
129                    commands,
130                }));
131            }
132            acp::SessionUpdate::CurrentModeUpdate(mode) => {
133                self.emit(AppEvent::Protocol(ProtocolEvent::ModeChanged {
134                    session_id: args.session_id,
135                    mode,
136                }));
137            }
138            acp::SessionUpdate::ConfigOptionUpdate(update) => {
139                self.emit(AppEvent::Protocol(ProtocolEvent::ConfigOptionsUpdated {
140                    session_id: args.session_id,
141                    config_options: update.config_options,
142                }));
143            }
144            _ => {
145                tracing::debug!("Unhandled session notification variant");
146            }
147        }
148        Ok(())
149    }
150
151    async fn ext_notification(&self, args: acp::ExtNotification) -> acp::Result<()> {
152        tracing::info!("Received ext_notification: method={}", args.method);
153        tracing::info!("ext_notification params: {}", args.params);
154
155        if args.method.as_ref() == "kiro.dev/commands/available" {
156            match serde_json::from_str::<KiroCommandsPayload>(args.params.get()) {
157                Ok(payload) => {
158                    let commands = payload.commands();
159                    tracing::info!(
160                        "Parsed {} Kiro commands from ext_notification",
161                        commands.len()
162                    );
163                    self.emit(AppEvent::Extension(ExtensionEvent::KiroCommandsAvailable { commands }));
164                }
165                Err(e) => {
166                    tracing::warn!("Failed to parse kiro.dev/commands/available: {e}");
167                }
168            }
169        } else if args.method.as_ref() == "kiro.dev/metadata" {
170            // Log the full raw payload so we can discover all available fields
171            tracing::info!("kiro.dev/metadata raw: {}", args.params.get());
172
173            #[derive(serde::Deserialize)]
174            #[serde(rename_all = "camelCase")]
175            struct MetadataPayload {
176                session_id: String,
177                #[serde(default)]
178                context_usage_percentage: f64,
179            }
180            match serde_json::from_str::<MetadataPayload>(args.params.get()) {
181                Ok(payload) => {
182                    self.emit(AppEvent::Extension(ExtensionEvent::KiroMetadata {
183                        session_id: payload.session_id,
184                        context_usage_pct: payload.context_usage_percentage,
185                    }));
186                }
187                Err(e) => {
188                    tracing::warn!("Failed to parse kiro.dev/metadata: {e}");
189                }
190            }
191        }
192
193        Ok(())
194    }
195
196    async fn read_text_file(
197        &self,
198        args: acp::ReadTextFileRequest,
199    ) -> acp::Result<acp::ReadTextFileResponse> {
200        let native_path = path::to_native(&args.path);
201        tracing::info!("fs.readTextFile: {} -> {}", args.path.display(), native_path.display());
202
203        let hook_ctx = HookContext {
204            target: HookTarget::FsRead,
205            timing: HookTiming::Before,
206            path: Some(native_path.clone()),
207            content: None,
208            command: None,
209        };
210        if let HookResult::Blocked { reason } = self.hooks.borrow().run_before(&hook_ctx).await {
211            return Err(internal_err(reason));
212        }
213
214        let content = capabilities::fs::read_text_file(&native_path)
215            .await
216            .map_err(|e| internal_err(e.to_string()))?;
217
218        Ok(acp::ReadTextFileResponse::new(content))
219    }
220
221    async fn write_text_file(
222        &self,
223        args: acp::WriteTextFileRequest,
224    ) -> acp::Result<acp::WriteTextFileResponse> {
225        let native_path = path::to_native(&args.path);
226        tracing::info!("fs.writeTextFile: {} -> {}", args.path.display(), native_path.display());
227
228        let mut content = args.content.clone();
229
230        let hook_ctx = HookContext {
231            target: HookTarget::FsWrite,
232            timing: HookTiming::Before,
233            path: Some(native_path.clone()),
234            content: Some(content.clone()),
235            command: None,
236        };
237        match self.hooks.borrow().run_before(&hook_ctx).await {
238            HookResult::Blocked { reason } => return Err(internal_err(reason)),
239            HookResult::ModifiedArgs { content: Some(c), .. } => content = c,
240            _ => {}
241        }
242
243        capabilities::fs::write_text_file(&native_path, &content)
244            .await
245            .map_err(|e| internal_err(e.to_string()))?;
246
247        // Run after hooks
248        let after_ctx = HookContext {
249            target: HookTarget::FsWrite,
250            timing: HookTiming::After,
251            path: Some(native_path),
252            content: Some(content),
253            command: None,
254        };
255        let after_results = self.hooks.borrow().run_after(&after_ctx).await;
256        for result in after_results {
257            if let HookResult::FeedbackPrompt { text } = result {
258                self.emit(AppEvent::Internal(InternalEvent::HookFeedback { text }));
259            }
260        }
261
262        Ok(acp::WriteTextFileResponse::new())
263    }
264
265    async fn create_terminal(
266        &self,
267        args: acp::CreateTerminalRequest,
268    ) -> acp::Result<acp::CreateTerminalResponse> {
269        let command = args.command.clone();
270        tracing::info!("terminal.create: {command}");
271
272        let hook_ctx = HookContext {
273            target: HookTarget::Terminal,
274            timing: HookTiming::Before,
275            path: None,
276            content: None,
277            command: Some(command.clone()),
278        };
279        if let HookResult::Blocked { reason } = self.hooks.borrow().run_before(&hook_ctx).await {
280            return Err(internal_err(reason));
281        }
282
283        let id = self
284            .terminal_manager
285            .borrow_mut()
286            .create_terminal(&command)
287            .map_err(|e| internal_err(e.to_string()))?;
288
289        Ok(acp::CreateTerminalResponse::new(id.to_string()))
290    }
291
292    async fn terminal_output(
293        &self,
294        args: acp::TerminalOutputRequest,
295    ) -> acp::Result<acp::TerminalOutputResponse> {
296        let id = TerminalId::from_str(&args.terminal_id.to_string())
297            .map_err(|e| internal_err(format!("Invalid terminal ID: {e}")))?;
298
299        let output = self
300            .terminal_manager
301            .borrow_mut()
302            .get_output(&id)
303            .map_err(|e| internal_err(e.to_string()))?;
304
305        Ok(acp::TerminalOutputResponse::new(output, false))
306    }
307
308    async fn wait_for_terminal_exit(
309        &self,
310        args: acp::WaitForTerminalExitRequest,
311    ) -> acp::Result<acp::WaitForTerminalExitResponse> {
312        let id = TerminalId::from_str(&args.terminal_id.to_string())
313            .map_err(|e| internal_err(format!("Invalid terminal ID: {e}")))?;
314
315        let exit_code = self
316            .terminal_manager
317            .borrow_mut()
318            .wait_for_exit(&id)
319            .await
320            .map_err(|e| internal_err(e.to_string()))?;
321
322        let exit_status = acp::TerminalExitStatus::new()
323            .exit_code(exit_code.max(0) as u32);
324
325        Ok(acp::WaitForTerminalExitResponse::new(exit_status))
326    }
327
328    async fn release_terminal(
329        &self,
330        args: acp::ReleaseTerminalRequest,
331    ) -> acp::Result<acp::ReleaseTerminalResponse> {
332        let id = TerminalId::from_str(&args.terminal_id.to_string())
333            .map_err(|e| internal_err(format!("Invalid terminal ID: {e}")))?;
334
335        self.terminal_manager
336            .borrow_mut()
337            .release(&id)
338            .map_err(|e| internal_err(e.to_string()))?;
339
340        Ok(acp::ReleaseTerminalResponse::new())
341    }
342
343    async fn kill_terminal_command(
344        &self,
345        args: acp::KillTerminalCommandRequest,
346    ) -> acp::Result<acp::KillTerminalCommandResponse> {
347        let id = TerminalId::from_str(&args.terminal_id.to_string())
348            .map_err(|e| internal_err(format!("Invalid terminal ID: {e}")))?;
349
350        self.terminal_manager
351            .borrow_mut()
352            .kill(&id)
353            .await
354            .map_err(|e| internal_err(e.to_string()))?;
355
356        Ok(acp::KillTerminalCommandResponse::new())
357    }
358}