Skip to main content

rusty_commit/providers/
ollama.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::{build_prompt, AIProvider};
7// Note: Ollama uses the combined prompt since its API doesn't support separate system messages
8use crate::config::Config;
9use crate::utils::retry::retry_async;
10
11pub struct OllamaProvider {
12    client: Client,
13    api_url: String,
14    model: String,
15}
16
17#[derive(Serialize)]
18struct OllamaRequest {
19    model: String,
20    prompt: String,
21    stream: bool,
22    options: OllamaOptions,
23}
24
25#[derive(Serialize)]
26struct OllamaOptions {
27    temperature: f32,
28    num_predict: i32,
29}
30
31#[derive(Deserialize)]
32struct OllamaResponse {
33    response: String,
34}
35
36impl OllamaProvider {
37    pub fn new(config: &Config) -> Result<Self> {
38        let client = Client::new();
39        let api_url = config
40            .api_url
41            .as_deref()
42            .unwrap_or("http://localhost:11434")
43            .to_string();
44        let model = config.model.as_deref().unwrap_or("mistral").to_string();
45
46        Ok(Self {
47            client,
48            api_url,
49            model,
50        })
51    }
52
53    /// Create provider from account configuration
54    #[allow(dead_code)]
55    pub fn from_account(
56        account: &crate::config::accounts::AccountConfig,
57        _api_key: &str,
58        config: &Config,
59    ) -> Result<Self> {
60        let client = Client::new();
61        let api_url = account
62            .api_url
63            .as_deref()
64            .or(config.api_url.as_deref())
65            .unwrap_or("http://localhost:11434")
66            .to_string();
67        let model = account
68            .model
69            .as_deref()
70            .or(config.model.as_deref())
71            .unwrap_or("mistral")
72            .to_string();
73
74        Ok(Self {
75            client,
76            api_url,
77            model,
78        })
79    }
80}
81
82#[async_trait]
83impl AIProvider for OllamaProvider {
84    async fn generate_commit_message(
85        &self,
86        diff: &str,
87        context: Option<&str>,
88        full_gitmoji: bool,
89        config: &Config,
90    ) -> Result<String> {
91        let prompt = build_prompt(diff, context, config, full_gitmoji);
92
93        let request = OllamaRequest {
94            model: self.model.clone(),
95            prompt,
96            stream: false,
97            options: OllamaOptions {
98                temperature: 0.7,
99                num_predict: config.tokens_max_output.unwrap_or(500) as i32,
100            },
101        };
102
103        let ollama_response: OllamaResponse = retry_async(|| async {
104            let url = format!("{}/api/generate", self.api_url);
105            let response = self
106                .client
107                .post(&url)
108                .json(&request)
109                .send()
110                .await
111                .context("Failed to connect to Ollama")?;
112
113            if !response.status().is_success() {
114                let error_text = response.text().await?;
115                return Err(anyhow::anyhow!("Ollama API error: {}", error_text));
116            }
117
118            let ollama_response: OllamaResponse = response
119                .json()
120                .await
121                .context("Failed to parse Ollama response")?;
122
123            Ok(ollama_response)
124        })
125        .await
126        .context("Failed to generate commit message from Ollama after retries")?;
127
128        Ok(ollama_response.response.trim().to_string())
129    }
130}
131
132/// ProviderBuilder for Ollama
133pub struct OllamaProviderBuilder;
134
135impl super::registry::ProviderBuilder for OllamaProviderBuilder {
136    fn name(&self) -> &'static str {
137        "ollama"
138    }
139
140    fn category(&self) -> super::registry::ProviderCategory {
141        super::registry::ProviderCategory::Local
142    }
143
144    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
145        Ok(Box::new(OllamaProvider::new(config)?))
146    }
147
148    fn requires_api_key(&self) -> bool {
149        false
150    }
151
152    fn default_model(&self) -> Option<&'static str> {
153        Some("llama3.1")
154    }
155}