1use aether_core::{
6 AetherError, AiProvider, ProviderConfig, Result,
7 provider::{GenerationRequest, GenerationResponse},
8 SlotKind,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use tracing::{debug, instrument};
14use aether_core::provider::StreamResponse;
15use futures::stream::{BoxStream, StreamExt};
16
17const GEMINI_API_BASE: &str = "https://generativelanguage.googleapis.com/v1beta/models";
18
19#[derive(Debug, Clone)]
21pub struct GeminiProvider {
22 client: Client,
23 config: ProviderConfig,
24}
25
26#[derive(Debug, Serialize)]
28struct GeminiRequest {
29 contents: Vec<Content>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 generation_config: Option<GenerationConfig>,
32}
33
34#[derive(Debug, Serialize)]
35struct Content {
36 parts: Vec<Part>,
37 role: String,
38}
39
40#[derive(Debug, Serialize)]
41struct Part {
42 text: String,
43}
44
45#[derive(Debug, Serialize)]
46#[serde(rename_all = "camelCase")]
47struct GenerationConfig {
48 #[serde(skip_serializing_if = "Option::is_none")]
49 temperature: Option<f32>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 max_output_tokens: Option<u32>,
52}
53
54#[derive(Debug, Deserialize)]
56struct GeminiResponse {
57 candidates: Option<Vec<Candidate>>,
58 usage_metadata: Option<UsageMetadata>,
59}
60
61#[derive(Debug, Deserialize)]
62struct Candidate {
63 content: ContentResponse,
64}
65
66#[derive(Debug, Deserialize)]
67struct ContentResponse {
68 parts: Vec<PartResponse>,
69}
70
71#[derive(Debug, Deserialize)]
72struct PartResponse {
73 text: String,
74}
75
76#[derive(Debug, Deserialize)]
77#[serde(rename_all = "camelCase")]
78struct UsageMetadata {
79 total_token_count: u32,
80}
81
82impl GeminiProvider {
83 pub fn new(config: ProviderConfig) -> Result<Self> {
85 let timeout = config.timeout_seconds.unwrap_or(60);
86 let client = Client::builder()
87 .timeout(std::time::Duration::from_secs(timeout))
88 .build()
89 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
90
91 Ok(Self { client, config })
92 }
93
94 pub fn from_env() -> Result<Self> {
98 let api_key = std::env::var("GOOGLE_API_KEY")
99 .map_err(|_| AetherError::ConfigError("GOOGLE_API_KEY not set".to_string()))?;
100
101 let model = std::env::var("GEMINI_MODEL").unwrap_or_else(|_| "gemini-1.5-pro".to_string());
102
103 let config = ProviderConfig::new(api_key, model);
106 Self::new(config)
107 }
108
109 fn build_prompt(&self, kind: &SlotKind, context: Option<&str>, user_prompt: &str) -> String {
111 let base_instructions = match kind {
112 SlotKind::Html => "Generate valid HTML5 markup.",
113 SlotKind::Css => "Generate valid CSS styles.",
114 SlotKind::JavaScript => "Generate valid JavaScript code.",
115 SlotKind::Function => "Generate a complete function definition.",
116 SlotKind::Class => "Generate a complete class/struct definition.",
117 SlotKind::Component => "Generate a complete component with HTML, CSS, and JavaScript as needed.",
118 _ => "Generate code based on the request.",
119 };
120
121 let context_str = context
122 .map(|c| format!("\nContext:\n{}", c))
123 .unwrap_or_default();
124
125 format!(
126 "Role: Code Generator. Task: {}\n{}\nRequest: {}\nOutput only raw code, no markdown.",
127 base_instructions, context_str, user_prompt
128 )
129 }
130}
131
132#[async_trait]
133impl AiProvider for GeminiProvider {
134 fn name(&self) -> &str {
135 "gemini"
136 }
137
138 #[instrument(skip(self, request), fields(slot = %request.slot.name))]
139 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
140 debug!("Generating code with Gemini for slot: {}", request.slot.name);
141
142 let api_key = self.config.resolve_api_key().await?;
143
144 let full_prompt = self.build_prompt(&request.slot.kind, request.context.as_deref(), &request.slot.prompt);
147
148 let contents = vec![Content {
149 role: "user".to_string(),
150 parts: vec![Part { text: full_prompt }],
151 }];
152
153 let temperature = request.slot.temperature.or(self.config.temperature);
154 let api_request = GeminiRequest {
155 contents,
156 generation_config: Some(GenerationConfig {
157 temperature,
158 max_output_tokens: request.max_tokens.or(self.config.max_tokens),
159 }),
160 };
161
162 let model = request.model.clone().unwrap_or_else(|| self.config.model.clone());
163 let url = format!(
164 "{}/{}:generateContent?key={}",
165 GEMINI_API_BASE, model, api_key
166 );
167
168 let response = self
169 .client
170 .post(&url)
171 .header("Content-Type", "application/json")
172 .json(&api_request)
173 .send()
174 .await
175 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
176
177 if !response.status().is_success() {
178 let status = response.status();
179 let body = response.text().await.unwrap_or_default();
180 return Err(AetherError::ProviderError(format!(
181 "API error {}: {}",
182 status, body
183 )));
184 }
185
186 let gemini_response: GeminiResponse = response
187 .json()
188 .await
189 .map_err(|e| AetherError::ProviderError(e.to_string()))?;
190
191 let code = gemini_response
193 .candidates
194 .as_ref()
195 .and_then(|c| c.first())
196 .and_then(|c| c.content.parts.first())
197 .map(|p| p.text.clone())
198 .ok_or_else(|| AetherError::ProviderError("No content generated".to_string()))?;
199
200 let code = code.trim().trim_start_matches("```").trim_end_matches("```");
202 let code = if let Some(newline_idx) = code.find('\n') {
204 if code[..newline_idx].chars().all(char::is_alphanumeric) {
205 &code[newline_idx + 1..]
206 } else {
207 code
208 }
209 } else {
210 code
211 };
212
213 Ok(GenerationResponse {
214 code: code.to_string(),
215 tokens_used: gemini_response.usage_metadata.map(|u| u.total_token_count),
216 metadata: None,
217 })
218 }
219
220 fn generate_stream(
221 &self,
222 request: GenerationRequest,
223 ) -> BoxStream<'static, Result<StreamResponse>> {
224 let client = self.client.clone();
225 let config = self.config.clone();
226 let full_prompt = self.build_prompt(&request.slot.kind, request.context.as_deref(), &request.slot.prompt);
227
228 let temperature = request.slot.temperature.or(config.temperature);
229 let api_request = GeminiRequest {
230 contents: vec![Content {
231 role: "user".to_string(),
232 parts: vec![Part { text: full_prompt }],
233 }],
234 generation_config: Some(GenerationConfig {
235 temperature,
236 max_output_tokens: request.max_tokens.or(config.max_tokens),
237 }),
238 };
239
240 let stream = async_stream::stream! {
241 let api_key = match config.resolve_api_key().await {
242 Ok(k) => k,
243 Err(e) => {
244 yield Err(e);
245 return;
246 }
247 };
248
249 let model = request.model.clone().unwrap_or_else(|| config.model.clone());
250 let url = format!(
251 "{}/{}:streamGenerateContent?alt=sse&key={}",
252 GEMINI_API_BASE, model, api_key
253 );
254
255 let response = client
256 .post(&url)
257 .header("Content-Type", "application/json")
258 .json(&api_request)
259 .send()
260 .await
261 .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
262
263 let response = match response {
264 Ok(r) => r,
265 Err(e) => {
266 yield Err(e);
267 return;
268 }
269 };
270
271 if !response.status().is_success() {
272 let status = response.status();
273 let body = response.text().await.unwrap_or_default();
274 yield Err(aether_core::AetherError::ProviderError(format!(
275 "API error {}: {}",
276 status, body
277 )));
278 return;
279 }
280
281 let mut stream = response.bytes_stream();
282
283 while let Some(chunk_result) = stream.next().await {
284 let chunk = match chunk_result {
285 Ok(c) => c,
286 Err(e) => {
287 yield Err(aether_core::AetherError::NetworkError(e.to_string()));
288 break;
289 }
290 };
291
292 let text = String::from_utf8_lossy(&chunk);
293 for line in text.lines() {
294 let line = line.trim();
295 if line.is_empty() { continue; }
296
297 if let Some(event_data) = line.strip_prefix("data: ") {
298 if let Ok(gemini_resp) = serde_json::from_str::<GeminiResponse>(event_data) {
299 if let Some(candidate) = gemini_resp.candidates.as_ref().and_then(|c| c.first()) {
300 if let Some(part) = candidate.content.parts.first() {
301 yield Ok(StreamResponse {
302 delta: part.text.clone(),
303 metadata: None,
304 });
305 }
306 }
307 }
308 }
309 }
310 }
311 };
312
313 Box::pin(stream)
314 }
315
316 async fn health_check(&self) -> Result<bool> {
317 let api_key = self.config.resolve_api_key().await?;
318 let url = format!(
320 "{}/{}?key={}",
321 GEMINI_API_BASE, self.config.model, api_key
322 );
323
324 let response = self
325 .client
326 .get(&url)
327 .send()
328 .await
329 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
330
331 Ok(response.status().is_success())
332 }
333}