1use std::collections::BTreeMap;
12
13use async_stream::stream;
14use async_trait::async_trait;
15use eventsource_stream::Eventsource;
16use futures::StreamExt;
17use serde::Deserialize;
18use serde_json::{json, Value};
19
20use crate::error::{Error, Result};
21use crate::providers::Provider;
22use crate::retry::{classify_status, parse_retry_after, with_retry, Attempt, RetryConfig};
23use crate::stream::AssistantMessageEventStream;
24use crate::types::{
25 now_ms, AssistantMessage, AssistantMessageEvent, Content, Context, Message, Model, StopReason,
26 StreamOptions, Usage,
27};
28
29#[derive(Deserialize, Debug)]
30struct Chunk {
31 #[serde(default)]
32 candidates: Vec<Candidate>,
33 #[serde(default)]
34 usage_metadata: Option<UsageMetadata>,
35 #[serde(default)]
36 model_version: Option<String>,
37}
38
39#[derive(Deserialize, Debug)]
40struct Candidate {
41 #[serde(default)]
42 content: Option<CandidateContent>,
43 #[serde(default)]
44 finish_reason: Option<String>,
45}
46
47#[derive(Deserialize, Debug)]
48struct CandidateContent {
49 #[serde(default)]
50 parts: Vec<Part>,
51}
52
53#[derive(Deserialize, Debug)]
54struct Part {
55 #[serde(default)]
56 text: Option<String>,
57 #[serde(default)]
58 function_call: Option<FunctionCall>,
59}
60
61#[derive(Deserialize, Debug)]
62struct FunctionCall {
63 #[serde(default)]
64 name: String,
65 #[serde(default)]
66 args: Value,
67}
68
69#[derive(Deserialize, Debug, Default)]
70struct UsageMetadata {
71 #[serde(default)]
72 prompt_token_count: u64,
73 #[serde(default)]
74 candidates_token_count: u64,
75 #[serde(default)]
76 total_token_count: u64,
77}
78
79fn convert_messages(messages: &[Message]) -> Vec<Value> {
80 let mut out: Vec<Value> = Vec::new();
81 for m in messages {
82 match m {
83 Message::User { content, .. } => {
84 let parts: Vec<Value> = content
85 .iter()
86 .filter_map(|c| c.as_text().map(|t| json!({"text": t})))
87 .collect();
88 out.push(json!({"role": "user", "parts": parts}));
89 }
90 Message::Assistant(a) => {
91 let mut parts: Vec<Value> = Vec::new();
92 for c in &a.content {
93 match c {
94 Content::Text { text } => parts.push(json!({"text": text})),
95 Content::ToolCall {
96 name, arguments, ..
97 } => {
98 parts.push(json!({
99 "functionCall": {"name": name, "args": arguments}
100 }));
101 }
102 _ => {}
103 }
104 }
105 out.push(json!({"role": "model", "parts": parts}));
106 }
107 Message::ToolResult(tr) => {
108 let text = tr
109 .content
110 .iter()
111 .filter_map(|c| c.as_text().map(|s| s.to_string()))
112 .collect::<Vec<_>>()
113 .join("");
114 out.push(json!({
115 "role": "user",
116 "parts": [{
117 "functionResponse": {
118 "name": tr.tool_name,
119 "response": {"output": text, "is_error": tr.is_error}
120 }
121 }]
122 }));
123 }
124 }
125 }
126 out
127}
128
129fn build_body(context: &Context, options: &StreamOptions) -> Value {
130 let mut body = json!({
131 "contents": convert_messages(&context.messages),
132 });
133 if let Some(sp) = &context.system_prompt {
134 body["systemInstruction"] = json!({"role": "system", "parts": [{"text": sp}]});
135 }
136 if let Some(t) = options.temperature {
137 body["generationConfig"] = json!({"temperature": t});
138 }
139 if !context.tools.is_empty() {
140 let decls: Vec<Value> = context
141 .tools
142 .iter()
143 .map(|t| {
144 json!({
145 "name": t.name,
146 "description": t.description,
147 "parameters": t.parameters,
148 })
149 })
150 .collect();
151 body["tools"] = json!([{"functionDeclarations": decls}]);
152 }
153 body
154}
155
156pub struct GoogleProvider {
157 client: reqwest::Client,
158}
159
160impl GoogleProvider {
161 pub fn new() -> Self {
162 Self {
163 client: reqwest::Client::new(),
164 }
165 }
166}
167
168impl Default for GoogleProvider {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174#[async_trait]
175impl Provider for GoogleProvider {
176 async fn stream(
177 &self,
178 model: &Model,
179 context: &Context,
180 options: &StreamOptions,
181 ) -> Result<AssistantMessageEventStream> {
182 let api_key = options
183 .api_key
184 .clone()
185 .or_else(|| std::env::var("GOOGLE_API_KEY").ok())
186 .or_else(|| std::env::var("GEMINI_API_KEY").ok())
187 .ok_or_else(|| Error::MissingApiKey("google".into()))?;
188 let base_url = options
189 .base_url
190 .clone()
191 .unwrap_or_else(|| model.base_url.clone());
192 let url = format!(
193 "{}/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
194 base_url.trim_end_matches('/'),
195 model.id,
196 api_key,
197 );
198 let body = build_body(context, options);
199 let cancel = options.cancel.clone();
200 let extra_headers: BTreeMap<String, String> = options.headers.clone();
201
202 let resp = with_retry(&RetryConfig::default(), cancel.as_ref(), |_| {
203 let client = self.client.clone();
204 let url = url.clone();
205 let body = body.clone();
206 let extra_headers = extra_headers.clone();
207 async move {
208 let mut req = client
209 .post(&url)
210 .header("accept", "text/event-stream")
211 .header("content-type", "application/json");
212 for (k, v) in extra_headers {
213 req = req.header(k, v);
214 }
215 let r = match req.json(&body).send().await {
216 Ok(r) => r,
217 Err(e) => {
218 return if e.is_timeout() || e.is_connect() {
219 Attempt::Retry {
220 error: Error::Http(e),
221 retry_after: None,
222 }
223 } else {
224 Attempt::Fatal(Error::Http(e))
225 };
226 }
227 };
228 let status = r.status();
229 if status.is_success() {
230 return Attempt::Ok(r);
231 }
232 let retry_after = r
233 .headers()
234 .get("retry-after")
235 .and_then(|v| v.to_str().ok())
236 .and_then(parse_retry_after);
237 let body_text = r.text().await.unwrap_or_default();
238 let err = Error::ProviderError {
239 status: status.as_u16(),
240 body: body_text,
241 };
242 match classify_status(status.as_u16()) {
243 Some(_) => Attempt::Retry {
244 error: err,
245 retry_after,
246 },
247 None => Attempt::Fatal(err),
248 }
249 }
250 })
251 .await?;
252
253 let api = model.api.clone();
254 let provider = model.provider.clone();
255 let model_id = model.id.clone();
256 let cancel_for_stream = cancel.clone();
257
258 let s = stream! {
259 yield Ok(AssistantMessageEvent::Start);
260 let mut sse = resp.bytes_stream().eventsource();
261
262 let mut text_buf = String::new();
263 let mut text_started = false;
264 let mut text_index: usize = 0;
265 let mut tool_blocks: Vec<(String, String, Value)> = Vec::new();
266 let mut stop = StopReason::Stop;
267 let mut usage = Usage::default();
268 let mut response_model: Option<String> = None;
269
270 while let Some(ev) = sse.next().await {
271 if let Some(c) = &cancel_for_stream {
272 if c.is_cancelled() { yield Err(Error::Cancelled); return; }
273 }
274 let ev = match ev {
275 Ok(e) => e,
276 Err(e) => { yield Err(Error::InvalidResponse(format!("sse: {e}"))); return; }
277 };
278 if ev.data.is_empty() { continue; }
279 let chunk: Chunk = match serde_json::from_str(&ev.data) {
280 Ok(c) => c,
281 Err(_) => continue,
282 };
283 if let Some(m) = chunk.model_version { response_model = Some(m); }
284 if let Some(u) = chunk.usage_metadata {
285 usage.input = u.prompt_token_count;
286 usage.output = u.candidates_token_count;
287 usage.total_tokens = u.total_token_count;
288 }
289 for cand in chunk.candidates {
290 if let Some(reason) = cand.finish_reason {
291 stop = match reason.as_str() {
292 "STOP" => StopReason::Stop,
293 "MAX_TOKENS" => StopReason::Length,
294 _ => StopReason::Stop,
295 };
296 }
297 if let Some(content) = cand.content {
298 for part in content.parts {
299 if let Some(t) = part.text {
300 if !t.is_empty() {
301 if !text_started {
302 text_started = true;
303 yield Ok(AssistantMessageEvent::TextStart { content_index: text_index });
304 }
305 text_buf.push_str(&t);
306 yield Ok(AssistantMessageEvent::TextDelta { content_index: text_index, delta: t });
307 }
308 }
309 if let Some(fc) = part.function_call {
310 let id = format!("call_{}", tool_blocks.len() + 1);
311 let block_index = text_index + if text_started { 1 } else { 0 } + tool_blocks.len();
312 yield Ok(AssistantMessageEvent::ToolCallStart {
313 content_index: block_index,
314 id: id.clone(),
315 name: fc.name.clone(),
316 });
317 yield Ok(AssistantMessageEvent::ToolCallEnd {
318 content_index: block_index,
319 id: id.clone(),
320 name: fc.name.clone(),
321 arguments: fc.args.clone(),
322 });
323 if fc.finish_reason_set_to_tool_use() { stop = StopReason::ToolUse; }
324 tool_blocks.push((id, fc.name, fc.args));
325 }
326 }
327 }
328 }
329 }
330
331 if text_started {
332 yield Ok(AssistantMessageEvent::TextEnd { content_index: text_index, content: text_buf.clone() });
333 text_index += 1;
334 }
335 if !tool_blocks.is_empty() && stop == StopReason::Stop {
336 stop = StopReason::ToolUse;
337 }
338 let mut out_content: Vec<Content> = Vec::new();
339 if text_started {
340 out_content.push(Content::Text { text: text_buf });
341 }
342 for (id, name, args) in tool_blocks {
343 out_content.push(Content::ToolCall { id, name, arguments: args });
344 }
345 let _ = text_index;
346 let message = AssistantMessage {
347 content: out_content,
348 api,
349 provider,
350 model: response_model.unwrap_or(model_id),
351 usage,
352 stop_reason: stop,
353 error_message: None,
354 timestamp: now_ms(),
355 };
356 yield Ok(AssistantMessageEvent::Done { reason: stop, message });
357 };
358
359 Ok(s.boxed())
360 }
361}
362
363impl FunctionCall {
366 fn finish_reason_set_to_tool_use(&self) -> bool {
367 true
368 }
369}