llm_api_rs/providers/
gemini.rs

1// Gemini API provider
2// https://ai.google.dev/api/generate-content?hl=en
3// https://aistudio.google.com/app/apikey
4
5use crate::core::client::APIClient;
6use crate::core::{ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage};
7use crate::error::LlmApiError;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Serialize, Deserialize, Default)]
11struct GeminiChatCompletionRequest {
12    contents: Vec<GeminiChatCompletionContent>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    generation_config: Option<GeminiGenerationConfig>,
15}
16
17#[derive(Debug, Deserialize)]
18struct GeminiChatCompletionResponse {
19    candidates: Vec<GeminiCandidate>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23struct GeminiChatCompletionContent {
24    role: String,
25    parts: Vec<GeminiPart>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29struct GeminiPart {
30    text: String,
31}
32
33#[derive(Debug, Serialize, Deserialize, Default)]
34struct GeminiGenerationConfig {
35    #[serde(skip_serializing_if = "Option::is_none")]
36    temperature: Option<f32>,
37    #[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
38    max_output_tokens: Option<u32>,
39}
40
41#[derive(Debug, Deserialize)]
42struct GeminiCandidate {
43    content: GeminiChatCompletionContent,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    finish_reason: Option<String>,
46}
47
48pub struct Gemini {
49    domain: String,
50    api_key: String,
51    client: APIClient,
52}
53
54impl Gemini {
55    pub fn new(api_key: String) -> Self {
56        Self {
57            domain: "https://generativelanguage.googleapis.com".to_string(),
58            api_key,
59            client: APIClient::new(),
60        }
61    }
62}
63
64#[async_trait::async_trait]
65impl super::LlmProvider for Gemini {
66    async fn chat_completion<'a>(
67        &'a self,
68        request: ChatCompletionRequest,
69    ) -> Result<ChatCompletionResponse, LlmApiError> {
70        let url = format!(
71            "{}/v1beta/models/{}:generateContent?key={}",
72            self.domain, request.model, self.api_key
73        );
74
75        let req = GeminiChatCompletionRequest {
76            contents: request
77                .messages
78                .into_iter()
79                .map(|msg| GeminiChatCompletionContent {
80                    role: msg.role,
81                    parts: vec![GeminiPart { text: msg.content }],
82                })
83                .collect(),
84            generation_config: Some(GeminiGenerationConfig {
85                temperature: request.temperature,
86                max_output_tokens: request.max_tokens,
87            }),
88        };
89
90        let res: GeminiChatCompletionResponse = self.client.send_request(url, vec![], &req).await?;
91
92        Ok(ChatCompletionResponse {
93            id: res.candidates[0].content.parts[0].text.clone(),
94            choices: res
95                .candidates
96                .into_iter()
97                .map(|candidate| ChatChoice {
98                    message: ChatMessage {
99                        role: candidate.content.role.clone(),
100                        content: candidate.content.parts[0].text.clone(),
101                    },
102                    finish_reason: candidate.finish_reason.unwrap_or_default(),
103                })
104                .collect(),
105            model: request.model,
106            usage: None,
107        })
108    }
109}