Skip to main content

agent_core_runtime/client/providers/gemini/
mod.rs

1//! Google Gemini API provider implementation.
2
3mod sse;
4mod types;
5
6use async_stream::stream;
7use futures::Stream;
8
9use crate::client::error::LlmError;
10use crate::client::http::HttpClient;
11use crate::client::models::{Message, MessageOptions, StreamEvent};
12use crate::client::traits::LlmProvider;
13use std::future::Future;
14use std::pin::Pin;
15
16// =============================================================================
17// Constants
18// =============================================================================
19
20/// Error code for SSE decoding errors.
21const ERROR_SSE_DECODE: &str = "SSE_DECODE_ERROR";
22
23/// Error message for invalid UTF-8 in stream.
24const MSG_INVALID_UTF8: &str = "Invalid UTF-8 in stream";
25
26// =============================================================================
27// Provider
28// =============================================================================
29
30/// Google Gemini API provider.
31pub struct GeminiProvider {
32    /// Gemini API key.
33    api_key: String,
34    /// Model identifier (e.g., "gemini-1.5-pro", "gemini-1.5-flash").
35    model: String,
36}
37
38impl GeminiProvider {
39    /// Create a new Gemini provider with API key and model.
40    pub fn new(api_key: String, model: String) -> Self {
41        Self { api_key, model }
42    }
43
44    /// Returns the model identifier.
45    pub fn model(&self) -> &str {
46        &self.model
47    }
48}
49
50impl LlmProvider for GeminiProvider {
51    fn send_msg(
52        &self,
53        client: &HttpClient,
54        messages: &[Message],
55        options: &MessageOptions,
56    ) -> Pin<Box<dyn Future<Output = Result<Message, LlmError>> + Send>> {
57        // Clone data for the async block
58        let client = client.clone();
59        let api_key = self.api_key.clone();
60        let model = options.model.as_deref().unwrap_or(&self.model).to_string();
61        let messages = messages.to_vec();
62        let options = options.clone();
63
64        Box::pin(async move {
65            // Build request body
66            let body = types::build_request_body(&messages, &options)?;
67
68            // Get headers (validates API key)
69            let headers = types::get_request_headers(&api_key)?;
70            let headers_ref: Vec<(&str, &str)> = headers
71                .iter()
72                .map(|(k, v)| (*k, v.as_str()))
73                .collect();
74
75            // Get the API URL for this model
76            let url = types::get_api_url(&model);
77
78            // Make the API call
79            let response = client.post(&url, &headers_ref, &body).await?;
80
81            // Parse and return the response
82            types::parse_response(&response)
83        })
84    }
85
86    fn send_msg_stream(
87        &self,
88        client: &HttpClient,
89        messages: &[Message],
90        options: &MessageOptions,
91    ) -> Pin<Box<dyn Future<Output = Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>, LlmError>> + Send>> {
92        // Clone data for the async block
93        let client = client.clone();
94        let api_key = self.api_key.clone();
95        let model = options.model.as_deref().unwrap_or(&self.model).to_string();
96        let messages = messages.to_vec();
97        let options = options.clone();
98
99        Box::pin(async move {
100            // Build request body (same format for streaming and non-streaming)
101            let body = types::build_request_body(&messages, &options)?;
102
103            // Get headers (validates API key)
104            let headers = types::get_request_headers(&api_key)?;
105            let headers_ref: Vec<(&str, &str)> = headers
106                .iter()
107                .map(|(k, v)| (*k, v.as_str()))
108                .collect();
109
110            // Get the streaming API URL for this model
111            let url = types::get_streaming_api_url(&model);
112
113            // Make the streaming API call
114            let byte_stream = client.post_stream(&url, &headers_ref, &body).await?;
115
116            // Convert byte stream to SSE events stream
117            use futures::StreamExt;
118            let event_stream = stream! {
119                let mut buffer = String::new();
120                let mut byte_stream = byte_stream;
121                let mut message_started = false;
122                let mut stream_state = sse::StreamState::default();
123
124                while let Some(chunk_result) = byte_stream.next().await {
125                    match chunk_result {
126                        Ok(bytes) => {
127                            // Append new bytes to buffer
128                            if let Ok(text) = std::str::from_utf8(&bytes) {
129                                buffer.push_str(text);
130                            } else {
131                                yield Err(LlmError::new(ERROR_SSE_DECODE, MSG_INVALID_UTF8));
132                                break;
133                            }
134
135                            // Parse complete SSE events from buffer
136                            let (events, remaining) = sse::parse_sse_chunk(&buffer);
137                            buffer = remaining;
138
139                            // Convert and yield each SSE event
140                            for sse_event in events {
141                                match sse::parse_stream_event(&sse_event, &mut stream_state) {
142                                    Ok(stream_events) => {
143                                        // Emit MessageStart on first content
144                                        if !message_started && !stream_events.is_empty() {
145                                            message_started = true;
146                                            yield Ok(StreamEvent::MessageStart {
147                                                message_id: String::new(),
148                                                model: model.clone(),
149                                            });
150                                        }
151
152                                        for stream_event in stream_events {
153                                            yield Ok(stream_event);
154                                        }
155                                    }
156                                    Err(e) => {
157                                        yield Err(e);
158                                        return;
159                                    }
160                                }
161                            }
162                        }
163                        Err(e) => {
164                            yield Err(e);
165                            break;
166                        }
167                    }
168                }
169
170                // Emit MessageStop at the end
171                if message_started {
172                    yield Ok(StreamEvent::MessageStop);
173                }
174            };
175
176            Ok(Box::pin(event_stream) as Pin<Box<dyn Stream<Item = Result<StreamEvent, LlmError>> + Send>>)
177        })
178    }
179}