1use aether_core::{
6 AetherError, AiProvider, Result,
7 provider::{GenerationRequest, GenerationResponse},
8 SlotKind,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use aether_core::provider::StreamResponse;
14use futures::stream::{BoxStream, StreamExt};
15use tracing::{debug, instrument};
16
17const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434/api/generate";
18
19#[derive(Debug, Clone)]
21pub struct OllamaProvider {
22 client: Client,
23 model: String,
24 base_url: String,
25}
26
27#[derive(Debug, Serialize)]
29struct GenerateRequest {
30 model: String,
31 prompt: String,
32 system: Option<String>,
33 stream: bool,
34 options: Option<GenerateOptions>,
35}
36
37#[derive(Debug, Serialize)]
38struct GenerateOptions {
39 #[serde(skip_serializing_if = "Option::is_none")]
40 temperature: Option<f32>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 num_predict: Option<u32>,
43}
44
45#[derive(Debug, Deserialize)]
47#[allow(dead_code)]
48struct GenerateResponse {
49 response: String,
50 done: bool,
51 #[serde(default)]
52 eval_count: Option<u32>,
53}
54
55impl OllamaProvider {
56 pub fn new(model: impl Into<String>) -> Self {
58 Self::with_options(model, DEFAULT_OLLAMA_URL)
59 }
60
61 pub fn with_options(model: impl Into<String>, base_url: impl Into<String>) -> Self {
63 let client = Client::builder()
64 .timeout(std::time::Duration::from_secs(300)) .build()
66 .expect("Failed to create HTTP client");
67
68 Self {
69 client,
70 model: model.into(),
71 base_url: base_url.into(),
72 }
73 }
74
75 pub fn from_env() -> Self {
79 let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "codellama".to_string());
80 let url = std::env::var("OLLAMA_URL").unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string());
81 Self::with_options(model, url)
82 }
83
84 fn build_system_prompt(&self, kind: &SlotKind, context: Option<&str>) -> String {
86 let base = "You are a code generation assistant. Generate only the requested code without explanations or markdown code blocks. Output raw code only.";
87
88 let kind_specific = match kind {
89 SlotKind::Html => "\nGenerate valid HTML5 markup.",
90 SlotKind::Css => "\nGenerate valid CSS styles.",
91 SlotKind::JavaScript => "\nGenerate valid JavaScript code.",
92 SlotKind::Function => "\nGenerate a complete function definition.",
93 SlotKind::Class => "\nGenerate a complete class/struct definition.",
94 SlotKind::Component => "\nGenerate a complete component with HTML, CSS, and JavaScript as needed.",
95 _ => "",
96 };
97
98 let context_part = context
99 .filter(|c| !c.is_empty())
100 .map(|c| format!("\n\nContext:\n{}", c))
101 .unwrap_or_default();
102
103 format!("{}{}{}", base, kind_specific, context_part)
104 }
105}
106
107#[async_trait]
108impl AiProvider for OllamaProvider {
109 fn name(&self) -> &str {
110 "ollama"
111 }
112
113 #[instrument(skip(self, request), fields(slot = %request.slot.name))]
114 async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
115 debug!("Generating code with Ollama for slot: {}", request.slot.name);
116
117 let system = Some(request.system_prompt.unwrap_or_else(|| {
118 self.build_system_prompt(&request.slot.kind, request.context.as_deref())
119 }));
120
121 let temperature = request.slot.temperature.unwrap_or(0.7);
122 let api_request = GenerateRequest {
123 model: request.model.clone().unwrap_or_else(|| self.model.clone()),
124 prompt: request.slot.prompt.clone(),
125 system,
126 stream: false,
127 options: Some(GenerateOptions {
128 temperature: Some(temperature),
129 num_predict: Some(request.max_tokens.unwrap_or(2048)),
130 }),
131 };
132
133 let response = self
134 .client
135 .post(&self.base_url)
136 .json(&api_request)
137 .send()
138 .await
139 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
140
141 if !response.status().is_success() {
142 let status = response.status();
143 let body = response.text().await.unwrap_or_default();
144 return Err(AetherError::ProviderError(format!(
145 "Ollama error {}: {}",
146 status, body
147 )));
148 }
149
150 let gen_response: GenerateResponse = response
151 .json()
152 .await
153 .map_err(|e| AetherError::ProviderError(e.to_string()))?;
154
155 let code = strip_code_blocks(&gen_response.response);
156
157 Ok(GenerationResponse {
158 code,
159 tokens_used: gen_response.eval_count,
160 metadata: None,
161 })
162 }
163
164 fn generate_stream(
165 &self,
166 request: GenerationRequest,
167 ) -> BoxStream<'static, Result<StreamResponse>> {
168 let client = self.client.clone();
169 let model = self.model.clone();
170 let base_url = self.base_url.clone();
171
172 let system = Some(request.system_prompt.unwrap_or_else(|| {
173 self.build_system_prompt(&request.slot.kind, request.context.as_deref())
174 }));
175
176 let temperature = request.slot.temperature.unwrap_or(0.7);
177 let api_request = GenerateRequest {
178 model: request.model.clone().unwrap_or_else(|| model.clone()),
179 prompt: request.slot.prompt.clone(),
180 system,
181 stream: true,
182 options: Some(GenerateOptions {
183 temperature: Some(temperature),
184 num_predict: Some(request.max_tokens.unwrap_or(2048)),
185 }),
186 };
187
188 let stream = async_stream::stream! {
189 let response = client
190 .post(&base_url)
191 .json(&api_request)
192 .send()
193 .await
194 .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
195
196 let response = match response {
197 Ok(r) => r,
198 Err(e) => {
199 yield Err(e);
200 return;
201 }
202 };
203
204 if !response.status().is_success() {
205 let status = response.status();
206 let body = response.text().await.unwrap_or_default();
207 yield Err(aether_core::AetherError::ProviderError(format!(
208 "Ollama error {}: {}",
209 status, body
210 )));
211 return;
212 }
213
214 let mut stream = response.bytes_stream();
215
216 while let Some(chunk_result) = stream.next().await {
217 let chunk = match chunk_result {
218 Ok(c) => c,
219 Err(e) => {
220 yield Err(aether_core::AetherError::NetworkError(e.to_string()));
221 break;
222 }
223 };
224
225 let text = String::from_utf8_lossy(&chunk);
226 for line in text.lines() {
227 let line = line.trim();
228 if line.is_empty() { continue; }
229
230 if let Ok(gen_resp) = serde_json::from_str::<GenerateResponse>(line) {
231 yield Ok(StreamResponse {
232 delta: gen_resp.response,
233 metadata: None,
234 });
235 if gen_resp.done { break; }
236 }
237 }
238 }
239 };
240
241 Box::pin(stream)
242 }
243
244 async fn health_check(&self) -> Result<bool> {
245 let response = self
247 .client
248 .get("http://localhost:11434/api/tags")
249 .send()
250 .await
251 .map_err(|e| AetherError::NetworkError(e.to_string()))?;
252
253 Ok(response.status().is_success())
254 }
255}
256
257fn strip_code_blocks(code: &str) -> String {
259 let code = code.trim();
260
261 if code.starts_with("```") && code.ends_with("```") {
262 let lines: Vec<&str> = code.lines().collect();
263 if lines.len() >= 2 {
264 return lines[1..lines.len() - 1].join("\n");
265 }
266 }
267
268 code.to_string()
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_provider_creation() {
277 let provider = OllamaProvider::new("codellama");
278 assert_eq!(provider.model, "codellama");
279 }
280}