Skip to main content

gemini_cli_sdk/
client.rs

1//! Stateful client for multi-turn Gemini CLI sessions.
2//!
3//! The [`Client`] ties together transport, permissions, hooks, and translation
4//! into a single consumer-facing type. Call [`connect`] once to establish the
5//! session, then [`send`] or [`send_content`] for each user turn.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Client
11//!   ├── AnyTransport  ──► GeminiTransport (production) | MockTransport (testing)
12//!   ├── TranslationContext — wire SessionUpdate → public Message
13//!   ├── HookContext   — lifecycle hooks (UserPromptSubmit, Stop, …)
14//!   └── notification_stream — single mpsc receiver taken from transport
15//! ```
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! use gemini_cli_sdk::{Client, ClientConfig};
21//!
22//! #[tokio::main]
23//! async fn main() -> gemini_cli_sdk::Result<()> {
24//!     let config = ClientConfig::builder()
25//!         .prompt("Build a REST API")
26//!         .build();
27//!     let mut client = Client::new(config)?;
28//!     let _info = client.connect().await?;
29//!     // Consume the stream by pinning it in place.
30//!     // (See tokio::pin! or futures::pin_mut! for non-Unpin streams.)
31//!     client.close().await?;
32//!     Ok(())
33//! }
34//! ```
35
36use std::pin::Pin;
37use std::sync::atomic::{AtomicBool, Ordering};
38use std::sync::Arc;
39
40use futures_core::Stream;
41use serde_json::Value;
42use tokio::sync::Mutex;
43
44use crate::callback::MessageCallback;
45use crate::config::ClientConfig;
46use crate::hooks::{self, HookContext, HookDecision, HookEvent, HookInput};
47use crate::permissions::{PermissionHandler, ToolInputCache};
48use crate::translate::TranslationContext;
49use crate::transport::{GeminiTransport, Transport};
50use crate::types::content::UserContent;
51use crate::types::messages::{Message, SessionInfo};
52use crate::wire;
53use crate::{Error, Result};
54
55// ── AnyTransport ─────────────────────────────────────────────────────────────
56
57/// Internal enum that wraps either the production [`GeminiTransport`] or the
58/// test [`MockTransport`] so both can expose request/notification helpers.
59///
60/// This sidesteps the object-safety constraint that prevents adding generic
61/// `send_request<P, R>` methods to the [`Transport`] trait directly.
62pub(crate) enum AnyTransport {
63    /// Production subprocess transport.
64    Gemini(Arc<GeminiTransport>),
65    /// In-memory transport for unit tests.
66    #[cfg(feature = "testing")]
67    Mock(Arc<crate::testing::MockTransport>),
68}
69
70/// Generate a simple delegation method on `AnyTransport` that forwards the call
71/// identically to every variant.
72macro_rules! delegate_transport {
73    // async method with arguments
74    (async fn $name:ident(&self $(, $arg:ident : $arg_ty:ty)*) -> $ret:ty) => {
75        async fn $name(&self $(, $arg: $arg_ty)*) -> $ret {
76            match self {
77                AnyTransport::Gemini(t) => t.$name($($arg),*).await,
78                #[cfg(feature = "testing")]
79                AnyTransport::Mock(t) => t.$name($($arg),*).await,
80            }
81        }
82    };
83    // sync method
84    (fn $name:ident(&self $(, $arg:ident : $arg_ty:ty)*) -> $ret:ty) => {
85        fn $name(&self $(, $arg: $arg_ty)*) -> $ret {
86            match self {
87                AnyTransport::Gemini(t) => t.$name($($arg),*),
88                #[cfg(feature = "testing")]
89                AnyTransport::Mock(t) => t.$name($($arg),*),
90            }
91        }
92    };
93}
94
95impl AnyTransport {
96    // ── Transport trait delegation ────────────────────────────────────────
97
98    delegate_transport!(async fn connect(&self) -> Result<()>);
99    delegate_transport!(fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>>);
100    delegate_transport!(async fn interrupt(&self) -> Result<()>);
101    delegate_transport!(async fn close(&self) -> Result<Option<i32>>);
102
103    // ── JSON-RPC helpers ──────────────────────────────────────────────────
104
105    /// Send a typed JSON-RPC request and await the correlated response.
106    ///
107    /// For `Mock` transports the params are serialised and written to the
108    /// captures list; the result is a zero-value default-deserialised response.
109    /// Tests that need specific response values should pre-load them via
110    /// `ScenarioBuilder` or `MockTransport::push_message`.
111    async fn send_request<P, R>(&self, method: &str, params: P) -> Result<R>
112    where
113        P: serde::Serialize + Send,
114        R: serde::de::DeserializeOwned,
115    {
116        match self {
117            AnyTransport::Gemini(t) => t.send_request(method, params).await,
118            #[cfg(feature = "testing")]
119            AnyTransport::Mock(t) => {
120                // Capture the outbound request for assertion in tests.
121                let req = serde_json::json!({
122                    "jsonrpc": "2.0",
123                    "method": method,
124                    "params": serde_json::to_value(params)?,
125                    "id": 1
126                });
127                t.write(&serde_json::to_string(&req)?).await?;
128                // Return a default-deserialized result. Tests using the mock
129                // transport should override this via scenario injection when
130                // the actual response fields matter.
131                serde_json::from_value(Value::Object(Default::default())).map_err(Error::Json)
132            }
133        }
134    }
135
136    /// Send a JSON-RPC request and return the response receiver without
137    /// blocking. The caller must await the receiver concurrently.
138    async fn send_request_start<P>(
139        &self,
140        method: &str,
141        params: P,
142    ) -> Result<tokio::sync::oneshot::Receiver<crate::jsonrpc::JsonRpcResponse>>
143    where
144        P: serde::Serialize + Send,
145    {
146        match self {
147            AnyTransport::Gemini(t) => t.send_request_start(method, params).await,
148            #[cfg(feature = "testing")]
149            AnyTransport::Mock(t) => {
150                // Capture the request for test assertions.
151                let req = serde_json::json!({
152                    "jsonrpc": "2.0",
153                    "method": method,
154                    "params": serde_json::to_value(params)?,
155                    "id": 1
156                });
157                t.write(&serde_json::to_string(&req)?).await?;
158                // Return a pre-resolved receiver with a default prompt result.
159                let (tx, rx) = tokio::sync::oneshot::channel();
160                let _ = tx.send(crate::jsonrpc::JsonRpcResponse::success(
161                    crate::jsonrpc::JsonRpcId::Number(0),
162                    serde_json::json!({"stopReason": "end_turn"}),
163                ));
164                Ok(rx)
165            }
166        }
167    }
168
169    /// Send a JSON-RPC notification (fire-and-forget, no response expected).
170    async fn send_notification<P>(&self, method: &str, params: P) -> Result<()>
171    where
172        P: serde::Serialize + Send,
173    {
174        match self {
175            AnyTransport::Gemini(t) => t.send_notification(method, params).await,
176            #[cfg(feature = "testing")]
177            AnyTransport::Mock(t) => {
178                let notif = serde_json::json!({
179                    "jsonrpc": "2.0",
180                    "method": method,
181                    "params": serde_json::to_value(params)?
182                });
183                t.write(&serde_json::to_string(&notif)?).await
184            }
185        }
186    }
187
188    /// Register the reverse-request handler on the underlying transport.
189    ///
190    /// Only meaningful for the `Gemini` variant; `Mock` silently ignores it
191    /// because `MockTransport` has no background reader for reverse requests.
192    async fn set_reverse_handler(
193        &self,
194        handler: Arc<dyn crate::transport::ReverseRequestHandler>,
195    ) {
196        match self {
197            AnyTransport::Gemini(t) => t.set_reverse_handler(handler).await,
198            #[cfg(feature = "testing")]
199            AnyTransport::Mock(_) => {} // no-op — mock has no background reader
200        }
201    }
202}
203
204// ── Client ───────────────────────────────────────────────────────────────────
205
206/// Stateful client for multi-turn Gemini CLI sessions.
207///
208/// Each `Client` instance manages a single Gemini CLI subprocess and JSON-RPC
209/// session. Calls to [`send`] / [`send_content`] return streams of [`Message`]
210/// values translated from raw `session/update` notifications.
211///
212/// # Lifecycle
213///
214/// 1. Construct with [`Client::new`].
215/// 2. Call [`connect`] — this spawns the subprocess and runs the handshake.
216/// 3. Call [`send`] one or more times to converse with the model.
217/// 4. Call [`close`] when finished. Dropping without closing is safe but may
218///    leave the subprocess running briefly until the OS reclaims it.
219///
220/// # Threading
221///
222/// `Client` is `Send` but not `Sync`. Share across tasks by wrapping in
223/// `Arc<Mutex<Client>>` when concurrent access is required.
224///
225/// [`connect`]: Client::connect
226/// [`send`]: Client::send
227/// [`close`]: Client::close
228pub struct Client {
229    /// Full session configuration — fields are referenced during the connect
230    /// handshake and at the start of each prompt turn.
231    config: ClientConfig,
232    /// Concrete transport implementation, wrapped in an enum for testability.
233    transport: AnyTransport,
234    /// Session ID assigned by the server after `session/new` or `session/load`.
235    session_id: Option<String>,
236    /// Stored notification stream — taken exactly once in `connect()` via
237    /// `Transport::read_messages()`. Held behind a `Mutex` so the async-stream
238    /// closure inside `send_content()` can lock it across yield points.
239    #[allow(clippy::type_complexity)]
240    notification_stream: Mutex<Option<Pin<Box<dyn Stream<Item = Result<Value>> + Send>>>>,
241    /// Accumulated per-turn translation state (text buffer, tool calls, …).
242    translation_ctx: Mutex<Option<TranslationContext>>,
243    /// Immutable context stamped onto every hook invocation.
244    hook_context: Option<HookContext>,
245    /// `true` after a successful `connect()`.
246    connected: bool,
247    /// Guard that prevents concurrent `send_content` calls from silently
248    /// hanging on the `notification_stream` Mutex. Set to `true` at the
249    /// start of `send_content`, reset to `false` when the stream completes
250    /// or in `close()`.
251    turn_in_progress: Arc<AtomicBool>,
252}
253
254/// RAII guard that resets `turn_in_progress` to `false` on drop, ensuring the
255/// flag is cleared even when the stream or function body returns early.
256struct TurnGuard(Arc<AtomicBool>);
257
258impl Drop for TurnGuard {
259    fn drop(&mut self) {
260        self.0.store(false, Ordering::Release);
261    }
262}
263
264impl Client {
265    // ── Constructors ─────────────────────────────────────────────────────────
266
267    /// Build a `Client` from a pre-constructed transport. All public
268    /// constructors delegate to this.
269    fn from_transport(config: ClientConfig, transport: AnyTransport) -> Self {
270        Self {
271            config,
272            transport,
273            session_id: None,
274            notification_stream: Mutex::new(None),
275            translation_ctx: Mutex::new(None),
276            hook_context: None,
277            connected: false,
278            turn_in_progress: Arc::new(AtomicBool::new(false)),
279        }
280    }
281
282    /// Create a new client with the given configuration.
283    ///
284    /// Resolves the `gemini` binary path (via `config.cli_path` or `PATH`)
285    /// and constructs a [`GeminiTransport`]. The subprocess is not spawned
286    /// until [`connect`] is called.
287    ///
288    /// # Errors
289    ///
290    /// Returns [`Error::CliNotFound`] when the binary cannot be located on
291    /// `PATH` and `config.cli_path` is `None`.
292    ///
293    /// [`connect`]: Client::connect
294    /// [`Error::CliNotFound`]: crate::Error::CliNotFound
295    pub fn new(config: ClientConfig) -> Result<Self> {
296        let transport = Arc::new(GeminiTransport::from_config(&config)?);
297        Ok(Self::from_transport(config, AnyTransport::Gemini(transport)))
298    }
299
300    /// Create a client backed by a caller-supplied [`GeminiTransport`].
301    ///
302    /// Useful when the caller has already constructed the transport with custom
303    /// parameters (e.g. a non-default working directory or extra env vars).
304    pub fn with_gemini_transport(config: ClientConfig, transport: Arc<GeminiTransport>) -> Self {
305        Self::from_transport(config, AnyTransport::Gemini(transport))
306    }
307
308    /// Create a client backed by a [`MockTransport`] for unit testing.
309    ///
310    /// Only available when the `testing` crate feature is enabled. The mock
311    /// transport captures writes and yields pre-loaded messages, making it
312    /// straightforward to test connect / send behaviour without spawning a
313    /// real subprocess.
314    ///
315    /// [`MockTransport`]: crate::testing::MockTransport
316    #[cfg(feature = "testing")]
317    pub fn with_mock_transport(
318        config: ClientConfig,
319        transport: Arc<crate::testing::MockTransport>,
320    ) -> Self {
321        Self::from_transport(config, AnyTransport::Mock(transport))
322    }
323
324    // ── Accessors ─────────────────────────────────────────────────────────────
325
326    /// Return the session ID assigned by the server, or `None` before
327    /// [`connect`] is called.
328    ///
329    /// [`connect`]: Client::connect
330    pub fn session_id(&self) -> Option<&str> {
331        self.session_id.as_deref()
332    }
333
334    /// Return the `prompt` field from the config.
335    ///
336    /// Exposed as a convenience for the free-function wrappers in `lib.rs`
337    /// that need to send the initial prompt without a separate `send()` call.
338    #[inline]
339    pub fn prompt(&self) -> &str {
340        &self.config.prompt
341    }
342
343    /// Return `true` if [`connect`] has been called successfully.
344    ///
345    /// [`connect`]: Client::connect
346    #[inline]
347    pub fn is_connected(&self) -> bool {
348        self.connected
349    }
350
351    // ── connect() ────────────────────────────────────────────────────────────
352
353    /// Connect to the Gemini CLI and establish a session.
354    ///
355    /// Performs the full initialisation sequence in order:
356    ///
357    /// 1. Spawn the subprocess via [`Transport::connect`].
358    /// 2. Take the notification stream (must happen before any requests).
359    /// 3. Register the optional [`PermissionHandler`] as the reverse-request handler.
360    /// 4. Send the `initialize` JSON-RPC request and await the result.
361    /// 5. Send `session/new` (or `session/load` when `config.resume` is set).
362    /// 6. Initialise the [`TranslationContext`] and hook context.
363    ///
364    /// Returns [`SessionInfo`] describing the established session.
365    ///
366    /// # Errors
367    ///
368    /// - [`Error::Config`] — called more than once on the same client.
369    /// - [`Error::SpawnFailed`] — the subprocess could not be started.
370    /// - [`Error::JsonRpcError`] — the server rejected `initialize` or `session/new`.
371    /// - [`Error::NotConnected`] — internal transport error during the handshake.
372    ///
373    /// [`Transport::connect`]: crate::transport::Transport::connect
374    /// [`Error::Config`]: crate::Error::Config
375    /// [`Error::SpawnFailed`]: crate::Error::SpawnFailed
376    /// [`Error::JsonRpcError`]: crate::Error::JsonRpcError
377    /// [`Error::NotConnected`]: crate::Error::NotConnected
378    pub async fn connect(&mut self) -> Result<SessionInfo> {
379        if self.connected {
380            return Err(Error::Config("Already connected".to_string()));
381        }
382        match self.config.connect_timeout {
383            Some(d) => {
384                tokio::time::timeout(d, self.connect_inner())
385                    .await
386                    .map_err(|_| {
387                        Error::Timeout(format!(
388                            "connect timed out after {:.1}s",
389                            d.as_secs_f64()
390                        ))
391                    })?
392            }
393            None => self.connect_inner().await,
394        }
395    }
396
397    async fn connect_inner(&mut self) -> Result<SessionInfo> {
398        // ── Step 1: Spawn subprocess ─────────────────────────────────────────
399        self.transport.connect().await?;
400
401        // ── Step 2: Take the notification stream ─────────────────────────────
402        let stream = self.transport.read_messages();
403        *self.notification_stream.lock().await = Some(stream);
404
405        // ── Step 3: Create shared tool input cache ───────────────────────────
406        let tool_input_cache: ToolInputCache =
407            Arc::new(std::sync::Mutex::new(std::collections::HashMap::new()));
408
409        // ── Step 4: Register permission handler ──────────────────────────────
410        if let Some(callback) = self.config.can_use_tool.clone() {
411            let handler =
412                Arc::new(PermissionHandler::new(callback, Some(Arc::clone(&tool_input_cache))));
413            self.transport.set_reverse_handler(handler).await;
414        }
415
416        // ── Step 5: initialize request ───────────────────────────────────────
417        let init_params = wire::InitializeParams {
418            protocol_version: 1,
419            client_capabilities: wire::ClientCapabilities::default(),
420            client_info: wire::ClientInfo {
421                name: "gemini-cli-sdk".to_string(),
422                version: env!("CARGO_PKG_VERSION").to_string(),
423            },
424        };
425        let init_result: wire::InitializeResult = self
426            .transport
427            .send_request(wire::method::INITIALIZE, init_params)
428            .await?;
429
430        // ── Step 6: Create or resume session ─────────────────────────────────
431        let session_id = if let Some(resume_id) = self.config.resume.clone() {
432            let params = wire::SessionLoadParams {
433                session_id: resume_id,
434                extra: Value::Object(Default::default()),
435            };
436            let result: wire::SessionLoadResult = self
437                .transport
438                .send_request(wire::method::SESSION_LOAD, params)
439                .await?;
440            result.session_id
441        } else {
442            let cwd = self
443                .config
444                .cwd
445                .clone()
446                .map(Ok)
447                .unwrap_or_else(|| {
448                    std::env::current_dir()
449                        .map_err(|e| Error::Config(format!("cannot determine cwd: {e}")))
450                })?
451                .to_string_lossy()
452                .to_string();
453            let mcp_wire = crate::mcp::mcp_servers_to_wire(&self.config.mcp_servers);
454            let params = wire::SessionNewParams {
455                cwd,
456                mcp_servers: mcp_wire,
457                extra: Value::Object(Default::default()),
458            };
459            let result: wire::SessionNewResult = self
460                .transport
461                .send_request(wire::method::SESSION_NEW, params)
462                .await?;
463            result.session_id
464        };
465
466        self.session_id = Some(session_id.clone());
467
468        // ── Step 7: Initialise translation and hook contexts ─────────────────
469        let model = self
470            .config
471            .model
472            .clone()
473            .unwrap_or_else(|| "gemini-2.5-pro".to_string());
474
475        *self.translation_ctx.lock().await =
476            Some(TranslationContext::new_with_cache(session_id.clone(), model.clone(), tool_input_cache));
477
478        let cwd_str = self
479            .config
480            .cwd
481            .clone()
482            .map(Ok)
483            .unwrap_or_else(|| {
484                std::env::current_dir()
485                    .map_err(|e| Error::Config(format!("cannot determine cwd: {e}")))
486            })?
487            .to_string_lossy()
488            .to_string();
489
490        self.hook_context = Some(HookContext {
491            session_id: session_id.clone(),
492            cwd: cwd_str,
493        });
494
495        self.connected = true;
496
497        let tools = init_result.agent_capabilities.tools.unwrap_or_default();
498        Ok(SessionInfo {
499            session_id,
500            model,
501            tools,
502            extra: init_result.extra,
503        })
504    }
505
506    // ── send() / send_content() ───────────────────────────────────────────────
507
508    /// Send a plain-text prompt and return a stream of translated [`Message`]
509    /// values.
510    ///
511    /// Convenience wrapper around [`send_content`] that constructs a single
512    /// `UserContent::Text` block from the provided string slice.
513    ///
514    /// # Errors
515    ///
516    /// Returns [`Error::NotConnected`] if [`connect`] has not been called.
517    /// Returns [`Error::Config`] if a `UserPromptSubmit` hook blocks the send.
518    ///
519    /// [`send_content`]: Client::send_content
520    /// [`connect`]: Client::connect
521    /// [`Error::NotConnected`]: crate::Error::NotConnected
522    /// [`Error::Config`]: crate::Error::Config
523    pub async fn send(
524        &self,
525        message: &str,
526    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
527        self.send_content(vec![UserContent::text(message)]).await
528    }
529
530    /// Send structured content and return a stream of translated [`Message`]
531    /// values.
532    ///
533    /// Accepts any mix of [`UserContent`] variants (text, base-64 image, URL
534    /// image). The content is serialised to `WireContentBlock` format and sent
535    /// to the CLI as a `session/prompt` JSON-RPC notification. The returned
536    /// stream drains `session/update` notifications from the shared channel
537    /// until the transport closes or the caller drops the stream.
538    ///
539    /// # Backpressure
540    ///
541    /// The notification channel is shared across all `send_content` calls on
542    /// the same client. Only one stream should be active at a time; concurrent
543    /// polling will contend on the `notification_stream` Mutex and may
544    /// interleave results.
545    ///
546    /// # Errors
547    ///
548    /// Returns [`Error::NotConnected`] if [`connect`] has not been called.
549    /// Returns [`Error::Config`] if a `UserPromptSubmit` hook blocks the send.
550    ///
551    /// [`connect`]: Client::connect
552    /// [`Error::NotConnected`]: crate::Error::NotConnected
553    /// [`Error::Config`]: crate::Error::Config
554    pub async fn send_content(
555        &self,
556        content: Vec<UserContent>,
557    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
558        if !self.connected {
559            return Err(Error::NotConnected);
560        }
561
562        // Prevent concurrent turns: a second call would block indefinitely on
563        // the notification_stream Mutex. Fail fast with a descriptive error.
564        if self
565            .turn_in_progress
566            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
567            .is_err()
568        {
569            return Err(Error::TurnInProgress);
570        }
571        let turn_guard = TurnGuard(Arc::clone(&self.turn_in_progress));
572
573        let session_id = self
574            .session_id
575            .as_ref()
576            .ok_or(Error::NotConnected)?
577            .clone();
578
579        // ── Fire UserPromptSubmit hook ────────────────────────────────────────
580        if let Some(ctx) = &self.hook_context {
581            let prompt_text = content.iter().find_map(|c| match c {
582                UserContent::Text { text } => Some(text.clone()),
583                _ => None,
584            });
585            let hook_input = HookInput {
586                event: HookEvent::UserPromptSubmit,
587                tool_name: None,
588                tool_input: None,
589                tool_output: None,
590                prompt: prompt_text,
591                session_id: session_id.clone(),
592                extra: Value::Object(Default::default()),
593            };
594            let output = hooks::execute_hooks(
595                &self.config.hooks,
596                hook_input,
597                ctx,
598                self.config.default_hook_timeout,
599            )
600            .await;
601            if output.decision == HookDecision::Block {
602                return Err(Error::Config(
603                    output
604                        .message
605                        .unwrap_or_else(|| "Blocked by hook".to_string()),
606                ));
607            }
608        }
609
610        // ── Reset translation context for the new turn ────────────────────────
611        {
612            let mut ctx_guard = self.translation_ctx.lock().await;
613            if let Some(ctx) = ctx_guard.as_mut() {
614                ctx.reset_turn();
615            }
616        }
617
618        // ── Convert content to wire format ────────────────────────────────────
619        let wire_content: Vec<wire::WireContentBlock> = content
620            .iter()
621            .map(crate::translate::user_content_to_wire)
622            .collect();
623
624        // ── Send session/prompt as a request ──────────────────────────────────
625        //
626        // `session/prompt` is sent as a JSON-RPC request (with `id`) so that the
627        // Gemini CLI sends back a correlated response containing `stopReason`
628        // when the turn completes. The notification stream delivers `session/update`
629        // events (text deltas, tool calls, etc.) concurrently, while the request
630        // response signals the turn boundary.
631        let prompt_params = wire::SessionPromptParams {
632            session_id: session_id.clone(),
633            prompt: wire_content,
634            extra: Value::Object(Default::default()),
635        };
636        let prompt_response_rx = self
637            .transport
638            .send_request_start(wire::method::SESSION_PROMPT, prompt_params)
639            .await?;
640
641        // ── Return notification-draining stream ───────────────────────────────
642        //
643        // Borrow `self` for the stream's lifetime. The Mutex is locked inside
644        // the generator so it can be released between yields, avoiding a held
645        // lock across await points in caller code.
646        let translation_ctx = &self.translation_ctx;
647        let notification_stream = &self.notification_stream;
648        let callback: Option<MessageCallback> = self.config.message_callback.clone();
649
650        Ok(async_stream::stream! {
651            // Move the guard into the stream so it is dropped (clearing the flag)
652            // when the stream is dropped or completes, regardless of the exit path.
653            let _turn_guard = turn_guard;
654            use tokio_stream::StreamExt as _;
655
656            // Lock the notification stream for the duration of this turn.
657            // A second concurrent call will block here until this stream is
658            // dropped.
659            let mut ns_guard = notification_stream.lock().await;
660            let stream = match ns_guard.as_mut() {
661                Some(s) => s,
662                None => {
663                    yield Err(Error::NotConnected);
664                    return;
665                }
666            };
667
668            // Fuse the prompt response oneshot so we can poll it alongside
669            // the notification stream without consuming it on the first poll.
670            let mut prompt_done = prompt_response_rx;
671            let mut turn_finished = false;
672
673            #[allow(unused_assignments)] // `turn_finished = true` precedes a `break`
674            loop {
675                tokio::select! {
676                    biased;
677
678                    // Poll notifications first — drain all pending updates before
679                    // checking if the turn is complete.  With `biased`, this branch
680                    // is checked before prompt_done, ensuring mock transports (which
681                    // resolve the oneshot immediately) still deliver notifications.
682                    maybe_notif = stream.next() => {
683                        match maybe_notif {
684                            None => break, // channel closed — subprocess exited
685                            Some(Err(e)) => {
686                                yield Err(e);
687                                break;
688                            }
689                            Some(Ok(value)) => {
690                                // Filter: only process session/update notifications.
691                                let method = value
692                                    .get("method")
693                                    .and_then(|m| m.as_str())
694                                    .unwrap_or("");
695                                tracing::debug!(method, "notification received");
696                                if method != wire::method::SESSION_UPDATE {
697                                    continue;
698                                }
699
700                                let params = match value.get("params") {
701                                    Some(p) => p.clone(),
702                                    None => continue,
703                                };
704
705                                let notif: wire::SessionUpdateNotification =
706                                    match serde_json::from_value(params) {
707                                        Ok(n) => n,
708                                        Err(e) => {
709                                            tracing::warn!(
710                                                error = %e,
711                                                "client: failed to parse session/update params — skipping"
712                                            );
713                                            continue;
714                                        }
715                                    };
716
717                                // Parse the discriminator and translate to public messages.
718                                let update = notif.parse();
719                                tracing::debug!(?update, "parsed session update");
720
721                                let mut ctx_guard = translation_ctx.lock().await;
722                                if let Some(ctx) = ctx_guard.as_mut() {
723                                    let messages = ctx.translate(update);
724                                    tracing::debug!(count = messages.len(), "translated to messages");
725                                    // Drop the ctx lock before yielding so callers that
726                                    // inspect TranslationContext are not blocked.
727                                    drop(ctx_guard);
728
729                                    for msg in messages {
730                                        // Invoke optional side-effect callback.
731                                        if let Some(cb) = &callback {
732                                            cb(msg.clone()).await;
733                                        }
734                                        yield Ok(msg);
735                                    }
736                                }
737                            }
738                        }
739                    }
740
741                    // Poll prompt response — when it arrives, the turn is done.
742                    resp = &mut prompt_done, if !turn_finished => {
743                        turn_finished = true;
744
745                        // Translate the prompt result into a Message::Result.
746                        match resp {
747                            Ok(response) => {
748                                match response.into_result() {
749                                    Ok(result_value) => {
750                                        let prompt_result: wire::SessionPromptResult =
751                                            serde_json::from_value(result_value)
752                                                .unwrap_or_else(|e| {
753                                                    tracing::warn!(
754                                                        error = %e,
755                                                        "failed to parse SessionPromptResult, using default"
756                                                    );
757                                                    Default::default()
758                                                });
759
760                                        let result_msg = Message::Result(crate::types::messages::ResultMessage {
761                                            subtype: "success".to_string(),
762                                            is_error: false,
763                                            duration_ms: 0.0,
764                                            duration_api_ms: 0.0,
765                                            num_turns: 1,
766                                            session_id: session_id.clone(),
767                                            usage: crate::types::messages::Usage::default(),
768                                            stop_reason: prompt_result.stop_reason,
769                                            extra: prompt_result.extra,
770                                        });
771
772                                        if let Some(cb) = &callback {
773                                            cb(result_msg.clone()).await;
774                                        }
775                                        yield Ok(result_msg);
776                                    }
777                                    Err(err) => {
778                                        let error_msg = Message::Result(crate::types::messages::ResultMessage {
779                                            subtype: "error".to_string(),
780                                            is_error: true,
781                                            duration_ms: 0.0,
782                                            duration_api_ms: 0.0,
783                                            num_turns: 1,
784                                            session_id: session_id.clone(),
785                                            usage: crate::types::messages::Usage::default(),
786                                            stop_reason: format!(
787                                                "JSON-RPC error {}: {}",
788                                                err.code, err.message
789                                            ),
790                                            extra: serde_json::json!({
791                                                "code": err.code,
792                                                "message": err.message,
793                                                "data": err.data,
794                                            }),
795                                        });
796
797                                        if let Some(cb) = &callback {
798                                            cb(error_msg.clone()).await;
799                                        }
800                                        yield Ok(error_msg);
801                                    }
802                                }
803                            }
804                            Err(_) => {
805                                // Response channel dropped — treat as transport error.
806                                yield Err(Error::Transport(
807                                    "Prompt response channel closed unexpectedly".to_string()
808                                ));
809                            }
810                        }
811
812                        // Reset the translation context for the next turn.
813                        let mut ctx_guard = translation_ctx.lock().await;
814                        if let Some(ctx) = ctx_guard.as_mut() {
815                            ctx.reset_turn();
816                        }
817
818                        break;
819                    }
820                }
821            }
822
823            // Turn complete — _turn_guard is dropped here, clearing the flag.
824        })
825    }
826
827    // ── interrupt() ──────────────────────────────────────────────────────────
828
829    /// Interrupt the current in-progress prompt turn.
830    ///
831    /// Sends a `session/cancel` notification to the CLI (best-effort) and then
832    /// delivers a process-level interrupt signal via the transport (SIGINT on
833    /// Unix, CTRL_C_EVENT on Windows).
834    ///
835    /// # Errors
836    ///
837    /// The `session/cancel` notification errors are silently ignored. Transport
838    /// interrupt errors are propagated.
839    pub async fn interrupt(&self) -> Result<()> {
840        if let Some(session_id) = &self.session_id {
841            let params = wire::SessionCancelParams {
842                session_id: session_id.clone(),
843            };
844            // Best-effort: a cancel notification failure must not prevent the
845            // process-level interrupt below from being delivered.
846            let _ = self
847                .transport
848                .send_notification(wire::method::SESSION_CANCEL, params)
849                .await;
850        }
851        self.transport.interrupt().await
852    }
853
854    // ── close() ──────────────────────────────────────────────────────────────
855
856    /// Close the client and terminate the CLI subprocess.
857    ///
858    /// Fires the `Stop` lifecycle hook before closing the transport. The
859    /// call is idempotent — invoking it on an already-closed client is safe.
860    ///
861    /// # Errors
862    ///
863    /// Propagates errors from the transport's `close()` implementation. Hook
864    /// errors are silently ignored.
865    pub async fn close(&mut self) -> Result<()> {
866        // ── Fire Stop hook ────────────────────────────────────────────────────
867        if let Some(ctx) = &self.hook_context {
868            let hook_input = HookInput {
869                event: HookEvent::Stop,
870                tool_name: None,
871                tool_input: None,
872                tool_output: None,
873                prompt: None,
874                session_id: self.session_id.clone().unwrap_or_default(),
875                extra: Value::Object(Default::default()),
876            };
877            // Ignore errors — we are shutting down regardless.
878            let _ = hooks::execute_hooks(
879                &self.config.hooks,
880                hook_input,
881                ctx,
882                self.config.default_hook_timeout,
883            )
884            .await;
885        }
886
887        self.connected = false;
888        self.turn_in_progress.store(false, Ordering::Release);
889        self.transport.close().await?;
890        Ok(())
891    }
892}
893
894// ── Tests ─────────────────────────────────────────────────────────────────────
895
896#[cfg(test)]
897mod tests {
898    use super::*;
899
900    // ── Helpers ───────────────────────────────────────────────────────────────
901
902    /// Build a `GeminiTransport` that points to a nonexistent binary so tests
903    /// never accidentally spawn a real subprocess.
904    fn make_fake_transport() -> Arc<GeminiTransport> {
905        Arc::new(GeminiTransport::new(
906            std::path::PathBuf::from("/nonexistent/gemini"),
907            vec!["--experimental-acp".to_string()],
908            std::path::PathBuf::from("/tmp"),
909            std::collections::HashMap::new(),
910            None,
911            None,
912        ))
913    }
914
915    fn minimal_config() -> ClientConfig {
916        ClientConfig::builder().prompt("test prompt").build()
917    }
918
919    // ── test_client_session_id ────────────────────────────────────────────────
920
921    /// Before `connect()`, `session_id()` must return `None`.
922    #[test]
923    fn test_client_session_id() {
924        let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
925        assert!(
926            client.session_id().is_none(),
927            "session_id must be None before connect() is called"
928        );
929    }
930
931    // ── test_client_not_connected_error ───────────────────────────────────────
932
933    /// Calling `send()` before `connect()` must return `Err(Error::NotConnected)`.
934    #[tokio::test]
935    async fn test_client_not_connected_error() {
936        let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
937        let result = client.send("hello").await;
938        assert!(result.is_err(), "send() before connect() must fail");
939        let err = result.err().expect("expected an error");
940        assert!(
941            matches!(err, Error::NotConnected),
942            "error must be Error::NotConnected, got: {err:?}"
943        );
944    }
945
946    // ── test_client_send_content_not_connected ────────────────────────────────
947
948    /// `send_content()` before connect must also return `Error::NotConnected`.
949    #[tokio::test]
950    async fn test_client_send_content_not_connected() {
951        let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
952        let result = client.send_content(vec![UserContent::text("hi")]).await;
953        let err = result.err().expect("expected an error");
954        assert!(
955            matches!(err, Error::NotConnected),
956            "send_content before connect must return Error::NotConnected, got: {err:?}"
957        );
958    }
959
960    // ── test_client_prompt_accessor ───────────────────────────────────────────
961
962    /// `prompt()` must reflect the value set in the config.
963    #[test]
964    fn test_client_prompt_accessor() {
965        let config = ClientConfig::builder().prompt("my test prompt").build();
966        let client = Client::with_gemini_transport(config, make_fake_transport());
967        assert_eq!(client.prompt(), "my test prompt");
968    }
969
970    // ── test_client_is_connected_default ─────────────────────────────────────
971
972    /// A freshly constructed client must not report itself as connected.
973    #[test]
974    fn test_client_is_connected_default() {
975        let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
976        assert!(
977            !client.is_connected(),
978            "is_connected must be false before connect()"
979        );
980    }
981
982    // ── test_client_double_connect_error ─────────────────────────────────────
983
984    /// Calling `connect()` after a successful connection must return
985    /// `Err(Error::Config)`. We simulate the connected state by setting the
986    /// field directly — no real subprocess is involved.
987    #[tokio::test]
988    async fn test_client_double_connect_error() {
989        let mut client =
990            Client::with_gemini_transport(minimal_config(), make_fake_transport());
991        // Bypass the real connect sequence by directly marking as connected.
992        client.connected = true;
993        let result = client.connect().await;
994        assert!(result.is_err());
995        assert!(
996            matches!(result.unwrap_err(), Error::Config(_)),
997            "second connect must return Error::Config"
998        );
999    }
1000
1001    // ── test_client_interrupt_before_connect ──────────────────────────────────
1002
1003    /// `interrupt()` before connect must not panic and must return `Ok(())`.
1004    /// When there is no subprocess, the platform-level signal is a no-op.
1005    #[tokio::test]
1006    async fn test_client_interrupt_before_connect() {
1007        let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
1008        // No session_id, no subprocess — must not panic.
1009        let result = client.interrupt().await;
1010        assert!(
1011            result.is_ok(),
1012            "interrupt before connect must not return an error"
1013        );
1014    }
1015
1016    // ── test_client_mock_transport_constructor ────────────────────────────────
1017
1018    #[cfg(feature = "testing")]
1019    #[test]
1020    fn test_client_mock_transport_constructor() {
1021        use crate::testing::MockTransport;
1022        let transport = Arc::new(MockTransport::new(vec![]));
1023        let client = Client::with_mock_transport(minimal_config(), transport);
1024        assert!(!client.is_connected());
1025        assert!(client.session_id().is_none());
1026    }
1027}