Skip to main content

aether_ai/
ollama.rs

1//! Ollama local provider implementation.
2//!
3//! Supports local LLM models through Ollama.
4
5use aether_core::{
6    AetherError, AiProvider, Result,
7    provider::{GenerationRequest, GenerationResponse},
8    SlotKind,
9};
10use async_trait::async_trait;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use aether_core::provider::StreamResponse;
14use futures::stream::{BoxStream, StreamExt};
15use tracing::{debug, instrument};
16
17const DEFAULT_OLLAMA_URL: &str = "http://localhost:11434/api/generate";
18
19/// Ollama provider for local code generation.
20#[derive(Debug, Clone)]
21pub struct OllamaProvider {
22    client: Client,
23    model: String,
24    base_url: String,
25}
26
27/// Ollama generate request.
28#[derive(Debug, Serialize)]
29struct GenerateRequest {
30    model: String,
31    prompt: String,
32    system: Option<String>,
33    stream: bool,
34    options: Option<GenerateOptions>,
35}
36
37#[derive(Debug, Serialize)]
38struct GenerateOptions {
39    #[serde(skip_serializing_if = "Option::is_none")]
40    temperature: Option<f32>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    num_predict: Option<u32>,
43}
44
45/// Ollama generate response.
46#[derive(Debug, Deserialize)]
47#[allow(dead_code)]
48struct GenerateResponse {
49    response: String,
50    done: bool,
51    #[serde(default)]
52    eval_count: Option<u32>,
53}
54
55impl OllamaProvider {
56    /// Create a new Ollama provider with the given model.
57    pub fn new(model: impl Into<String>) -> Self {
58        Self::with_options(model, DEFAULT_OLLAMA_URL)
59    }
60
61    /// Create a provider with a custom URL.
62    pub fn with_options(model: impl Into<String>, base_url: impl Into<String>) -> Self {
63        let client = Client::builder()
64            .timeout(std::time::Duration::from_secs(300)) // Local models can be slow
65            .build()
66            .expect("Failed to create HTTP client");
67
68        Self {
69            client,
70            model: model.into(),
71            base_url: base_url.into(),
72        }
73    }
74
75    /// Create from environment variables.
76    ///
77    /// Reads `OLLAMA_MODEL` and optionally `OLLAMA_URL`.
78    pub fn from_env() -> Self {
79        let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "codellama".to_string());
80        let url = std::env::var("OLLAMA_URL").unwrap_or_else(|_| DEFAULT_OLLAMA_URL.to_string());
81        Self::with_options(model, url)
82    }
83
84    /// Build the system prompt for code generation.
85    fn build_system_prompt(&self, kind: &SlotKind, context: Option<&str>) -> String {
86        let base = "You are a code generation assistant. Generate only the requested code without explanations or markdown code blocks. Output raw code only.";
87
88        let kind_specific = match kind {
89            SlotKind::Html => "\nGenerate valid HTML5 markup.",
90            SlotKind::Css => "\nGenerate valid CSS styles.",
91            SlotKind::JavaScript => "\nGenerate valid JavaScript code.",
92            SlotKind::Function => "\nGenerate a complete function definition.",
93            SlotKind::Class => "\nGenerate a complete class/struct definition.",
94            SlotKind::Component => "\nGenerate a complete component with HTML, CSS, and JavaScript as needed.",
95            _ => "",
96        };
97
98        let context_part = context
99            .filter(|c| !c.is_empty())
100            .map(|c| format!("\n\nContext:\n{}", c))
101            .unwrap_or_default();
102
103        format!("{}{}{}", base, kind_specific, context_part)
104    }
105}
106
107#[async_trait]
108impl AiProvider for OllamaProvider {
109    fn name(&self) -> &str {
110        "ollama"
111    }
112
113    #[instrument(skip(self, request), fields(slot = %request.slot.name))]
114    async fn generate(&self, request: GenerationRequest) -> Result<GenerationResponse> {
115        debug!("Generating code with Ollama for slot: {}", request.slot.name);
116
117        let system = Some(request.system_prompt.unwrap_or_else(|| {
118            self.build_system_prompt(&request.slot.kind, request.context.as_deref())
119        }));
120
121        let temperature = request.slot.temperature.unwrap_or(0.7);
122        let api_request = GenerateRequest {
123            model: request.model.clone().unwrap_or_else(|| self.model.clone()),
124            prompt: request.slot.prompt.clone(),
125            system,
126            stream: false,
127            options: Some(GenerateOptions {
128                temperature: Some(temperature),
129                num_predict: Some(request.max_tokens.unwrap_or(2048)),
130            }),
131        };
132
133        let response = self
134            .client
135            .post(&self.base_url)
136            .json(&api_request)
137            .send()
138            .await
139            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
140
141        if !response.status().is_success() {
142            let status = response.status();
143            let body = response.text().await.unwrap_or_default();
144            return Err(AetherError::ProviderError(format!(
145                "Ollama error {}: {}",
146                status, body
147            )));
148        }
149
150        let gen_response: GenerateResponse = response
151            .json()
152            .await
153            .map_err(|e| AetherError::ProviderError(e.to_string()))?;
154
155        let code = strip_code_blocks(&gen_response.response);
156
157        Ok(GenerationResponse {
158            code,
159            tokens_used: gen_response.eval_count,
160            metadata: None,
161        })
162    }
163
164    fn generate_stream(
165        &self,
166        request: GenerationRequest,
167    ) -> BoxStream<'static, Result<StreamResponse>> {
168        let client = self.client.clone();
169        let model = self.model.clone();
170        let base_url = self.base_url.clone();
171
172        let system = Some(request.system_prompt.unwrap_or_else(|| {
173            self.build_system_prompt(&request.slot.kind, request.context.as_deref())
174        }));
175
176        let temperature = request.slot.temperature.unwrap_or(0.7);
177        let api_request = GenerateRequest {
178            model: request.model.clone().unwrap_or_else(|| model.clone()),
179            prompt: request.slot.prompt.clone(),
180            system,
181            stream: true,
182            options: Some(GenerateOptions {
183                temperature: Some(temperature),
184                num_predict: Some(request.max_tokens.unwrap_or(2048)),
185            }),
186        };
187
188        let stream = async_stream::stream! {
189            let response = client
190                .post(&base_url)
191                .json(&api_request)
192                .send()
193                .await
194                .map_err(|e| aether_core::AetherError::NetworkError(e.to_string()));
195
196            let response = match response {
197                Ok(r) => r,
198                Err(e) => {
199                    yield Err(e);
200                    return;
201                }
202            };
203
204            if !response.status().is_success() {
205                let status = response.status();
206                let body = response.text().await.unwrap_or_default();
207                yield Err(aether_core::AetherError::ProviderError(format!(
208                    "Ollama error {}: {}",
209                    status, body
210                )));
211                return;
212            }
213
214            let mut stream = response.bytes_stream();
215            
216            while let Some(chunk_result) = stream.next().await {
217                let chunk = match chunk_result {
218                    Ok(c) => c,
219                    Err(e) => {
220                        yield Err(aether_core::AetherError::NetworkError(e.to_string()));
221                        break;
222                    }
223                };
224
225                let text = String::from_utf8_lossy(&chunk);
226                for line in text.lines() {
227                    let line = line.trim();
228                    if line.is_empty() { continue; }
229                    
230                    if let Ok(gen_resp) = serde_json::from_str::<GenerateResponse>(line) {
231                        yield Ok(StreamResponse {
232                            delta: gen_resp.response,
233                            metadata: None,
234                        });
235                        if gen_resp.done { break; }
236                    }
237                }
238            }
239        };
240
241        Box::pin(stream)
242    }
243
244    async fn health_check(&self) -> Result<bool> {
245        // Check if Ollama is running
246        let response = self
247            .client
248            .get("http://localhost:11434/api/tags")
249            .send()
250            .await
251            .map_err(|e| AetherError::NetworkError(e.to_string()))?;
252
253        Ok(response.status().is_success())
254    }
255}
256
257/// Strip markdown code blocks from generated code.
258fn strip_code_blocks(code: &str) -> String {
259    let code = code.trim();
260
261    if code.starts_with("```") && code.ends_with("```") {
262        let lines: Vec<&str> = code.lines().collect();
263        if lines.len() >= 2 {
264            return lines[1..lines.len() - 1].join("\n");
265        }
266    }
267
268    code.to_string()
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_provider_creation() {
277        let provider = OllamaProvider::new("codellama");
278        assert_eq!(provider.model, "codellama");
279    }
280}