1use std::time::Duration;
16
17use reqwest::Client;
18use serde::{Deserialize, Serialize};
19
20mod embedding;
21
22#[cfg(feature = "pubsub")]
23pub mod pubsub;
24
25pub use embedding::GeminiEmbedding;
26
27#[cfg(feature = "pubsub")]
28pub use pubsub::PubSubBroker;
29
30use daimon_core::{
31 ChatRequest, ChatResponse, DaimonError, Message, Model, ResponseStream, Result, Role,
32 StopReason, StreamEvent, ToolCall, ToolSpec, Usage,
33};
34
35const DEFAULT_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
36const DEFAULT_MAX_RETRIES: u32 = 3;
37
38fn build_client(timeout: Option<Duration>) -> Client {
39 let mut builder = Client::builder();
40 if let Some(t) = timeout {
41 builder = builder.timeout(t);
42 }
43 builder.build().expect("failed to build HTTP client")
44}
45
46#[derive(Debug)]
52pub struct Gemini {
53 client: Client,
54 api_key: String,
55 model_id: String,
56 base_url: String,
57 timeout: Option<Duration>,
58 max_retries: u32,
59 use_bearer_token: bool,
60 cached_content: Option<String>,
61}
62
63impl Gemini {
64 pub fn new(model_id: impl Into<String>) -> Self {
66 let api_key = std::env::var("GOOGLE_API_KEY").unwrap_or_default();
67 Self::with_api_key(model_id, api_key)
68 }
69
70 pub fn with_api_key(model_id: impl Into<String>, api_key: impl Into<String>) -> Self {
72 Self {
73 client: build_client(None),
74 api_key: api_key.into(),
75 model_id: model_id.into(),
76 base_url: DEFAULT_BASE_URL.to_string(),
77 timeout: None,
78 max_retries: DEFAULT_MAX_RETRIES,
79 use_bearer_token: false,
80 cached_content: None,
81 }
82 }
83
84 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
86 self.base_url = url.into();
87 self
88 }
89
90 pub fn with_timeout(mut self, timeout: Duration) -> Self {
92 self.timeout = Some(timeout);
93 self.client = build_client(Some(timeout));
94 self
95 }
96
97 pub fn with_max_retries(mut self, retries: u32) -> Self {
99 self.max_retries = retries;
100 self
101 }
102
103 pub fn with_bearer_token(mut self) -> Self {
107 self.use_bearer_token = true;
108 self
109 }
110
111 pub fn with_cached_content(mut self, name: impl Into<String>) -> Self {
116 self.cached_content = Some(name.into());
117 self
118 }
119
120 fn endpoint_url(&self, method: &str) -> String {
121 format!(
122 "{}/models/{}:{}",
123 self.base_url, self.model_id, method
124 )
125 }
126
127 fn apply_auth(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
128 if self.use_bearer_token {
129 req.bearer_auth(&self.api_key)
130 } else {
131 req.query(&[("key", &self.api_key)])
132 }
133 }
134
135 fn build_request_body(&self, request: &ChatRequest) -> GeminiRequest {
136 let mut system_instruction = None;
137 let mut contents = Vec::new();
138
139 for msg in &request.messages {
140 match msg.role {
141 Role::System => {
142 if let Some(text) = &msg.content {
143 system_instruction = Some(GeminiContent {
144 role: "user".to_string(),
145 parts: vec![GeminiPart::Text {
146 text: text.clone(),
147 }],
148 });
149 }
150 }
151 Role::User => {
152 if let Some(text) = &msg.content {
153 contents.push(GeminiContent {
154 role: "user".to_string(),
155 parts: vec![GeminiPart::Text {
156 text: text.clone(),
157 }],
158 });
159 }
160 }
161 Role::Assistant => {
162 if !msg.tool_calls.is_empty() {
163 let parts = msg
164 .tool_calls
165 .iter()
166 .map(|tc| GeminiPart::FunctionCall {
167 function_call: GeminiFunctionCall {
168 name: tc.name.clone(),
169 args: tc.arguments.clone(),
170 },
171 })
172 .collect();
173 contents.push(GeminiContent {
174 role: "model".to_string(),
175 parts,
176 });
177 } else if let Some(text) = &msg.content {
178 contents.push(GeminiContent {
179 role: "model".to_string(),
180 parts: vec![GeminiPart::Text {
181 text: text.clone(),
182 }],
183 });
184 }
185 }
186 Role::Tool => {
187 let name = msg.tool_call_id.clone().unwrap_or_default();
188 let content = msg.content.clone().unwrap_or_default();
189 let response_value: serde_json::Value =
190 serde_json::from_str(&content).unwrap_or_else(|_| {
191 serde_json::json!({ "result": content })
192 });
193 contents.push(GeminiContent {
194 role: "user".to_string(),
195 parts: vec![GeminiPart::FunctionResponse {
196 function_response: GeminiFunctionResponse {
197 name,
198 response: response_value,
199 },
200 }],
201 });
202 }
203 }
204 }
205
206 let tools = if request.tools.is_empty() {
207 None
208 } else {
209 let declarations: Vec<GeminiFunctionDeclaration> =
210 request.tools.iter().map(Into::into).collect();
211 Some(vec![GeminiToolConfig {
212 function_declarations: declarations,
213 }])
214 };
215
216 let generation_config = Some(GeminiGenerationConfig {
217 temperature: request.temperature,
218 max_output_tokens: request.max_tokens,
219 });
220
221 GeminiRequest {
222 cached_content: self.cached_content.clone(),
223 system_instruction,
224 contents,
225 tools,
226 generation_config,
227 }
228 }
229}
230
231impl Model for Gemini {
232 #[tracing::instrument(skip_all, fields(model = %self.model_id))]
233 async fn generate(&self, request: &ChatRequest) -> Result<ChatResponse> {
234 let body = self.build_request_body(request);
235 let url = self.endpoint_url("generateContent");
236
237 for attempt in 0..=self.max_retries {
238 let req = self.client.post(&url).json(&body);
239 let req = self.apply_auth(req);
240
241 tracing::debug!(attempt, "sending Gemini generateContent request");
242 let response = req
243 .send()
244 .await
245 .map_err(|e| DaimonError::Model(format!("Gemini HTTP error: {e}")))?;
246 let status = response.status();
247
248 if status.is_success() {
249 let api_resp: GeminiResponse = response
250 .json()
251 .await
252 .map_err(|e| DaimonError::Model(format!("Gemini response parse error: {e}")))?;
253 tracing::debug!("received successful Gemini response");
254 return parse_response(api_resp);
255 }
256
257 let text = response.text().await.unwrap_or_default();
258 let is_retryable = status.as_u16() == 429 || status.is_server_error();
259
260 if is_retryable && attempt < self.max_retries {
261 let delay_ms = 100 * 2u64.pow(attempt);
262 tracing::debug!(status = %status, attempt, delay_ms, "retryable error, backing off");
263 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
264 } else {
265 return Err(DaimonError::Model(format!(
266 "Gemini API error ({status}): {text}"
267 )));
268 }
269 }
270
271 unreachable!("loop always returns or retries")
272 }
273
274 #[tracing::instrument(skip_all, fields(model = %self.model_id))]
275 async fn generate_stream(&self, request: &ChatRequest) -> Result<ResponseStream> {
276 let body = self.build_request_body(request);
277 let url = self.endpoint_url("streamGenerateContent");
278
279 let req = self
280 .client
281 .post(&url)
282 .query(&[("alt", "sse")])
283 .json(&body);
284 let req = self.apply_auth(req);
285
286 tracing::debug!("sending Gemini streaming request");
287 let response = req
288 .send()
289 .await
290 .map_err(|e| DaimonError::Model(format!("Gemini HTTP error: {e}")))?;
291
292 if !response.status().is_success() {
293 let status = response.status();
294 let text = response.text().await.unwrap_or_default();
295 return Err(DaimonError::Model(format!(
296 "Gemini API error ({status}): {text}"
297 )));
298 }
299
300 tracing::debug!("Gemini stream established");
301 let byte_stream = response.bytes_stream();
302
303 let stream = async_stream::try_stream! {
304 use futures::StreamExt;
305
306 let mut buffer = String::new();
307 let mut stream = Box::pin(byte_stream);
308
309 while let Some(chunk) = stream.next().await {
310 let chunk = chunk.map_err(|e| DaimonError::Model(format!("Gemini stream error: {e}")))?;
311 buffer.push_str(&String::from_utf8_lossy(&chunk));
312
313 while let Some(line_end) = buffer.find('\n') {
314 let line = buffer[..line_end].trim().to_string();
315 buffer = buffer[line_end + 1..].to_string();
316
317 if line.is_empty() {
318 continue;
319 }
320
321 if let Some(data) = line.strip_prefix("data: ") {
322 if let Ok(chunk_resp) = serde_json::from_str::<GeminiResponse>(data) {
323 for candidate in &chunk_resp.candidates {
324 for part in &candidate.content.parts {
325 match part {
326 GeminiResponsePart::Text { text } => {
327 if !text.is_empty() {
328 yield StreamEvent::TextDelta(text.clone());
329 }
330 }
331 GeminiResponsePart::FunctionCall { function_call } => {
332 let id = format!("gemini_{}", function_call.name);
333 yield StreamEvent::ToolCallStart {
334 id: id.clone(),
335 name: function_call.name.clone(),
336 };
337 let args = serde_json::to_string(&function_call.args)
338 .unwrap_or_default();
339 yield StreamEvent::ToolCallDelta {
340 id: id.clone(),
341 arguments_delta: args,
342 };
343 yield StreamEvent::ToolCallEnd { id };
344 }
345 }
346 }
347 }
348
349 let is_done = chunk_resp.candidates.iter().any(|c| {
350 c.finish_reason.as_deref() == Some("STOP")
351 || c.finish_reason.as_deref() == Some("MAX_TOKENS")
352 });
353 if is_done {
354 yield StreamEvent::Done;
355 }
356 }
357 }
358 }
359 }
360 };
361
362 Ok(Box::pin(stream))
363 }
364}
365
366fn parse_response(response: GeminiResponse) -> Result<ChatResponse> {
367 let candidate = response
368 .candidates
369 .into_iter()
370 .next()
371 .ok_or_else(|| DaimonError::Model("no candidates in Gemini response".into()))?;
372
373 let mut text_content = String::new();
374 let mut tool_calls = Vec::new();
375
376 for part in candidate.content.parts {
377 match part {
378 GeminiResponsePart::Text { text } => {
379 text_content.push_str(&text);
380 }
381 GeminiResponsePart::FunctionCall { function_call } => {
382 tool_calls.push(ToolCall {
383 id: format!("gemini_{}", function_call.name),
384 name: function_call.name,
385 arguments: function_call.args,
386 });
387 }
388 }
389 }
390
391 let stop_reason = if !tool_calls.is_empty() {
392 StopReason::ToolUse
393 } else {
394 match candidate.finish_reason.as_deref() {
395 Some("MAX_TOKENS") => StopReason::MaxTokens,
396 _ => StopReason::EndTurn,
397 }
398 };
399
400 let message = if tool_calls.is_empty() {
401 Message::assistant(text_content)
402 } else {
403 Message {
404 role: Role::Assistant,
405 content: if text_content.is_empty() {
406 None
407 } else {
408 Some(text_content)
409 },
410 tool_calls,
411 tool_call_id: None,
412 }
413 };
414
415 Ok(ChatResponse {
416 message,
417 stop_reason,
418 usage: response.usage_metadata.map(|u| Usage {
419 input_tokens: u.prompt_token_count,
420 output_tokens: u.candidates_token_count,
421 cached_tokens: u.cached_content_token_count,
422 }),
423 })
424}
425
426#[derive(Serialize)]
429#[serde(rename_all = "camelCase")]
430struct GeminiRequest {
431 #[serde(skip_serializing_if = "Option::is_none")]
432 cached_content: Option<String>,
433 #[serde(skip_serializing_if = "Option::is_none")]
434 system_instruction: Option<GeminiContent>,
435 contents: Vec<GeminiContent>,
436 #[serde(skip_serializing_if = "Option::is_none")]
437 tools: Option<Vec<GeminiToolConfig>>,
438 #[serde(skip_serializing_if = "Option::is_none")]
439 generation_config: Option<GeminiGenerationConfig>,
440}
441
442#[derive(Serialize)]
443struct GeminiContent {
444 role: String,
445 parts: Vec<GeminiPart>,
446}
447
448#[derive(Serialize)]
449#[serde(untagged)]
450enum GeminiPart {
451 Text {
452 text: String,
453 },
454 FunctionCall {
455 #[serde(rename = "functionCall")]
456 function_call: GeminiFunctionCall,
457 },
458 FunctionResponse {
459 #[serde(rename = "functionResponse")]
460 function_response: GeminiFunctionResponse,
461 },
462}
463
464#[derive(Serialize)]
465struct GeminiFunctionCall {
466 name: String,
467 args: serde_json::Value,
468}
469
470#[derive(Serialize)]
471struct GeminiFunctionResponse {
472 name: String,
473 response: serde_json::Value,
474}
475
476#[derive(Serialize)]
477#[serde(rename_all = "camelCase")]
478struct GeminiToolConfig {
479 function_declarations: Vec<GeminiFunctionDeclaration>,
480}
481
482#[derive(Serialize)]
483struct GeminiFunctionDeclaration {
484 name: String,
485 description: String,
486 parameters: serde_json::Value,
487}
488
489impl From<&ToolSpec> for GeminiFunctionDeclaration {
490 fn from(spec: &ToolSpec) -> Self {
491 Self {
492 name: spec.name.clone(),
493 description: spec.description.clone(),
494 parameters: spec.parameters.clone(),
495 }
496 }
497}
498
499#[derive(Serialize)]
500#[serde(rename_all = "camelCase")]
501struct GeminiGenerationConfig {
502 #[serde(skip_serializing_if = "Option::is_none")]
503 temperature: Option<f32>,
504 #[serde(skip_serializing_if = "Option::is_none")]
505 max_output_tokens: Option<u32>,
506}
507
508#[derive(Deserialize)]
511#[serde(rename_all = "camelCase")]
512struct GeminiResponse {
513 #[serde(default)]
514 candidates: Vec<GeminiCandidate>,
515 usage_metadata: Option<GeminiUsageMetadata>,
516}
517
518#[derive(Deserialize)]
519#[serde(rename_all = "camelCase")]
520struct GeminiCandidate {
521 content: GeminiResponseContent,
522 finish_reason: Option<String>,
523}
524
525#[derive(Deserialize)]
526struct GeminiResponseContent {
527 #[serde(default)]
528 parts: Vec<GeminiResponsePart>,
529}
530
531#[derive(Deserialize)]
532#[serde(untagged)]
533enum GeminiResponsePart {
534 FunctionCall {
535 #[serde(rename = "functionCall")]
536 function_call: GeminiResponseFunctionCall,
537 },
538 Text {
539 text: String,
540 },
541}
542
543#[derive(Deserialize)]
544struct GeminiResponseFunctionCall {
545 name: String,
546 args: serde_json::Value,
547}
548
549#[derive(Deserialize)]
550#[serde(rename_all = "camelCase")]
551struct GeminiUsageMetadata {
552 #[serde(default)]
553 prompt_token_count: u32,
554 #[serde(default)]
555 candidates_token_count: u32,
556 #[serde(default)]
557 cached_content_token_count: u32,
558}
559
560#[cfg(test)]
561mod tests {
562 use super::*;
563
564 #[test]
565 fn test_gemini_new_default() {
566 let model = Gemini::new("gemini-2.0-flash");
567 assert_eq!(model.model_id, "gemini-2.0-flash");
568 assert_eq!(model.base_url, DEFAULT_BASE_URL);
569 assert_eq!(model.max_retries, DEFAULT_MAX_RETRIES);
570 assert!(!model.use_bearer_token);
571 }
572
573 #[test]
574 fn test_with_base_url() {
575 let model = Gemini::new("gemini-pro").with_base_url("https://vertex.example.com");
576 assert_eq!(model.base_url, "https://vertex.example.com");
577 }
578
579 #[test]
580 fn test_with_timeout() {
581 let model = Gemini::new("gemini-pro").with_timeout(Duration::from_secs(30));
582 assert_eq!(model.timeout, Some(Duration::from_secs(30)));
583 }
584
585 #[test]
586 fn test_with_max_retries() {
587 let model = Gemini::new("gemini-pro").with_max_retries(5);
588 assert_eq!(model.max_retries, 5);
589 }
590
591 #[test]
592 fn test_with_bearer_token() {
593 let model = Gemini::new("gemini-pro").with_bearer_token();
594 assert!(model.use_bearer_token);
595 }
596
597 #[test]
598 fn test_endpoint_url() {
599 let model = Gemini::new("gemini-2.0-flash");
600 assert_eq!(
601 model.endpoint_url("generateContent"),
602 "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent"
603 );
604 }
605
606 #[test]
607 fn test_tool_spec_conversion() {
608 let spec = ToolSpec {
609 name: "search".into(),
610 description: "Web search".into(),
611 parameters: serde_json::json!({"type": "object"}),
612 };
613 let decl: GeminiFunctionDeclaration = (&spec).into();
614 assert_eq!(decl.name, "search");
615 assert_eq!(decl.description, "Web search");
616 }
617
618 #[test]
619 fn test_parse_response_text() {
620 let raw = GeminiResponse {
621 candidates: vec![GeminiCandidate {
622 content: GeminiResponseContent {
623 parts: vec![GeminiResponsePart::Text {
624 text: "Hello world".into(),
625 }],
626 },
627 finish_reason: Some("STOP".into()),
628 }],
629 usage_metadata: Some(GeminiUsageMetadata {
630 prompt_token_count: 10,
631 candidates_token_count: 5,
632 cached_content_token_count: 0,
633 }),
634 };
635 let resp = parse_response(raw).unwrap();
636 assert_eq!(resp.text(), "Hello world");
637 assert_eq!(resp.stop_reason, StopReason::EndTurn);
638 assert!(!resp.has_tool_calls());
639 assert_eq!(resp.usage.unwrap().input_tokens, 10);
640 }
641
642 #[test]
643 fn test_parse_response_function_call() {
644 let raw = GeminiResponse {
645 candidates: vec![GeminiCandidate {
646 content: GeminiResponseContent {
647 parts: vec![GeminiResponsePart::FunctionCall {
648 function_call: GeminiResponseFunctionCall {
649 name: "calculator".into(),
650 args: serde_json::json!({"expr": "2+2"}),
651 },
652 }],
653 },
654 finish_reason: Some("STOP".into()),
655 }],
656 usage_metadata: None,
657 };
658 let resp = parse_response(raw).unwrap();
659 assert!(resp.has_tool_calls());
660 assert_eq!(resp.tool_calls()[0].name, "calculator");
661 assert_eq!(resp.stop_reason, StopReason::ToolUse);
662 }
663
664 #[test]
665 fn test_parse_response_no_candidates() {
666 let raw = GeminiResponse {
667 candidates: vec![],
668 usage_metadata: None,
669 };
670 assert!(parse_response(raw).is_err());
671 }
672
673 #[test]
674 fn test_build_request_with_system_prompt() {
675 let model = Gemini::with_api_key("gemini-pro", "key");
676 let request = ChatRequest {
677 messages: vec![Message::system("Be helpful"), Message::user("Hello")],
678 tools: vec![],
679 temperature: Some(0.7),
680 max_tokens: Some(1024),
681 };
682 let body = model.build_request_body(&request);
683 assert!(body.system_instruction.is_some());
684 assert_eq!(body.contents.len(), 1);
685 assert_eq!(
686 body.generation_config.as_ref().unwrap().temperature,
687 Some(0.7)
688 );
689 }
690
691 #[test]
692 fn test_build_request_with_tools() {
693 let model = Gemini::with_api_key("gemini-pro", "key");
694 let request = ChatRequest {
695 messages: vec![Message::user("hi")],
696 tools: vec![ToolSpec {
697 name: "calc".into(),
698 description: "Calculator".into(),
699 parameters: serde_json::json!({"type": "object"}),
700 }],
701 temperature: None,
702 max_tokens: None,
703 };
704 let body = model.build_request_body(&request);
705 assert!(body.tools.is_some());
706 assert_eq!(body.tools.unwrap()[0].function_declarations.len(), 1);
707 }
708
709 #[test]
710 fn test_build_request_with_tool_results() {
711 let model = Gemini::with_api_key("gemini-pro", "key");
712 let request = ChatRequest {
713 messages: vec![
714 Message::user("calc 2+2"),
715 Message::assistant_with_tool_calls(vec![ToolCall {
716 id: "gemini_calc".into(),
717 name: "calc".into(),
718 arguments: serde_json::json!({"expr": "2+2"}),
719 }]),
720 Message::tool_result("calc", "4"),
721 ],
722 tools: vec![],
723 temperature: None,
724 max_tokens: None,
725 };
726 let body = model.build_request_body(&request);
727 assert_eq!(body.contents.len(), 3);
728 }
729
730 #[test]
731 fn test_builder_chain() {
732 let model = Gemini::with_api_key("gemini-2.0-flash", "key")
733 .with_base_url("https://custom.example.com")
734 .with_timeout(Duration::from_secs(60))
735 .with_max_retries(5)
736 .with_bearer_token();
737
738 assert_eq!(model.model_id, "gemini-2.0-flash");
739 assert_eq!(model.base_url, "https://custom.example.com");
740 assert_eq!(model.timeout, Some(Duration::from_secs(60)));
741 assert_eq!(model.max_retries, 5);
742 assert!(model.use_bearer_token);
743 }
744}