Skip to main content

langhub_core/llms/
google.rs

1/// Google (Gemini)
2use crate::types::*;
3
4use super::{ChatMessage, LLM, LLMOptions, LLMResult};
5use serde_json::json;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone)]
10pub enum GoogleModel {
11    Gemini15Pro,        // Gemini 1.5 Pro
12    Gemini15Flash,      // Gemini 1.5 Flash
13    Gemini15ProVision,  // Gemini 1.5 Pro Vision
14    GeminiPro,          // Gemini Pro
15    GeminiProVision,    // Gemini Pro Vision
16    GeminiUltra,        // Gemini Ultra
17}
18
19impl GoogleModel {
20    fn as_str(&self) -> &'static str {
21        match self {
22            GoogleModel::Gemini15Pro => "gemini-1.5-pro",
23            GoogleModel::Gemini15Flash => "gemini-1.5-flash",
24            GoogleModel::Gemini15ProVision => "gemini-1.5-pro-vision",
25            GoogleModel::GeminiPro => "gemini-pro",
26            GoogleModel::GeminiProVision => "gemini-pro-vision",
27            GoogleModel::GeminiUltra => "gemini-ultra",
28        }
29    }
30}
31
32impl From<GoogleModel> for String {
33    fn from(model: GoogleModel) -> Self {
34        model.as_str().to_string()
35    }
36}
37
38pub struct GoogleAI {
39    api_key: String,
40    model: GoogleModel,
41    base_url: String,
42    client: reqwest::Client,
43    default_options: LLMOptions,
44}
45
46impl GoogleAI {
47    pub fn new(api_key: String) -> Self {
48        Self {
49            api_key,
50            model: GoogleModel::Gemini15Pro,
51            base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
52            client: reqwest::Client::new(),
53            default_options: LLMOptions::default(),
54        }
55    }
56
57    pub fn with_model(mut self, model: GoogleModel) -> Self {
58        self.model = model;
59        self
60    }
61
62    pub fn gemini15_pro(self) -> Self {
63        self.with_model(GoogleModel::Gemini15Pro)
64    }
65
66    pub fn gemini15_flash(self) -> Self {
67        self.with_model(GoogleModel::Gemini15Flash)
68    }
69
70    pub fn gemini_pro(self) -> Self {
71        self.with_model(GoogleModel::GeminiPro)
72    }
73
74    pub fn with_temperature(mut self, temperature: f32) -> Self {
75        self.default_options.temperature = Some(temperature);
76        self
77    }
78
79    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
80        self.default_options.max_tokens = Some(max_tokens);
81        self
82    }
83
84    pub fn with_top_p(mut self, top_p: f32) -> Self {
85        self.default_options.top_p = Some(top_p);
86        self
87    }
88
89    pub fn with_top_k(mut self, top_k: u32) -> Self {
90        self.default_options.top_k = Some(top_k);
91        self
92    }
93
94    pub fn with_base_url(mut self, base_url: &str) -> Self {
95        self.base_url = base_url.to_string();
96        self
97    }
98
99    async fn chat_completion(
100        &self,
101        messages: &[ChatMessage],
102        options: &LLMOptions,
103    ) -> Result<String> {
104        let model_name: String = self.model.clone().into();
105
106        let contents: Vec<serde_json::Value> = messages
107            .iter()
108            .map(|m| {
109                json!({
110                    "parts": [{"text": m.content}],
111                    "role": if m.role == "user" { "user" } else { "model" },
112                })
113            })
114            .collect();
115
116        let mut generation_config = json!({});
117
118        if let Some(temp) = options.temperature.or(self.default_options.temperature) {
119            generation_config["temperature"] = json!(temp);
120        }
121        if let Some(max_tokens) = options.max_tokens.or(self.default_options.max_tokens) {
122            generation_config["maxOutputTokens"] = json!(max_tokens);
123        }
124        if let Some(top_p) = options.top_p.or(self.default_options.top_p) {
125            generation_config["topP"] = json!(top_p);
126        }
127        if let Some(top_k) = options.top_k.or(self.default_options.top_k) {
128            generation_config["topK"] = json!(top_k);
129        }
130
131        let request_body = json!({
132            "contents": contents,
133            "generationConfig": generation_config,
134        });
135
136        let url = format!(
137            "{}/models/{}:generateContent?key={}",
138            self.base_url, model_name, self.api_key
139        );
140
141        let response = self
142            .client
143            .post(&url)
144            .header("Content-Type", "application/json")
145            .json(&request_body)
146            .send()
147            .await
148            .map_err(|e| LangHubError::LLMError(format!("Google request error: {}", e)))?;
149
150        if !response.status().is_success() {
151            let status = response.status();
152            let error_text = response.text().await.unwrap_or_default();
153            return Err(LangHubError::LLMError(format!(
154                "Google API error ({}): {}",
155                status, error_text
156            )));
157        }
158
159        let json: serde_json::Value = response
160            .json()
161            .await
162            .map_err(|e| LangHubError::LLMError(format!("JSON parse error: {}", e)))?;
163
164        let text = json["candidates"][0]["content"]["parts"][0]["text"]
165            .as_str()
166            .ok_or_else(|| {
167                LangHubError::ParseError("Missing 'text' field in response".to_string())
168            })?
169            .to_string();
170
171        Ok(text)
172    }
173}
174
175impl LLM for GoogleAI {
176    fn generate(
177        &self,
178        prompt: &str,
179    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
180        let prompt = prompt.to_string();
181        let options = self.default_options.clone();
182        Box::pin(async move {
183            let messages = vec![ChatMessage::user(&prompt)];
184            let text = self.chat_completion(&messages, &options).await?;
185            Ok(LLMResult {
186                text,
187                metadata: None,
188            })
189        })
190    }
191
192    fn generate_with_options(
193        &self,
194        prompt: &str,
195        options: LLMOptions,
196    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
197        let prompt = prompt.to_string();
198        Box::pin(async move {
199            let messages = vec![ChatMessage::user(&prompt)];
200            let text = self.chat_completion(&messages, &options).await?;
201            Ok(LLMResult {
202                text,
203                metadata: None,
204            })
205        })
206    }
207
208    fn chat(
209        &self,
210        messages: Vec<ChatMessage>,
211    ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
212        Box::pin(async move {
213            let text = self
214                .chat_completion(&messages, &LLMOptions::default())
215                .await?;
216            Ok(LLMResult {
217                text,
218                metadata: None,
219            })
220        })
221    }
222
223    fn get_model_name(&self) -> &str {
224        match self.model {
225            GoogleModel::Gemini15Pro => "gemini-1.5-pro",
226            GoogleModel::Gemini15Flash => "gemini-1.5-flash",
227            GoogleModel::Gemini15ProVision => "gemini-1.5-pro-vision",
228            GoogleModel::GeminiPro => "gemini-pro",
229            GoogleModel::GeminiProVision => "gemini-pro-vision",
230            GoogleModel::GeminiUltra => "gemini-ultra",
231        }
232    }
233
234    fn get_provider_name(&self) -> &str {
235        match self.model {
236            GoogleModel::Gemini15Pro => "Google-Gemini1.5-Pro",
237            GoogleModel::Gemini15Flash => "Google-Gemini1.5-Flash",
238            GoogleModel::Gemini15ProVision => "Google-Gemini1.5-Pro-Vision",
239            GoogleModel::GeminiPro => "Google-Gemini-Pro",
240            GoogleModel::GeminiProVision => "Google-Gemini-Pro-Vision",
241            GoogleModel::GeminiUltra => "Google-Gemini-Ultra",
242        }
243    }
244
245    fn get_provider_enum(&self) -> ModelProvider {
246        ModelProvider::Google
247    }
248
249    fn supports_function_calling(&self) -> bool {
250        true
251    }
252
253    fn supports_json_mode(&self) -> bool {
254        true
255    }
256
257    fn max_context_length(&self) -> Option<usize> {
258        match self.model {
259            GoogleModel::Gemini15Pro => Some(2_000_000),
260            GoogleModel::Gemini15Flash => Some(1_000_000),
261            GoogleModel::GeminiPro => Some(32768),
262            GoogleModel::GeminiUltra => Some(32768),
263            _ => Some(32768),
264        }
265    }
266}