1use async_stream::try_stream;
2use async_trait::async_trait;
3use forgeai_core::{
4 AdapterInfo, CapabilityMatrix, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Role,
5 StreamEvent, StreamResult, ToolCall, Usage,
6};
7use futures_util::StreamExt;
8use reqwest::{Client as HttpClient, StatusCode};
9use serde_json::{json, Map, Value};
10use std::env;
11use url::Url;
12
13#[derive(Clone, Debug)]
14pub struct GeminiAdapter {
15 pub api_key: String,
16 pub base_url: Url,
17 pub api_version: String,
18 client: HttpClient,
19}
20
21impl GeminiAdapter {
22 pub fn new(api_key: impl Into<String>) -> Result<Self, ForgeError> {
23 let base_url = Url::parse("https://generativelanguage.googleapis.com")
24 .map_err(|e| ForgeError::Internal(e.to_string()))?;
25 Self::with_base_url(api_key, base_url)
26 }
27
28 pub fn with_base_url(api_key: impl Into<String>, base_url: Url) -> Result<Self, ForgeError> {
29 let client = HttpClient::builder()
30 .build()
31 .map_err(|e| ForgeError::Internal(format!("failed to build http client: {e}")))?;
32 Ok(Self {
33 api_key: api_key.into(),
34 base_url,
35 api_version: "v1beta".to_string(),
36 client,
37 })
38 }
39
40 pub fn from_env() -> Result<Self, ForgeError> {
41 let api_key = env::var("GEMINI_API_KEY").map_err(|_| ForgeError::Authentication)?;
42 match env::var("GEMINI_BASE_URL") {
43 Ok(raw) => {
44 let base_url = Url::parse(&raw)
45 .map_err(|e| ForgeError::Validation(format!("invalid GEMINI_BASE_URL: {e}")))?;
46 Self::with_base_url(api_key, base_url)
47 }
48 Err(_) => Self::new(api_key),
49 }
50 }
51
52 fn endpoint_url(&self, model: &str, stream: bool) -> Result<Url, ForgeError> {
53 let action = if stream {
54 "streamGenerateContent"
55 } else {
56 "generateContent"
57 };
58 let mut url = self
59 .base_url
60 .join(&format!(
61 "{}/models/{}:{}",
62 self.api_version,
63 model.trim(),
64 action
65 ))
66 .map_err(|e| ForgeError::Internal(format!("failed to construct endpoint url: {e}")))?;
67 {
68 let mut qp = url.query_pairs_mut();
69 qp.append_pair("key", &self.api_key);
70 if stream {
71 qp.append_pair("alt", "sse");
72 }
73 }
74 Ok(url)
75 }
76}
77
78#[async_trait]
79impl ChatAdapter for GeminiAdapter {
80 fn info(&self) -> AdapterInfo {
81 AdapterInfo {
82 name: "gemini".to_string(),
83 base_url: Url::parse("https://generativelanguage.googleapis.com").ok(),
84 capabilities: CapabilityMatrix {
85 streaming: true,
86 tools: true,
87 structured_output: true,
88 multimodal_input: true,
89 citations: true,
90 },
91 }
92 }
93
94 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
95 let url = self.endpoint_url(&request.model, false)?;
96 let model = request.model.clone();
97 let response = self
98 .client
99 .post(url)
100 .json(&build_generate_body(request))
101 .send()
102 .await
103 .map_err(|e| ForgeError::Transport(format!("request failed: {e}")))?;
104
105 if !response.status().is_success() {
106 let status = response.status();
107 let text = response
108 .text()
109 .await
110 .unwrap_or_else(|_| "failed to read error body".to_string());
111 return Err(parse_http_error(status, text));
112 }
113
114 let payload = response
115 .json::<Value>()
116 .await
117 .map_err(|e| ForgeError::Provider(format!("invalid json response: {e}")))?;
118 parse_chat_response(model, payload)
119 }
120
121 async fn chat_stream(
122 &self,
123 request: ChatRequest,
124 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
125 let url = self.endpoint_url(&request.model, true)?;
126 let response = self
127 .client
128 .post(url)
129 .json(&build_generate_body(request))
130 .send()
131 .await
132 .map_err(|e| ForgeError::Transport(format!("stream request failed: {e}")))?;
133
134 if !response.status().is_success() {
135 let status = response.status();
136 let text = response
137 .text()
138 .await
139 .unwrap_or_else(|_| "failed to read error body".to_string());
140 return Err(parse_http_error(status, text));
141 }
142
143 let mut bytes = response.bytes_stream();
144 let stream = try_stream! {
145 let mut buffer = String::new();
146 let mut saw_done = false;
147
148 while let Some(chunk) = bytes.next().await {
149 let chunk = chunk.map_err(|e| ForgeError::Transport(format!("stream chunk error: {e}")))?;
150 let chunk_text = std::str::from_utf8(&chunk)
151 .map_err(|e| ForgeError::Transport(format!("invalid utf8 stream chunk: {e}")))?;
152 buffer.push_str(chunk_text);
153
154 while let Some(line_end) = buffer.find('\n') {
155 let mut line = buffer[..line_end].to_string();
156 buffer.drain(..=line_end);
157 if line.ends_with('\r') {
158 line.pop();
159 }
160 if line.trim().is_empty() {
161 continue;
162 }
163 if let Some(data) = line.strip_prefix("data:") {
164 let payload = data.trim();
165 if payload == "[DONE]" {
166 saw_done = true;
167 yield StreamEvent::Done;
168 continue;
169 }
170 for event in parse_stream_payload(payload)? {
171 if matches!(event, StreamEvent::Done) {
172 saw_done = true;
173 }
174 yield event;
175 }
176 }
177 }
178 }
179
180 if !buffer.trim().is_empty() {
181 let line = buffer.trim();
182 if let Some(data) = line.strip_prefix("data:") {
183 let payload = data.trim();
184 if payload == "[DONE]" {
185 saw_done = true;
186 yield StreamEvent::Done;
187 } else {
188 for event in parse_stream_payload(payload)? {
189 if matches!(event, StreamEvent::Done) {
190 saw_done = true;
191 }
192 yield event;
193 }
194 }
195 }
196 }
197
198 if !saw_done {
199 yield StreamEvent::Done;
200 }
201 };
202
203 Ok(Box::pin(stream))
204 }
205}
206
207fn build_generate_body(request: ChatRequest) -> Value {
208 let mut body = Map::new();
209 if let Some(temperature) = request.temperature {
210 body.insert(
211 "generationConfig".to_string(),
212 json!({
213 "temperature": temperature,
214 "maxOutputTokens": request.max_tokens
215 }),
216 );
217 } else if let Some(max_tokens) = request.max_tokens {
218 body.insert(
219 "generationConfig".to_string(),
220 json!({
221 "maxOutputTokens": max_tokens
222 }),
223 );
224 }
225
226 let mut contents = Vec::new();
227 let mut system_chunks = Vec::new();
228 for message in request.messages {
229 if matches!(message.role, Role::System) {
230 system_chunks.push(message.content);
231 continue;
232 }
233 let role = if matches!(message.role, Role::Assistant) {
234 "model"
235 } else {
236 "user"
237 };
238 contents.push(json!({
239 "role": role,
240 "parts": [{ "text": message.content }]
241 }));
242 }
243 body.insert("contents".to_string(), Value::Array(contents));
244
245 if !system_chunks.is_empty() {
246 body.insert(
247 "systemInstruction".to_string(),
248 json!({
249 "parts": [{
250 "text": system_chunks.join("\n\n")
251 }]
252 }),
253 );
254 }
255
256 if !request.tools.is_empty() {
257 body.insert(
258 "tools".to_string(),
259 Value::Array(
260 request
261 .tools
262 .into_iter()
263 .map(|tool| {
264 json!({
265 "functionDeclarations": [{
266 "name": tool.name,
267 "description": tool.description,
268 "parameters": tool.input_schema
269 }]
270 })
271 })
272 .collect(),
273 ),
274 );
275 }
276
277 Value::Object(body)
278}
279
280fn parse_http_error(status: StatusCode, body: String) -> ForgeError {
281 let message = extract_provider_error(body);
282 match status {
283 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ForgeError::Authentication,
284 StatusCode::TOO_MANY_REQUESTS => ForgeError::RateLimited,
285 _ => ForgeError::Provider(message),
286 }
287}
288
289fn extract_provider_error(body: String) -> String {
290 serde_json::from_str::<Value>(&body)
291 .ok()
292 .and_then(|v| {
293 v.get("error")
294 .and_then(|e| e.get("message"))
295 .and_then(Value::as_str)
296 .map(ToString::to_string)
297 })
298 .unwrap_or(body)
299}
300
301fn parse_chat_response(model: String, payload: Value) -> Result<ChatResponse, ForgeError> {
302 let output_text = extract_text_from_payload(&payload);
303 let tool_calls = extract_tool_calls_from_payload(&payload);
304 let usage = extract_usage(payload.get("usageMetadata"));
305
306 Ok(ChatResponse {
307 id: payload
308 .get("responseId")
309 .and_then(Value::as_str)
310 .unwrap_or_default()
311 .to_string(),
312 model,
313 output_text,
314 tool_calls,
315 usage,
316 })
317}
318
319fn extract_text_from_payload(payload: &Value) -> String {
320 payload
321 .get("candidates")
322 .and_then(Value::as_array)
323 .map(|candidates| {
324 candidates
325 .iter()
326 .flat_map(|candidate| {
327 candidate
328 .get("content")
329 .and_then(|c| c.get("parts"))
330 .and_then(Value::as_array)
331 .cloned()
332 .unwrap_or_default()
333 })
334 .filter_map(|part| {
335 part.get("text")
336 .and_then(Value::as_str)
337 .map(ToString::to_string)
338 })
339 .collect::<Vec<_>>()
340 .join("")
341 })
342 .unwrap_or_default()
343}
344
345fn extract_tool_calls_from_payload(payload: &Value) -> Vec<ToolCall> {
346 payload
347 .get("candidates")
348 .and_then(Value::as_array)
349 .map(|candidates| {
350 candidates
351 .iter()
352 .flat_map(|candidate| {
353 candidate
354 .get("content")
355 .and_then(|c| c.get("parts"))
356 .and_then(Value::as_array)
357 .cloned()
358 .unwrap_or_default()
359 })
360 .filter_map(|part| {
361 let function_call = part.get("functionCall")?;
362 Some(ToolCall {
363 id: function_call
364 .get("id")
365 .and_then(Value::as_str)
366 .unwrap_or_default()
367 .to_string(),
368 name: function_call
369 .get("name")
370 .and_then(Value::as_str)
371 .unwrap_or_default()
372 .to_string(),
373 arguments: function_call.get("args").cloned().unwrap_or(Value::Null),
374 })
375 })
376 .collect()
377 })
378 .unwrap_or_default()
379}
380
381fn extract_usage(raw: Option<&Value>) -> Option<Usage> {
382 let usage = raw?;
383 let input_tokens = usage
384 .get("promptTokenCount")
385 .and_then(Value::as_u64)
386 .unwrap_or(0) as u32;
387 let output_tokens = usage
388 .get("candidatesTokenCount")
389 .and_then(Value::as_u64)
390 .unwrap_or(0) as u32;
391 let total_tokens = usage
392 .get("totalTokenCount")
393 .and_then(Value::as_u64)
394 .map(|v| v as u32)
395 .unwrap_or_else(|| input_tokens.saturating_add(output_tokens));
396 Some(Usage {
397 input_tokens,
398 output_tokens,
399 total_tokens,
400 })
401}
402
403fn parse_stream_payload(payload: &str) -> Result<Vec<StreamEvent>, ForgeError> {
404 let value = serde_json::from_str::<Value>(payload)
405 .map_err(|e| ForgeError::Provider(format!("invalid stream payload: {e}")))?;
406
407 let mut events = Vec::new();
408 let text = extract_text_from_payload(&value);
409 if !text.is_empty() {
410 events.push(StreamEvent::TextDelta { delta: text });
411 }
412
413 for tool_call in extract_tool_calls_from_payload(&value) {
414 events.push(StreamEvent::ToolCallDelta {
415 call_id: tool_call.id,
416 delta: json!({
417 "name": tool_call.name,
418 "arguments": tool_call.arguments
419 }),
420 });
421 }
422
423 if let Some(usage) = extract_usage(value.get("usageMetadata")) {
424 events.push(StreamEvent::Usage { usage });
425 }
426
427 if value
428 .get("candidates")
429 .and_then(Value::as_array)
430 .and_then(|items| items.first())
431 .and_then(|c| c.get("finishReason"))
432 .is_some()
433 {
434 events.push(StreamEvent::Done);
435 }
436
437 Ok(events)
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use forgeai_core::{ChatRequest, Message, Role};
444 use futures_util::StreamExt;
445 use wiremock::matchers::{body_partial_json, method, path, query_param};
446 use wiremock::{Mock, MockServer, ResponseTemplate};
447
448 fn sample_request() -> ChatRequest {
449 ChatRequest {
450 model: "gemini-1.5-flash".to_string(),
451 messages: vec![Message {
452 role: Role::User,
453 content: "Say hello".to_string(),
454 }],
455 temperature: Some(0.2),
456 max_tokens: Some(64),
457 tools: vec![],
458 metadata: json!({}),
459 }
460 }
461
462 #[tokio::test]
463 async fn chat_contract_parses_response_and_usage() {
464 let server = MockServer::start().await;
465 Mock::given(method("POST"))
466 .and(path("/v1beta/models/gemini-1.5-flash:generateContent"))
467 .and(query_param("key", "test-key"))
468 .and(body_partial_json(json!({"contents": [{"role":"user"}]})))
469 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
470 "responseId": "resp_123",
471 "candidates": [{
472 "content": {
473 "parts": [{"text":"Hello from Gemini"}]
474 }
475 }],
476 "usageMetadata": {
477 "promptTokenCount": 9,
478 "candidatesTokenCount": 4,
479 "totalTokenCount": 13
480 }
481 })))
482 .mount(&server)
483 .await;
484
485 let adapter =
486 GeminiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
487 let response = adapter.chat(sample_request()).await.unwrap();
488
489 assert_eq!(response.id, "resp_123");
490 assert_eq!(response.model, "gemini-1.5-flash");
491 assert_eq!(response.output_text, "Hello from Gemini");
492 assert_eq!(response.usage.unwrap().total_tokens, 13);
493 }
494
495 #[tokio::test]
496 async fn chat_stream_contract_parses_sse_events() {
497 let server = MockServer::start().await;
498 let sse_body = concat!(
499 "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello\"}]}}]}\n\n",
500 "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\" world\"}]}}]}\n\n",
501 "data: {\"usageMetadata\":{\"promptTokenCount\":9,\"candidatesTokenCount\":2,\"totalTokenCount\":11},\"candidates\":[{\"finishReason\":\"STOP\"}]}\n\n"
502 );
503
504 Mock::given(method("POST"))
505 .and(path(
506 "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
507 ))
508 .and(query_param("key", "test-key"))
509 .and(query_param("alt", "sse"))
510 .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream"))
511 .mount(&server)
512 .await;
513
514 let adapter =
515 GeminiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
516 let mut stream = adapter.chat_stream(sample_request()).await.unwrap();
517 let mut events = Vec::new();
518 while let Some(item) = stream.next().await {
519 let event = item.unwrap();
520 let done = matches!(event, StreamEvent::Done);
521 events.push(event);
522 if done {
523 break;
524 }
525 }
526
527 assert!(events
528 .iter()
529 .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == "Hello")));
530 assert!(events
531 .iter()
532 .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == " world")));
533 assert!(events
534 .iter()
535 .any(|e| matches!(e, StreamEvent::Usage { usage } if usage.total_tokens == 11)));
536 assert!(events.iter().any(|e| matches!(e, StreamEvent::Done)));
537 }
538}