Skip to main content

forgeai_adapter_gemini/
lib.rs

1use async_stream::try_stream;
2use async_trait::async_trait;
3use forgeai_core::{
4    AdapterInfo, CapabilityMatrix, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Role,
5    StreamEvent, StreamResult, ToolCall, Usage,
6};
7use futures_util::StreamExt;
8use reqwest::{Client as HttpClient, StatusCode};
9use serde_json::{json, Map, Value};
10use std::env;
11use url::Url;
12
13#[derive(Clone, Debug)]
14pub struct GeminiAdapter {
15    pub api_key: String,
16    pub base_url: Url,
17    pub api_version: String,
18    client: HttpClient,
19}
20
21impl GeminiAdapter {
22    pub fn new(api_key: impl Into<String>) -> Result<Self, ForgeError> {
23        let base_url = Url::parse("https://generativelanguage.googleapis.com")
24            .map_err(|e| ForgeError::Internal(e.to_string()))?;
25        Self::with_base_url(api_key, base_url)
26    }
27
28    pub fn with_base_url(api_key: impl Into<String>, base_url: Url) -> Result<Self, ForgeError> {
29        let client = HttpClient::builder()
30            .build()
31            .map_err(|e| ForgeError::Internal(format!("failed to build http client: {e}")))?;
32        Ok(Self {
33            api_key: api_key.into(),
34            base_url,
35            api_version: "v1beta".to_string(),
36            client,
37        })
38    }
39
40    pub fn from_env() -> Result<Self, ForgeError> {
41        let api_key = env::var("GEMINI_API_KEY").map_err(|_| ForgeError::Authentication)?;
42        match env::var("GEMINI_BASE_URL") {
43            Ok(raw) => {
44                let base_url = Url::parse(&raw)
45                    .map_err(|e| ForgeError::Validation(format!("invalid GEMINI_BASE_URL: {e}")))?;
46                Self::with_base_url(api_key, base_url)
47            }
48            Err(_) => Self::new(api_key),
49        }
50    }
51
52    fn endpoint_url(&self, model: &str, stream: bool) -> Result<Url, ForgeError> {
53        let action = if stream {
54            "streamGenerateContent"
55        } else {
56            "generateContent"
57        };
58        let mut url = self
59            .base_url
60            .join(&format!(
61                "{}/models/{}:{}",
62                self.api_version,
63                model.trim(),
64                action
65            ))
66            .map_err(|e| ForgeError::Internal(format!("failed to construct endpoint url: {e}")))?;
67        {
68            let mut qp = url.query_pairs_mut();
69            qp.append_pair("key", &self.api_key);
70            if stream {
71                qp.append_pair("alt", "sse");
72            }
73        }
74        Ok(url)
75    }
76}
77
78#[async_trait]
79impl ChatAdapter for GeminiAdapter {
80    fn info(&self) -> AdapterInfo {
81        AdapterInfo {
82            name: "gemini".to_string(),
83            base_url: Url::parse("https://generativelanguage.googleapis.com").ok(),
84            capabilities: CapabilityMatrix {
85                streaming: true,
86                tools: true,
87                structured_output: true,
88                multimodal_input: true,
89                citations: true,
90            },
91        }
92    }
93
94    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
95        let url = self.endpoint_url(&request.model, false)?;
96        let model = request.model.clone();
97        let response = self
98            .client
99            .post(url)
100            .json(&build_generate_body(request))
101            .send()
102            .await
103            .map_err(|e| ForgeError::Transport(format!("request failed: {e}")))?;
104
105        if !response.status().is_success() {
106            let status = response.status();
107            let text = response
108                .text()
109                .await
110                .unwrap_or_else(|_| "failed to read error body".to_string());
111            return Err(parse_http_error(status, text));
112        }
113
114        let payload = response
115            .json::<Value>()
116            .await
117            .map_err(|e| ForgeError::Provider(format!("invalid json response: {e}")))?;
118        parse_chat_response(model, payload)
119    }
120
121    async fn chat_stream(
122        &self,
123        request: ChatRequest,
124    ) -> Result<StreamResult<StreamEvent>, ForgeError> {
125        let url = self.endpoint_url(&request.model, true)?;
126        let response = self
127            .client
128            .post(url)
129            .json(&build_generate_body(request))
130            .send()
131            .await
132            .map_err(|e| ForgeError::Transport(format!("stream request failed: {e}")))?;
133
134        if !response.status().is_success() {
135            let status = response.status();
136            let text = response
137                .text()
138                .await
139                .unwrap_or_else(|_| "failed to read error body".to_string());
140            return Err(parse_http_error(status, text));
141        }
142
143        let mut bytes = response.bytes_stream();
144        let stream = try_stream! {
145            let mut buffer = String::new();
146            let mut saw_done = false;
147
148            while let Some(chunk) = bytes.next().await {
149                let chunk = chunk.map_err(|e| ForgeError::Transport(format!("stream chunk error: {e}")))?;
150                let chunk_text = std::str::from_utf8(&chunk)
151                    .map_err(|e| ForgeError::Transport(format!("invalid utf8 stream chunk: {e}")))?;
152                buffer.push_str(chunk_text);
153
154                while let Some(line_end) = buffer.find('\n') {
155                    let mut line = buffer[..line_end].to_string();
156                    buffer.drain(..=line_end);
157                    if line.ends_with('\r') {
158                        line.pop();
159                    }
160                    if line.trim().is_empty() {
161                        continue;
162                    }
163                    if let Some(data) = line.strip_prefix("data:") {
164                        let payload = data.trim();
165                        if payload == "[DONE]" {
166                            saw_done = true;
167                            yield StreamEvent::Done;
168                            continue;
169                        }
170                        for event in parse_stream_payload(payload)? {
171                            if matches!(event, StreamEvent::Done) {
172                                saw_done = true;
173                            }
174                            yield event;
175                        }
176                    }
177                }
178            }
179
180            if !buffer.trim().is_empty() {
181                let line = buffer.trim();
182                if let Some(data) = line.strip_prefix("data:") {
183                    let payload = data.trim();
184                    if payload == "[DONE]" {
185                        saw_done = true;
186                        yield StreamEvent::Done;
187                    } else {
188                        for event in parse_stream_payload(payload)? {
189                            if matches!(event, StreamEvent::Done) {
190                                saw_done = true;
191                            }
192                            yield event;
193                        }
194                    }
195                }
196            }
197
198            if !saw_done {
199                yield StreamEvent::Done;
200            }
201        };
202
203        Ok(Box::pin(stream))
204    }
205}
206
207fn build_generate_body(request: ChatRequest) -> Value {
208    let mut body = Map::new();
209    if let Some(temperature) = request.temperature {
210        body.insert(
211            "generationConfig".to_string(),
212            json!({
213                "temperature": temperature,
214                "maxOutputTokens": request.max_tokens
215            }),
216        );
217    } else if let Some(max_tokens) = request.max_tokens {
218        body.insert(
219            "generationConfig".to_string(),
220            json!({
221                "maxOutputTokens": max_tokens
222            }),
223        );
224    }
225
226    let mut contents = Vec::new();
227    let mut system_chunks = Vec::new();
228    for message in request.messages {
229        if matches!(message.role, Role::System) {
230            system_chunks.push(message.content);
231            continue;
232        }
233        let role = if matches!(message.role, Role::Assistant) {
234            "model"
235        } else {
236            "user"
237        };
238        contents.push(json!({
239            "role": role,
240            "parts": [{ "text": message.content }]
241        }));
242    }
243    body.insert("contents".to_string(), Value::Array(contents));
244
245    if !system_chunks.is_empty() {
246        body.insert(
247            "systemInstruction".to_string(),
248            json!({
249                "parts": [{
250                    "text": system_chunks.join("\n\n")
251                }]
252            }),
253        );
254    }
255
256    if !request.tools.is_empty() {
257        body.insert(
258            "tools".to_string(),
259            Value::Array(
260                request
261                    .tools
262                    .into_iter()
263                    .map(|tool| {
264                        json!({
265                            "functionDeclarations": [{
266                                "name": tool.name,
267                                "description": tool.description,
268                                "parameters": tool.input_schema
269                            }]
270                        })
271                    })
272                    .collect(),
273            ),
274        );
275    }
276
277    Value::Object(body)
278}
279
280fn parse_http_error(status: StatusCode, body: String) -> ForgeError {
281    let message = extract_provider_error(body);
282    match status {
283        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ForgeError::Authentication,
284        StatusCode::TOO_MANY_REQUESTS => ForgeError::RateLimited,
285        _ => ForgeError::Provider(message),
286    }
287}
288
289fn extract_provider_error(body: String) -> String {
290    serde_json::from_str::<Value>(&body)
291        .ok()
292        .and_then(|v| {
293            v.get("error")
294                .and_then(|e| e.get("message"))
295                .and_then(Value::as_str)
296                .map(ToString::to_string)
297        })
298        .unwrap_or(body)
299}
300
301fn parse_chat_response(model: String, payload: Value) -> Result<ChatResponse, ForgeError> {
302    let output_text = extract_text_from_payload(&payload);
303    let tool_calls = extract_tool_calls_from_payload(&payload);
304    let usage = extract_usage(payload.get("usageMetadata"));
305
306    Ok(ChatResponse {
307        id: payload
308            .get("responseId")
309            .and_then(Value::as_str)
310            .unwrap_or_default()
311            .to_string(),
312        model,
313        output_text,
314        tool_calls,
315        usage,
316    })
317}
318
319fn extract_text_from_payload(payload: &Value) -> String {
320    payload
321        .get("candidates")
322        .and_then(Value::as_array)
323        .map(|candidates| {
324            candidates
325                .iter()
326                .flat_map(|candidate| {
327                    candidate
328                        .get("content")
329                        .and_then(|c| c.get("parts"))
330                        .and_then(Value::as_array)
331                        .cloned()
332                        .unwrap_or_default()
333                })
334                .filter_map(|part| {
335                    part.get("text")
336                        .and_then(Value::as_str)
337                        .map(ToString::to_string)
338                })
339                .collect::<Vec<_>>()
340                .join("")
341        })
342        .unwrap_or_default()
343}
344
345fn extract_tool_calls_from_payload(payload: &Value) -> Vec<ToolCall> {
346    payload
347        .get("candidates")
348        .and_then(Value::as_array)
349        .map(|candidates| {
350            candidates
351                .iter()
352                .flat_map(|candidate| {
353                    candidate
354                        .get("content")
355                        .and_then(|c| c.get("parts"))
356                        .and_then(Value::as_array)
357                        .cloned()
358                        .unwrap_or_default()
359                })
360                .filter_map(|part| {
361                    let function_call = part.get("functionCall")?;
362                    Some(ToolCall {
363                        id: function_call
364                            .get("id")
365                            .and_then(Value::as_str)
366                            .unwrap_or_default()
367                            .to_string(),
368                        name: function_call
369                            .get("name")
370                            .and_then(Value::as_str)
371                            .unwrap_or_default()
372                            .to_string(),
373                        arguments: function_call.get("args").cloned().unwrap_or(Value::Null),
374                    })
375                })
376                .collect()
377        })
378        .unwrap_or_default()
379}
380
381fn extract_usage(raw: Option<&Value>) -> Option<Usage> {
382    let usage = raw?;
383    let input_tokens = usage
384        .get("promptTokenCount")
385        .and_then(Value::as_u64)
386        .unwrap_or(0) as u32;
387    let output_tokens = usage
388        .get("candidatesTokenCount")
389        .and_then(Value::as_u64)
390        .unwrap_or(0) as u32;
391    let total_tokens = usage
392        .get("totalTokenCount")
393        .and_then(Value::as_u64)
394        .map(|v| v as u32)
395        .unwrap_or_else(|| input_tokens.saturating_add(output_tokens));
396    Some(Usage {
397        input_tokens,
398        output_tokens,
399        total_tokens,
400    })
401}
402
403fn parse_stream_payload(payload: &str) -> Result<Vec<StreamEvent>, ForgeError> {
404    let value = serde_json::from_str::<Value>(payload)
405        .map_err(|e| ForgeError::Provider(format!("invalid stream payload: {e}")))?;
406
407    let mut events = Vec::new();
408    let text = extract_text_from_payload(&value);
409    if !text.is_empty() {
410        events.push(StreamEvent::TextDelta { delta: text });
411    }
412
413    for tool_call in extract_tool_calls_from_payload(&value) {
414        events.push(StreamEvent::ToolCallDelta {
415            call_id: tool_call.id,
416            delta: json!({
417                "name": tool_call.name,
418                "arguments": tool_call.arguments
419            }),
420        });
421    }
422
423    if let Some(usage) = extract_usage(value.get("usageMetadata")) {
424        events.push(StreamEvent::Usage { usage });
425    }
426
427    if value
428        .get("candidates")
429        .and_then(Value::as_array)
430        .and_then(|items| items.first())
431        .and_then(|c| c.get("finishReason"))
432        .is_some()
433    {
434        events.push(StreamEvent::Done);
435    }
436
437    Ok(events)
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use forgeai_core::{ChatRequest, Message, Role};
444    use futures_util::StreamExt;
445    use wiremock::matchers::{body_partial_json, method, path, query_param};
446    use wiremock::{Mock, MockServer, ResponseTemplate};
447
448    fn sample_request() -> ChatRequest {
449        ChatRequest {
450            model: "gemini-1.5-flash".to_string(),
451            messages: vec![Message {
452                role: Role::User,
453                content: "Say hello".to_string(),
454            }],
455            temperature: Some(0.2),
456            max_tokens: Some(64),
457            tools: vec![],
458            metadata: json!({}),
459        }
460    }
461
462    #[tokio::test]
463    async fn chat_contract_parses_response_and_usage() {
464        let server = MockServer::start().await;
465        Mock::given(method("POST"))
466            .and(path("/v1beta/models/gemini-1.5-flash:generateContent"))
467            .and(query_param("key", "test-key"))
468            .and(body_partial_json(json!({"contents": [{"role":"user"}]})))
469            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
470                "responseId": "resp_123",
471                "candidates": [{
472                    "content": {
473                        "parts": [{"text":"Hello from Gemini"}]
474                    }
475                }],
476                "usageMetadata": {
477                    "promptTokenCount": 9,
478                    "candidatesTokenCount": 4,
479                    "totalTokenCount": 13
480                }
481            })))
482            .mount(&server)
483            .await;
484
485        let adapter =
486            GeminiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
487        let response = adapter.chat(sample_request()).await.unwrap();
488
489        assert_eq!(response.id, "resp_123");
490        assert_eq!(response.model, "gemini-1.5-flash");
491        assert_eq!(response.output_text, "Hello from Gemini");
492        assert_eq!(response.usage.unwrap().total_tokens, 13);
493    }
494
495    #[tokio::test]
496    async fn chat_stream_contract_parses_sse_events() {
497        let server = MockServer::start().await;
498        let sse_body = concat!(
499            "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Hello\"}]}}]}\n\n",
500            "data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\" world\"}]}}]}\n\n",
501            "data: {\"usageMetadata\":{\"promptTokenCount\":9,\"candidatesTokenCount\":2,\"totalTokenCount\":11},\"candidates\":[{\"finishReason\":\"STOP\"}]}\n\n"
502        );
503
504        Mock::given(method("POST"))
505            .and(path(
506                "/v1beta/models/gemini-1.5-flash:streamGenerateContent",
507            ))
508            .and(query_param("key", "test-key"))
509            .and(query_param("alt", "sse"))
510            .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream"))
511            .mount(&server)
512            .await;
513
514        let adapter =
515            GeminiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
516        let mut stream = adapter.chat_stream(sample_request()).await.unwrap();
517        let mut events = Vec::new();
518        while let Some(item) = stream.next().await {
519            let event = item.unwrap();
520            let done = matches!(event, StreamEvent::Done);
521            events.push(event);
522            if done {
523                break;
524            }
525        }
526
527        assert!(events
528            .iter()
529            .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == "Hello")));
530        assert!(events
531            .iter()
532            .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == " world")));
533        assert!(events
534            .iter()
535            .any(|e| matches!(e, StreamEvent::Usage { usage } if usage.total_tokens == 11)));
536        assert!(events.iter().any(|e| matches!(e, StreamEvent::Done)));
537    }
538}