agent_io/llm/google/
mod.rs1mod request;
4mod response;
5mod types;
6
7use async_trait::async_trait;
8use derive_builder::Builder;
9use futures::StreamExt;
10use reqwest::Client;
11use std::time::Duration;
12
13use crate::llm::{
14 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
15};
16
17use types::*;
18
19const GOOGLE_API_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
20
21#[derive(Builder, Clone)]
23#[builder(pattern = "owned", build_fn(skip))]
24pub struct ChatGoogle {
25 #[builder(setter(into))]
27 pub(super) model: String,
28 pub(super) api_key: String,
30 #[builder(setter(into, strip_option), default = "None")]
32 pub(super) base_url: Option<String>,
33 #[builder(default = "8192")]
35 pub(super) max_tokens: u64,
36 #[builder(default = "0.2")]
38 pub(super) temperature: f32,
39 #[builder(default = "None")]
41 pub(super) thinking_budget: Option<u64>,
42 #[builder(setter(skip))]
44 pub(super) client: Client,
45 #[builder(setter(skip))]
47 pub(super) context_window: u64,
48}
49
50impl ChatGoogle {
51 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
53 let api_key = std::env::var("GOOGLE_API_KEY")
54 .or_else(|_| std::env::var("GEMINI_API_KEY"))
55 .map_err(|_| LlmError::Config("GOOGLE_API_KEY or GEMINI_API_KEY not set".into()))?;
56
57 Self::builder().model(model).api_key(api_key).build()
58 }
59
60 pub fn builder() -> ChatGoogleBuilder {
62 ChatGoogleBuilder::default()
63 }
64
65 fn api_url(&self, stream: bool) -> String {
67 let base = self.base_url.as_deref().unwrap_or(GOOGLE_API_URL);
68 let method = if stream {
69 "streamGenerateContent"
70 } else {
71 "generateContent"
72 };
73 format!("{}/{}:{}?key={}", base, self.model, method, self.api_key)
74 }
75
76 fn build_client() -> Client {
78 Client::builder()
79 .timeout(Duration::from_secs(120))
80 .build()
81 .expect("Failed to create HTTP client")
82 }
83
84 fn get_context_window(model: &str) -> u64 {
86 let model_lower = model.to_lowercase();
87
88 if model_lower.contains("gemini-1.5-pro") {
89 2_097_152 } else {
91 1_048_576 }
93 }
94
95 fn is_thinking_model(&self) -> bool {
97 let model_lower = self.model.to_lowercase();
98 model_lower.contains("gemini-2.5")
99 || model_lower.contains("thinking")
100 || model_lower.contains("gemini-exp")
101 }
102}
103
104impl ChatGoogleBuilder {
105 pub fn build(&self) -> Result<ChatGoogle, LlmError> {
106 let model = self
107 .model
108 .clone()
109 .ok_or_else(|| LlmError::Config("model is required".into()))?;
110 let api_key = self
111 .api_key
112 .clone()
113 .ok_or_else(|| LlmError::Config("api_key is required".into()))?;
114
115 Ok(ChatGoogle {
116 context_window: ChatGoogle::get_context_window(&model),
117 client: ChatGoogle::build_client(),
118 model,
119 api_key,
120 base_url: self.base_url.clone().flatten(),
121 max_tokens: self.max_tokens.unwrap_or(8192),
122 temperature: self.temperature.unwrap_or(0.2),
123 thinking_budget: self.thinking_budget.flatten(),
124 })
125 }
126}
127
128#[async_trait]
129impl BaseChatModel for ChatGoogle {
130 fn model(&self) -> &str {
131 &self.model
132 }
133
134 fn provider(&self) -> &str {
135 "google"
136 }
137
138 fn context_window(&self) -> Option<u64> {
139 Some(self.context_window)
140 }
141
142 async fn invoke(
143 &self,
144 messages: Vec<Message>,
145 tools: Option<Vec<ToolDefinition>>,
146 tool_choice: Option<ToolChoice>,
147 ) -> Result<ChatCompletion, LlmError> {
148 let request = self.build_request(messages, tools, tool_choice)?;
149
150 let response = self
151 .client
152 .post(self.api_url(false))
153 .header("Content-Type", "application/json")
154 .json(&request)
155 .send()
156 .await?;
157
158 if !response.status().is_success() {
159 let status = response.status();
160 let body = response.text().await.unwrap_or_default();
161 return Err(LlmError::Api(format!(
162 "Google API error ({}): {}",
163 status, body
164 )));
165 }
166
167 let completion: GeminiResponse = response.json().await?;
168 Ok(self.parse_response(completion))
169 }
170
171 async fn invoke_stream(
172 &self,
173 messages: Vec<Message>,
174 tools: Option<Vec<ToolDefinition>>,
175 tool_choice: Option<ToolChoice>,
176 ) -> Result<ChatStream, LlmError> {
177 let request = self.build_request(messages, tools, tool_choice)?;
178
179 let response = self
180 .client
181 .post(self.api_url(true))
182 .header("Content-Type", "application/json")
183 .json(&request)
184 .send()
185 .await?;
186
187 if !response.status().is_success() {
188 let status = response.status();
189 let body = response.text().await.unwrap_or_default();
190 return Err(LlmError::Api(format!(
191 "Google API error ({}): {}",
192 status, body
193 )));
194 }
195
196 let stream = response.bytes_stream().filter_map(|result| async move {
198 match result {
199 Ok(bytes) => {
200 let text = String::from_utf8_lossy(&bytes);
201 Self::parse_stream_chunk(&text)
202 }
203 Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
204 }
205 });
206
207 Ok(Box::pin(stream))
208 }
209
210 fn supports_vision(&self) -> bool {
211 true
213 }
214}