Skip to main content

aether_ai/
gemini.rs

1//! Google Gemini provider implementation.
2//!
3//! Supports Gemini Pro and other Google AI models.
4
5use aether_core::{
6    AetherError, AiProvider, ProviderConfig, Result,
7    provider::{GenerationRequest, GenerationResponse},
8    SlotKind,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use tracing::{debug, instrument};
14use aether_core::provider::StreamResponse;
15use futures::stream::{BoxStream, StreamExt};
16
17const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
18
19/// Google Gemini provider for code generation.
20#[derive(Debug, Clone)]
21pub struct GeminiProvider {
22    client: Client,
23    config: ProviderConfig,
24}
25
26// Request structures
27#[derive(Debug, Serialize)]
28struct GeminiRequest {
29    contents: Vec<Content>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    generation_config: Option<GenerationConfig>,
32}
33
34#[derive(Debug, Serialize)]
35struct Content {
36    parts: Vec<Part>,
37    role: String,
38}
39
40#[derive(Debug, Serialize)]
41struct Part {
42    text: String,
43}
44
45#[derive(Debug, Serialize)]
46#[serde(rename_all = "camelCase")]
47struct GenerationConfig {
48    #[serde(skip_serializing_if = "Option::is_none")]
49    temperature: Option<f32>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    max_output_tokens: Option<u32>,
52}
53
54// Response structures
55#[derive(Debug, Deserialize)]
56struct GeminiResponse {
57    candidates: Option<Vec<Candidate>>,
58    usage_metadata: Option<UsageMetadata>,
59}
60
61#[derive(Debug, Deserialize)]
62struct Candidate {
63    content: ContentResponse,
64}
65
66#[derive(Debug, Deserialize)]
67struct ContentResponse {
68    parts: Vec<PartResponse>,
69}
70
71#[derive(Debug, Deserialize)]
72struct PartResponse {
73    text: String,
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(rename_all = "camelCase")]
78struct UsageMetadata {
79    total_token_count: u32,
80}
81
82impl GeminiProvider {
83    /// Create a new Gemini provider with the given configuration.
84    pub fn new(config: ProviderConfig) -> Result<Self> {
85        let timeout = config.timeout_seconds.unwrap_or(60);
86        let client = Client::builder()
87            .timeout(std::time::Duration::from_secs(timeout))
88            .build()
89            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
90
91        Ok(Self { client, config })
92    }
93
94    /// Create a provider from environment variables.
95    ///
96    /// Reads `GOOGLE_API_KEY` and optionally `GEMINI_MODEL`.
97    pub fn from_env() -> Result<Self> {
98        let api_key = std::env::var("GOOGLE_API_KEY")
99            .map_err(|_| AetherError::ConfigError("GOOGLE_API_KEY not set".to_string()))?;
100
101        let model = std::env::var("GEMINI_MODEL").unwrap_or_else(|_| "gemini-1.5-pro".to_string());
102        
103        // Google API key is query param, not header like OpenAI
104        // We store it in config.api_key but will use it in URL
105        let config = ProviderConfig::new(api_key, model);
106        Self::new(config)
107    }
108
109    /// Build the specific prompt for Gemini
110    fn build_prompt(&self, kind: &SlotKind, context: Option<&str>, user_prompt: &str) -> String {
111        let base_instructions = match kind {
112            SlotKind::Html => "Generate valid HTML5 markup.",
113            SlotKind::Css => "Generate valid CSS styles.",
114            SlotKind::JavaScript => "Generate valid JavaScript code.",
115            SlotKind::Function => "Generate a complete function definition.",
116            SlotKind::Class => "Generate a complete class/struct definition.",
117            SlotKind::Component => "Generate a complete component with HTML, CSS, and JavaScript as needed.",
118            _ => "Generate code based on the request.",
119        };
120
121        let context_str = context
122            .map(|c| format!("\nContext:\n{}", c))
123            .unwrap_or_default();
124
125        format!(
126            "Role: Code Generator. Task: {}\n{}\nRequest: {}\nOutput only raw code, no markdown.",
127            base_instructions, context_str, user_prompt
128        )
129    }
130}
131
132#[async_trait]
133impl AiProvider for GeminiProvider {
134    fn name(&self) -> &str {
135        "gemini"
136    }
137
138    #[instrument(skip(self, request), fields(slot = %request.slot.name))]
139    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
140        debug!("Generating code with Gemini for slot: {}", request.slot.name);
141
142        let api_key = self.config.resolve_api_key().await?;
143
144        // Gemini API is slightly different (no system role in v1beta easily)
145        // so we verify robust prompt engineering in the user message
146        let full_prompt = self.build_prompt(&request.slot.kind, request.context.as_deref(), &request.slot.prompt);
147
148        let contents = vec![Content {
149            role: "user".to_string(),
150            parts: vec![Part { text: full_prompt }],
151        }];
152
153        let temperature = request.slot.temperature.or(self.config.temperature);
154        let api_request = GeminiRequest {
155            contents,
156            generation_config: Some(GenerationConfig {
157                temperature,
158                max_output_tokens: request.max_tokens.or(self.config.max_tokens),
159            }),
160        };
161
162        let model = request.model.clone().unwrap_or_else(|| self.config.model.clone());
163        let url = format!(
164            "{}/{}:generateContent?key={}",
165            GEMINI_API_BASE, model, api_key
166        );
167
168        let response = self
169            .client
170            .post(&url)
171            .header("Content-Type", "application/json")
172            .json(&api_request)
173            .send()
174            .await
175            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
176
177        if !response.status().is_success() {
178            let status = response.status();
179            let body = response.text().await.unwrap_or_default();
180            return Err(AetherError::ProviderError(format!(
181                "API error {}: {}",
182                status, body
183            )));
184        }
185
186        let gemini_response: GeminiResponse = response
187            .json()
188            .await
189            .map_err(|e| AetherError::ProviderError(e.to_string()))?;
190
191        // Extract text from the first candidate
192        let code = gemini_response
193            .candidates
194            .as_ref()
195            .and_then(|c| c.first())
196            .and_then(|c| c.content.parts.first())
197            .map(|p| p.text.clone())
198            .ok_or_else(|| AetherError::ProviderError("No content generated".to_string()))?;
199
200        // Clean up markdown
201        let code = code.trim().trim_start_matches("```").trim_end_matches("```");
202        // Sometimes it includes the language name like ```rust ... ```
203        let code = if let Some(newline_idx) = code.find('\n') {
204            if code[..newline_idx].chars().all(char::is_alphanumeric) {
205                &code[newline_idx + 1..]
206            } else {
207                code
208            }
209        } else {
210            code
211        };
212
213        Ok(GenerationResponse {
214            code: code.to_string(),
215            tokens_used: gemini_response.usage_metadata.map(|u| u.total_token_count),
216            metadata: None,
217        })
218    }
219
220    fn generate_stream(
221        &self,
222        request: GenerationRequest,
223    ) -> BoxStream<'static, Result<StreamResponse>> {
224        let client = self.client.clone();
225        let config = self.config.clone();
226        let full_prompt = self.build_prompt(&request.slot.kind, request.context.as_deref(), &request.slot.prompt);
227        
228        let temperature = request.slot.temperature.or(config.temperature);
229        let api_request = GeminiRequest {
230            contents: vec![Content {
231                role: "user".to_string(),
232                parts: vec![Part { text: full_prompt }],
233            }],
234            generation_config: Some(GenerationConfig {
235                temperature,
236                max_output_tokens: request.max_tokens.or(config.max_tokens),
237            }),
238        };
239
240        let stream = async_stream::stream! {
241            let api_key = match config.resolve_api_key().await {
242                Ok(k) => k,
243                Err(e) => {
244                    yield Err(e);
245                    return;
246                }
247            };
248
249            let model = request.model.clone().unwrap_or_else(|| config.model.clone());
250            let url = format!(
251                "{}/{}:streamGenerateContent?alt=sse&key={}",
252                GEMINI_API_BASE, model, api_key
253            );
254
255            let response = client
256                .post(&url)
257                .header("Content-Type", "application/json")
258                .json(&api_request)
259                .send()
260                .await
261                .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
262
263            let response = match response {
264                Ok(r) => r,
265                Err(e) => {
266                    yield Err(e);
267                    return;
268                }
269            };
270
271            if !response.status().is_success() {
272                let status = response.status();
273                let body = response.text().await.unwrap_or_default();
274                yield Err(aether_core::AetherError::ProviderError(format!(
275                    "API error {}: {}",
276                    status, body
277                )));
278                return;
279            }
280
281            let mut stream = response.bytes_stream();
282            
283            while let Some(chunk_result) = stream.next().await {
284                let chunk = match chunk_result {
285                    Ok(c) => c,
286                    Err(e) => {
287                        yield Err(aether_core::AetherError::NetworkError(e.to_string()));
288                        break;
289                    }
290                };
291
292                let text = String::from_utf8_lossy(&chunk);
293                for line in text.lines() {
294                    let line = line.trim();
295                    if line.is_empty() { continue; }
296                    
297                    if let Some(event_data) = line.strip_prefix("data: ") {
298                        if let Ok(gemini_resp) = serde_json::from_str::<GeminiResponse>(event_data) {
299                            if let Some(candidate) = gemini_resp.candidates.as_ref().and_then(|c| c.first()) {
300                                if let Some(part) = candidate.content.parts.first() {
301                                    yield Ok(StreamResponse {
302                                        delta: part.text.clone(),
303                                        metadata: None,
304                                    });
305                                }
306                            }
307                        }
308                    }
309                }
310            }
311        };
312
313        Box::pin(stream)
314    }
315
316    async fn health_check(&self) -> Result<bool> {
317        let api_key = self.config.resolve_api_key().await?;
318        // Minimal check - try to get model info
319         let url = format!(
320            "{}/{}?key={}",
321            GEMINI_API_BASE, self.config.model, api_key
322        );
323
324        let response = self
325            .client
326            .get(&url)
327            .send()
328            .await
329            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
330
331        Ok(response.status().is_success())
332    }
333}