llm_api_rs/providers/
gemini.rs1use crate::core::client::APIClient;
6use crate::core::{ChatChoice, ChatCompletionRequest, ChatCompletionResponse, ChatMessage};
7use crate::error::LlmApiError;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Serialize, Deserialize, Default)]
11struct GeminiChatCompletionRequest {
12 contents: Vec<GeminiChatCompletionContent>,
13 #[serde(skip_serializing_if = "Option::is_none")]
14 generation_config: Option<GeminiGenerationConfig>,
15}
16
17#[derive(Debug, Deserialize)]
18struct GeminiChatCompletionResponse {
19 candidates: Vec<GeminiCandidate>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
23struct GeminiChatCompletionContent {
24 role: String,
25 parts: Vec<GeminiPart>,
26}
27
28#[derive(Debug, Serialize, Deserialize)]
29struct GeminiPart {
30 text: String,
31}
32
33#[derive(Debug, Serialize, Deserialize, Default)]
34struct GeminiGenerationConfig {
35 #[serde(skip_serializing_if = "Option::is_none")]
36 temperature: Option<f32>,
37 #[serde(skip_serializing_if = "Option::is_none", rename = "maxOutputTokens")]
38 max_output_tokens: Option<u32>,
39}
40
41#[derive(Debug, Deserialize)]
42struct GeminiCandidate {
43 content: GeminiChatCompletionContent,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 finish_reason: Option<String>,
46}
47
48pub struct Gemini {
49 domain: String,
50 api_key: String,
51 client: APIClient,
52}
53
54impl Gemini {
55 pub fn new(api_key: String) -> Self {
56 Self {
57 domain: "https://generativelanguage.googleapis.com".to_string(),
58 api_key,
59 client: APIClient::new(),
60 }
61 }
62}
63
64#[async_trait::async_trait]
65impl super::LlmProvider for Gemini {
66 async fn chat_completion<'a>(
67 &'a self,
68 request: ChatCompletionRequest,
69 ) -> Result<ChatCompletionResponse, LlmApiError> {
70 let url = format!(
71 "{}/v1beta/models/{}:generateContent?key={}",
72 self.domain, request.model, self.api_key
73 );
74
75 let req = GeminiChatCompletionRequest {
76 contents: request
77 .messages
78 .into_iter()
79 .map(|msg| GeminiChatCompletionContent {
80 role: msg.role,
81 parts: vec![GeminiPart { text: msg.content }],
82 })
83 .collect(),
84 generation_config: Some(GeminiGenerationConfig {
85 temperature: request.temperature,
86 max_output_tokens: request.max_tokens,
87 }),
88 };
89
90 let res: GeminiChatCompletionResponse = self.client.send_request(url, vec![], &req).await?;
91
92 Ok(ChatCompletionResponse {
93 id: res.candidates[0].content.parts[0].text.clone(),
94 choices: res
95 .candidates
96 .into_iter()
97 .map(|candidate| ChatChoice {
98 message: ChatMessage {
99 role: candidate.content.role.clone(),
100 content: candidate.content.parts[0].text.clone(),
101 },
102 finish_reason: candidate.finish_reason.unwrap_or_default(),
103 })
104 .collect(),
105 model: request.model,
106 usage: None,
107 })
108 }
109}