1use std::pin::Pin;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use futures_core::Stream;
6use serde::{Deserialize, Serialize};
7
8use crate::auth::{ApiKey, AuthStore};
9use crate::error::{Error, Result};
10use crate::message::{AssistantMessage, ContentBlock, Message, StopReason, ToolResultMessage};
11use crate::model::{Capabilities, Model, ModelMeta, ModelPricing};
12use crate::provider::{Context, Provider, RequestOptions, ThinkingLevel, ToolDefinition};
13use crate::stream::StreamEvent;
14use crate::usage::Usage;
15
16const API_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models";
17const FLASH_MAX_THINKING_BUDGET: i32 = 24_576;
18const PRO_MAX_THINKING_BUDGET: i32 = 32_768;
19
20#[derive(Debug, Serialize)]
25struct ApiRequest {
26 contents: Vec<ApiContent>,
27 #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")]
28 system_instruction: Option<ApiInstruction>,
29 #[serde(skip_serializing_if = "Vec::is_empty")]
30 tools: Vec<ApiTool>,
31 #[serde(rename = "generationConfig")]
32 generation_config: ApiGenerationConfig,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
36struct ApiInstruction {
37 parts: Vec<ApiPart>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct ApiContent {
42 role: String,
43 parts: Vec<ApiPart>,
44}
45
46#[derive(Debug, Clone, Default, Serialize, Deserialize)]
47struct ApiPart {
48 #[serde(skip_serializing_if = "Option::is_none")]
49 text: Option<String>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 thought: Option<bool>,
52 #[serde(rename = "functionCall", skip_serializing_if = "Option::is_none")]
53 function_call: Option<ApiFunctionCall>,
54 #[serde(rename = "functionResponse", skip_serializing_if = "Option::is_none")]
55 function_response: Option<ApiFunctionResponse>,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59struct ApiFunctionCall {
60 #[serde(skip_serializing_if = "Option::is_none")]
61 id: Option<String>,
62 name: String,
63 args: serde_json::Value,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67struct ApiFunctionResponse {
68 #[serde(skip_serializing_if = "Option::is_none")]
69 id: Option<String>,
70 name: String,
71 response: serde_json::Value,
72}
73
74#[derive(Debug, Serialize)]
75struct ApiTool {
76 #[serde(rename = "functionDeclarations")]
77 function_declarations: Vec<ApiFunctionDeclaration>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
81struct ApiFunctionDeclaration {
82 name: String,
83 description: String,
84 parameters: serde_json::Value,
85}
86
87#[derive(Debug, Serialize)]
88struct ApiGenerationConfig {
89 #[serde(rename = "maxOutputTokens", skip_serializing_if = "Option::is_none")]
90 max_output_tokens: Option<u32>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 temperature: Option<f32>,
93 #[serde(rename = "thinkingConfig", skip_serializing_if = "Option::is_none")]
94 thinking_config: Option<ApiThinkingConfig>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
98struct ApiThinkingConfig {
99 #[serde(rename = "includeThoughts")]
100 include_thoughts: bool,
101 #[serde(rename = "thinkingBudget")]
102 thinking_budget: i32,
103}
104
105#[derive(Debug, Clone, Deserialize)]
110struct GenerateContentResponse {
111 #[serde(default)]
112 candidates: Vec<ApiCandidate>,
113 #[serde(rename = "usageMetadata")]
114 usage_metadata: Option<ApiUsageMetadata>,
115}
116
117#[derive(Debug, Clone, Deserialize)]
118struct ApiCandidate {
119 content: Option<ApiContent>,
120 #[serde(rename = "finishReason")]
121 finish_reason: Option<String>,
122}
123
124#[derive(Debug, Clone, Deserialize)]
125struct ApiUsageMetadata {
126 #[serde(rename = "promptTokenCount", default)]
127 prompt_token_count: u32,
128 #[serde(rename = "candidatesTokenCount", default)]
129 candidates_token_count: u32,
130 #[serde(rename = "thoughtsTokenCount", default)]
131 thoughts_token_count: u32,
132 #[serde(rename = "cachedContentTokenCount", default)]
133 cached_content_token_count: u32,
134}
135
136#[derive(Debug, Clone)]
141enum PartState {
142 Text(String),
143 Thinking(String),
144 ToolCall {
145 id: String,
146 name: String,
147 arguments: serde_json::Value,
148 emitted: bool,
149 },
150}
151
152#[derive(Debug)]
153struct StreamState {
154 model: String,
155 started: bool,
156 finished: bool,
157 parts: Vec<PartState>,
158 usage: Usage,
159 finish_reason: Option<String>,
160 saw_tool_call: bool,
161}
162
163impl StreamState {
164 fn new(model: String) -> Self {
165 Self {
166 model,
167 started: false,
168 finished: false,
169 parts: Vec::new(),
170 usage: Usage::default(),
171 finish_reason: None,
172 saw_tool_call: false,
173 }
174 }
175
176 fn ensure_index(&mut self, index: usize) {
177 while self.parts.len() <= index {
178 self.parts.push(PartState::Text(String::new()));
179 }
180 }
181
182 fn stop_reason(&self) -> StopReason {
183 if self.saw_tool_call {
184 return StopReason::ToolUse;
185 }
186
187 match self.finish_reason.as_deref() {
188 Some("STOP") | Some("FINISH_REASON_UNSPECIFIED") | None => StopReason::EndTurn,
189 Some("MAX_TOKENS") => StopReason::MaxTokens,
190 Some(other) => StopReason::Error(other.to_string()),
191 }
192 }
193
194 fn build_message(&self) -> AssistantMessage {
195 let content = self
196 .parts
197 .iter()
198 .filter_map(|part| match part {
199 PartState::Text(text) if !text.is_empty() => {
200 Some(ContentBlock::Text { text: text.clone() })
201 }
202 PartState::Thinking(text) if !text.is_empty() => {
203 Some(ContentBlock::Thinking { text: text.clone() })
204 }
205 PartState::ToolCall {
206 id,
207 name,
208 arguments,
209 ..
210 } => Some(ContentBlock::ToolCall {
211 id: id.clone(),
212 name: name.clone(),
213 arguments: arguments.clone(),
214 }),
215 _ => None,
216 })
217 .collect();
218
219 AssistantMessage {
220 content,
221 usage: Some(self.usage.clone()),
222 stop_reason: self.stop_reason(),
223 timestamp: crate::now(),
224 }
225 }
226}
227
228pub struct GoogleProvider {
230 client: reqwest::Client,
231 models: Vec<ModelMeta>,
232}
233
234impl Default for GoogleProvider {
235 fn default() -> Self {
236 Self::new()
237 }
238}
239
240impl GoogleProvider {
241 pub fn new() -> Self {
242 Self {
243 client: super::streaming_http_client(),
244 models: builtin_models(),
245 }
246 }
247
248 pub fn into_arc(self) -> Arc<Self> {
249 Arc::new(self)
250 }
251}
252
253fn max_thinking_budget(model_id: &str) -> i32 {
258 if model_id.contains("flash") {
259 FLASH_MAX_THINKING_BUDGET
260 } else {
261 PRO_MAX_THINKING_BUDGET
262 }
263}
264
265fn thinking_budget(model: &Model, level: ThinkingLevel) -> Option<i32> {
266 let budget = match level {
267 ThinkingLevel::Off => return None,
268 ThinkingLevel::Minimal => 1024,
269 ThinkingLevel::Low => 4096,
270 ThinkingLevel::Medium => 10_000,
271 ThinkingLevel::High => 24_576,
272 ThinkingLevel::XHigh => max_thinking_budget(&model.meta.id),
273 };
274
275 Some(budget.min(max_thinking_budget(&model.meta.id)))
276}
277
278fn default_max_output_tokens(model: &Model, thinking_budget: Option<i32>) -> u32 {
279 let base = model.meta.max_output_tokens.min(8_192);
280 match thinking_budget {
281 Some(budget) => base.max((budget as u32).saturating_add(1024)),
282 None => base,
283 }
284}
285
286fn build_request(model: &Model, context: Context, options: RequestOptions) -> ApiRequest {
287 let thinking_config =
288 thinking_budget(model, options.thinking_level).map(|thinking_budget| ApiThinkingConfig {
289 include_thoughts: true,
290 thinking_budget,
291 });
292
293 ApiRequest {
294 contents: build_messages(&context.messages),
295 system_instruction: build_system_instruction(&options.system_prompt),
296 tools: build_tools(&options.tools),
297 generation_config: ApiGenerationConfig {
298 max_output_tokens: options.max_tokens.or(Some(default_max_output_tokens(
299 model,
300 thinking_budget(model, options.thinking_level),
301 ))),
302 temperature: options.temperature,
303 thinking_config,
304 },
305 }
306}
307
308fn build_system_instruction(prompt: &str) -> Option<ApiInstruction> {
309 if prompt.is_empty() {
310 return None;
311 }
312
313 Some(ApiInstruction {
314 parts: vec![ApiPart {
315 text: Some(prompt.to_string()),
316 ..Default::default()
317 }],
318 })
319}
320
321fn build_tools(tools: &[ToolDefinition]) -> Vec<ApiTool> {
322 if tools.is_empty() {
323 return Vec::new();
324 }
325
326 vec![ApiTool {
327 function_declarations: tools.iter().map(convert_tool_def).collect(),
328 }]
329}
330
331fn build_messages(messages: &[Message]) -> Vec<ApiContent> {
332 messages.iter().map(convert_message).collect()
333}
334
335fn convert_message(message: &Message) -> ApiContent {
336 match message {
337 Message::User(user) => ApiContent {
338 role: "user".into(),
339 parts: user
340 .content
341 .iter()
342 .filter_map(convert_content_block)
343 .collect(),
344 },
345 Message::Assistant(assistant) => ApiContent {
346 role: "model".into(),
347 parts: assistant
348 .content
349 .iter()
350 .filter_map(convert_content_block)
351 .collect(),
352 },
353 Message::ToolResult(tool_result) => ApiContent {
354 role: "user".into(),
355 parts: vec![ApiPart {
356 function_response: Some(ApiFunctionResponse {
357 id: Some(tool_result.tool_call_id.clone()),
358 name: tool_result.tool_name.clone(),
359 response: convert_tool_result_response(tool_result),
360 }),
361 ..Default::default()
362 }],
363 },
364 }
365}
366
367fn convert_content_block(block: &ContentBlock) -> Option<ApiPart> {
368 match block {
369 ContentBlock::Text { text } => Some(ApiPart {
370 text: Some(text.clone()),
371 ..Default::default()
372 }),
373 ContentBlock::Thinking { text } => Some(ApiPart {
374 text: Some(text.clone()),
375 thought: Some(true),
376 ..Default::default()
377 }),
378 ContentBlock::ToolCall {
379 id,
380 name,
381 arguments,
382 } => Some(ApiPart {
383 function_call: Some(ApiFunctionCall {
384 id: Some(id.clone()),
385 name: name.clone(),
386 args: arguments.clone(),
387 }),
388 ..Default::default()
389 }),
390 ContentBlock::Image { .. } => None,
391 }
392}
393
394fn convert_tool_result_response(tool_result: &ToolResultMessage) -> serde_json::Value {
395 let output = tool_result
396 .content
397 .iter()
398 .filter_map(|block| match block {
399 ContentBlock::Text { text } => Some(text.as_str()),
400 _ => None,
401 })
402 .collect::<Vec<_>>()
403 .join("\n");
404
405 let mut response = serde_json::Map::new();
406 response.insert("result".into(), serde_json::Value::String(output));
407
408 if tool_result.is_error {
409 response.insert("isError".into(), serde_json::Value::Bool(true));
410 }
411
412 if !tool_result.details.is_null() {
413 response.insert("details".into(), tool_result.details.clone());
414 }
415
416 serde_json::Value::Object(response)
417}
418
419fn convert_tool_def(tool: &ToolDefinition) -> ApiFunctionDeclaration {
420 ApiFunctionDeclaration {
421 name: tool.name.clone(),
422 description: tool.description.clone(),
423 parameters: tool.parameters.clone(),
424 }
425}
426
427fn parse_sse_event(data: &str) -> Result<Option<GenerateContentResponse>> {
432 let trimmed = data.trim();
433 if trimmed.is_empty() || trimmed == "[DONE]" {
434 return Ok(None);
435 }
436
437 serde_json::from_str(trimmed)
438 .map(Some)
439 .map_err(|e| Error::Stream(format!("Failed to parse Gemini SSE data: {e}: {trimmed}")))
440}
441
442fn text_delta(previous: &str, current: &str) -> String {
443 current
444 .strip_prefix(previous)
445 .unwrap_or(current)
446 .to_string()
447}
448
449fn update_usage(usage: &ApiUsageMetadata, state: &mut StreamState) {
450 state.usage.input_tokens = usage.prompt_token_count;
451 state.usage.output_tokens = usage.candidates_token_count + usage.thoughts_token_count;
452 state.usage.cache_read_tokens = usage.cached_content_token_count;
453 state.usage.cache_write_tokens = 0;
454}
455
456fn process_response(
457 response: GenerateContentResponse,
458 state: &mut StreamState,
459) -> Vec<StreamEvent> {
460 let mut out = Vec::new();
461
462 if !state.started {
463 state.started = true;
464 out.push(StreamEvent::MessageStart {
465 model: state.model.clone(),
466 });
467 }
468
469 if let Some(usage) = &response.usage_metadata {
470 update_usage(usage, state);
471 }
472
473 if let Some(candidate) = response.candidates.first() {
474 if let Some(content) = &candidate.content {
475 for (index, part) in content.parts.iter().enumerate() {
476 if let Some(function_call) = &part.function_call {
477 state.ensure_index(index);
478 let id = function_call
479 .id
480 .clone()
481 .unwrap_or_else(|| format!("call_{index}"));
482 let name = function_call.name.clone();
483 let arguments = function_call.args.clone();
484
485 let emit = match state.parts.get_mut(index) {
486 Some(PartState::ToolCall {
487 id: existing_id,
488 name: existing_name,
489 arguments: existing_arguments,
490 emitted,
491 }) if *existing_id == id && *existing_name == name => {
492 *existing_arguments = arguments.clone();
493 if *emitted {
494 false
495 } else {
496 *emitted = true;
497 true
498 }
499 }
500 Some(slot) => {
501 *slot = PartState::ToolCall {
502 id: id.clone(),
503 name: name.clone(),
504 arguments: arguments.clone(),
505 emitted: true,
506 };
507 true
508 }
509 None => false,
510 };
511
512 state.saw_tool_call = true;
513 if emit {
514 out.push(StreamEvent::ToolCall {
515 id,
516 name,
517 arguments,
518 });
519 }
520 continue;
521 }
522
523 let Some(text) = part.text.as_deref() else {
524 continue;
525 };
526
527 state.ensure_index(index);
528 if part.thought.unwrap_or(false) {
529 let previous = match &state.parts[index] {
530 PartState::Thinking(existing) => existing.clone(),
531 _ => String::new(),
532 };
533 state.parts[index] = PartState::Thinking(text.to_string());
534 let delta = text_delta(&previous, text);
535 if !delta.is_empty() {
536 out.push(StreamEvent::ThinkingDelta { text: delta });
537 }
538 } else {
539 let previous = match &state.parts[index] {
540 PartState::Text(existing) => existing.clone(),
541 _ => String::new(),
542 };
543 state.parts[index] = PartState::Text(text.to_string());
544 let delta = text_delta(&previous, text);
545 if !delta.is_empty() {
546 out.push(StreamEvent::TextDelta { text: delta });
547 }
548 }
549 }
550 }
551
552 if let Some(reason) = &candidate.finish_reason {
553 state.finish_reason = Some(reason.clone());
554 }
555
556 if candidate.finish_reason.is_some() && !state.finished {
557 state.finished = true;
558 out.push(StreamEvent::MessageEnd {
559 message: state.build_message(),
560 });
561 }
562 }
563
564 out
565}
566
567#[cfg(test)]
568fn parse_sse_stream(raw: &str, state: &mut StreamState) -> Vec<Result<StreamEvent>> {
569 let mut events = Vec::new();
570
571 for line in raw.lines() {
572 let trimmed = line.trim();
573 if let Some(data) = trimmed.strip_prefix("data: ") {
574 match parse_sse_event(data) {
575 Ok(Some(response)) => {
576 for event in process_response(response, state) {
577 events.push(Ok(event));
578 }
579 }
580 Ok(None) => {}
581 Err(error) => events.push(Err(error)),
582 }
583 }
584 }
585
586 events
587}
588
589fn stream_response(
594 client: reqwest::Client,
595 model_id: String,
596 api_key: String,
597 request: ApiRequest,
598) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>> {
599 let (tx, rx) = futures::channel::mpsc::unbounded();
600
601 tokio::spawn(async move {
602 let result = client
603 .post(format!("{API_BASE_URL}/{model_id}:streamGenerateContent"))
604 .query(&[("alt", "sse"), ("key", api_key.as_str())])
605 .header("content-type", "application/json")
606 .json(&request)
607 .send()
608 .await;
609
610 let response = match result {
611 Ok(response) => response,
612 Err(error) => {
613 let _ = tx.unbounded_send(Err(Error::Http(error)));
614 return;
615 }
616 };
617
618 let status = response.status();
619 if !status.is_success() {
620 let body = response.text().await.unwrap_or_default();
621 let _ = tx.unbounded_send(Err(Error::Provider(format!("HTTP {status}: {body}"))));
622 return;
623 }
624
625 let mut state = StreamState::new(model_id);
626 let mut buffer = String::new();
627 let mut byte_stream = response.bytes_stream();
628
629 use futures::StreamExt;
630 while let Some(chunk) = byte_stream.next().await {
631 match chunk {
632 Ok(bytes) => {
633 buffer.push_str(&String::from_utf8_lossy(&bytes));
634
635 while let Some(pos) = buffer.find('\n') {
636 let line = buffer[..pos].to_string();
637 buffer = buffer[pos + 1..].to_string();
638
639 let trimmed = line.trim();
640 if let Some(data) = trimmed.strip_prefix("data: ") {
641 match parse_sse_event(data) {
642 Ok(Some(response)) => {
643 for event in process_response(response, &mut state) {
644 if tx.unbounded_send(Ok(event)).is_err() {
645 return;
646 }
647 }
648 }
649 Ok(None) => {}
650 Err(error) => {
651 if tx.unbounded_send(Err(error)).is_err() {
652 return;
653 }
654 }
655 }
656 }
657 }
658 }
659 Err(error) => {
660 let _ = tx.unbounded_send(Err(Error::Http(error)));
661 return;
662 }
663 }
664 }
665
666 let trimmed = buffer.trim();
667 if let Some(data) = trimmed.strip_prefix("data: ") {
668 match parse_sse_event(data) {
669 Ok(Some(response)) => {
670 for event in process_response(response, &mut state) {
671 if tx.unbounded_send(Ok(event)).is_err() {
672 return;
673 }
674 }
675 }
676 Ok(None) => {}
677 Err(error) => {
678 let _ = tx.unbounded_send(Err(error));
679 return;
680 }
681 }
682 }
683
684 if !state.finished {
685 let _ = tx.unbounded_send(Err(Error::Stream(
686 "Google stream ended before terminal finishReason".into(),
687 )));
688 }
689 });
690
691 Box::pin(rx)
692}
693
694#[async_trait]
695impl Provider for GoogleProvider {
696 fn stream(
697 &self,
698 model: &Model,
699 context: Context,
700 options: RequestOptions,
701 api_key: &str,
702 ) -> Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>> {
703 let request = build_request(model, context, options);
704 stream_response(
705 self.client.clone(),
706 model.meta.id.clone(),
707 api_key.to_string(),
708 request,
709 )
710 }
711
712 async fn resolve_auth(&self, auth: &AuthStore) -> Result<ApiKey> {
713 auth.resolve("google")
714 }
715
716 fn id(&self) -> &str {
717 "google"
718 }
719
720 fn models(&self) -> &[ModelMeta] {
721 &self.models
722 }
723}
724
725fn builtin_models() -> Vec<ModelMeta> {
726 vec![
727 ModelMeta {
728 id: "gemini-2.5-pro".into(),
729 provider: "google".into(),
730 name: "Gemini 2.5 Pro".into(),
731 context_window: 1_048_576,
732 max_output_tokens: 65_536,
733 pricing: ModelPricing {
734 input_per_mtok: 1.25,
735 output_per_mtok: 10.0,
736 cache_read_per_mtok: 0.315,
737 cache_write_per_mtok: 1.25,
738 },
739 capabilities: Capabilities {
740 reasoning: true,
741 images: true,
742 tool_use: true,
743 },
744 },
745 ModelMeta {
746 id: "gemini-2.5-flash".into(),
747 provider: "google".into(),
748 name: "Gemini 2.5 Flash".into(),
749 context_window: 1_048_576,
750 max_output_tokens: 65_536,
751 pricing: ModelPricing {
752 input_per_mtok: 0.15,
753 output_per_mtok: 3.5,
754 cache_read_per_mtok: 0.0375,
755 cache_write_per_mtok: 0.15,
756 },
757 capabilities: Capabilities {
758 reasoning: true,
759 images: true,
760 tool_use: true,
761 },
762 },
763 ]
764}
765
766#[cfg(test)]
767mod tests {
768 use super::*;
769 use crate::message::UserMessage;
770
771 fn test_model(id: &str) -> Model {
772 let provider = GoogleProvider::new();
773 Model {
774 meta: builtin_models()
775 .into_iter()
776 .find(|meta| meta.id == id)
777 .expect("test model should exist"),
778 provider: provider.into_arc(),
779 }
780 }
781
782 #[test]
783 fn serialize_text_user_message() {
784 let message = Message::User(UserMessage {
785 content: vec![ContentBlock::Text {
786 text: "Hello Gemini".into(),
787 }],
788 timestamp: 0,
789 });
790
791 let api = convert_message(&message);
792 let json = serde_json::to_value(&api).unwrap();
793
794 assert_eq!(json["role"], "user");
795 assert_eq!(json["parts"][0]["text"], "Hello Gemini");
796 }
797
798 #[test]
799 fn serialize_assistant_tool_call_block() {
800 let message = Message::Assistant(AssistantMessage {
801 content: vec![ContentBlock::ToolCall {
802 id: "call_1".into(),
803 name: "bash".into(),
804 arguments: serde_json::json!({"command": "ls"}),
805 }],
806 usage: None,
807 stop_reason: StopReason::ToolUse,
808 timestamp: 0,
809 });
810
811 let api = convert_message(&message);
812 let json = serde_json::to_value(&api).unwrap();
813
814 assert_eq!(json["role"], "model");
815 assert_eq!(json["parts"][0]["functionCall"]["id"], "call_1");
816 assert_eq!(json["parts"][0]["functionCall"]["name"], "bash");
817 assert_eq!(json["parts"][0]["functionCall"]["args"]["command"], "ls");
818 }
819
820 #[test]
821 fn serialize_tool_result_message() {
822 let message = Message::ToolResult(ToolResultMessage {
823 tool_call_id: "call_1".into(),
824 tool_name: "bash".into(),
825 content: vec![ContentBlock::Text {
826 text: "README.md\nsrc/".into(),
827 }],
828 is_error: false,
829 details: serde_json::json!({"cwd": "/tmp"}),
830 timestamp: 0,
831 });
832
833 let api = convert_message(&message);
834 let json = serde_json::to_value(&api).unwrap();
835
836 assert_eq!(json["role"], "user");
837 assert_eq!(json["parts"][0]["functionResponse"]["id"], "call_1");
838 assert_eq!(json["parts"][0]["functionResponse"]["name"], "bash");
839 assert_eq!(
840 json["parts"][0]["functionResponse"]["response"]["result"],
841 "README.md\nsrc/"
842 );
843 assert_eq!(
844 json["parts"][0]["functionResponse"]["response"]["details"]["cwd"],
845 "/tmp"
846 );
847 }
848
849 #[test]
850 fn thinking_budget_mapping_matches_model_limits() {
851 let pro = test_model("gemini-2.5-pro");
852 let flash = test_model("gemini-2.5-flash");
853
854 assert_eq!(thinking_budget(&pro, ThinkingLevel::Off), None);
855 assert_eq!(thinking_budget(&pro, ThinkingLevel::Minimal), Some(1024));
856 assert_eq!(thinking_budget(&pro, ThinkingLevel::Low), Some(4096));
857 assert_eq!(thinking_budget(&pro, ThinkingLevel::Medium), Some(10_000));
858 assert_eq!(thinking_budget(&pro, ThinkingLevel::High), Some(24_576));
859 assert_eq!(thinking_budget(&pro, ThinkingLevel::XHigh), Some(32_768));
860 assert_eq!(thinking_budget(&flash, ThinkingLevel::XHigh), Some(24_576));
861 }
862
863 #[test]
864 fn default_max_output_tokens_caps_google_models_without_thinking() {
865 let pro = test_model("gemini-2.5-pro");
866 assert_eq!(default_max_output_tokens(&pro, None), 8_192);
867 }
868
869 #[test]
870 fn default_max_output_tokens_grows_for_google_thinking_budget() {
871 let pro = test_model("gemini-2.5-pro");
872 assert_eq!(default_max_output_tokens(&pro, Some(24_576)), 25_600);
873 }
874
875 #[test]
876 fn build_request_serializes_system_tools_and_thinking() {
877 let model = test_model("gemini-2.5-pro");
878 let context = Context {
879 messages: vec![
880 Message::user("List the files in this directory."),
881 Message::Assistant(AssistantMessage {
882 content: vec![ContentBlock::ToolCall {
883 id: "call_1".into(),
884 name: "bash".into(),
885 arguments: serde_json::json!({"command": "ls"}),
886 }],
887 usage: None,
888 stop_reason: StopReason::ToolUse,
889 timestamp: 0,
890 }),
891 Message::ToolResult(ToolResultMessage {
892 tool_call_id: "call_1".into(),
893 tool_name: "bash".into(),
894 content: vec![ContentBlock::Text {
895 text: "Cargo.toml\nsrc/".into(),
896 }],
897 is_error: false,
898 details: serde_json::Value::Null,
899 timestamp: 0,
900 }),
901 ],
902 };
903 let options = RequestOptions {
904 system_prompt: "You are a helpful coding assistant.".into(),
905 max_tokens: Some(2048),
906 temperature: Some(0.2),
907 thinking_level: ThinkingLevel::High,
908 tools: vec![ToolDefinition {
909 name: "bash".into(),
910 description: "Run a shell command".into(),
911 parameters: serde_json::json!({
912 "type": "object",
913 "properties": {
914 "command": { "type": "string" }
915 },
916 "required": ["command"]
917 }),
918 }],
919 ..Default::default()
920 };
921
922 let request = build_request(&model, context, options);
923 let json = serde_json::to_value(&request).unwrap();
924
925 assert_eq!(
926 json["systemInstruction"]["parts"][0]["text"],
927 "You are a helpful coding assistant."
928 );
929 assert_eq!(json["contents"].as_array().unwrap().len(), 3);
930 assert_eq!(json["contents"][0]["role"], "user");
931 assert_eq!(json["contents"][1]["role"], "model");
932 assert_eq!(
933 json["contents"][1]["parts"][0]["functionCall"]["name"],
934 "bash"
935 );
936 assert_eq!(
937 json["contents"][2]["parts"][0]["functionResponse"]["name"],
938 "bash"
939 );
940 assert_eq!(json["tools"][0]["functionDeclarations"][0]["name"], "bash");
941 assert_eq!(json["generationConfig"]["maxOutputTokens"], 2048);
942 assert!(
943 (json["generationConfig"]["temperature"]
944 .as_f64()
945 .expect("temperature should be numeric")
946 - 0.2)
947 .abs()
948 < 1e-6
949 );
950 assert_eq!(
951 json["generationConfig"]["thinkingConfig"]["includeThoughts"],
952 true
953 );
954 assert_eq!(
955 json["generationConfig"]["thinkingConfig"]["thinkingBudget"],
956 24_576
957 );
958 }
959
960 #[test]
961 fn parse_text_and_thinking_deltas() {
962 let raw = "\
963 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"thought\":true,\"text\":\"Plan\"}]}}]}\n\
964 \n\
965 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"thought\":true,\"text\":\"Planning\"},{\"text\":\"Answer\"}]}}]}\n\
966 \n\
967 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"thought\":true,\"text\":\"Planning\"},{\"text\":\"Answer done\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":10,\"candidatesTokenCount\":5,\"thoughtsTokenCount\":3}}\n";
968
969 let mut state = StreamState::new("gemini-2.5-pro".into());
970 let events = parse_sse_stream(raw, &mut state);
971 let events: Vec<_> = events
972 .into_iter()
973 .collect::<std::result::Result<Vec<_>, _>>()
974 .unwrap();
975
976 assert!(
977 matches!(&events[0], StreamEvent::MessageStart { model } if model == "gemini-2.5-pro")
978 );
979 assert!(matches!(&events[1], StreamEvent::ThinkingDelta { text } if text == "Plan"));
980 assert!(matches!(&events[2], StreamEvent::ThinkingDelta { text } if text == "ning"));
981 assert!(matches!(&events[3], StreamEvent::TextDelta { text } if text == "Answer"));
982 assert!(matches!(&events[4], StreamEvent::TextDelta { text } if text == " done"));
983 assert!(
984 matches!(&events[5], StreamEvent::MessageEnd { message } if message.stop_reason == StopReason::EndTurn)
985 );
986
987 if let StreamEvent::MessageEnd { message } = &events[5] {
988 assert_eq!(message.usage.as_ref().unwrap().input_tokens, 10);
989 assert_eq!(message.usage.as_ref().unwrap().output_tokens, 8);
990 assert_eq!(message.content.len(), 2);
991 } else {
992 panic!("expected MessageEnd");
993 }
994 }
995
996 #[test]
997 fn parse_tool_call_response() {
998 let raw = "\
999 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"functionCall\":{\"id\":\"call_1\",\"name\":\"read\",\"args\":{\"path\":\"src/lib.rs\"}}}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":12,\"candidatesTokenCount\":4}}\n";
1000
1001 let mut state = StreamState::new("gemini-2.5-pro".into());
1002 let events = parse_sse_stream(raw, &mut state);
1003 let events: Vec<_> = events
1004 .into_iter()
1005 .collect::<std::result::Result<Vec<_>, _>>()
1006 .unwrap();
1007
1008 assert_eq!(events.len(), 3);
1009 assert!(matches!(&events[0], StreamEvent::MessageStart { .. }));
1010 assert!(
1011 matches!(&events[1], StreamEvent::ToolCall { id, name, arguments } if id == "call_1" && name == "read" && arguments["path"] == "src/lib.rs")
1012 );
1013 assert!(
1014 matches!(&events[2], StreamEvent::MessageEnd { message } if message.stop_reason == StopReason::ToolUse)
1015 );
1016 }
1017
1018 #[test]
1019 fn parse_invalid_sse_event_returns_error() {
1020 let error = parse_sse_event("not json").unwrap_err();
1021 assert!(matches!(error, Error::Stream(_)));
1022 }
1023
1024 #[test]
1025 fn builtin_models_include_flash_and_pro() {
1026 let models = builtin_models();
1027 assert_eq!(models.len(), 2);
1028 assert!(models.iter().any(|model| model.id == "gemini-2.5-pro"));
1029 assert!(models.iter().any(|model| model.id == "gemini-2.5-flash"));
1030 }
1031
1032 #[test]
1033 fn parse_multi_part_response_text_and_tool_call() {
1034 let raw = "\
1036 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"Let me check\"},{\"functionCall\":{\"id\":\"call_1\",\"name\":\"read\",\"args\":{\"path\":\"a.rs\"}}}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":8,\"candidatesTokenCount\":6}}\n";
1037
1038 let mut state = StreamState::new("gemini-2.5-pro".into());
1039 let events = parse_sse_stream(raw, &mut state);
1040 let events: Vec<_> = events
1041 .into_iter()
1042 .collect::<std::result::Result<Vec<_>, _>>()
1043 .unwrap();
1044
1045 assert_eq!(events.len(), 4);
1047 assert!(matches!(&events[0], StreamEvent::MessageStart { .. }));
1048 assert!(matches!(&events[1], StreamEvent::TextDelta { text } if text == "Let me check"));
1049 assert!(matches!(&events[2], StreamEvent::ToolCall { name, .. } if name == "read"));
1050 if let StreamEvent::MessageEnd { message } = &events[3] {
1051 assert_eq!(message.stop_reason, StopReason::ToolUse);
1052 } else {
1053 panic!("expected MessageEnd");
1054 }
1055 }
1056
1057 #[test]
1058 fn parse_usage_metadata_extraction() {
1059 let raw = "\
1060 data: {\"candidates\":[{\"content\":{\"role\":\"model\",\"parts\":[{\"text\":\"Hi\"}]},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":42,\"candidatesTokenCount\":10,\"thoughtsTokenCount\":5,\"cachedContentTokenCount\":3}}\n";
1061
1062 let mut state = StreamState::new("gemini-2.5-pro".into());
1063 let events = parse_sse_stream(raw, &mut state);
1064 let events: Vec<_> = events
1065 .into_iter()
1066 .collect::<std::result::Result<Vec<_>, _>>()
1067 .unwrap();
1068
1069 if let StreamEvent::MessageEnd { message } = events.last().unwrap() {
1070 let usage = message.usage.as_ref().unwrap();
1071 assert_eq!(usage.input_tokens, 42);
1072 assert_eq!(usage.output_tokens, 15); assert_eq!(usage.cache_read_tokens, 3);
1074 } else {
1075 panic!("expected MessageEnd");
1076 }
1077 }
1078
1079 #[test]
1080 fn stop_reason_mapping() {
1081 let mut state = StreamState::new("test".into());
1082 state.finish_reason = Some("STOP".into());
1083 assert_eq!(state.stop_reason(), StopReason::EndTurn);
1084
1085 state.finish_reason = Some("MAX_TOKENS".into());
1086 assert_eq!(state.stop_reason(), StopReason::MaxTokens);
1087
1088 state.finish_reason = Some("SAFETY".into());
1089 assert_eq!(state.stop_reason(), StopReason::Error("SAFETY".into()));
1090
1091 state.finish_reason = None;
1092 assert_eq!(state.stop_reason(), StopReason::EndTurn);
1093
1094 state.saw_tool_call = true;
1095 assert_eq!(state.stop_reason(), StopReason::ToolUse);
1096 }
1097
1098 #[test]
1099 fn empty_candidates_produces_no_content_events() {
1100 let raw = "\
1101 data: {\"candidates\":[],\"usageMetadata\":{\"promptTokenCount\":5,\"candidatesTokenCount\":0}}\n";
1102
1103 let mut state = StreamState::new("gemini-2.5-pro".into());
1104 let events = parse_sse_stream(raw, &mut state);
1105 let events: Vec<_> = events
1106 .into_iter()
1107 .collect::<std::result::Result<Vec<_>, _>>()
1108 .unwrap();
1109
1110 assert_eq!(events.len(), 1);
1112 assert!(matches!(&events[0], StreamEvent::MessageStart { .. }));
1113 }
1114
1115 #[test]
1116 fn parse_sse_event_done_marker_returns_none() {
1117 let result = parse_sse_event("[DONE]").unwrap();
1118 assert!(result.is_none());
1119 }
1120
1121 #[test]
1122 fn empty_system_prompt_produces_no_instruction() {
1123 let instruction = build_system_instruction("");
1124 assert!(instruction.is_none());
1125 }
1126
1127 #[test]
1128 fn empty_tools_produces_empty_vec() {
1129 let tools = build_tools(&[]);
1130 assert!(tools.is_empty());
1131 }
1132}