use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use futures_core::Stream;
use serde_json::Value;
use tokio::sync::Mutex;
use crate::callback::MessageCallback;
use crate::config::ClientConfig;
use crate::hooks::{self, HookContext, HookDecision, HookEvent, HookInput};
use crate::permissions::{PermissionHandler, ToolInputCache};
use crate::translate::TranslationContext;
use crate::transport::{GeminiTransport, Transport};
use crate::types::content::UserContent;
use crate::types::messages::{Message, SessionInfo};
use crate::wire;
use crate::{Error, Result};
pub(crate) enum AnyTransport {
Gemini(Arc<GeminiTransport>),
#[cfg(feature = "testing")]
Mock(Arc<crate::testing::MockTransport>),
}
macro_rules! delegate_transport {
(async fn $name:ident(&self $(, $arg:ident : $arg_ty:ty)*) -> $ret:ty) => {
async fn $name(&self $(, $arg: $arg_ty)*) -> $ret {
match self {
AnyTransport::Gemini(t) => t.$name($($arg),*).await,
#[cfg(feature = "testing")]
AnyTransport::Mock(t) => t.$name($($arg),*).await,
}
}
};
(fn $name:ident(&self $(, $arg:ident : $arg_ty:ty)*) -> $ret:ty) => {
fn $name(&self $(, $arg: $arg_ty)*) -> $ret {
match self {
AnyTransport::Gemini(t) => t.$name($($arg),*),
#[cfg(feature = "testing")]
AnyTransport::Mock(t) => t.$name($($arg),*),
}
}
};
}
impl AnyTransport {
delegate_transport!(async fn connect(&self) -> Result<()>);
delegate_transport!(fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>>);
delegate_transport!(async fn interrupt(&self) -> Result<()>);
delegate_transport!(async fn close(&self) -> Result<Option<i32>>);
async fn send_request<P, R>(&self, method: &str, params: P) -> Result<R>
where
P: serde::Serialize + Send,
R: serde::de::DeserializeOwned,
{
match self {
AnyTransport::Gemini(t) => t.send_request(method, params).await,
#[cfg(feature = "testing")]
AnyTransport::Mock(t) => {
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": serde_json::to_value(params)?,
"id": 1
});
t.write(&serde_json::to_string(&req)?).await?;
serde_json::from_value(Value::Object(Default::default())).map_err(Error::Json)
}
}
}
async fn send_request_start<P>(
&self,
method: &str,
params: P,
) -> Result<tokio::sync::oneshot::Receiver<crate::jsonrpc::JsonRpcResponse>>
where
P: serde::Serialize + Send,
{
match self {
AnyTransport::Gemini(t) => t.send_request_start(method, params).await,
#[cfg(feature = "testing")]
AnyTransport::Mock(t) => {
let req = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": serde_json::to_value(params)?,
"id": 1
});
t.write(&serde_json::to_string(&req)?).await?;
let (tx, rx) = tokio::sync::oneshot::channel();
let _ = tx.send(crate::jsonrpc::JsonRpcResponse::success(
crate::jsonrpc::JsonRpcId::Number(0),
serde_json::json!({"stopReason": "end_turn"}),
));
Ok(rx)
}
}
}
async fn send_notification<P>(&self, method: &str, params: P) -> Result<()>
where
P: serde::Serialize + Send,
{
match self {
AnyTransport::Gemini(t) => t.send_notification(method, params).await,
#[cfg(feature = "testing")]
AnyTransport::Mock(t) => {
let notif = serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": serde_json::to_value(params)?
});
t.write(&serde_json::to_string(¬if)?).await
}
}
}
async fn set_reverse_handler(
&self,
handler: Arc<dyn crate::transport::ReverseRequestHandler>,
) {
match self {
AnyTransport::Gemini(t) => t.set_reverse_handler(handler).await,
#[cfg(feature = "testing")]
AnyTransport::Mock(_) => {} }
}
}
pub struct Client {
config: ClientConfig,
transport: AnyTransport,
session_id: Option<String>,
#[allow(clippy::type_complexity)]
notification_stream: Mutex<Option<Pin<Box<dyn Stream<Item = Result<Value>> + Send>>>>,
translation_ctx: Mutex<Option<TranslationContext>>,
hook_context: Option<HookContext>,
connected: bool,
turn_in_progress: Arc<AtomicBool>,
}
struct TurnGuard(Arc<AtomicBool>);
impl Drop for TurnGuard {
fn drop(&mut self) {
self.0.store(false, Ordering::Release);
}
}
impl Client {
fn from_transport(config: ClientConfig, transport: AnyTransport) -> Self {
Self {
config,
transport,
session_id: None,
notification_stream: Mutex::new(None),
translation_ctx: Mutex::new(None),
hook_context: None,
connected: false,
turn_in_progress: Arc::new(AtomicBool::new(false)),
}
}
pub fn new(config: ClientConfig) -> Result<Self> {
let transport = Arc::new(GeminiTransport::from_config(&config)?);
Ok(Self::from_transport(config, AnyTransport::Gemini(transport)))
}
pub fn with_gemini_transport(config: ClientConfig, transport: Arc<GeminiTransport>) -> Self {
Self::from_transport(config, AnyTransport::Gemini(transport))
}
#[cfg(feature = "testing")]
pub fn with_mock_transport(
config: ClientConfig,
transport: Arc<crate::testing::MockTransport>,
) -> Self {
Self::from_transport(config, AnyTransport::Mock(transport))
}
pub fn session_id(&self) -> Option<&str> {
self.session_id.as_deref()
}
#[inline]
pub fn prompt(&self) -> &str {
&self.config.prompt
}
#[inline]
pub fn is_connected(&self) -> bool {
self.connected
}
pub async fn connect(&mut self) -> Result<SessionInfo> {
if self.connected {
return Err(Error::Config("Already connected".to_string()));
}
match self.config.connect_timeout {
Some(d) => {
tokio::time::timeout(d, self.connect_inner())
.await
.map_err(|_| {
Error::Timeout(format!(
"connect timed out after {:.1}s",
d.as_secs_f64()
))
})?
}
None => self.connect_inner().await,
}
}
async fn connect_inner(&mut self) -> Result<SessionInfo> {
self.transport.connect().await?;
let stream = self.transport.read_messages();
*self.notification_stream.lock().await = Some(stream);
let tool_input_cache: ToolInputCache =
Arc::new(std::sync::Mutex::new(std::collections::HashMap::new()));
if let Some(callback) = self.config.can_use_tool.clone() {
let handler =
Arc::new(PermissionHandler::new(callback, Some(Arc::clone(&tool_input_cache))));
self.transport.set_reverse_handler(handler).await;
}
let init_params = wire::InitializeParams {
protocol_version: 1,
client_capabilities: wire::ClientCapabilities::default(),
client_info: wire::ClientInfo {
name: "gemini-cli-sdk".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
};
let init_result: wire::InitializeResult = self
.transport
.send_request(wire::method::INITIALIZE, init_params)
.await?;
let session_id = if let Some(resume_id) = self.config.resume.clone() {
let params = wire::SessionLoadParams {
session_id: resume_id,
extra: Value::Object(Default::default()),
};
let result: wire::SessionLoadResult = self
.transport
.send_request(wire::method::SESSION_LOAD, params)
.await?;
result.session_id
} else {
let cwd = self
.config
.cwd
.clone()
.map(Ok)
.unwrap_or_else(|| {
std::env::current_dir()
.map_err(|e| Error::Config(format!("cannot determine cwd: {e}")))
})?
.to_string_lossy()
.to_string();
let mcp_wire = crate::mcp::mcp_servers_to_wire(&self.config.mcp_servers);
let params = wire::SessionNewParams {
cwd,
mcp_servers: mcp_wire,
extra: Value::Object(Default::default()),
};
let result: wire::SessionNewResult = self
.transport
.send_request(wire::method::SESSION_NEW, params)
.await?;
result.session_id
};
self.session_id = Some(session_id.clone());
let model = self
.config
.model
.clone()
.unwrap_or_else(|| "gemini-2.5-pro".to_string());
*self.translation_ctx.lock().await =
Some(TranslationContext::new_with_cache(session_id.clone(), model.clone(), tool_input_cache));
let cwd_str = self
.config
.cwd
.clone()
.map(Ok)
.unwrap_or_else(|| {
std::env::current_dir()
.map_err(|e| Error::Config(format!("cannot determine cwd: {e}")))
})?
.to_string_lossy()
.to_string();
self.hook_context = Some(HookContext {
session_id: session_id.clone(),
cwd: cwd_str,
});
self.connected = true;
let tools = init_result.agent_capabilities.tools.unwrap_or_default();
Ok(SessionInfo {
session_id,
model,
tools,
extra: init_result.extra,
})
}
pub async fn send(
&self,
message: &str,
) -> Result<impl Stream<Item = Result<Message>> + '_> {
self.send_content(vec![UserContent::text(message)]).await
}
pub async fn send_content(
&self,
content: Vec<UserContent>,
) -> Result<impl Stream<Item = Result<Message>> + '_> {
if !self.connected {
return Err(Error::NotConnected);
}
if self
.turn_in_progress
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_err()
{
return Err(Error::TurnInProgress);
}
let turn_guard = TurnGuard(Arc::clone(&self.turn_in_progress));
let session_id = self
.session_id
.as_ref()
.ok_or(Error::NotConnected)?
.clone();
if let Some(ctx) = &self.hook_context {
let prompt_text = content.iter().find_map(|c| match c {
UserContent::Text { text } => Some(text.clone()),
_ => None,
});
let hook_input = HookInput {
event: HookEvent::UserPromptSubmit,
tool_name: None,
tool_input: None,
tool_output: None,
prompt: prompt_text,
session_id: session_id.clone(),
extra: Value::Object(Default::default()),
};
let output = hooks::execute_hooks(
&self.config.hooks,
hook_input,
ctx,
self.config.default_hook_timeout,
)
.await;
if output.decision == HookDecision::Block {
return Err(Error::Config(
output
.message
.unwrap_or_else(|| "Blocked by hook".to_string()),
));
}
}
{
let mut ctx_guard = self.translation_ctx.lock().await;
if let Some(ctx) = ctx_guard.as_mut() {
ctx.reset_turn();
}
}
let wire_content: Vec<wire::WireContentBlock> = content
.iter()
.map(crate::translate::user_content_to_wire)
.collect();
let prompt_params = wire::SessionPromptParams {
session_id: session_id.clone(),
prompt: wire_content,
extra: Value::Object(Default::default()),
};
let prompt_response_rx = self
.transport
.send_request_start(wire::method::SESSION_PROMPT, prompt_params)
.await?;
let translation_ctx = &self.translation_ctx;
let notification_stream = &self.notification_stream;
let callback: Option<MessageCallback> = self.config.message_callback.clone();
Ok(async_stream::stream! {
let _turn_guard = turn_guard;
use tokio_stream::StreamExt as _;
let mut ns_guard = notification_stream.lock().await;
let stream = match ns_guard.as_mut() {
Some(s) => s,
None => {
yield Err(Error::NotConnected);
return;
}
};
let mut prompt_done = prompt_response_rx;
let mut turn_finished = false;
#[allow(unused_assignments)] loop {
tokio::select! {
biased;
maybe_notif = stream.next() => {
match maybe_notif {
None => break, Some(Err(e)) => {
yield Err(e);
break;
}
Some(Ok(value)) => {
let method = value
.get("method")
.and_then(|m| m.as_str())
.unwrap_or("");
tracing::debug!(method, "notification received");
if method != wire::method::SESSION_UPDATE {
continue;
}
let params = match value.get("params") {
Some(p) => p.clone(),
None => continue,
};
let notif: wire::SessionUpdateNotification =
match serde_json::from_value(params) {
Ok(n) => n,
Err(e) => {
tracing::warn!(
error = %e,
"client: failed to parse session/update params — skipping"
);
continue;
}
};
let update = notif.parse();
tracing::debug!(?update, "parsed session update");
let mut ctx_guard = translation_ctx.lock().await;
if let Some(ctx) = ctx_guard.as_mut() {
let messages = ctx.translate(update);
tracing::debug!(count = messages.len(), "translated to messages");
drop(ctx_guard);
for msg in messages {
if let Some(cb) = &callback {
cb(msg.clone()).await;
}
yield Ok(msg);
}
}
}
}
}
resp = &mut prompt_done, if !turn_finished => {
turn_finished = true;
match resp {
Ok(response) => {
match response.into_result() {
Ok(result_value) => {
let prompt_result: wire::SessionPromptResult =
serde_json::from_value(result_value)
.unwrap_or_else(|e| {
tracing::warn!(
error = %e,
"failed to parse SessionPromptResult, using default"
);
Default::default()
});
let result_msg = Message::Result(crate::types::messages::ResultMessage {
subtype: "success".to_string(),
is_error: false,
duration_ms: 0.0,
duration_api_ms: 0.0,
num_turns: 1,
session_id: session_id.clone(),
usage: crate::types::messages::Usage::default(),
stop_reason: prompt_result.stop_reason,
extra: prompt_result.extra,
});
if let Some(cb) = &callback {
cb(result_msg.clone()).await;
}
yield Ok(result_msg);
}
Err(err) => {
let error_msg = Message::Result(crate::types::messages::ResultMessage {
subtype: "error".to_string(),
is_error: true,
duration_ms: 0.0,
duration_api_ms: 0.0,
num_turns: 1,
session_id: session_id.clone(),
usage: crate::types::messages::Usage::default(),
stop_reason: format!(
"JSON-RPC error {}: {}",
err.code, err.message
),
extra: serde_json::json!({
"code": err.code,
"message": err.message,
"data": err.data,
}),
});
if let Some(cb) = &callback {
cb(error_msg.clone()).await;
}
yield Ok(error_msg);
}
}
}
Err(_) => {
yield Err(Error::Transport(
"Prompt response channel closed unexpectedly".to_string()
));
}
}
let mut ctx_guard = translation_ctx.lock().await;
if let Some(ctx) = ctx_guard.as_mut() {
ctx.reset_turn();
}
break;
}
}
}
})
}
pub async fn interrupt(&self) -> Result<()> {
if let Some(session_id) = &self.session_id {
let params = wire::SessionCancelParams {
session_id: session_id.clone(),
};
let _ = self
.transport
.send_notification(wire::method::SESSION_CANCEL, params)
.await;
}
self.transport.interrupt().await
}
pub async fn close(&mut self) -> Result<()> {
if let Some(ctx) = &self.hook_context {
let hook_input = HookInput {
event: HookEvent::Stop,
tool_name: None,
tool_input: None,
tool_output: None,
prompt: None,
session_id: self.session_id.clone().unwrap_or_default(),
extra: Value::Object(Default::default()),
};
let _ = hooks::execute_hooks(
&self.config.hooks,
hook_input,
ctx,
self.config.default_hook_timeout,
)
.await;
}
self.connected = false;
self.turn_in_progress.store(false, Ordering::Release);
self.transport.close().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_fake_transport() -> Arc<GeminiTransport> {
Arc::new(GeminiTransport::new(
std::path::PathBuf::from("/nonexistent/gemini"),
vec!["--experimental-acp".to_string()],
std::path::PathBuf::from("/tmp"),
std::collections::HashMap::new(),
None,
None,
))
}
fn minimal_config() -> ClientConfig {
ClientConfig::builder().prompt("test prompt").build()
}
#[test]
fn test_client_session_id() {
let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
assert!(
client.session_id().is_none(),
"session_id must be None before connect() is called"
);
}
#[tokio::test]
async fn test_client_not_connected_error() {
let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
let result = client.send("hello").await;
assert!(result.is_err(), "send() before connect() must fail");
let err = result.err().expect("expected an error");
assert!(
matches!(err, Error::NotConnected),
"error must be Error::NotConnected, got: {err:?}"
);
}
#[tokio::test]
async fn test_client_send_content_not_connected() {
let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
let result = client.send_content(vec![UserContent::text("hi")]).await;
let err = result.err().expect("expected an error");
assert!(
matches!(err, Error::NotConnected),
"send_content before connect must return Error::NotConnected, got: {err:?}"
);
}
#[test]
fn test_client_prompt_accessor() {
let config = ClientConfig::builder().prompt("my test prompt").build();
let client = Client::with_gemini_transport(config, make_fake_transport());
assert_eq!(client.prompt(), "my test prompt");
}
#[test]
fn test_client_is_connected_default() {
let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
assert!(
!client.is_connected(),
"is_connected must be false before connect()"
);
}
#[tokio::test]
async fn test_client_double_connect_error() {
let mut client =
Client::with_gemini_transport(minimal_config(), make_fake_transport());
client.connected = true;
let result = client.connect().await;
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), Error::Config(_)),
"second connect must return Error::Config"
);
}
#[tokio::test]
async fn test_client_interrupt_before_connect() {
let client = Client::with_gemini_transport(minimal_config(), make_fake_transport());
let result = client.interrupt().await;
assert!(
result.is_ok(),
"interrupt before connect must not return an error"
);
}
#[cfg(feature = "testing")]
#[test]
fn test_client_mock_transport_constructor() {
use crate::testing::MockTransport;
let transport = Arc::new(MockTransport::new(vec![]));
let client = Client::with_mock_transport(minimal_config(), transport);
assert!(!client.is_connected());
assert!(client.session_id().is_none());
}
}