use crate::error::{Error, Result};
use crate::providers::{ListModels, ProviderApi, ProviderClient};
use crate::sse::drain_sse_frames;
use crate::types::{
ContentPart, DEFAULT_GEMINI_MODEL, Event, Message, Model, ModelInfo, Provider, Response,
ResponseRequest, Role, ToolCall, ToolSpec,
};
use futures_util::StreamExt;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct GeminiModelsResponse {
#[serde(default)]
models: Vec<GeminiModel>,
}
#[derive(Debug, Deserialize)]
struct GeminiModel {
name: String,
#[serde(default)]
display_name: Option<String>,
#[serde(default)]
version: Option<String>,
#[serde(default)]
input_token_limit: Option<u64>,
#[serde(default)]
output_token_limit: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct GenerateContentResponse {
#[serde(default)]
candidates: Vec<Candidate>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: Content,
}
#[derive(Debug, Deserialize)]
struct Content {
#[serde(default)]
parts: Vec<Part>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct Part {
#[serde(default)]
text: Option<String>,
#[serde(default)]
function_call: Option<FunctionCallPart>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
struct FunctionCallPart {
name: String,
#[serde(default)]
args: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Copy)]
pub struct GeminiProvider;
#[derive(Debug, Clone)]
pub struct GeminiApi {
api_key: String,
base_url: String,
}
impl GeminiApi {
pub fn new(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: base_url.into(),
}
}
}
impl GeminiProvider {
fn models_url(base_url: &str, api_key: &str) -> String {
format!(
"{}/v1beta/models?key={}",
base_url.trim_end_matches('/'),
api_key
)
}
fn normalize_model_id(id: &str) -> String {
if id.starts_with("models/") {
id.to_string()
} else {
format!("models/{id}")
}
}
fn generate_url(base_url: &str, api_key: &str, model: &str) -> String {
format!(
"{}/v1beta/{}:generateContent?key={}",
base_url.trim_end_matches('/'),
Self::normalize_model_id(model),
api_key
)
}
fn stream_url(base_url: &str, api_key: &str, model: &str) -> String {
format!(
"{}/v1beta/{}:streamGenerateContent?alt=sse&key={}",
base_url.trim_end_matches('/'),
Self::normalize_model_id(model),
api_key
)
}
fn gemini_args(arguments: &serde_json::Value) -> serde_json::Value {
match arguments {
serde_json::Value::String(s) => {
serde_json::from_str(s).unwrap_or(serde_json::Value::Object(Default::default()))
}
v => v.clone(),
}
}
fn contents_from_messages(messages: &[Message]) -> Vec<serde_json::Value> {
let mut out = Vec::new();
for m in messages {
let role = match m.role {
Role::System | Role::User | Role::Tool => "user",
Role::Assistant => "model",
};
let mut parts: Vec<serde_json::Value> = Vec::new();
for p in &m.content {
match p {
ContentPart::Text(t) => {
if !t.is_empty() {
parts.push(serde_json::json!({ "text": t }));
}
}
ContentPart::ImageBase64 { media_type, data } => {
parts.push(serde_json::json!({
"inlineData": {
"mimeType": media_type,
"data": data
}
}));
}
ContentPart::ImageUrl { url } => {
parts.push(serde_json::json!({
"fileData": {
"mimeType": "image/jpeg",
"fileUri": url
}
}));
}
ContentPart::Thinking { text, .. } => {
if !text.is_empty() {
parts.push(serde_json::json!({ "text": text }));
}
}
ContentPart::Citation { .. } => {}
ContentPart::ToolCall {
name, arguments, ..
} => {
parts.push(serde_json::json!({
"functionCall": {
"name": name,
"args": Self::gemini_args(arguments)
}
}));
}
ContentPart::ToolResult {
function_name,
content,
..
} => {
let Some(fname) = function_name
.as_ref()
.map(|s| s.trim())
.filter(|s| !s.is_empty())
else {
continue;
};
parts.push(serde_json::json!({
"functionResponse": {
"name": fname,
"response": content
}
}));
}
}
}
if parts.is_empty() {
continue;
}
out.push(serde_json::json!({
"role": role,
"parts": parts
}));
}
out
}
fn tools_from_tools(tools: &[ToolSpec]) -> Vec<serde_json::Value> {
if tools.is_empty() {
return Vec::new();
}
vec![serde_json::json!({
"function_declarations": tools.iter().map(|t| serde_json::json!({
"name": t.name,
"description": t.description,
"parameters": t.parameters,
})).collect::<Vec<_>>()
})]
}
async fn resolve_model(http: &reqwest::Client, api_key: &str, base_url: &str) -> String {
match ListModels::list_models(&GeminiProvider, http, api_key, base_url).await {
Ok(models) => models
.first()
.map(|m| m.id.clone())
.unwrap_or_else(|| DEFAULT_GEMINI_MODEL.to_string()),
Err(_) => DEFAULT_GEMINI_MODEL.to_string(),
}
}
fn extract_text(resp: &GenerateContentResponse) -> String {
let mut out = String::new();
if let Some(c) = resp.candidates.first() {
for p in &c.content.parts {
if let Some(t) = &p.text {
out.push_str(t);
}
}
}
out
}
fn extract_tool_calls(resp: &GenerateContentResponse) -> Vec<ToolCall> {
let mut out = Vec::new();
let Some(c) = resp.candidates.first() else {
return out;
};
for (i, p) in c.content.parts.iter().enumerate() {
let Some(fc) = &p.function_call else {
continue;
};
let args = fc
.args
.clone()
.unwrap_or(serde_json::Value::Object(Default::default()));
out.push(ToolCall {
id: Some(format!("gemini:{}:{i}", fc.name)),
name: fc.name.clone(),
arguments: args,
});
}
out
}
}
#[async_trait::async_trait(?Send)]
impl ProviderApi for GeminiApi {
fn provider(&self) -> Provider {
Provider::Gemini
}
async fn send(&self, http: &reqwest::Client, req: ResponseRequest) -> Result<Response> {
GeminiProvider
.send(http, &self.api_key, &self.base_url, req)
.await
}
async fn stream(
&self,
http: &reqwest::Client,
req: ResponseRequest,
on_event: &mut dyn FnMut(Event),
) -> Result<Response> {
GeminiProvider
.stream(http, &self.api_key, &self.base_url, req, on_event)
.await
}
async fn list_models(&self, http: &reqwest::Client) -> Result<Vec<ModelInfo>> {
GeminiProvider
.list_models(http, &self.api_key, &self.base_url)
.await
}
}
impl ListModels for GeminiProvider {
fn list_models(
&self,
http: &reqwest::Client,
api_key: &str,
base_url: &str,
) -> impl std::future::Future<Output = Result<Vec<ModelInfo>>> + Send {
let url = Self::models_url(base_url, api_key);
let http = http.clone();
async move {
let resp = http.get(url).send().await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(Error::Api {
provider: Provider::Gemini,
status: status.as_u16(),
body: text,
});
}
let parsed: GeminiModelsResponse = serde_json::from_str(&text)?;
Ok(parsed
.models
.into_iter()
.map(|m| ModelInfo {
id: m.name.trim_start_matches("models/").to_string(),
display_name: m.display_name.or(m.version),
provider: Provider::Gemini,
created_at: None,
max_input_tokens: m.input_token_limit.and_then(|n| u32::try_from(n).ok()),
max_output_tokens: m.output_token_limit.and_then(|n| u32::try_from(n).ok()),
})
.collect())
}
}
}
impl ProviderClient for GeminiProvider {
async fn send(
&self,
http: &reqwest::Client,
api_key: &str,
base_url: &str,
req: ResponseRequest,
) -> Result<Response> {
let model = match req.model {
Some(m) => m.0,
None => Self::resolve_model(http, api_key, base_url).await,
};
let url = Self::generate_url(base_url, api_key, &model);
let contents = Self::contents_from_messages(&req.messages);
let mut body = serde_json::json!({
"contents": contents,
});
if let Some(max) = req.max_output_tokens {
body["generationConfig"] = serde_json::json!({ "maxOutputTokens": max });
}
if !req.tools.is_empty() {
body["tools"] = serde_json::json!(Self::tools_from_tools(&req.tools));
}
let resp = http.post(url).json(&body).send().await?;
let status = resp.status();
let text = resp.text().await?;
if !status.is_success() {
return Err(Error::Api {
provider: Provider::Gemini,
status: status.as_u16(),
body: text,
});
}
let parsed: GenerateContentResponse = serde_json::from_str(&text)?;
let out = Self::extract_text(&parsed);
let tool_calls = Self::extract_tool_calls(&parsed);
Ok(Response {
model: Model::new(model),
message: Message::text(Role::Assistant, out),
tool_calls,
metadata: serde_json::Value::Null,
#[cfg(feature = "raw-json")]
raw_json: serde_json::from_str::<serde_json::Value>(&text).ok(),
})
}
async fn stream<F>(
&self,
http: &reqwest::Client,
api_key: &str,
base_url: &str,
req: ResponseRequest,
on_event: &mut F,
) -> Result<Response>
where
F: FnMut(Event) + ?Sized,
{
let model = match req.model {
Some(m) => m.0,
None => Self::resolve_model(http, api_key, base_url).await,
};
let url = Self::stream_url(base_url, api_key, &model);
let contents = Self::contents_from_messages(&req.messages);
let mut body = serde_json::json!({
"contents": contents,
});
if let Some(max) = req.max_output_tokens {
body["generationConfig"] = serde_json::json!({ "maxOutputTokens": max });
}
if !req.tools.is_empty() {
body["tools"] = serde_json::json!(Self::tools_from_tools(&req.tools));
}
let resp = http.post(url).json(&body).send().await?;
let status = resp.status();
if !status.is_success() {
let text = resp.text().await?;
return Err(Error::Api {
provider: Provider::Gemini,
status: status.as_u16(),
body: text,
});
}
let mut out = String::new();
let mut tool_calls: Vec<ToolCall> = Vec::new();
let mut last_tool_emit = String::new();
let mut buf = String::new();
let mut frames = Vec::new();
let mut bytes = resp.bytes_stream();
while let Some(chunk) = bytes.next().await {
let chunk = chunk?;
buf.push_str(&String::from_utf8_lossy(&chunk));
drain_sse_frames(&mut buf, &mut frames)?;
for frame in frames.drain(..) {
let parsed: GenerateContentResponse = match serde_json::from_str(&frame.data) {
Ok(v) => v,
Err(_) => continue,
};
let delta = Self::extract_text(&parsed);
if !delta.is_empty() {
out.push_str(&delta);
on_event(Event::TextDelta(delta));
}
let calls = Self::extract_tool_calls(&parsed);
if !calls.is_empty() {
tool_calls = calls;
let sig = serde_json::to_string(&tool_calls).unwrap_or_default();
if sig != last_tool_emit {
last_tool_emit = sig;
for c in &tool_calls {
on_event(Event::ToolCall(c.clone()));
}
}
}
}
}
let resp = Response {
model: Model::new(model),
message: Message::text(Role::Assistant, out),
tool_calls,
metadata: serde_json::Value::Null,
#[cfg(feature = "raw-json")]
raw_json: None,
};
on_event(Event::Completed(resp.clone()));
Ok(resp)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_models_response_into_model_info() {
let json = r#"
{
"models": [
{
"name": "models/gemini-x",
"display_name": "Gemini X",
"input_token_limit": 100,
"output_token_limit": 10
}
]
}
"#;
let parsed: GeminiModelsResponse = serde_json::from_str(json).unwrap();
let infos: Vec<ModelInfo> = parsed
.models
.into_iter()
.map(|m| ModelInfo {
id: m.name.trim_start_matches("models/").to_string(),
display_name: m.display_name.or(m.version),
provider: Provider::Gemini,
created_at: None,
max_input_tokens: m.input_token_limit.and_then(|n| u32::try_from(n).ok()),
max_output_tokens: m.output_token_limit.and_then(|n| u32::try_from(n).ok()),
})
.collect();
assert_eq!(infos.len(), 1);
assert_eq!(infos[0].id, "gemini-x");
assert_eq!(infos[0].display_name.as_deref(), Some("Gemini X"));
assert_eq!(infos[0].max_input_tokens, Some(100));
assert_eq!(infos[0].max_output_tokens, Some(10));
}
#[test]
fn extract_text_from_generate_content_response() {
let json = r#"
{
"candidates": [
{
"content": {
"parts": [
{ "text": "hello" },
{ "text": " world" }
]
}
}
]
}
"#;
let parsed: GenerateContentResponse = serde_json::from_str(json).unwrap();
assert_eq!(GeminiProvider::extract_text(&parsed), "hello world");
}
#[test]
fn extract_tool_calls_from_generate_content_response() {
let json = r#"
{
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "add",
"args": { "a": 19, "b": 23 }
}
}
]
}
}
]
}
"#;
let parsed: GenerateContentResponse = serde_json::from_str(json).unwrap();
let calls = GeminiProvider::extract_tool_calls(&parsed);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "add");
assert_eq!(calls[0].arguments["a"], 19);
assert_eq!(calls[0].arguments["b"], 23);
assert!(calls[0].id.as_deref().unwrap().starts_with("gemini:add:"));
}
}