use std::collections::HashMap;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use dashmap::DashMap;
use futures_core::Stream;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value;
use tokio::io::AsyncWriteExt;
use tokio::process::{Child, ChildStdin};
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::task::JoinHandle;
use crate::jsonrpc::{JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
use crate::Result;
#[async_trait]
pub trait Transport: Send + Sync {
async fn connect(&self) -> Result<()>;
async fn write(&self, data: &str) -> Result<()>;
fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>>;
async fn end_input(&self) -> Result<()>;
async fn interrupt(&self) -> Result<()>;
fn is_ready(&self) -> bool;
async fn close(&self) -> Result<Option<i32>>;
}
#[async_trait]
pub(crate) trait ReverseRequestHandler: Send + Sync {
async fn handle_permission_request(&self, id: JsonRpcId, params: Value) -> Value;
}
pub struct GeminiTransport {
pub(crate) cli_path: PathBuf,
pub(crate) args: Vec<String>,
pub(crate) cwd: PathBuf,
pub(crate) env: HashMap<String, String>,
pub(crate) process: std::sync::Mutex<Option<Child>>,
pub(crate) stdin: Arc<Mutex<Option<ChildStdin>>>,
pub(crate) next_request_id: AtomicU64,
pub(crate) pending_requests: Arc<DashMap<JsonRpcId, oneshot::Sender<JsonRpcResponse>>>,
pub(crate) message_tx: Mutex<Option<mpsc::Sender<Result<Value>>>>,
pub(crate) message_rx: Mutex<Option<mpsc::Receiver<Result<Value>>>>,
pub(crate) reader_handle: Mutex<Option<JoinHandle<()>>>,
pub(crate) reverse_handler: Arc<Mutex<Option<Arc<dyn ReverseRequestHandler>>>>,
pub(crate) ready: AtomicBool,
pub(crate) stderr_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
pub(crate) close_timeout: Option<std::time::Duration>,
}
impl GeminiTransport {
pub fn new(
cli_path: PathBuf,
args: Vec<String>,
cwd: PathBuf,
env: HashMap<String, String>,
stderr_callback: Option<Arc<dyn Fn(String) + Send + Sync>>,
close_timeout: Option<std::time::Duration>,
) -> Self {
let (tx, rx) = mpsc::channel(256);
Self {
cli_path,
args,
cwd,
env,
process: std::sync::Mutex::new(None),
stdin: Arc::new(Mutex::new(None)),
next_request_id: AtomicU64::new(1),
pending_requests: Arc::new(DashMap::new()),
message_tx: Mutex::new(Some(tx)),
message_rx: Mutex::new(Some(rx)),
reader_handle: Mutex::new(None),
reverse_handler: Arc::new(Mutex::new(None)),
ready: AtomicBool::new(false),
stderr_callback,
close_timeout,
}
}
pub fn from_config(config: &crate::config::ClientConfig) -> crate::Result<Self> {
let cli_path = if let Some(path) = &config.cli_path {
path.clone()
} else {
crate::discovery::find_cli()?
};
let cwd = config
.cwd
.clone()
.unwrap_or_else(|| std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
let args = config.to_cli_args();
Ok(Self::new(
cli_path,
args,
cwd,
config.env.clone(),
config.stderr_callback.clone(),
config.close_timeout,
))
}
pub(crate) async fn send_request<P: Serialize, R: DeserializeOwned>(
&self,
method: &str,
params: P,
) -> Result<R> {
let id = JsonRpcId::Number(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let request =
JsonRpcRequest::new(id.clone(), method, Some(serde_json::to_value(params)?));
let (tx, rx) = oneshot::channel();
self.pending_requests.insert(id, tx);
let line = serde_json::to_string(&request)?;
self.write(&line).await?;
let response = rx.await.map_err(|_| {
crate::Error::Transport(
"Response channel closed before response received".to_string(),
)
})?;
match response.into_result() {
Ok(value) => serde_json::from_value(value).map_err(Into::into),
Err(err) => Err(crate::Error::JsonRpcError {
code: err.code,
message: err.message,
data: err.data,
}),
}
}
pub(crate) async fn send_request_start<P: Serialize>(
&self,
method: &str,
params: P,
) -> Result<oneshot::Receiver<JsonRpcResponse>> {
let id = JsonRpcId::Number(self.next_request_id.fetch_add(1, Ordering::Relaxed));
let request =
JsonRpcRequest::new(id.clone(), method, Some(serde_json::to_value(params)?));
let (tx, rx) = oneshot::channel();
self.pending_requests.insert(id, tx);
let line = serde_json::to_string(&request)?;
self.write(&line).await?;
Ok(rx)
}
pub(crate) async fn send_notification<P: Serialize>(
&self,
method: &str,
params: P,
) -> Result<()> {
let notif = JsonRpcNotification::new(method, Some(serde_json::to_value(params)?));
let line = serde_json::to_string(¬if)?;
self.write(&line).await
}
#[allow(dead_code)]
pub(crate) async fn send_response(&self, id: JsonRpcId, result: Value) -> Result<()> {
let response = JsonRpcResponse::success(id, result);
let line = serde_json::to_string(&response)?;
self.write(&line).await
}
pub(crate) async fn set_reverse_handler(&self, handler: Arc<dyn ReverseRequestHandler>) {
*self.reverse_handler.lock().await = Some(handler);
}
#[allow(dead_code)]
pub(crate) fn next_id(&self) -> JsonRpcId {
JsonRpcId::Number(self.next_request_id.fetch_add(1, Ordering::Relaxed))
}
#[allow(dead_code)]
pub(crate) fn route_response(&self, response: JsonRpcResponse) {
if let Some((_, tx)) = self.pending_requests.remove(&response.id) {
let _ = tx.send(response);
} else {
tracing::warn!(
id = %response.id,
"Received JSON-RPC response for unknown request ID — dropping"
);
}
}
#[allow(dead_code)]
pub(crate) async fn take_message_tx(&self) -> Option<mpsc::Sender<Result<Value>>> {
self.message_tx.lock().await.clone()
}
#[allow(dead_code)]
pub(crate) async fn get_reverse_handler(&self) -> Option<Arc<dyn ReverseRequestHandler>> {
self.reverse_handler.lock().await.clone()
}
async fn write_line_to_stdin(stdin: &Mutex<Option<ChildStdin>>, data: &str) -> Result<()> {
let mut guard = stdin.lock().await;
let stdin_handle = guard.as_mut().ok_or(crate::Error::NotConnected)?;
stdin_handle.write_all(data.as_bytes()).await?;
stdin_handle.write_all(b"\n").await?;
stdin_handle.flush().await?;
Ok(())
}
pub(crate) fn spawn_reader(
pending_requests: Arc<DashMap<JsonRpcId, oneshot::Sender<JsonRpcResponse>>>,
reverse_handler: Arc<Mutex<Option<Arc<dyn ReverseRequestHandler>>>>,
stdin_writer: Arc<Mutex<Option<ChildStdin>>>,
message_tx: mpsc::Sender<Result<Value>>,
stdout: tokio::process::ChildStdout,
) -> JoinHandle<()> {
tokio::spawn(async move {
use tokio::io::{AsyncBufReadExt, BufReader};
let reader = BufReader::new(stdout);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
tracing::trace!(line = %line, "reader: received line");
if std::env::var("GEMINI_SDK_DEBUG").is_ok() {
let preview = if line.len() > 200 { &line[..200] } else { &line };
eprintln!("[SDK reader] raw: {preview}");
}
let value: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, line = %line, "reader: JSON parse error");
let _ = message_tx
.send(Err(crate::Error::ParseError {
message: e.to_string(),
line: line.clone(),
}))
.await;
continue;
}
};
let kind = crate::jsonrpc::JsonRpcMessage::classify(&value);
if std::env::var("GEMINI_SDK_DEBUG").is_ok() {
let method = value.get("method").and_then(|m| m.as_str()).unwrap_or("(none)");
let has_id = value.get("id").is_some();
eprintln!("[SDK reader] classify={kind:?} method={method} has_id={has_id}");
}
match kind {
Ok(crate::jsonrpc::MessageKind::Response) => {
let id = match value.get("id") {
Some(raw_id) => {
match serde_json::from_value::<JsonRpcId>(raw_id.clone()) {
Ok(id) => id,
Err(e) => {
tracing::warn!(
error = %e,
"reader: failed to parse response ID"
);
continue;
}
}
}
None => continue,
};
if let Some((_, tx)) = pending_requests.remove(&id) {
match serde_json::from_value::<JsonRpcResponse>(value) {
Ok(resp) => {
let _ = tx.send(resp);
}
Err(e) => {
tracing::warn!(
error = %e,
"reader: failed to deserialize response body"
);
}
}
} else {
tracing::warn!(
id = %id,
"reader: response for unknown request ID — dropping"
);
}
}
Ok(crate::jsonrpc::MessageKind::Notification) => {
if message_tx.send(Ok(value)).await.is_err() {
tracing::debug!("reader: message channel closed, stopping");
break;
}
}
Ok(crate::jsonrpc::MessageKind::Request) => {
if std::env::var("GEMINI_SDK_DEBUG").is_ok() {
eprintln!("[SDK reader] REVERSE REQUEST detected — dispatching to handler");
}
let request: JsonRpcRequest = match serde_json::from_value(value) {
Ok(r) => r,
Err(e) => {
tracing::warn!(
error = %e,
"reader: failed to deserialize reverse request"
);
continue;
}
};
let request_id = request.id.clone();
let method = request.method.clone();
let params = request.params.unwrap_or(Value::Null);
let maybe_handler: Option<Arc<dyn ReverseRequestHandler>> =
reverse_handler.lock().await.clone();
let response = if let Some(h) = maybe_handler {
let result = h.handle_permission_request(request_id.clone(), params).await;
JsonRpcResponse::success(request_id, result)
} else {
tracing::warn!(
method = %method,
"reader: no reverse handler registered, responding with error"
);
JsonRpcResponse::error(
request_id,
crate::jsonrpc::JsonRpcError {
code: crate::jsonrpc::error_codes::METHOD_NOT_FOUND,
message: format!("No handler registered for '{method}'"),
data: None,
},
)
};
let response_line = match serde_json::to_string(&response) {
Ok(s) => s,
Err(e) => {
tracing::warn!(
error = %e,
"reader: failed to serialize reverse-request response"
);
continue;
}
};
if std::env::var("GEMINI_SDK_DEBUG").is_ok() {
eprintln!("[SDK reader] SENDING reverse response: {response_line}");
}
if let Err(e) =
Self::write_line_to_stdin(&stdin_writer, &response_line).await
{
tracing::warn!(
error = %e,
"reader: failed to write reverse-request response to stdin"
);
} else if std::env::var("GEMINI_SDK_DEBUG").is_ok() {
eprintln!("[SDK reader] reverse response SENT OK");
}
}
Err(e) => {
tracing::warn!(error = %e, "reader: invalid JSON-RPC message shape");
}
}
}
tracing::debug!("reader: stdout closed, reader task exiting");
})
}
}
#[async_trait]
impl Transport for GeminiTransport {
async fn connect(&self) -> Result<()> {
use tokio::process::Command;
let mut cmd = Command::new(&self.cli_path);
cmd.args(&self.args)
.current_dir(&self.cwd)
.envs(&self.env)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(crate::Error::SpawnFailed)?;
let stdout = child
.stdout
.take()
.ok_or_else(|| crate::Error::Transport("Failed to capture stdout".to_string()))?;
let child_stdin = child
.stdin
.take()
.ok_or_else(|| crate::Error::Transport("Failed to capture stdin".to_string()))?;
let stderr = child.stderr.take();
*self.process.lock().unwrap() = Some(child);
*self.stdin.lock().await = Some(child_stdin);
if let (Some(stderr_stream), Some(cb)) = (stderr, &self.stderr_callback) {
let cb = Arc::clone(cb);
tokio::spawn(async move {
use tokio::io::{AsyncBufReadExt, BufReader};
let mut lines = BufReader::new(stderr_stream).lines();
while let Ok(Some(line)) = lines.next_line().await {
cb(line);
}
});
}
let message_tx = self
.message_tx
.lock()
.await
.clone()
.ok_or_else(|| crate::Error::Transport("Message channel not initialized".to_string()))?;
let handle = Self::spawn_reader(
Arc::clone(&self.pending_requests),
Arc::clone(&self.reverse_handler),
Arc::clone(&self.stdin),
message_tx,
stdout,
);
*self.reader_handle.lock().await = Some(handle);
self.ready.store(true, Ordering::Release);
Ok(())
}
async fn write(&self, data: &str) -> Result<()> {
let mut guard = self.stdin.lock().await;
let stdin = guard.as_mut().ok_or(crate::Error::NotConnected)?;
stdin.write_all(data.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;
Ok(())
}
fn read_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Value>> + Send>> {
let rx = self
.message_rx
.try_lock()
.ok()
.and_then(|mut guard| guard.take());
match rx {
Some(rx) => Box::pin(async_stream::stream! {
let mut receiver = rx;
while let Some(item) = receiver.recv().await {
yield item;
}
}),
None => {
tracing::warn!(
"read_messages() called after the receiver was already consumed \
or the mutex is contended; returning empty stream"
);
let empty: Vec<Result<Value>> = Vec::new();
Box::pin(async_stream::stream! {
for item in empty {
yield item;
}
})
}
}
}
async fn end_input(&self) -> Result<()> {
*self.stdin.lock().await = None;
Ok(())
}
async fn interrupt(&self) -> Result<()> {
#[cfg(unix)]
{
let pid = {
let process = self.process.lock().unwrap();
process.as_ref().and_then(|c| c.id())
};
if let Some(pid) = pid {
let _ = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(pid as i32),
nix::sys::signal::Signal::SIGINT,
);
}
}
#[cfg(windows)]
{
let pid = {
let process = self.process.lock().unwrap();
process.as_ref().and_then(|c| c.id())
};
if let Some(pid) = pid {
unsafe {
let _ = windows_sys::Win32::System::Console::GenerateConsoleCtrlEvent(
windows_sys::Win32::System::Console::CTRL_BREAK_EVENT,
pid,
);
}
}
}
Ok(())
}
fn is_ready(&self) -> bool {
self.ready.load(Ordering::Relaxed)
}
async fn close(&self) -> Result<Option<i32>> {
if let Some(handle) = self.reader_handle.lock().await.take() {
handle.abort();
}
self.end_input().await?;
self.ready.store(false, Ordering::Relaxed);
let child = {
let mut process = self.process.lock().unwrap();
process.take()
};
if let Some(mut child) = child {
if let Some(duration) = self.close_timeout {
match tokio::time::timeout(duration, child.wait()).await {
Ok(Ok(status)) => return Ok(status.code()),
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
tracing::warn!(
timeout_ms = duration.as_millis() as u64,
"close() timed out waiting for subprocess; killing"
);
child.start_kill().ok();
let status = child.wait().await?;
return Ok(status.code());
}
}
} else {
let status = child.wait().await?;
return Ok(status.code());
}
}
Ok(None)
}
}
impl Drop for GeminiTransport {
fn drop(&mut self) {
if let Ok(mut process) = self.process.lock() {
if let Some(ref mut child) = *process {
let _ = child.start_kill();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn make_transport() -> GeminiTransport {
GeminiTransport::new(
PathBuf::from("/nonexistent/gemini"),
vec!["--experimental-acp".to_string()],
PathBuf::from("/tmp"),
HashMap::new(),
None,
None,
)
}
#[test]
fn test_transport_new_does_not_panic() {
let t = make_transport();
assert_eq!(t.cli_path, PathBuf::from("/nonexistent/gemini"));
assert!(!t.args.is_empty());
}
#[test]
fn test_is_not_ready_before_connect() {
let t = make_transport();
assert!(
!t.is_ready(),
"transport must not report ready before connect() is called"
);
}
#[test]
fn test_next_id_increments_monotonically() {
let t = make_transport();
let id0 = t.next_id();
let id1 = t.next_id();
let id2 = t.next_id();
let n = |id: &JsonRpcId| match id {
JsonRpcId::Number(n) => *n,
JsonRpcId::String(_) => panic!("expected Number variant"),
};
assert!(
n(&id0) < n(&id1) && n(&id1) < n(&id2),
"IDs must be strictly increasing: {id0}, {id1}, {id2}"
);
}
#[test]
fn test_next_id_starts_at_one() {
let t = make_transport();
let id = t.next_id();
match id {
JsonRpcId::Number(n) => assert_eq!(n, 1, "first ID must be 1"),
JsonRpcId::String(_) => panic!("expected Number variant"),
}
}
#[tokio::test]
async fn test_write_returns_not_connected_when_stdin_absent() {
let t = make_transport();
let result = t.write("{}").await;
assert!(
result.is_err(),
"write must fail when transport is not connected"
);
assert!(
matches!(result.unwrap_err(), crate::Error::NotConnected),
"error must be NotConnected"
);
}
#[tokio::test]
async fn test_end_input_clears_stdin() {
let t = make_transport();
t.end_input().await.expect("end_input must not fail on None stdin");
let guard = t.stdin.lock().await;
assert!(guard.is_none(), "stdin must be None after end_input");
}
#[test]
fn test_route_response_resolves_pending_request() {
let t = make_transport();
let id = JsonRpcId::Number(42);
let (tx, mut rx) = oneshot::channel();
t.pending_requests.insert(id.clone(), tx);
let response = JsonRpcResponse::success(id, serde_json::json!({"ok": true}));
t.route_response(response);
let received = rx.try_recv().expect("response must be available immediately");
assert!(received.result.is_some());
assert_eq!(received.result.unwrap(), serde_json::json!({"ok": true}));
}
#[test]
fn test_route_response_unknown_id_does_not_panic() {
let t = make_transport();
let id = JsonRpcId::Number(9999);
let response = JsonRpcResponse::success(id, serde_json::Value::Null);
t.route_response(response); }
#[tokio::test]
async fn test_read_messages_yields_sent_items() {
use tokio_stream::StreamExt;
let t = make_transport();
let tx = t.take_message_tx().await.expect("sender must be Some");
tx.send(Ok(serde_json::json!({"hello": "world"})))
.await
.expect("send must succeed");
drop(tx);
*t.message_tx.lock().await = None;
let mut stream = t.read_messages();
let item = stream.next().await;
assert!(item.is_some(), "stream must yield the sent item");
let value = item.unwrap().expect("item must be Ok");
assert_eq!(value, serde_json::json!({"hello": "world"}));
let end = stream.next().await;
assert!(end.is_none(), "stream must terminate after all senders are dropped");
}
#[tokio::test]
async fn test_read_messages_second_call_returns_empty_stream() {
use tokio_stream::StreamExt;
let t = make_transport();
let _stream1 = t.read_messages();
let mut stream2 = t.read_messages();
let item = stream2.next().await;
assert!(
item.is_none(),
"second read_messages() call must return an empty stream"
);
}
#[tokio::test]
async fn test_take_message_tx_returns_sender() {
let t = make_transport();
let sender = t.take_message_tx().await;
assert!(sender.is_some(), "message_tx must be Some after construction");
}
#[tokio::test]
async fn test_take_message_tx_is_reentrant() {
let t = make_transport();
let s1 = t.take_message_tx().await;
let s2 = t.take_message_tx().await;
assert!(s1.is_some() && s2.is_some(), "sender must be clonable");
}
}