Skip to main content

claude_code_rust/app/
connect.rs

1// Claude Code Rust - A native Rust terminal interface for Claude Code
2// Copyright (C) 2025  Simon Peter Rothgang
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as
6// published by the Free Software Foundation, either version 3 of the
7// License, or (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program.  If not, see <https://www.gnu.org/licenses/>.
16
17use super::{
18    App, AppStatus, ChatViewport, FocusManager, HelpView, ModeInfo, ModeState, SelectionState,
19    TodoItem,
20};
21use crate::Cli;
22use crate::acp::client::{ClaudeClient, ClientEvent, TerminalMap};
23use crate::acp::connection;
24use agent_client_protocol::{self as acp, Agent as _};
25use std::collections::{HashMap, HashSet};
26use std::path::PathBuf;
27use std::rc::Rc;
28use std::time::Instant;
29use tokio::sync::mpsc;
30
31/// Shorten cwd for display: use `~` for the home directory prefix.
32fn shorten_cwd(cwd: &std::path::Path) -> String {
33    let cwd_str = cwd.to_string_lossy().to_string();
34    if let Some(home) = dirs::home_dir() {
35        let home_str = home.to_string_lossy().to_string();
36        if cwd_str.starts_with(&home_str) {
37            return format!("~{}", &cwd_str[home_str.len()..]);
38        }
39    }
40    cwd_str
41}
42
43/// Create the `App` struct in `Connecting` state. No I/O - returns immediately.
44pub fn create_app(cli: &Cli) -> App {
45    let cwd = cli
46        .dir
47        .clone()
48        .unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
49
50    let (event_tx, event_rx) = mpsc::unbounded_channel();
51    let terminals: TerminalMap = Rc::new(std::cell::RefCell::new(HashMap::new()));
52
53    let cwd_display = shorten_cwd(&cwd);
54    let initial_model_name = "Connecting...".to_owned();
55
56    let mut app = App {
57        messages: vec![super::ChatMessage::welcome(&initial_model_name, &cwd_display)],
58        viewport: ChatViewport::new(),
59        input: super::InputState::new(),
60        status: AppStatus::Connecting,
61        should_quit: false,
62        session_id: None,
63        conn: None,
64        adapter_child: None,
65        model_name: initial_model_name,
66        cwd_raw: cwd.to_string_lossy().to_string(),
67        cwd: cwd_display,
68        files_accessed: 0,
69        mode: None,
70        login_hint: None,
71        pending_compact_clear: false,
72        help_view: HelpView::Keys,
73        pending_permission_ids: Vec::new(),
74        cancelled_turn_pending_hint: false,
75        event_tx,
76        event_rx,
77        spinner_frame: 0,
78        tools_collapsed: true,
79        active_task_ids: HashSet::new(),
80        terminals,
81        force_redraw: false,
82        tool_call_index: HashMap::new(),
83        todos: Vec::<TodoItem>::new(),
84        show_header: true,
85        show_todo_panel: false,
86        todo_scroll: 0,
87        todo_selected: 0,
88        focus: FocusManager::default(),
89        available_commands: Vec::new(),
90        cached_frame_area: ratatui::layout::Rect::new(0, 0, 0, 0),
91        selection: Option::<SelectionState>::None,
92        scrollbar_drag: None,
93        rendered_chat_lines: Vec::new(),
94        rendered_chat_area: ratatui::layout::Rect::new(0, 0, 0, 0),
95        rendered_input_lines: Vec::new(),
96        rendered_input_area: ratatui::layout::Rect::new(0, 0, 0, 0),
97        mention: None,
98        slash: None,
99        pending_submit: false,
100        drain_key_count: 0,
101        paste_burst: crate::app::paste_burst::PasteBurstDetector::new(),
102        pending_paste_text: String::new(),
103        file_cache: None,
104        input_wrap_cache: None,
105        cached_todo_compact: None,
106        git_branch: None,
107        cached_header_line: None,
108        cached_footer_line: None,
109        terminal_tool_calls: Vec::new(),
110        needs_redraw: true,
111        perf: cli
112            .perf_log
113            .as_deref()
114            .and_then(|path| crate::perf::PerfLogger::open(path, cli.perf_append)),
115        fps_ema: None,
116        last_frame_at: None,
117    };
118
119    app.refresh_git_branch();
120    app
121}
122
123/// Spawn the background connection task. Uses `spawn_local` so it runs on the
124/// same `LocalSet` as the TUI - `Rc<Connection>` stays on one thread.
125///
126/// On success, stores the connection in `app.conn` via a shared slot and sends
127/// `ClientEvent::Connected`. On auth error, sends `ClientEvent::AuthRequired`.
128/// On failure, sends `ClientEvent::ConnectionFailed`.
129#[allow(clippy::too_many_lines, clippy::items_after_statements, clippy::similar_names)]
130pub fn start_connection(app: &App, cli: &Cli, launchers: Vec<connection::AdapterLauncher>) {
131    let event_tx = app.event_tx.clone();
132    let terminals = Rc::clone(&app.terminals);
133    let cwd_raw = app.cwd_raw.clone();
134    let cwd = PathBuf::from(&cwd_raw);
135    let yolo = cli.yolo;
136    let model_override = cli.model.clone();
137    let resume_id = cli.resume.clone();
138
139    // Rc<Connection> is !Send, so it can't be sent through the mpsc channel.
140    // Instead, the task deposits it into a thread-local slot, then signals
141    // via ClientEvent::Connected. The event handler calls take_connection_slot().
142    let conn_slot: Rc<std::cell::RefCell<Option<ConnectionSlot>>> =
143        Rc::new(std::cell::RefCell::new(None));
144    let conn_slot_writer = Rc::clone(&conn_slot);
145
146    tokio::task::spawn_local(async move {
147        let result = connect_impl(
148            &event_tx,
149            &terminals,
150            &cwd,
151            &launchers,
152            yolo,
153            model_override.as_deref(),
154            resume_id.as_deref(),
155        )
156        .await;
157
158        match result {
159            Ok((conn, child, session_id, model_name, mode)) => {
160                // Deposit connection + child in the shared slot
161                *conn_slot_writer.borrow_mut() =
162                    Some(ConnectionSlot { conn: Rc::clone(&conn), child });
163                let _ = event_tx.send(ClientEvent::Connected { session_id, model_name, mode });
164            }
165            Err(ConnectError::AuthRequired { method_name, method_description }) => {
166                let _ =
167                    event_tx.send(ClientEvent::AuthRequired { method_name, method_description });
168            }
169            Err(ConnectError::Failed(msg)) => {
170                let _ = event_tx.send(ClientEvent::ConnectionFailed(msg));
171            }
172        }
173    });
174
175    // Store the slot in a thread-local so handle_acp_event can retrieve the
176    // Rc<Connection> when ClientEvent::Connected arrives. This is safe because
177    // start_connection() must only be called once per app lifetime.
178    CONN_SLOT.with(|slot| {
179        debug_assert!(
180            slot.borrow().is_none(),
181            "CONN_SLOT already populated -- start_connection() called twice?"
182        );
183        *slot.borrow_mut() = Some(conn_slot);
184    });
185}
186
187/// Shared slot for passing `Rc<Connection>` from the background task to the event loop.
188pub struct ConnectionSlot {
189    pub conn: Rc<acp::ClientSideConnection>,
190    pub child: tokio::process::Child,
191}
192
193// Thread-local storage for the connection slot. Used by start_connection() to deposit
194// and by handle_acp_event() to retrieve the Rc<Connection>.
195thread_local! {
196    pub static CONN_SLOT: std::cell::RefCell<Option<Rc<std::cell::RefCell<Option<ConnectionSlot>>>>> =
197        const { std::cell::RefCell::new(None) };
198}
199
200/// Take the connection data from the thread-local slot. Called once when
201/// `ClientEvent::Connected` is received.
202pub(super) fn take_connection_slot() -> Option<ConnectionSlot> {
203    CONN_SLOT.with(|slot| slot.borrow().as_ref().and_then(|inner| inner.borrow_mut().take()))
204}
205
206/// Internal error type for the connection task.
207enum ConnectError {
208    AuthRequired { method_name: String, method_description: String },
209    Failed(String),
210}
211
212/// The actual connection logic, extracted from the old `connect()`.
213/// Runs inside `spawn_local` - can use `Rc`, `!Send` types freely.
214#[allow(clippy::too_many_lines, clippy::similar_names)]
215async fn connect_impl(
216    event_tx: &mpsc::UnboundedSender<ClientEvent>,
217    terminals: &crate::acp::client::TerminalMap,
218    cwd: &std::path::Path,
219    launchers: &[connection::AdapterLauncher],
220    yolo: bool,
221    model_override: Option<&str>,
222    resume_id: Option<&str>,
223) -> Result<
224    (
225        Rc<acp::ClientSideConnection>,
226        tokio::process::Child,
227        acp::SessionId,
228        String,
229        Option<ModeState>,
230    ),
231    ConnectError,
232> {
233    if launchers.is_empty() {
234        return Err(ConnectError::Failed("No adapter launchers configured".into()));
235    }
236
237    let mut failures = Vec::new();
238    for launcher in launchers {
239        let started = Instant::now();
240        tracing::info!("Connecting with adapter launcher: {}", launcher.describe());
241        match connect_with_launcher(
242            event_tx,
243            terminals,
244            cwd,
245            launcher,
246            yolo,
247            model_override,
248            resume_id,
249        )
250        .await
251        {
252            Ok(result) => {
253                tracing::info!("Connected via {} in {:?}", launcher.describe(), started.elapsed());
254                return Ok(result);
255            }
256            Err(auth_required @ ConnectError::AuthRequired { .. }) => {
257                return Err(auth_required);
258            }
259            Err(ConnectError::Failed(msg)) => {
260                tracing::warn!("Launcher {} failed: {}", launcher.describe(), msg);
261                failures.push(format!("{}: {msg}", launcher.describe()));
262            }
263        }
264    }
265
266    Err(ConnectError::Failed(format!("All adapter launchers failed: {}", failures.join(" | "))))
267}
268
269#[allow(clippy::too_many_lines, clippy::similar_names)]
270async fn connect_with_launcher(
271    event_tx: &mpsc::UnboundedSender<ClientEvent>,
272    terminals: &crate::acp::client::TerminalMap,
273    cwd: &std::path::Path,
274    launcher: &connection::AdapterLauncher,
275    yolo: bool,
276    model_override: Option<&str>,
277    resume_id: Option<&str>,
278) -> Result<
279    (
280        Rc<acp::ClientSideConnection>,
281        tokio::process::Child,
282        acp::SessionId,
283        String,
284        Option<ModeState>,
285    ),
286    ConnectError,
287> {
288    let client = ClaudeClient::with_terminals(
289        event_tx.clone(),
290        yolo,
291        cwd.to_path_buf(),
292        Rc::clone(terminals),
293    );
294
295    let adapter_start = Instant::now();
296    let adapter = connection::spawn_adapter(client, launcher, cwd)
297        .await
298        .map_err(|e| ConnectError::Failed(format!("Failed to spawn adapter: {e}")))?;
299    tracing::debug!("Spawned adapter via {} in {:?}", launcher.describe(), adapter_start.elapsed());
300    let child = adapter.child;
301    let conn = Rc::new(adapter.connection);
302
303    // Initialize handshake
304    let handshake_start = Instant::now();
305    let init_response = conn
306        .initialize(
307            acp::InitializeRequest::new(acp::ProtocolVersion::LATEST)
308                .client_capabilities(
309                    acp::ClientCapabilities::new()
310                        .fs(acp::FileSystemCapability::new()
311                            .read_text_file(true)
312                            .write_text_file(true))
313                        .terminal(true),
314                )
315                .client_info(acp::Implementation::new(
316                    "claude-code-rust",
317                    env!("CARGO_PKG_VERSION"),
318                )),
319        )
320        .await
321        .map_err(|e| ConnectError::Failed(format!("Handshake failed: {e}")))?;
322    tracing::debug!(
323        "Handshake via {} completed in {:?}",
324        launcher.describe(),
325        handshake_start.elapsed()
326    );
327
328    tracing::info!("Connected to agent: {:?}", init_response);
329
330    // Create or resume session - on AuthRequired, signal back instead of blocking
331    let session_result = if let Some(sid) = resume_id {
332        let session_id = acp::SessionId::new(sid);
333        let load_req = acp::LoadSessionRequest::new(session_id.clone(), cwd);
334        match conn.load_session(load_req).await {
335            Ok(resp) => Ok((session_id, resp.models, resp.modes)),
336            Err(err) if err.code == acp::ErrorCode::AuthRequired => {
337                return Err(auth_required_error(&init_response));
338            }
339            Err(err) => Err(err),
340        }
341    } else {
342        match conn.new_session(acp::NewSessionRequest::new(cwd)).await {
343            Ok(resp) => Ok((resp.session_id, resp.models, resp.modes)),
344            Err(err) if err.code == acp::ErrorCode::AuthRequired => {
345                return Err(auth_required_error(&init_response));
346            }
347            Err(err) => Err(err),
348        }
349    };
350
351    let (session_id, resp_models, resp_modes) = session_result
352        .map_err(|e| ConnectError::Failed(format!("Session creation failed: {e}")))?;
353
354    // Extract model name
355    let mut model_name = resp_models
356        .as_ref()
357        .and_then(|m| {
358            m.available_models
359                .iter()
360                .find(|info| info.model_id == m.current_model_id)
361                .map(|info| info.name.clone())
362        })
363        .unwrap_or_else(|| "Unknown model".to_owned());
364
365    // --model override
366    if let Some(model_str) = model_override {
367        conn.set_session_model(acp::SetSessionModelRequest::new(
368            session_id.clone(),
369            acp::ModelId::new(model_str),
370        ))
371        .await
372        .map_err(|e| ConnectError::Failed(format!("Model switch failed: {e}")))?;
373        model_str.clone_into(&mut model_name);
374    }
375
376    // Extract mode state
377    let mut mode = resp_modes.map(|ms| {
378        let current_id = ms.current_mode_id.to_string();
379        let available: Vec<ModeInfo> = ms
380            .available_modes
381            .iter()
382            .map(|m| ModeInfo { id: m.id.to_string(), name: m.name.clone() })
383            .collect();
384        let current_name = available
385            .iter()
386            .find(|m| m.id == current_id)
387            .map_or_else(|| current_id.clone(), |m| m.name.clone());
388        ModeState {
389            current_mode_id: current_id,
390            current_mode_name: current_name,
391            available_modes: available,
392        }
393    });
394
395    if let Some(ref m) = mode {
396        tracing::info!(
397            "Available modes: {:?}",
398            m.available_modes.iter().map(|m| &m.id).collect::<Vec<_>>()
399        );
400        tracing::info!("Current mode: {}", m.current_mode_id);
401    }
402
403    // --yolo: switch to bypass-permissions mode
404    if yolo && let Some(ref mut ms) = mode {
405        let target_id = "bypassPermissions".to_owned();
406        let mode_id = acp::SessionModeId::new(target_id.as_str());
407        conn.set_session_mode(acp::SetSessionModeRequest::new(session_id.clone(), mode_id))
408            .await
409            .map_err(|e| ConnectError::Failed(format!("Mode switch failed: {e}")))?;
410        tracing::info!("YOLO: switched to mode '{}'", target_id);
411        let target_name = ms
412            .available_modes
413            .iter()
414            .find(|mi| mi.id == target_id)
415            .map_or_else(|| target_id.clone(), |mi| mi.name.clone());
416        ms.current_mode_id = target_id;
417        ms.current_mode_name = target_name;
418    }
419
420    tracing::info!("Session created: {:?}", session_id);
421
422    Ok((conn, child, session_id, model_name, mode))
423}
424
425/// Build a `ConnectError::AuthRequired` from the adapter's init response.
426fn auth_required_error(init_response: &acp::InitializeResponse) -> ConnectError {
427    let method = init_response.auth_methods.first();
428    ConnectError::AuthRequired {
429        method_name: method.map_or_else(|| "unknown".into(), |m| m.name.clone()),
430        method_description: method
431            .and_then(|m| m.description.clone())
432            .unwrap_or_else(|| "Sign in to continue".into()),
433    }
434}