langhub_core/llms/
google.rs1use crate::types::*;
3
4use super::{ChatMessage, LLM, LLMOptions, LLMResult};
5use serde_json::json;
6use std::future::Future;
7use std::pin::Pin;
8
9#[derive(Debug, Clone)]
10pub enum GoogleModel {
11 Gemini15Pro, Gemini15Flash, Gemini15ProVision, GeminiPro, GeminiProVision, GeminiUltra, }
18
19impl GoogleModel {
20 fn as_str(&self) -> &'static str {
21 match self {
22 GoogleModel::Gemini15Pro => "gemini-1.5-pro",
23 GoogleModel::Gemini15Flash => "gemini-1.5-flash",
24 GoogleModel::Gemini15ProVision => "gemini-1.5-pro-vision",
25 GoogleModel::GeminiPro => "gemini-pro",
26 GoogleModel::GeminiProVision => "gemini-pro-vision",
27 GoogleModel::GeminiUltra => "gemini-ultra",
28 }
29 }
30}
31
32impl From<GoogleModel> for String {
33 fn from(model: GoogleModel) -> Self {
34 model.as_str().to_string()
35 }
36}
37
38pub struct GoogleAI {
39 api_key: String,
40 model: GoogleModel,
41 base_url: String,
42 client: reqwest::Client,
43 default_options: LLMOptions,
44}
45
46impl GoogleAI {
47 pub fn new(api_key: String) -> Self {
48 Self {
49 api_key,
50 model: GoogleModel::Gemini15Pro,
51 base_url: "https://generativelanguage.googleapis.com/v1beta".to_string(),
52 client: reqwest::Client::new(),
53 default_options: LLMOptions::default(),
54 }
55 }
56
57 pub fn with_model(mut self, model: GoogleModel) -> Self {
58 self.model = model;
59 self
60 }
61
62 pub fn gemini15_pro(self) -> Self {
63 self.with_model(GoogleModel::Gemini15Pro)
64 }
65
66 pub fn gemini15_flash(self) -> Self {
67 self.with_model(GoogleModel::Gemini15Flash)
68 }
69
70 pub fn gemini_pro(self) -> Self {
71 self.with_model(GoogleModel::GeminiPro)
72 }
73
74 pub fn with_temperature(mut self, temperature: f32) -> Self {
75 self.default_options.temperature = Some(temperature);
76 self
77 }
78
79 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
80 self.default_options.max_tokens = Some(max_tokens);
81 self
82 }
83
84 pub fn with_top_p(mut self, top_p: f32) -> Self {
85 self.default_options.top_p = Some(top_p);
86 self
87 }
88
89 pub fn with_top_k(mut self, top_k: u32) -> Self {
90 self.default_options.top_k = Some(top_k);
91 self
92 }
93
94 pub fn with_base_url(mut self, base_url: &str) -> Self {
95 self.base_url = base_url.to_string();
96 self
97 }
98
99 async fn chat_completion(
100 &self,
101 messages: &[ChatMessage],
102 options: &LLMOptions,
103 ) -> Result<String> {
104 let model_name: String = self.model.clone().into();
105
106 let contents: Vec<serde_json::Value> = messages
107 .iter()
108 .map(|m| {
109 json!({
110 "parts": [{"text": m.content}],
111 "role": if m.role == "user" { "user" } else { "model" },
112 })
113 })
114 .collect();
115
116 let mut generation_config = json!({});
117
118 if let Some(temp) = options.temperature.or(self.default_options.temperature) {
119 generation_config["temperature"] = json!(temp);
120 }
121 if let Some(max_tokens) = options.max_tokens.or(self.default_options.max_tokens) {
122 generation_config["maxOutputTokens"] = json!(max_tokens);
123 }
124 if let Some(top_p) = options.top_p.or(self.default_options.top_p) {
125 generation_config["topP"] = json!(top_p);
126 }
127 if let Some(top_k) = options.top_k.or(self.default_options.top_k) {
128 generation_config["topK"] = json!(top_k);
129 }
130
131 let request_body = json!({
132 "contents": contents,
133 "generationConfig": generation_config,
134 });
135
136 let url = format!(
137 "{}/models/{}:generateContent?key={}",
138 self.base_url, model_name, self.api_key
139 );
140
141 let response = self
142 .client
143 .post(&url)
144 .header("Content-Type", "application/json")
145 .json(&request_body)
146 .send()
147 .await
148 .map_err(|e| LangHubError::LLMError(format!("Google request error: {}", e)))?;
149
150 if !response.status().is_success() {
151 let status = response.status();
152 let error_text = response.text().await.unwrap_or_default();
153 return Err(LangHubError::LLMError(format!(
154 "Google API error ({}): {}",
155 status, error_text
156 )));
157 }
158
159 let json: serde_json::Value = response
160 .json()
161 .await
162 .map_err(|e| LangHubError::LLMError(format!("JSON parse error: {}", e)))?;
163
164 let text = json["candidates"][0]["content"]["parts"][0]["text"]
165 .as_str()
166 .ok_or_else(|| {
167 LangHubError::ParseError("Missing 'text' field in response".to_string())
168 })?
169 .to_string();
170
171 Ok(text)
172 }
173}
174
175impl LLM for GoogleAI {
176 fn generate(
177 &self,
178 prompt: &str,
179 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
180 let prompt = prompt.to_string();
181 let options = self.default_options.clone();
182 Box::pin(async move {
183 let messages = vec![ChatMessage::user(&prompt)];
184 let text = self.chat_completion(&messages, &options).await?;
185 Ok(LLMResult {
186 text,
187 metadata: None,
188 })
189 })
190 }
191
192 fn generate_with_options(
193 &self,
194 prompt: &str,
195 options: LLMOptions,
196 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
197 let prompt = prompt.to_string();
198 Box::pin(async move {
199 let messages = vec![ChatMessage::user(&prompt)];
200 let text = self.chat_completion(&messages, &options).await?;
201 Ok(LLMResult {
202 text,
203 metadata: None,
204 })
205 })
206 }
207
208 fn chat(
209 &self,
210 messages: Vec<ChatMessage>,
211 ) -> Pin<Box<dyn Future<Output = Result<LLMResult>> + Send + '_>> {
212 Box::pin(async move {
213 let text = self
214 .chat_completion(&messages, &LLMOptions::default())
215 .await?;
216 Ok(LLMResult {
217 text,
218 metadata: None,
219 })
220 })
221 }
222
223 fn get_model_name(&self) -> &str {
224 match self.model {
225 GoogleModel::Gemini15Pro => "gemini-1.5-pro",
226 GoogleModel::Gemini15Flash => "gemini-1.5-flash",
227 GoogleModel::Gemini15ProVision => "gemini-1.5-pro-vision",
228 GoogleModel::GeminiPro => "gemini-pro",
229 GoogleModel::GeminiProVision => "gemini-pro-vision",
230 GoogleModel::GeminiUltra => "gemini-ultra",
231 }
232 }
233
234 fn get_provider_name(&self) -> &str {
235 match self.model {
236 GoogleModel::Gemini15Pro => "Google-Gemini1.5-Pro",
237 GoogleModel::Gemini15Flash => "Google-Gemini1.5-Flash",
238 GoogleModel::Gemini15ProVision => "Google-Gemini1.5-Pro-Vision",
239 GoogleModel::GeminiPro => "Google-Gemini-Pro",
240 GoogleModel::GeminiProVision => "Google-Gemini-Pro-Vision",
241 GoogleModel::GeminiUltra => "Google-Gemini-Ultra",
242 }
243 }
244
245 fn get_provider_enum(&self) -> ModelProvider {
246 ModelProvider::Google
247 }
248
249 fn supports_function_calling(&self) -> bool {
250 true
251 }
252
253 fn supports_json_mode(&self) -> bool {
254 true
255 }
256
257 fn max_context_length(&self) -> Option<usize> {
258 match self.model {
259 GoogleModel::Gemini15Pro => Some(2_000_000),
260 GoogleModel::Gemini15Flash => Some(1_000_000),
261 GoogleModel::GeminiPro => Some(32768),
262 GoogleModel::GeminiUltra => Some(32768),
263 _ => Some(32768),
264 }
265 }
266}