gemini_cli_sdk/client.rs
1//! Stateful client for multi-turn Gemini CLI sessions.
2//!
3//! The [`Client`] ties together transport, permissions, hooks, and translation
4//! into a single consumer-facing type. Call [`connect`] once to establish the
5//! session, then [`send`] or [`send_content`] for each user turn.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Client
11//! ├── AnyTransport ──► GeminiTransport (production) | MockTransport (testing)
12//! ├── TranslationContext — wire SessionUpdate → public Message
13//! ├── HookContext — lifecycle hooks (UserPromptSubmit, Stop, …)
14//! └── notification_stream — single mpsc receiver taken from transport
15//! ```
16//!
17//! # Example
18//!
19//! ```rust,no_run
20//! use gemini_cli_sdk::{Client, ClientConfig};
21//!
22//! #[tokio::main]
23//! async fn main() -> gemini_cli_sdk::Result<()> {
24//! let config = ClientConfig::builder()
25//! .prompt("Build a REST API")
26//! .build();
27//! let mut client = Client::new(config)?;
28//! let _info = client.connect().await?;
29//! // Consume the stream by pinning it in place.
30//! // (See tokio::pin! or futures::pin_mut! for non-Unpin streams.)
31//! client.close().await?;
32//! Ok(())
33//! }
34//! ```
35
36use std::pin::Pin;
37use std::sync::atomic::{AtomicBool, Ordering};
38use std::sync::Arc;
39
40use futures_core::Stream;
41use serde_json::Value;
42use tokio::sync::Mutex;
43
44use crate::callback::MessageCallback;
45use crate::config::ClientConfig;
46use crate::hooks::{self, HookContext, HookDecision, HookEvent, HookInput};
47use crate::permissions::{PermissionHandler, ToolInputCache};
48use crate::translate::TranslationContext;
49use crate::transport::{GeminiTransport, Transport};
50use crate::types::content::UserContent;
51use crate::types::messages::{Message, SessionInfo};
52use crate::wire;
53use crate::{Error, Result};
54
55// ── AnyTransport ─────────────────────────────────────────────────────────────
56
57/// Internal enum that wraps either the production [`GeminiTransport`] or the
58/// test [`MockTransport`] so both can expose request/notification helpers.
59///
60/// This sidesteps the object-safety constraint that prevents adding generic
61/// `send_request<P, R>` methods to the [`Transport`] trait directly.
62pub(crate) enum AnyTransport {
63 /// Production subprocess transport.
64 Gemini(Arc<GeminiTransport>),
65 /// In-memory transport for unit tests.
66 #[cfg(feature = "testing")]
67 Mock(Arc<crate::testing::MockTransport>),
68}
69
70/// Generate a simple delegation method on `AnyTransport` that forwards the call
71/// identically to every variant.
72macro_rules! delegate_transport {
73 // async method with arguments
74 (async fn $name:ident(&self $(, $arg:ident : $arg_ty:ty)*) -> $ret:ty) => {
75 async fn $name(&self $(, $arg: $arg_ty)*) -> $ret {
76 match self {
77 AnyTransport::Gemini(t) => t.$name($($arg),*).await,
78 #[cfg(feature = "testing")]
79 AnyTransport::Mock(t) => t.$name($($arg),*).await,
80 }
81 }
82 };
83 // sync method
84 (fn $name:ident(&self $(, $arg:ident : $arg_ty:ty)*) -> $ret:ty) => {
85 fn $name(&self $(, $arg: $arg_ty)*) -> $ret {
86 match self {
87 AnyTransport::Gemini(t) => t.$name($($arg),*),
88 #[cfg(feature = "testing")]
89 AnyTransport::Mock(t) => t.$name($($arg),*),
90 }
91 }
92 };
93}
94
95impl AnyTransport {
96 // ── Transport trait delegation ────────────────────────────────────────
97
98 delegate_transport!(async fn connect(&self) -> Result<()>);
99 delegate_transport!(fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>>);
100 delegate_transport!(async fn interrupt(&self) -> Result<()>);
101 delegate_transport!(async fn close(&self) -> Result<Option<i32>>);
102
103 // ── JSON-RPC helpers ──────────────────────────────────────────────────
104
105 /// Send a typed JSON-RPC request and await the correlated response.
106 ///
107 /// For `Mock` transports the params are serialised and written to the
108 /// captures list; the result is a zero-value default-deserialised response.
109 /// Tests that need specific response values should pre-load them via
110 /// `ScenarioBuilder` or `MockTransport::push_message`.
111 async fn send_request<P, R>(&self, method: &str, params: P) -> Result<R>
112 where
113 P: serde::Serialize + Send,
114 R: serde::de::DeserializeOwned,
115 {
116 match self {
117 AnyTransport::Gemini(t) => t.send_request(method, params).await,
118 #[cfg(feature = "testing")]
119 AnyTransport::Mock(t) => {
120 // Capture the outbound request for assertion in tests.
121 let req = serde_json::json!({
122 "jsonrpc": "2.0",
123 "method": method,
124 "params": serde_json::to_value(params)?,
125 "id": 1
126 });
127 t.write(&serde_json::to_string(&req)?).await?;
128 // Return a default-deserialized result. Tests using the mock
129 // transport should override this via scenario injection when
130 // the actual response fields matter.
131 serde_json::from_value(Value::Object(Default::default())).map_err(Error::Json)
132 }
133 }
134 }
135
136 /// Send a JSON-RPC request and return the response receiver without
137 /// blocking. The caller must await the receiver concurrently.
138 async fn send_request_start<P>(
139 &self,
140 method: &str,
141 params: P,
142 ) -> Result<tokio::sync::oneshot::Receiver<crate::jsonrpc::JsonRpcResponse>>
143 where
144 P: serde::Serialize + Send,
145 {
146 match self {
147 AnyTransport::Gemini(t) => t.send_request_start(method, params).await,
148 #[cfg(feature = "testing")]
149 AnyTransport::Mock(t) => {
150 // Capture the request for test assertions.
151 let req = serde_json::json!({
152 "jsonrpc": "2.0",
153 "method": method,
154 "params": serde_json::to_value(params)?,
155 "id": 1
156 });
157 t.write(&serde_json::to_string(&req)?).await?;
158 // Return a pre-resolved receiver with a default prompt result.
159 let (tx, rx) = tokio::sync::oneshot::channel();
160 let _ = tx.send(crate::jsonrpc::JsonRpcResponse::success(
161 crate::jsonrpc::JsonRpcId::Number(0),
162 serde_json::json!({"stopReason": "end_turn"}),
163 ));
164 Ok(rx)
165 }
166 }
167 }
168
169 /// Send a JSON-RPC notification (fire-and-forget, no response expected).
170 async fn send_notification<P>(&self, method: &str, params: P) -> Result<()>
171 where
172 P: serde::Serialize + Send,
173 {
174 match self {
175 AnyTransport::Gemini(t) => t.send_notification(method, params).await,
176 #[cfg(feature = "testing")]
177 AnyTransport::Mock(t) => {
178 let notif = serde_json::json!({
179 "jsonrpc": "2.0",
180 "method": method,
181 "params": serde_json::to_value(params)?
182 });
183 t.write(&serde_json::to_string(¬if)?).await
184 }
185 }
186 }
187
188 /// Register the reverse-request handler on the underlying transport.
189 ///
190 /// Only meaningful for the `Gemini` variant; `Mock` silently ignores it
191 /// because `MockTransport` has no background reader for reverse requests.
192 async fn set_reverse_handler(
193 &self,
194 handler: Arc<dyn crate::transport::ReverseRequestHandler>,
195 ) {
196 match self {
197 AnyTransport::Gemini(t) => t.set_reverse_handler(handler).await,
198 #[cfg(feature = "testing")]
199 AnyTransport::Mock(_) => {} // no-op — mock has no background reader
200 }
201 }
202}
203
204// ── Client ───────────────────────────────────────────────────────────────────
205
206/// Stateful client for multi-turn Gemini CLI sessions.
207///
208/// Each `Client` instance manages a single Gemini CLI subprocess and JSON-RPC
209/// session. Calls to [`send`] / [`send_content`] return streams of [`Message`]
210/// values translated from raw `session/update` notifications.
211///
212/// # Lifecycle
213///
214/// 1. Construct with [`Client::new`].
215/// 2. Call [`connect`] — this spawns the subprocess and runs the handshake.
216/// 3. Call [`send`] one or more times to converse with the model.
217/// 4. Call [`close`] when finished. Dropping without closing is safe but may
218/// leave the subprocess running briefly until the OS reclaims it.
219///
220/// # Threading
221///
222/// `Client` is `Send` but not `Sync`. Share across tasks by wrapping in
223/// `Arc<Mutex<Client>>` when concurrent access is required.
224///
225/// [`connect`]: Client::connect
226/// [`send`]: Client::send
227/// [`close`]: Client::close
228pub struct Client {
229 /// Full session configuration — fields are referenced during the connect
230 /// handshake and at the start of each prompt turn.
231 config: ClientConfig,
232 /// Concrete transport implementation, wrapped in an enum for testability.
233 transport: AnyTransport,
234 /// Session ID assigned by the server after `session/new` or `session/load`.
235 session_id: Option<String>,
236 /// Stored notification stream — taken exactly once in `connect()` via
237 /// `Transport::read_messages()`. Held behind a `Mutex` so the async-stream
238 /// closure inside `send_content()` can lock it across yield points.
239 #[allow(clippy::type_complexity)]
240 notification_stream: Mutex<Option<Pin<Box<dyn Stream<Item = Result<Value>> + Send>>>>,
241 /// Accumulated per-turn translation state (text buffer, tool calls, …).
242 translation_ctx: Mutex<Option<TranslationContext>>,
243 /// Immutable context stamped onto every hook invocation.
244 hook_context: Option<HookContext>,
245 /// `true` after a successful `connect()`.
246 connected: bool,
247 /// Guard that prevents concurrent `send_content` calls from silently
248 /// hanging on the `notification_stream` Mutex. Set to `true` at the
249 /// start of `send_content`, reset to `false` when the stream completes
250 /// or in `close()`.
251 turn_in_progress: Arc<AtomicBool>,
252}
253
254/// RAII guard that resets `turn_in_progress` to `false` on drop, ensuring the
255/// flag is cleared even when the stream or function body returns early.
256struct TurnGuard(Arc<AtomicBool>);
257
258impl Drop for TurnGuard {
259 fn drop(&mut self) {
260 self.0.store(false, Ordering::Release);
261 }
262}
263
264impl Client {
265 // ── Constructors ─────────────────────────────────────────────────────────
266
267 /// Build a `Client` from a pre-constructed transport. All public
268 /// constructors delegate to this.
269 fn from_transport(config: ClientConfig, transport: AnyTransport) -> Self {
270 Self {
271 config,
272 transport,
273 session_id: None,
274 notification_stream: Mutex::new(None),
275 translation_ctx: Mutex::new(None),
276 hook_context: None,
277 connected: false,
278 turn_in_progress: Arc::new(AtomicBool::new(false)),
279 }
280 }
281
282 /// Create a new client with the given configuration.
283 ///
284 /// Resolves the `gemini` binary path (via `config.cli_path` or `PATH`)
285 /// and constructs a [`GeminiTransport`]. The subprocess is not spawned
286 /// until [`connect`] is called.
287 ///
288 /// # Errors
289 ///
290 /// Returns [`Error::CliNotFound`] when the binary cannot be located on
291 /// `PATH` and `config.cli_path` is `None`.
292 ///
293 /// [`connect`]: Client::connect
294 /// [`Error::CliNotFound`]: crate::Error::CliNotFound
295 pub fn new(config: ClientConfig) -> Result<Self> {
296 let transport = Arc::new(GeminiTransport::from_config(&config)?);
297 Ok(Self::from_transport(config, AnyTransport::Gemini(transport)))
298 }
299
300 /// Create a client backed by a caller-supplied [`GeminiTransport`].
301 ///
302 /// Useful when the caller has already constructed the transport with custom
303 /// parameters (e.g. a non-default working directory or extra env vars).
304 pub fn with_gemini_transport(config: ClientConfig, transport: Arc<GeminiTransport>) -> Self {
305 Self::from_transport(config, AnyTransport::Gemini(transport))
306 }
307
308 /// Create a client backed by a [`MockTransport`] for unit testing.
309 ///
310 /// Only available when the `testing` crate feature is enabled. The mock
311 /// transport captures writes and yields pre-loaded messages, making it
312 /// straightforward to test connect / send behaviour without spawning a
313 /// real subprocess.
314 ///
315 /// [`MockTransport`]: crate::testing::MockTransport
316 #[cfg(feature = "testing")]
317 pub fn with_mock_transport(
318 config: ClientConfig,
319 transport: Arc<crate::testing::MockTransport>,
320 ) -> Self {
321 Self::from_transport(config, AnyTransport::Mock(transport))
322 }
323
324 // ── Accessors ─────────────────────────────────────────────────────────────
325
326 /// Return the session ID assigned by the server, or `None` before
327 /// [`connect`] is called.
328 ///
329 /// [`connect`]: Client::connect
330 pub fn session_id(&self) -> Option<&str> {
331 self.session_id.as_deref()
332 }
333
334 /// Return the `prompt` field from the config.
335 ///
336 /// Exposed as a convenience for the free-function wrappers in `lib.rs`
337 /// that need to send the initial prompt without a separate `send()` call.
338 #[inline]
339 pub fn prompt(&self) -> &str {
340 &self.config.prompt
341 }
342
343 /// Return `true` if [`connect`] has been called successfully.
344 ///
345 /// [`connect`]: Client::connect
346 #[inline]
347 pub fn is_connected(&self) -> bool {
348 self.connected
349 }
350
351 // ── connect() ────────────────────────────────────────────────────────────
352
353 /// Connect to the Gemini CLI and establish a session.
354 ///
355 /// Performs the full initialisation sequence in order:
356 ///
357 /// 1. Spawn the subprocess via [`Transport::connect`].
358 /// 2. Take the notification stream (must happen before any requests).
359 /// 3. Register the optional [`PermissionHandler`] as the reverse-request handler.
360 /// 4. Send the `initialize` JSON-RPC request and await the result.
361 /// 5. Send `session/new` (or `session/load` when `config.resume` is set).
362 /// 6. Initialise the [`TranslationContext`] and hook context.
363 ///
364 /// Returns [`SessionInfo`] describing the established session.
365 ///
366 /// # Errors
367 ///
368 /// - [`Error::Config`] — called more than once on the same client.
369 /// - [`Error::SpawnFailed`] — the subprocess could not be started.
370 /// - [`Error::JsonRpcError`] — the server rejected `initialize` or `session/new`.
371 /// - [`Error::NotConnected`] — internal transport error during the handshake.
372 ///
373 /// [`Transport::connect`]: crate::transport::Transport::connect
374 /// [`Error::Config`]: crate::Error::Config
375 /// [`Error::SpawnFailed`]: crate::Error::SpawnFailed
376 /// [`Error::JsonRpcError`]: crate::Error::JsonRpcError
377 /// [`Error::NotConnected`]: crate::Error::NotConnected
378 pub async fn connect(&mut self) -> Result<SessionInfo> {
379 if self.connected {
380 return Err(Error::Config("Already connected".to_string()));
381 }
382 match self.config.connect_timeout {
383 Some(d) => {
384 tokio::time::timeout(d, self.connect_inner())
385 .await
386 .map_err(|_| {
387 Error::Timeout(format!(
388 "connect timed out after {:.1}s",
389 d.as_secs_f64()
390 ))
391 })?
392 }
393 None => self.connect_inner().await,
394 }
395 }
396
397 async fn connect_inner(&mut self) -> Result<SessionInfo> {
398 // ── Step 1: Spawn subprocess ─────────────────────────────────────────
399 self.transport.connect().await?;
400
401 // ── Step 2: Take the notification stream ─────────────────────────────
402 let stream = self.transport.read_messages();
403 *self.notification_stream.lock().await = Some(stream);
404
405 // ── Step 3: Create shared tool input cache ───────────────────────────
406 let tool_input_cache: ToolInputCache =
407 Arc::new(std::sync::Mutex::new(std::collections::HashMap::new()));
408
409 // ── Step 4: Register permission handler ──────────────────────────────
410 if let Some(callback) = self.config.can_use_tool.clone() {
411 let handler =
412 Arc::new(PermissionHandler::new(callback, Some(Arc::clone(&tool_input_cache))));
413 self.transport.set_reverse_handler(handler).await;
414 }
415
416 // ── Step 5: initialize request ───────────────────────────────────────
417 let init_params = wire::InitializeParams {
418 protocol_version: 1,
419 client_capabilities: wire::ClientCapabilities::default(),
420 client_info: wire::ClientInfo {
421 name: "gemini-cli-sdk".to_string(),
422 version: env!("CARGO_PKG_VERSION").to_string(),
423 },
424 };
425 let init_result: wire::InitializeResult = self
426 .transport
427 .send_request(wire::method::INITIALIZE, init_params)
428 .await?;
429
430 // ── Step 6: Create or resume session ─────────────────────────────────
431 let session_id = if let Some(resume_id) = self.config.resume.clone() {
432 let params = wire::SessionLoadParams {
433 session_id: resume_id,
434 extra: Value::Object(Default::default()),
435 };
436 let result: wire::SessionLoadResult = self
437 .transport
438 .send_request(wire::method::SESSION_LOAD, params)
439 .await?;
440 result.session_id
441 } else {
442 let cwd = self
443 .config
444 .cwd
445 .clone()
446 .map(Ok)
447 .unwrap_or_else(|| {
448 std::env::current_dir()
449 .map_err(|e| Error::Config(format!("cannot determine cwd: {e}")))
450 })?
451 .to_string_lossy()
452 .to_string();
453 let mcp_wire = crate::mcp::mcp_servers_to_wire(&self.config.mcp_servers);
454 let params = wire::SessionNewParams {
455 cwd,
456 mcp_servers: mcp_wire,
457 extra: Value::Object(Default::default()),
458 };
459 let result: wire::SessionNewResult = self
460 .transport
461 .send_request(wire::method::SESSION_NEW, params)
462 .await?;
463 result.session_id
464 };
465
466 self.session_id = Some(session_id.clone());
467
468 // ── Step 7: Initialise translation and hook contexts ─────────────────
469 let model = self
470 .config
471 .model
472 .clone()
473 .unwrap_or_else(|| "gemini-2.5-pro".to_string());
474
475 *self.translation_ctx.lock().await =
476 Some(TranslationContext::new_with_cache(session_id.clone(), model.clone(), tool_input_cache));
477
478 let cwd_str = self
479 .config
480 .cwd
481 .clone()
482 .map(Ok)
483 .unwrap_or_else(|| {
484 std::env::current_dir()
485 .map_err(|e| Error::Config(format!("cannot determine cwd: {e}")))
486 })?
487 .to_string_lossy()
488 .to_string();
489
490 self.hook_context = Some(HookContext {
491 session_id: session_id.clone(),
492 cwd: cwd_str,
493 });
494
495 self.connected = true;
496
497 let tools = init_result.agent_capabilities.tools.unwrap_or_default();
498 Ok(SessionInfo {
499 session_id,
500 model,
501 tools,
502 extra: init_result.extra,
503 })
504 }
505
506 // ── send() / send_content() ───────────────────────────────────────────────
507
508 /// Send a plain-text prompt and return a stream of translated [`Message`]
509 /// values.
510 ///
511 /// Convenience wrapper around [`send_content`] that constructs a single
512 /// `UserContent::Text` block from the provided string slice.
513 ///
514 /// # Errors
515 ///
516 /// Returns [`Error::NotConnected`] if [`connect`] has not been called.
517 /// Returns [`Error::Config`] if a `UserPromptSubmit` hook blocks the send.
518 ///
519 /// [`send_content`]: Client::send_content
520 /// [`connect`]: Client::connect
521 /// [`Error::NotConnected`]: crate::Error::NotConnected
522 /// [`Error::Config`]: crate::Error::Config
523 pub async fn send(
524 &self,
525 message: &str,
526 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
527 self.send_content(vec![UserContent::text(message)]).await
528 }
529
530 /// Send structured content and return a stream of translated [`Message`]
531 /// values.
532 ///
533 /// Accepts any mix of [`UserContent`] variants (text, base-64 image, URL
534 /// image). The content is serialised to `WireContentBlock` format and sent
535 /// to the CLI as a `session/prompt` JSON-RPC notification. The returned
536 /// stream drains `session/update` notifications from the shared channel
537 /// until the transport closes or the caller drops the stream.
538 ///
539 /// # Backpressure
540 ///
541 /// The notification channel is shared across all `send_content` calls on
542 /// the same client. Only one stream should be active at a time; concurrent
543 /// polling will contend on the `notification_stream` Mutex and may
544 /// interleave results.
545 ///
546 /// # Errors
547 ///
548 /// Returns [`Error::NotConnected`] if [`connect`] has not been called.
549 /// Returns [`Error::Config`] if a `UserPromptSubmit` hook blocks the send.
550 ///
551 /// [`connect`]: Client::connect
552 /// [`Error::NotConnected`]: crate::Error::NotConnected
553 /// [`Error::Config`]: crate::Error::Config
554 pub async fn send_content(
555 &self,
556 content: Vec<UserContent>,
557 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
558 if !self.connected {
559 return Err(Error::NotConnected);
560 }
561
562 // Prevent concurrent turns: a second call would block indefinitely on
563 // the notification_stream Mutex. Fail fast with a descriptive error.
564 if self
565 .turn_in_progress
566 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
567 .is_err()
568 {
569 return Err(Error::TurnInProgress);
570 }
571 let turn_guard = TurnGuard(Arc::clone(&self.turn_in_progress));
572
573 let session_id = self
574 .session_id
575 .as_ref()
576 .ok_or(Error::NotConnected)?
577 .clone();
578
579 // ── Fire UserPromptSubmit hook ────────────────────────────────────────
580 if let Some(ctx) = &self.hook_context {
581 let prompt_text = content.iter().find_map(|c| match c {
582 UserContent::Text { text } => Some(text.clone()),
583 _ => None,
584 });
585 let hook_input = HookInput {
586 event: HookEvent::UserPromptSubmit,
587 tool_name: None,
588 tool_input: None,
589 tool_output: None,
590 prompt: prompt_text,
591 session_id: session_id.clone(),
592 extra: Value::Object(Default::default()),
593 };
594 let output = hooks::execute_hooks(
595 &self.config.hooks,
596 hook_input,
597 ctx,
598 self.config.default_hook_timeout,
599 )
600 .await;
601 if output.decision == HookDecision::Block {
602 return Err(Error::Config(
603 output
604 .message
605 .unwrap_or_else(|| "Blocked by hook".to_string()),
606 ));
607 }
608 }
609
610 // ── Reset translation context for the new turn ────────────────────────
611 {
612 let mut ctx_guard = self.translation_ctx.lock().await;
613 if let Some(ctx) = ctx_guard.as_mut() {
614 ctx.reset_turn();
615 }
616 }
617
618 // ── Convert content to wire format ────────────────────────────────────
619 let wire_content: Vec<wire::WireContentBlock> = content
620 .iter()
621 .map(crate::translate::user_content_to_wire)
622 .collect();
623
624 // ── Send session/prompt as a request ──────────────────────────────────
625 //
626 // `session/prompt` is sent as a JSON-RPC request (with `id`) so that the
627 // Gemini CLI sends back a correlated response containing `stopReason`
628 // when the turn completes. The notification stream delivers `session/update`
629 // events (text deltas, tool calls, etc.) concurrently, while the request
630 // response signals the turn boundary.
631 let prompt_params = wire::SessionPromptParams {
632 session_id: session_id.clone(),
633 prompt: wire_content,
634 extra: Value::Object(Default::default()),
635 };
636 let prompt_response_rx = self
637 .transport
638 .send_request_start(wire::method::SESSION_PROMPT, prompt_params)
639 .await?;
640
641 // ── Return notification-draining stream ───────────────────────────────
642 //
643 // Borrow `self` for the stream's lifetime. The Mutex is locked inside
644 // the generator so it can be released between yields, avoiding a held
645 // lock across await points in caller code.
646 let translation_ctx = &self.translation_ctx;
647 let notification_stream = &self.notification_stream;
648 let callback: Option<MessageCallback> = self.config.message_callback.clone();
649
650 Ok(async_stream::stream! {
651 // Move the guard into the stream so it is dropped (clearing the flag)
652 // when the stream is dropped or completes, regardless of the exit path.
653 let _turn_guard = turn_guard;
654 use tokio_stream::StreamExt as _;
655
656 // Lock the notification stream for the duration of this turn.
657 // A second concurrent call will block here until this stream is
658 // dropped.
659 let mut ns_guard = notification_stream.lock().await;
660 let stream = match ns_guard.as_mut() {
661 Some(s) => s,
662 None => {
663 yield Err(Error::NotConnected);
664 return;
665 }
666 };
667
668 // Fuse the prompt response oneshot so we can poll it alongside
669 // the notification stream without consuming it on the first poll.
670 let mut prompt_done = prompt_response_rx;
671 let mut turn_finished = false;
672
673 #[allow(unused_assignments)] // `turn_finished = true` precedes a `break`
674 loop {
675 tokio::select! {
676 biased;
677
678 // Poll notifications first — drain all pending updates before
679 // checking if the turn is complete. With `biased`, this branch
680 // is checked before prompt_done, ensuring mock transports (which
681 // resolve the oneshot immediately) still deliver notifications.
682 maybe_notif = stream.next() => {
683 match maybe_notif {
684 None => break, // channel closed — subprocess exited
685 Some(Err(e)) => {
686 yield Err(e);
687 break;
688 }
689 Some(Ok(value)) => {
690 // Filter: only process session/update notifications.
691 let method = value
692 .get("method")
693 .and_then(|m| m.as_str())
694 .unwrap_or("");
695 tracing::debug!(method, "notification received");
696 if method != wire::method::SESSION_UPDATE {
697 continue;
698 }
699
700 let params = match value.get("params") {
701 Some(p) => p.clone(),
702 None => continue,
703 };
704
705 let notif: wire::SessionUpdateNotification =
706 match serde_json::from_value(params) {
707 Ok(n) => n,
708 Err(e) => {
709 tracing::warn!(
710 error = %e,
711 "client: failed to parse session/update params — skipping"
712 );
713 continue;
714 }
715 };
716
717 // Parse the discriminator and translate to public messages.
718 let update = notif.parse();
719 tracing::debug!(?update, "parsed session update");
720
721 let mut ctx_guard = translation_ctx.lock().await;
722 if let Some(ctx) = ctx_guard.as_mut() {
723 let messages = ctx.translate(update);
724 tracing::debug!(count = messages.len(), "translated to messages");
725 // Drop the ctx lock before yielding so callers that
726 // inspect TranslationContext are not blocked.
727 drop(ctx_guard);
728
729 for msg in messages {
730 // Invoke optional side-effect callback.
731 if let Some(cb) = &callback {
732 cb(msg.clone()).await;
733 }
734 yield Ok(msg);
735 }
736 }
737 }
738 }
739 }
740
741 // Poll prompt response — when it arrives, the turn is done.
742 resp = &mut prompt_done, if !turn_finished => {
743 turn_finished = true;
744
745 // Translate the prompt result into a Message::Result.
746 match resp {
747 Ok(response) => {
748 match response.into_result() {
749 Ok(result_value) => {
750 let prompt_result: wire::SessionPromptResult =
751 serde_json::from_value(result_value)
752 .unwrap_or_else(|e| {
753 tracing::warn!(
754 error = %e,
755 "failed to parse SessionPromptResult, using default"
756 );
757 Default::default()
758 });
759
760 let result_msg = Message::Result(crate::types::messages::ResultMessage {
761 subtype: "success".to_string(),
762 is_error: false,
763 duration_ms: 0.0,
764 duration_api_ms: 0.0,
765 num_turns: 1,
766 session_id: session_id.clone(),
767 usage: crate::types::messages::Usage::default(),
768 stop_reason: prompt_result.stop_reason,
769 extra: prompt_result.extra,
770 });
771
772 if let Some(cb) = &callback {
773 cb(result_msg.clone()).await;
774 }
775 yield Ok(result_msg);
776 }
777 Err(err) => {
778 let error_msg = Message::Result(crate::types::messages::ResultMessage {
779 subtype: "error".to_string(),
780 is_error: true,
781 duration_ms: 0.0,
782 duration_api_ms: 0.0,
783 num_turns: 1,
784 session_id: session_id.clone(),
785 usage: crate::types::messages::Usage::default(),
786 stop_reason: format!(
787 "JSON-RPC error {}: {}",
788 err.code, err.message
789 ),
790 extra: serde_json::json!({
791 "code": err.code,
792 "message": err.message,
793 "data": err.data,
794 }),
795 });
796
797 if let Some(cb) = &callback {
798 cb(error_msg.clone()).await;
799 }
800 yield Ok(error_msg);
801 }
802 }
803 }
804 Err(_) => {
805 // Response channel dropped — treat as transport error.
806 yield Err(Error::Transport(
807 "Prompt response channel closed unexpectedly".to_string()
808 ));
809 }
810 }
811
812 // Reset the translation context for the next turn.
813 let mut ctx_guard = translation_ctx.lock().await;
814 if let Some(ctx) = ctx_guard.as_mut() {
815 ctx.reset_turn();
816 }
817
818 break;
819 }
820 }
821 }
822
823 // Turn complete — _turn_guard is dropped here, clearing the flag.
824 })
825 }
826
827 // ── interrupt() ──────────────────────────────────────────────────────────
828
829 /// Interrupt the current in-progress prompt turn.
830 ///
831 /// Sends a `session/cancel` notification to the CLI (best-effort) and then
832 /// delivers a process-level interrupt signal via the transport (SIGINT on
833 /// Unix, CTRL_C_EVENT on Windows).
834 ///
835 /// # Errors
836 ///
837 /// The `session/cancel` notification errors are silently ignored. Transport
838 /// interrupt errors are propagated.
839 pub async fn interrupt(&self) -> Result<()> {
840 if let Some(session_id) = &self.session_id {
841 let params = wire::SessionCancelParams {
842 session_id: session_id.clone(),
843 };
844 // Best-effort: a cancel notification failure must not prevent the
845 // process-level interrupt below from being delivered.
846 let _ = self
847 .transport
848 .send_notification(wire::method::SESSION_CANCEL, params)
849 .await;
850 }
851 self.transport.interrupt().await
852 }
853
854 // ── close() ──────────────────────────────────────────────────────────────
855
856 /// Close the client and terminate the CLI subprocess.
857 ///
858 /// Fires the `Stop` lifecycle hook before closing the transport. The
859 /// call is idempotent — invoking it on an already-closed client is safe.
860 ///
861 /// # Errors
862 ///
863 /// Propagates errors from the transport's `close()` implementation. Hook
864 /// errors are silently ignored.
865 pub async fn close(&mut self) -> Result<()> {
866 // ── Fire Stop hook ────────────────────────────────────────────────────
867 if let Some(ctx) = &self.hook_context {
868 let hook_input = HookInput {
869 event: HookEvent::Stop,
870 tool_name: None,
871 tool_input: None,
872 tool_output: None,
873 prompt: None,
874 session_id: self.session_id.clone().unwrap_or_default(),
875 extra: Value::Object(Default::default()),
876 };
877 // Ignore errors — we are shutting down regardless.
878 let _ = hooks::execute_hooks(
879 &self.config.hooks,
880 hook_input,
881 ctx,
882 self.config.default_hook_timeout,
883 )
884 .await;
885 }
886
887 self.connected = false;
888 self.turn_in_progress.store(false, Ordering::Release);
889 self.transport.close().await?;
890 Ok(())
891 }
892}
893
894// ── Tests ─────────────────────────────────────────────────────────────────────
895
896#[cfg(test)]
897mod tests {
898 use super::*;
899
900 // ── Helpers ───────────────────────────────────────────────────────────────
901
902 /// Build a `GeminiTransport` that points to a nonexistent binary so tests
903 /// never accidentally spawn a real subprocess.
904 fn make_fake_transport() -> Arc<GeminiTransport> {
905 Arc::new(GeminiTransport::new(
906 std::path::PathBuf::from("/nonexistent/gemini"),
907 vec!["--experimental-acp".to_string()],
908 std::path::PathBuf::from("/tmp"),
909 std::collections::HashMap::new(),
910 None,
911 None,
912 ))
913 }
914
915 fn minimal_config() -> ClientConfig {
916 ClientConfig::builder().prompt("test prompt").build()
917 }
918
919 // ── test_client_session_id ────────────────────────────────────────────────
920
921 /// Before `connect()`, `session_id()` must return `None`.
922 #[test]
923 fn test_client_session_id() {
924 let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
925 assert!(
926 client.session_id().is_none(),
927 "session_id must be None before connect() is called"
928 );
929 }
930
931 // ── test_client_not_connected_error ───────────────────────────────────────
932
933 /// Calling `send()` before `connect()` must return `Err(Error::NotConnected)`.
934 #[tokio::test]
935 async fn test_client_not_connected_error() {
936 let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
937 let result = client.send("hello").await;
938 assert!(result.is_err(), "send() before connect() must fail");
939 let err = result.err().expect("expected an error");
940 assert!(
941 matches!(err, Error::NotConnected),
942 "error must be Error::NotConnected, got: {err:?}"
943 );
944 }
945
946 // ── test_client_send_content_not_connected ────────────────────────────────
947
948 /// `send_content()` before connect must also return `Error::NotConnected`.
949 #[tokio::test]
950 async fn test_client_send_content_not_connected() {
951 let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
952 let result = client.send_content(vec![UserContent::text("hi")]).await;
953 let err = result.err().expect("expected an error");
954 assert!(
955 matches!(err, Error::NotConnected),
956 "send_content before connect must return Error::NotConnected, got: {err:?}"
957 );
958 }
959
960 // ── test_client_prompt_accessor ───────────────────────────────────────────
961
962 /// `prompt()` must reflect the value set in the config.
963 #[test]
964 fn test_client_prompt_accessor() {
965 let config = ClientConfig::builder().prompt("my test prompt").build();
966 let client = Client::with_gemini_transport(config, make_fake_transport());
967 assert_eq!(client.prompt(), "my test prompt");
968 }
969
970 // ── test_client_is_connected_default ─────────────────────────────────────
971
972 /// A freshly constructed client must not report itself as connected.
973 #[test]
974 fn test_client_is_connected_default() {
975 let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
976 assert!(
977 !client.is_connected(),
978 "is_connected must be false before connect()"
979 );
980 }
981
982 // ── test_client_double_connect_error ─────────────────────────────────────
983
984 /// Calling `connect()` after a successful connection must return
985 /// `Err(Error::Config)`. We simulate the connected state by setting the
986 /// field directly — no real subprocess is involved.
987 #[tokio::test]
988 async fn test_client_double_connect_error() {
989 let mut client =
990 Client::with_gemini_transport(minimal_config(), make_fake_transport());
991 // Bypass the real connect sequence by directly marking as connected.
992 client.connected = true;
993 let result = client.connect().await;
994 assert!(result.is_err());
995 assert!(
996 matches!(result.unwrap_err(), Error::Config(_)),
997 "second connect must return Error::Config"
998 );
999 }
1000
1001 // ── test_client_interrupt_before_connect ──────────────────────────────────
1002
1003 /// `interrupt()` before connect must not panic and must return `Ok(())`.
1004 /// When there is no subprocess, the platform-level signal is a no-op.
1005 #[tokio::test]
1006 async fn test_client_interrupt_before_connect() {
1007 let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
1008 // No session_id, no subprocess — must not panic.
1009 let result = client.interrupt().await;
1010 assert!(
1011 result.is_ok(),
1012 "interrupt before connect must not return an error"
1013 );
1014 }
1015
1016 // ── test_client_mock_transport_constructor ────────────────────────────────
1017
1018 #[cfg(feature = "testing")]
1019 #[test]
1020 fn test_client_mock_transport_constructor() {
1021 use crate::testing::MockTransport;
1022 let transport = Arc::new(MockTransport::new(vec![]));
1023 let client = Client::with_mock_transport(minimal_config(), transport);
1024 assert!(!client.is_connected());
1025 assert!(client.session_id().is_none());
1026 }
1027}