use std::fmt;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[cfg(not(target_arch = "wasm32"))]
use std::time::Duration;
use bytes::{Bytes, BytesMut};
use futures_core::Stream;
use futures_util::stream::StreamExt;
use reqwest::{Client, Url};
use tracing::trace;
use crate::backends::gemini::wire::{GenerateChunk, GenerateContentRequest};
use crate::error::{Error, Result};
const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com";
#[cfg(not(target_arch = "wasm32"))]
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
pub struct GeminiClient {
http: Client,
api_key: Box<str>,
base_url: Url,
}
impl fmt::Debug for GeminiClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GeminiClient")
.field("base_url", &self.base_url.as_str())
.field("api_key", &"<redacted>")
.finish()
}
}
impl GeminiClient {
pub fn new(api_key: impl Into<String>) -> Result<Self> {
let builder = Client::builder()
.user_agent(concat!("localharness/", env!("CARGO_PKG_VERSION")));
#[cfg(not(target_arch = "wasm32"))]
let builder = builder.timeout(DEFAULT_TIMEOUT);
let http = builder
.build()
.map_err(|e| Error::other(format!("reqwest client build: {e}")))?;
Ok(Self {
http,
api_key: api_key.into().into_boxed_str(),
base_url: Url::parse(DEFAULT_BASE_URL).expect("default base url is valid"),
})
}
pub fn with_base_url(mut self, url: Url) -> Self {
self.base_url = url;
self
}
pub async fn generate(
&self,
model: &str,
req: &GenerateContentRequest,
) -> Result<GenerateChunk> {
let path = format!("v1beta/models/{model}:generateContent");
let url = self
.base_url
.join(&path)
.map_err(|e| Error::other(format!("invalid model url: {e}")))?;
let response = self
.http
.post(url)
.header("x-goog-api-key", self.api_key.as_ref())
.json(req)
.send()
.await
.map_err(|e| Error::other(format!("gemini POST: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
return Err(Error::other(format!("gemini HTTP {status}: {body}")));
}
response
.json::<GenerateChunk>()
.await
.map_err(|e| Error::other(format!("gemini JSON: {e}")))
}
pub async fn stream_generate(
&self,
model: &str,
req: &GenerateContentRequest,
) -> Result<GeminiSseStream> {
let path = format!("v1beta/models/{model}:streamGenerateContent");
let mut url = self
.base_url
.join(&path)
.map_err(|e| Error::other(format!("invalid model url: {e}")))?;
url.query_pairs_mut().append_pair("alt", "sse");
let response = self
.http
.post(url)
.header("x-goog-api-key", self.api_key.as_ref())
.header("accept", "text/event-stream")
.json(req)
.send()
.await
.map_err(|e| Error::other(format!("gemini POST: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<no body>".to_string());
return Err(Error::other(format!("gemini HTTP {status}: {body}")));
}
let byte_stream = response.bytes_stream().map(|res| {
res.map_err(|e| Error::other(format!("gemini chunk read: {e}")))
});
Ok(GeminiSseStream::new(Box::pin(byte_stream)))
}
}
#[cfg(not(target_arch = "wasm32"))]
type ByteStream =
Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'static>>;
#[cfg(target_arch = "wasm32")]
type ByteStream =
Pin<Box<dyn Stream<Item = Result<Bytes>> + 'static>>;
pub struct GeminiSseStream {
upstream: ByteStream,
buffer: BytesMut,
done: bool,
}
impl GeminiSseStream {
fn new(upstream: ByteStream) -> Self {
Self {
upstream,
buffer: BytesMut::with_capacity(8 * 1024),
done: false,
}
}
fn take_frame(&mut self) -> Option<Vec<u8>> {
let bytes = &self.buffer[..];
let mut i = 0;
while i < bytes.len() {
if i + 3 < bytes.len()
&& bytes[i] == b'\r'
&& bytes[i + 1] == b'\n'
&& bytes[i + 2] == b'\r'
&& bytes[i + 3] == b'\n'
{
let frame = self.buffer.split_to(i + 4);
return Some(extract_data_payload(&frame));
}
if i + 1 < bytes.len() && bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
let frame = self.buffer.split_to(i + 2);
return Some(extract_data_payload(&frame));
}
i += 1;
}
None
}
fn take_remaining(&mut self) -> Option<Vec<u8>> {
if self.buffer.is_empty() {
return None;
}
let frame = self.buffer.split_to(self.buffer.len());
Some(extract_data_payload(&frame))
}
}
impl Stream for GeminiSseStream {
type Item = Result<GenerateChunk>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.done {
let payload = match self.take_frame() {
Some(p) => p,
None => match self.take_remaining() {
Some(p) => p,
None => return Poll::Ready(None),
},
};
if payload.is_empty() {
continue;
}
if payload == b"[DONE]" {
continue;
}
return Poll::Ready(Some(decode_chunk(&payload)));
}
if let Some(payload) = self.take_frame() {
if payload.is_empty() {
continue;
}
if payload == b"[DONE]" {
self.done = true;
self.buffer.clear();
continue;
}
return Poll::Ready(Some(decode_chunk(&payload)));
}
match self.upstream.as_mut().poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Some(Ok(bytes))) => {
trace!(len = bytes.len(), "gemini sse bytes");
self.buffer.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => self.done = true,
}
}
}
}
fn extract_data_payload(frame: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(frame.len());
let text = std::str::from_utf8(frame).unwrap_or("");
for line in text.split('\n') {
let line = line.trim_end_matches('\r');
if let Some(rest) = line.strip_prefix("data:") {
let rest = rest.strip_prefix(' ').unwrap_or(rest);
if !out.is_empty() {
out.push(b'\n');
}
out.extend_from_slice(rest.as_bytes());
}
}
out
}
fn decode_chunk(payload: &[u8]) -> Result<GenerateChunk> {
serde_json::from_slice::<GenerateChunk>(payload)
.map_err(|e| Error::other(format!("gemini sse decode: {e}; payload: {}",
String::from_utf8_lossy(payload))))
}
pub type SharedClient = Arc<GeminiClient>;
#[cfg(test)]
mod tests {
use super::*;
use futures_util::stream;
fn bytes_from(parts: &[&[u8]]) -> ByteStream {
let owned: Vec<Bytes> = parts.iter().map(|b| Bytes::copy_from_slice(b)).collect();
Box::pin(stream::iter(owned.into_iter().map(Ok)))
}
#[tokio::test]
async fn decodes_two_frames() {
let bytes = bytes_from(&[
b"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"hi\"}]}}]}\n\n",
b"data: {\"candidates\":[{\"finishReason\":\"STOP\"}]}\n\ndata: [DONE]\n\n",
]);
let mut s = GeminiSseStream::new(bytes);
let first = s.next().await.unwrap().unwrap();
assert_eq!(first.candidates.len(), 1);
let second = s.next().await.unwrap().unwrap();
assert_eq!(second.candidates[0].finish_reason.unwrap(),
crate::backends::gemini::wire::FinishReason::Stop);
assert!(s.next().await.is_none());
}
#[tokio::test]
async fn decodes_crlf_terminated_frames() {
let bytes = bytes_from(&[
b"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"hi\"}]}}]}\r\n\r\n",
b"data: {\"candidates\":[{\"finishReason\":\"STOP\"}]}\r\n\r\ndata: [DONE]\r\n\r\n",
]);
let mut s = GeminiSseStream::new(bytes);
let first = s.next().await.unwrap().unwrap();
assert_eq!(first.candidates.len(), 1);
let second = s.next().await.unwrap().unwrap();
assert_eq!(
second.candidates[0].finish_reason.unwrap(),
crate::backends::gemini::wire::FinishReason::Stop
);
assert!(s.next().await.is_none());
}
#[tokio::test]
async fn handles_split_across_chunks() {
let bytes = bytes_from(&[
b"data: {\"candi",
b"dates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"hi\"}]}}]}\n\n",
]);
let mut s = GeminiSseStream::new(bytes);
let first = s.next().await.unwrap().unwrap();
assert_eq!(first.candidates[0].content.as_ref().unwrap().parts.len(), 1);
}
async fn collect_texts(mut s: GeminiSseStream) -> Vec<String> {
let mut out = Vec::new();
while let Some(chunk) = s.next().await {
let chunk = chunk.unwrap();
for cand in chunk.candidates {
if let Some(content) = cand.content {
for part in content.parts {
if let crate::backends::gemini::wire::Part::Text { text } = part {
out.push(text);
}
}
}
}
}
out
}
#[tokio::test]
async fn flushes_final_frame_without_trailing_blank_line() {
let bytes = bytes_from(&[
b"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"hi\"}]}}]}\n\n",
b"data: {\"candidates\":[{\"finishReason\":\"STOP\"}]}",
]);
let mut s = GeminiSseStream::new(bytes);
let first = s.next().await.unwrap().unwrap();
assert_eq!(first.candidates.len(), 1);
let second = s.next().await.unwrap().unwrap();
assert_eq!(
second.candidates[0].finish_reason.unwrap(),
crate::backends::gemini::wire::FinishReason::Stop,
"the final unterminated frame must still be decoded"
);
assert!(s.next().await.is_none());
}
#[tokio::test]
async fn flushes_single_unterminated_frame() {
let bytes = bytes_from(&[
b"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"only\"}]}}]}",
]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["only".to_string()]);
}
#[tokio::test]
async fn flushes_final_frame_with_single_newline() {
let bytes = bytes_from(&[b"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"x\"}]}}]}\n"]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["x".to_string()]);
}
#[tokio::test]
async fn multiple_events_in_one_chunk() {
let bytes = bytes_from(&[concat!(
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"a\"}]}}]}\n\n",
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"b\"}]}}]}\n\n",
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"c\"}]}}]}\n\n",
).as_bytes()]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["a", "b", "c"]);
}
#[tokio::test]
async fn done_sentinel_terminates() {
let bytes = bytes_from(&[concat!(
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"a\"}]}}]}\n\n",
"data: [DONE]\n\n",
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"after\"}]}}]}\n\n",
).as_bytes()]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["a".to_string()]);
}
#[tokio::test]
async fn skips_empty_and_non_data_lines() {
let bytes = bytes_from(&[concat!(
": keepalive comment\n\n",
"event: message\nid: 42\n\n",
"data:\n\n", "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"real\"}]}}]}\n\n",
).as_bytes()]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["real".to_string()]);
}
#[tokio::test]
async fn data_field_without_leading_space() {
let bytes = bytes_from(&[b"data:{\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"ns\"}]}}]}\n\n"]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["ns".to_string()]);
}
#[tokio::test]
async fn crlf_terminator_split_across_chunks() {
let bytes = bytes_from(&[
b"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"hi\"}]}}]}\r\n\r",
b"\ndata: {\"candidates\":[{\"finishReason\":\"STOP\"}]}\r\n\r\n",
]);
let s = GeminiSseStream::new(bytes);
let texts = collect_texts(s).await;
assert_eq!(texts, vec!["hi".to_string()]);
}
#[tokio::test]
async fn multibyte_char_split_across_chunks() {
let full = "data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"é\"}]}}]}\n\n";
let raw = full.as_bytes();
let split_at = raw.iter().position(|&b| b == 0xC3).unwrap();
let (head, tail) = raw.split_at(split_at + 1); let bytes = bytes_from(&[head, tail]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["é".to_string()]);
}
#[tokio::test]
async fn mixed_lf_and_crlf_terminators() {
let bytes = bytes_from(&[concat!(
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"lf\"}]}}]}\n\n",
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"crlf\"}]}}]}\r\n\r\n",
).as_bytes()]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["lf", "crlf"]);
}
#[tokio::test]
async fn malformed_json_yields_error_not_panic() {
let bytes = bytes_from(&[b"data: {not json}\n\n"]);
let mut s = GeminiSseStream::new(bytes);
let item = s.next().await.unwrap();
assert!(item.is_err(), "malformed JSON must be an Err, got {item:?}");
}
#[tokio::test]
async fn bare_blank_frames_skipped() {
let bytes = bytes_from(&[concat!(
"\n\n",
"\r\n\r\n",
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"v\"}]}}]}\n\n",
).as_bytes()]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["v".to_string()]);
}
#[tokio::test]
async fn empty_stream_yields_nothing() {
let bytes = bytes_from(&[]);
let mut s = GeminiSseStream::new(bytes);
assert!(s.next().await.is_none());
assert!(s.next().await.is_none());
}
#[tokio::test]
async fn done_sentinel_unterminated_at_eof() {
let bytes = bytes_from(&[concat!(
"data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"a\"}]}}]}\n\n",
"data: [DONE]", ).as_bytes()]);
let s = GeminiSseStream::new(bytes);
assert_eq!(collect_texts(s).await, vec!["a".to_string()]);
}
}