use std::sync::Arc;
use crate::{
builder::LLMBackend,
chat::{
ChatMessage, ChatProvider, ChatResponse, ChatRole, MessageType, StructuredOutputFormat,
Tool, Usage,
},
completion::{CompletionProvider, CompletionRequest, CompletionResponse},
embedding::EmbeddingProvider,
error::LLMError,
models::{ModelListRawEntry, ModelListRequest, ModelListResponse, ModelsProvider},
stt::SpeechToTextProvider,
tts::TextToSpeechProvider,
FunctionCall, LLMProvider, ToolCall,
};
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use chrono::{DateTime, Utc};
use futures::{stream::Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug)]
pub struct GoogleConfig {
pub api_key: String,
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub system: Option<String>,
pub timeout_seconds: Option<u64>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub json_schema: Option<StructuredOutputFormat>,
pub tools: Option<Vec<Tool>>,
}
#[derive(Debug, Clone)]
pub struct Google {
pub config: Arc<GoogleConfig>,
pub client: Client,
}
#[derive(Serialize)]
struct GoogleChatRequest<'a> {
contents: Vec<GoogleChatContent<'a>>,
#[serde(skip_serializing_if = "Option::is_none", rename = "generationConfig")]
generation_config: Option<GoogleGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GoogleTool>>,
}
#[derive(Serialize)]
struct GoogleChatContent<'a> {
role: &'a str,
parts: Vec<GoogleContentPart<'a>>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
enum GoogleContentPart<'a> {
#[serde(rename = "text")]
Text(&'a str),
InlineData(GoogleInlineData),
FunctionCall(GoogleFunctionCall),
#[serde(rename = "functionResponse")]
FunctionResponse(GoogleFunctionResponse),
}
#[derive(Serialize)]
struct GoogleInlineData {
mime_type: String,
data: String,
}
#[derive(Serialize)]
struct GoogleGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
max_output_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none", rename = "topP")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none", rename = "topK")]
top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
response_mime_type: Option<GoogleResponseMimeType>,
#[serde(skip_serializing_if = "Option::is_none")]
response_schema: Option<Value>,
}
#[derive(Deserialize, Debug)]
struct GoogleChatResponse {
candidates: Vec<GoogleCandidate>,
#[serde(rename = "usageMetadata")]
usage_metadata: Option<GoogleUsageMetadata>,
}
#[derive(Deserialize, Debug)]
struct GoogleUsageMetadata {
#[serde(rename = "promptTokenCount")]
prompt_token_count: Option<u32>,
#[serde(rename = "candidatesTokenCount")]
candidates_token_count: Option<u32>,
#[serde(rename = "totalTokenCount")]
total_token_count: Option<u32>,
}
#[derive(Deserialize, Debug)]
struct GoogleStreamResponse {
candidates: Option<Vec<GoogleCandidate>>,
#[serde(rename = "usageMetadata")]
usage_metadata: Option<GoogleUsageMetadata>,
}
impl std::fmt::Display for GoogleChatResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match (self.text(), self.tool_calls()) {
(Some(text), Some(tool_calls)) => {
for call in tool_calls {
write!(f, "{call}")?;
}
write!(f, "{text}")
}
(Some(text), None) => write!(f, "{text}"),
(None, Some(tool_calls)) => {
for call in tool_calls {
write!(f, "{call}")?;
}
Ok(())
}
(None, None) => write!(f, ""),
}
}
}
#[derive(Deserialize, Debug)]
struct GoogleCandidate {
content: GoogleResponseContent,
}
#[derive(Deserialize, Debug)]
struct GoogleResponseContent {
#[serde(default)]
parts: Vec<GoogleResponsePart>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<GoogleFunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
function_calls: Option<Vec<GoogleFunctionCall>>,
}
impl ChatResponse for GoogleChatResponse {
fn text(&self) -> Option<String> {
self.candidates
.first()
.map(|c| c.content.parts.iter().map(|p| p.text.clone()).collect())
}
fn tool_calls(&self) -> Option<Vec<ToolCall>> {
self.candidates.first().and_then(|c| {
let part_function_calls: Vec<ToolCall> = c
.content
.parts
.iter()
.filter_map(|part| {
part.function_call.as_ref().map(|f| ToolCall {
id: format!("call_{}", f.name),
call_type: "function".to_string(),
function: FunctionCall {
name: f.name.clone(),
arguments: serde_json::to_string(&f.args).unwrap_or_default(),
},
})
})
.collect();
if !part_function_calls.is_empty() {
return Some(part_function_calls);
}
if let Some(fc) = &c.content.function_calls {
Some(
fc.iter()
.map(|f| ToolCall {
id: format!("call_{}", f.name),
call_type: "function".to_string(),
function: FunctionCall {
name: f.name.clone(),
arguments: serde_json::to_string(&f.args).unwrap_or_default(),
},
})
.collect(),
)
} else {
c.content.function_call.as_ref().map(|f| {
vec![ToolCall {
id: format!("call_{}", f.name),
call_type: "function".to_string(),
function: FunctionCall {
name: f.name.clone(),
arguments: serde_json::to_string(&f.args).unwrap_or_default(),
},
}]
})
}
})
}
fn usage(&self) -> Option<Usage> {
self.usage_metadata.as_ref().and_then(|metadata| {
match (metadata.prompt_token_count, metadata.candidates_token_count) {
(Some(prompt_tokens), Some(completion_tokens)) => Some(Usage {
prompt_tokens,
completion_tokens,
total_tokens: metadata
.total_token_count
.unwrap_or(prompt_tokens + completion_tokens),
completion_tokens_details: None,
prompt_tokens_details: None,
}),
_ => None,
}
})
}
}
#[derive(Deserialize, Debug)]
struct GoogleResponsePart {
#[serde(default)]
text: String,
#[serde(rename = "functionCall")]
function_call: Option<GoogleFunctionCall>,
}
#[derive(Deserialize, Debug, Serialize)]
enum GoogleResponseMimeType {
#[serde(rename = "text/plain")]
PlainText,
#[serde(rename = "application/json")]
Json,
#[serde(rename = "text/x.enum")]
Enum,
}
#[derive(Serialize, Debug)]
struct GoogleTool {
#[serde(rename = "functionDeclarations")]
function_declarations: Vec<GoogleFunctionDeclaration>,
}
#[derive(Serialize, Debug)]
struct GoogleFunctionDeclaration {
name: String,
description: String,
parameters: GoogleFunctionParameters,
}
impl From<&crate::chat::Tool> for GoogleFunctionDeclaration {
fn from(tool: &crate::chat::Tool) -> Self {
let properties_value = tool
.function
.parameters
.get("properties")
.cloned()
.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
GoogleFunctionDeclaration {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: GoogleFunctionParameters {
schema_type: "object".to_string(),
properties: properties_value,
required: tool
.function
.parameters
.get("required")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<String>>()
})
.unwrap_or_default(),
},
}
}
}
#[derive(Serialize, Debug)]
struct GoogleFunctionParameters {
#[serde(rename = "type")]
schema_type: String,
properties: Value,
required: Vec<String>,
}
#[derive(Deserialize, Debug, Serialize)]
struct GoogleFunctionCall {
name: String,
#[serde(default)]
args: Value,
}
#[derive(Deserialize, Debug, Serialize)]
struct GoogleFunctionResponse {
name: String,
response: GoogleFunctionResponseContent,
}
#[derive(Deserialize, Debug, Serialize)]
struct GoogleFunctionResponseContent {
name: String,
content: Value,
}
#[derive(Serialize)]
struct GoogleEmbeddingRequest<'a> {
model: &'a str,
content: GoogleEmbeddingContent<'a>,
}
#[derive(Serialize)]
struct GoogleEmbeddingContent<'a> {
parts: Vec<GoogleContentPart<'a>>,
}
#[derive(Deserialize)]
struct GoogleEmbeddingResponse {
embedding: GoogleEmbedding,
}
#[derive(Deserialize)]
struct GoogleEmbedding {
values: Vec<f32>,
}
impl Google {
#[allow(clippy::too_many_arguments)]
pub fn new(
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
json_schema: Option<StructuredOutputFormat>,
tools: Option<Vec<Tool>>,
) -> Self {
let mut builder = Client::builder();
if let Some(sec) = timeout_seconds {
builder = builder.timeout(std::time::Duration::from_secs(sec));
}
Self::with_client(
builder.build().expect("Failed to build reqwest Client"),
api_key,
model,
max_tokens,
temperature,
timeout_seconds,
system,
top_p,
top_k,
json_schema,
tools,
)
}
#[allow(clippy::too_many_arguments)]
pub fn with_client(
client: Client,
api_key: impl Into<String>,
model: Option<String>,
max_tokens: Option<u32>,
temperature: Option<f32>,
timeout_seconds: Option<u64>,
system: Option<String>,
top_p: Option<f32>,
top_k: Option<u32>,
json_schema: Option<StructuredOutputFormat>,
tools: Option<Vec<Tool>>,
) -> Self {
Self {
config: Arc::new(GoogleConfig {
api_key: api_key.into(),
model: model.unwrap_or_else(|| "gemini-1.5-flash".to_string()),
max_tokens,
temperature,
system,
timeout_seconds,
top_p,
top_k,
json_schema,
tools,
}),
client,
}
}
pub fn api_key(&self) -> &str {
&self.config.api_key
}
pub fn model(&self) -> &str {
&self.config.model
}
pub fn max_tokens(&self) -> Option<u32> {
self.config.max_tokens
}
pub fn temperature(&self) -> Option<f32> {
self.config.temperature
}
pub fn timeout_seconds(&self) -> Option<u64> {
self.config.timeout_seconds
}
pub fn system(&self) -> Option<&str> {
self.config.system.as_deref()
}
pub fn top_p(&self) -> Option<f32> {
self.config.top_p
}
pub fn top_k(&self) -> Option<u32> {
self.config.top_k
}
pub fn json_schema(&self) -> Option<&StructuredOutputFormat> {
self.config.json_schema.as_ref()
}
pub fn tools(&self) -> Option<&[Tool]> {
self.config.tools.as_deref()
}
pub fn client(&self) -> &Client {
&self.client
}
}
const AUDIO_UNSUPPORTED: &str = "Audio messages are not supported by Google chat";
#[async_trait]
impl ChatProvider for Google {
async fn chat(&self, messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing Google API key".to_string()));
}
let mut chat_contents = Vec::with_capacity(messages.len());
if let Some(system) = &self.config.system {
chat_contents.push(GoogleChatContent {
role: "user",
parts: vec![GoogleContentPart::Text(system)],
});
}
for msg in messages {
let role = match &msg.message_type {
MessageType::ToolResult(_) => "function",
_ => match msg.role {
ChatRole::User => "user",
ChatRole::Assistant => "model",
},
};
chat_contents.push(GoogleChatContent {
role,
parts: match &msg.message_type {
MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
MessageType::Image((image_mime, raw_bytes)) => {
vec![GoogleContentPart::InlineData(GoogleInlineData {
mime_type: image_mime.mime_type().to_string(),
data: BASE64.encode(raw_bytes),
})]
}
MessageType::ImageURL(_) => unimplemented!(),
MessageType::Pdf(raw_bytes) => {
vec![GoogleContentPart::InlineData(GoogleInlineData {
mime_type: "application/pdf".to_string(),
data: BASE64.encode(raw_bytes),
})]
}
MessageType::Audio(_) => vec![],
MessageType::ToolUse(calls) => calls
.iter()
.map(|call| {
GoogleContentPart::FunctionCall(GoogleFunctionCall {
name: call.function.name.clone(),
args: serde_json::from_str(&call.function.arguments)
.unwrap_or(serde_json::Value::Null),
})
})
.collect(),
MessageType::ToolResult(result) => result
.iter()
.map(|result| {
let parsed_args =
serde_json::from_str::<Value>(&result.function.arguments)
.unwrap_or(serde_json::Value::Null);
GoogleContentPart::FunctionResponse(GoogleFunctionResponse {
name: result.function.name.clone(),
response: GoogleFunctionResponseContent {
name: result.function.name.clone(),
content: parsed_args,
},
})
})
.collect(),
},
});
}
let generation_config = if self.config.max_tokens.is_none()
&& self.config.temperature.is_none()
&& self.config.top_p.is_none()
&& self.config.top_k.is_none()
&& self.config.json_schema.is_none()
{
None
} else {
let (response_mime_type, response_schema) =
if let Some(json_schema) = &self.config.json_schema {
if let Some(schema) = &json_schema.schema {
let mut schema = schema.clone();
if let Some(obj) = schema.as_object_mut() {
obj.remove("additionalProperties");
}
(Some(GoogleResponseMimeType::Json), Some(schema))
} else {
(None, None)
}
} else {
(None, None)
};
Some(GoogleGenerationConfig {
max_output_tokens: self.config.max_tokens,
temperature: self.config.temperature,
top_p: self.config.top_p,
top_k: self.config.top_k,
response_mime_type,
response_schema,
})
};
let req_body = GoogleChatRequest {
contents: chat_contents,
generation_config,
tools: None,
};
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&req_body) {
log::trace!("Google Gemini request payload: {}", json);
}
}
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}",
model = self.config.model,
key = self.config.api_key
);
let mut request = self.client.post(&url).json(&req_body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let resp = request.send().await?;
log::debug!("Google Gemini HTTP status: {}", resp.status());
let resp = resp.error_for_status()?;
let resp_text = resp.text().await?;
let json_resp: Result<GoogleChatResponse, serde_json::Error> =
serde_json::from_str(&resp_text);
match json_resp {
Ok(response) => Ok(Box::new(response)),
Err(e) => {
Err(LLMError::ResponseFormatError {
message: format!("Failed to decode Google API response: {e}"),
raw_response: resp_text,
})
}
}
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: Option<&[Tool]>,
) -> Result<Box<dyn ChatResponse>, LLMError> {
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing Google API key".to_string()));
}
let mut chat_contents = Vec::with_capacity(messages.len());
if let Some(system) = &self.config.system {
chat_contents.push(GoogleChatContent {
role: "user",
parts: vec![GoogleContentPart::Text(system)],
});
}
for msg in messages {
let role = match &msg.message_type {
MessageType::ToolResult(_) => "function",
_ => match msg.role {
ChatRole::User => "user",
ChatRole::Assistant => "model",
},
};
chat_contents.push(GoogleChatContent {
role,
parts: match &msg.message_type {
MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
MessageType::Image((image_mime, raw_bytes)) => {
vec![GoogleContentPart::InlineData(GoogleInlineData {
mime_type: image_mime.mime_type().to_string(),
data: BASE64.encode(raw_bytes),
})]
}
MessageType::ImageURL(_) => unimplemented!(),
MessageType::Pdf(raw_bytes) => {
vec![GoogleContentPart::InlineData(GoogleInlineData {
mime_type: "application/pdf".to_string(),
data: BASE64.encode(raw_bytes),
})]
}
MessageType::Audio(_) => vec![],
MessageType::ToolUse(calls) => calls
.iter()
.map(|call| {
GoogleContentPart::FunctionCall(GoogleFunctionCall {
name: call.function.name.clone(),
args: serde_json::from_str(&call.function.arguments)
.unwrap_or(serde_json::Value::Null),
})
})
.collect(),
MessageType::ToolResult(result) => result
.iter()
.map(|result| {
let parsed_args =
serde_json::from_str::<Value>(&result.function.arguments)
.unwrap_or(serde_json::Value::Null);
GoogleContentPart::FunctionResponse(GoogleFunctionResponse {
name: result.function.name.clone(),
response: GoogleFunctionResponseContent {
name: result.function.name.clone(),
content: parsed_args,
},
})
})
.collect(),
},
});
}
let google_tools = tools.map(|t| {
vec![GoogleTool {
function_declarations: t.iter().map(GoogleFunctionDeclaration::from).collect(),
}]
});
let generation_config = {
let (response_mime_type, response_schema) =
if let Some(json_schema) = &self.config.json_schema {
if let Some(schema) = &json_schema.schema {
let mut schema = schema.clone();
if let Some(obj) = schema.as_object_mut() {
obj.remove("additionalProperties");
}
(Some(GoogleResponseMimeType::Json), Some(schema))
} else {
(None, None)
}
} else {
(None, None)
};
Some(GoogleGenerationConfig {
max_output_tokens: self.config.max_tokens,
temperature: self.config.temperature,
top_p: self.config.top_p,
top_k: self.config.top_k,
response_mime_type,
response_schema,
})
};
let req_body = GoogleChatRequest {
contents: chat_contents,
generation_config,
tools: google_tools,
};
if log::log_enabled!(log::Level::Trace) {
if let Ok(json) = serde_json::to_string(&req_body) {
log::trace!("Google Gemini request payload (tool): {}", json);
}
}
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}",
model = self.config.model,
key = self.config.api_key
);
let mut request = self.client.post(&url).json(&req_body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let resp = request.send().await?;
log::debug!("Google Gemini HTTP status (tool): {}", resp.status());
let resp = resp.error_for_status()?;
let resp_text = resp.text().await?;
let json_resp: Result<GoogleChatResponse, serde_json::Error> =
serde_json::from_str(&resp_text);
match json_resp {
Ok(response) => Ok(Box::new(response)),
Err(e) => {
Err(LLMError::ResponseFormatError {
message: format!("Failed to decode Google API response: {e}"),
raw_response: resp_text,
})
}
}
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
) -> Result<std::pin::Pin<Box<dyn Stream<Item = Result<String, LLMError>> + Send>>, LLMError>
{
let struct_stream = self.chat_stream_struct(messages).await?;
let content_stream = struct_stream.filter_map(|result| async move {
match result {
Ok(stream_response) => {
if let Some(choice) = stream_response.choices.first() {
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Some(Ok(content.clone()));
}
}
}
None
}
Err(e) => Some(Err(e)),
}
});
Ok(Box::pin(content_stream))
}
async fn chat_stream_struct(
&self,
messages: &[ChatMessage],
) -> Result<
std::pin::Pin<Box<dyn Stream<Item = Result<crate::chat::StreamResponse, LLMError>> + Send>>,
LLMError,
> {
crate::chat::ensure_no_audio(messages, AUDIO_UNSUPPORTED)?;
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing Google API key".to_string()));
}
let mut chat_contents = Vec::with_capacity(messages.len());
if let Some(system) = &self.config.system {
chat_contents.push(GoogleChatContent {
role: "user",
parts: vec![GoogleContentPart::Text(system)],
});
}
for msg in messages {
let role = match msg.role {
ChatRole::User => "user",
ChatRole::Assistant => "model",
};
chat_contents.push(GoogleChatContent {
role,
parts: match &msg.message_type {
MessageType::Text => vec![GoogleContentPart::Text(&msg.content)],
MessageType::Image((image_mime, raw_bytes)) => {
vec![GoogleContentPart::InlineData(GoogleInlineData {
mime_type: image_mime.mime_type().to_string(),
data: BASE64.encode(raw_bytes),
})]
}
MessageType::Pdf(raw_bytes) => {
vec![GoogleContentPart::InlineData(GoogleInlineData {
mime_type: "application/pdf".to_string(),
data: BASE64.encode(raw_bytes),
})]
}
_ => vec![GoogleContentPart::Text(&msg.content)],
},
});
}
let generation_config = if self.config.max_tokens.is_none()
&& self.config.temperature.is_none()
&& self.config.top_p.is_none()
&& self.config.top_k.is_none()
{
None
} else {
Some(GoogleGenerationConfig {
max_output_tokens: self.config.max_tokens,
temperature: self.config.temperature,
top_p: self.config.top_p,
top_k: self.config.top_k,
response_mime_type: None,
response_schema: None,
})
};
let req_body = GoogleChatRequest {
contents: chat_contents,
generation_config,
tools: None,
};
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse&key={key}",
model = self.config.model,
key = self.config.api_key
);
let mut request = self.client.post(&url).json(&req_body);
if let Some(timeout) = self.config.timeout_seconds {
request = request.timeout(std::time::Duration::from_secs(timeout));
}
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await?;
return Err(LLMError::ResponseFormatError {
message: format!("Google API returned error status: {status}"),
raw_response: error_text,
});
}
Ok(create_google_sse_stream(response))
}
}
#[async_trait]
impl CompletionProvider for Google {
async fn complete(&self, req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
let chat_message = ChatMessage::user().content(req.prompt.clone()).build();
if let Some(text) = self.chat(&[chat_message]).await?.text() {
Ok(CompletionResponse { text })
} else {
Err(LLMError::ProviderError(
"No answer returned by Google".to_string(),
))
}
}
}
#[async_trait]
impl EmbeddingProvider for Google {
async fn embed(&self, texts: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing Google API key".to_string()));
}
let mut embeddings = Vec::new();
for text in texts {
let req_body = GoogleEmbeddingRequest {
model: "models/text-embedding-004",
content: GoogleEmbeddingContent {
parts: vec![GoogleContentPart::Text(&text)],
},
};
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent?key={}",
self.config.api_key
);
let resp = self
.client
.post(&url)
.json(&req_body)
.send()
.await?
.error_for_status()?;
let embedding_resp: GoogleEmbeddingResponse = resp.json().await?;
embeddings.push(embedding_resp.embedding.values);
}
Ok(embeddings)
}
}
#[async_trait]
impl SpeechToTextProvider for Google {
async fn transcribe(&self, _audio: Vec<u8>) -> Result<String, LLMError> {
Err(LLMError::ProviderError(
"Google does not implement speech to text endpoint yet.".into(),
))
}
}
impl LLMProvider for Google {
fn tools(&self) -> Option<&[Tool]> {
self.config.tools.as_deref()
}
}
fn create_google_sse_stream(
response: reqwest::Response,
) -> std::pin::Pin<Box<dyn Stream<Item = Result<crate::chat::StreamResponse, LLMError>> + Send>> {
let stream = response
.bytes_stream()
.map(move |chunk| match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
parse_google_sse_chunk(&text)
}
Err(e) => Err(LLMError::HttpError(e.to_string())),
})
.filter_map(|result| async move {
match result {
Ok(Some(response)) => Some(Ok(response)),
Ok(None) => None,
Err(e) => Some(Err(e)),
}
});
Box::pin(stream)
}
fn parse_google_sse_chunk(chunk: &str) -> Result<Option<crate::chat::StreamResponse>, LLMError> {
for line in chunk.lines() {
let line = line.trim();
if let Some(data) = line.strip_prefix("data: ") {
match serde_json::from_str::<GoogleStreamResponse>(data) {
Ok(response) => {
let mut content = None;
let mut usage = None;
if let Some(candidates) = &response.candidates {
if let Some(candidate) = candidates.first() {
if let Some(part) = candidate.content.parts.first() {
if !part.text.is_empty() {
content = Some(part.text.clone());
}
}
}
}
if let Some(usage_metadata) = &response.usage_metadata {
if let (Some(prompt_tokens), Some(completion_tokens)) = (
usage_metadata.prompt_token_count,
usage_metadata.candidates_token_count,
) {
usage = Some(Usage {
prompt_tokens,
completion_tokens,
total_tokens: usage_metadata
.total_token_count
.unwrap_or(prompt_tokens + completion_tokens),
completion_tokens_details: None,
prompt_tokens_details: None,
});
}
}
if content.is_some() || usage.is_some() {
return Ok(Some(crate::chat::StreamResponse {
choices: vec![crate::chat::StreamChoice {
delta: crate::chat::StreamDelta {
content,
tool_calls: None,
},
}],
usage,
}));
}
return Ok(None);
}
Err(_) => continue,
}
}
}
Ok(None)
}
#[async_trait]
impl TextToSpeechProvider for Google {}
#[derive(Clone, Debug, Deserialize)]
pub struct GoogleModelEntry {
pub name: String,
pub version: String,
pub display_name: String,
pub description: String,
pub input_token_limit: Option<u32>,
pub output_token_limit: Option<u32>,
pub supported_generation_methods: Vec<String>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
#[serde(flatten)]
pub extra: Value,
}
impl ModelListRawEntry for GoogleModelEntry {
fn get_id(&self) -> String {
self.name.clone()
}
fn get_created_at(&self) -> DateTime<Utc> {
DateTime::<Utc>::UNIX_EPOCH
}
fn get_raw(&self) -> Value {
self.extra.clone()
}
}
#[derive(Clone, Debug, Deserialize)]
pub struct GoogleModelListResponse {
pub models: Vec<GoogleModelEntry>,
}
impl ModelListResponse for GoogleModelListResponse {
fn get_models(&self) -> Vec<String> {
self.models.iter().map(|m| m.name.clone()).collect()
}
fn get_models_raw(&self) -> Vec<Box<dyn ModelListRawEntry>> {
self.models
.iter()
.map(|e| Box::new(e.clone()) as Box<dyn ModelListRawEntry>)
.collect()
}
fn get_backend(&self) -> LLMBackend {
LLMBackend::Google
}
}
#[async_trait]
impl ModelsProvider for Google {
async fn list_models(
&self,
_request: Option<&ModelListRequest>,
) -> Result<Box<dyn ModelListResponse>, LLMError> {
if self.config.api_key.is_empty() {
return Err(LLMError::AuthError("Missing Google API key".to_string()));
}
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models?key={}",
self.config.api_key
);
let resp = self.client.get(&url).send().await?.error_for_status()?;
let result: GoogleModelListResponse = resp.json().await?;
Ok(Box::new(result))
}
}