Skip to main content

rusty_commit/providers/
ollama.rs

1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::prompt::build_prompt;
7use super::AIProvider;
8// Note: Ollama uses the combined prompt since its API doesn't support separate system messages
9use crate::config::Config;
10use crate::utils::retry::retry_async;
11
12pub struct OllamaProvider {
13    client: Client,
14    api_url: String,
15    model: String,
16}
17
18#[derive(Serialize)]
19struct OllamaRequest {
20    model: String,
21    prompt: String,
22    stream: bool,
23    options: OllamaOptions,
24}
25
26#[derive(Serialize)]
27struct OllamaOptions {
28    temperature: f32,
29    num_predict: i32,
30}
31
32#[derive(Deserialize)]
33struct OllamaResponse {
34    response: String,
35}
36
37impl OllamaProvider {
38    pub fn new(config: &Config) -> Result<Self> {
39        let client = Client::new();
40        let api_url = config
41            .api_url
42            .as_deref()
43            .unwrap_or("http://localhost:11434")
44            .to_string();
45        let model = config.model.as_deref().unwrap_or("mistral").to_string();
46
47        Ok(Self {
48            client,
49            api_url,
50            model,
51        })
52    }
53
54    /// Create provider from account configuration
55    #[allow(dead_code)]
56    pub fn from_account(
57        account: &crate::config::accounts::AccountConfig,
58        _api_key: &str,
59        config: &Config,
60    ) -> Result<Self> {
61        let client = Client::new();
62        let api_url = account
63            .api_url
64            .as_deref()
65            .or(config.api_url.as_deref())
66            .unwrap_or("http://localhost:11434")
67            .to_string();
68        let model = account
69            .model
70            .as_deref()
71            .or(config.model.as_deref())
72            .unwrap_or("mistral")
73            .to_string();
74
75        Ok(Self {
76            client,
77            api_url,
78            model,
79        })
80    }
81}
82
83#[async_trait]
84impl AIProvider for OllamaProvider {
85    async fn generate_commit_message(
86        &self,
87        diff: &str,
88        context: Option<&str>,
89        full_gitmoji: bool,
90        config: &Config,
91    ) -> Result<String> {
92        let prompt = build_prompt(diff, context, config, full_gitmoji);
93
94        let request = OllamaRequest {
95            model: self.model.clone(),
96            prompt,
97            stream: false,
98            options: OllamaOptions {
99                temperature: 0.7,
100                num_predict: config.tokens_max_output.unwrap_or(500) as i32,
101            },
102        };
103
104        let ollama_response: OllamaResponse = retry_async(|| async {
105            let url = format!("{}/api/generate", self.api_url);
106            let response = self
107                .client
108                .post(&url)
109                .json(&request)
110                .send()
111                .await
112                .context("Failed to connect to Ollama")?;
113
114            if !response.status().is_success() {
115                let error_text = response.text().await?;
116                return Err(anyhow::anyhow!("Ollama API error: {}", error_text));
117            }
118
119            let ollama_response: OllamaResponse = response
120                .json()
121                .await
122                .context("Failed to parse Ollama response")?;
123
124            Ok(ollama_response)
125        })
126        .await
127        .context("Failed to generate commit message from Ollama after retries")?;
128
129        Ok(ollama_response.response.trim().to_string())
130    }
131}
132
133/// ProviderBuilder for Ollama
134pub struct OllamaProviderBuilder;
135
136impl super::registry::ProviderBuilder for OllamaProviderBuilder {
137    fn name(&self) -> &'static str {
138        "ollama"
139    }
140
141    fn category(&self) -> super::registry::ProviderCategory {
142        super::registry::ProviderCategory::Local
143    }
144
145    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
146        Ok(Box::new(OllamaProvider::new(config)?))
147    }
148
149    fn requires_api_key(&self) -> bool {
150        false
151    }
152
153    fn default_model(&self) -> Option<&'static str> {
154        Some("llama3.1")
155    }
156}