rusty_commit/providers/
ollama.rs1use anyhow::{Context, Result};
2use async_trait::async_trait;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use super::prompt::build_prompt;
7use super::AIProvider;
8use crate::config::Config;
10use crate::utils::retry::retry_async;
11
12pub struct OllamaProvider {
13 client: Client,
14 api_url: String,
15 model: String,
16}
17
18#[derive(Serialize)]
19struct OllamaRequest {
20 model: String,
21 prompt: String,
22 stream: bool,
23 options: OllamaOptions,
24}
25
26#[derive(Serialize)]
27struct OllamaOptions {
28 temperature: f32,
29 num_predict: i32,
30}
31
32#[derive(Deserialize)]
33struct OllamaResponse {
34 response: String,
35}
36
37impl OllamaProvider {
38 pub fn new(config: &Config) -> Result<Self> {
39 let client = Client::new();
40 let api_url = config
41 .api_url
42 .as_deref()
43 .unwrap_or("http://localhost:11434")
44 .to_string();
45 let model = config.model.as_deref().unwrap_or("mistral").to_string();
46
47 Ok(Self {
48 client,
49 api_url,
50 model,
51 })
52 }
53
54 #[allow(dead_code)]
56 pub fn from_account(
57 account: &crate::config::accounts::AccountConfig,
58 _api_key: &str,
59 config: &Config,
60 ) -> Result<Self> {
61 let client = Client::new();
62 let api_url = account
63 .api_url
64 .as_deref()
65 .or(config.api_url.as_deref())
66 .unwrap_or("http://localhost:11434")
67 .to_string();
68 let model = account
69 .model
70 .as_deref()
71 .or(config.model.as_deref())
72 .unwrap_or("mistral")
73 .to_string();
74
75 Ok(Self {
76 client,
77 api_url,
78 model,
79 })
80 }
81}
82
83#[async_trait]
84impl AIProvider for OllamaProvider {
85 async fn generate_commit_message(
86 &self,
87 diff: &str,
88 context: Option<&str>,
89 full_gitmoji: bool,
90 config: &Config,
91 ) -> Result<String> {
92 let prompt = build_prompt(diff, context, config, full_gitmoji);
93
94 let request = OllamaRequest {
95 model: self.model.clone(),
96 prompt,
97 stream: false,
98 options: OllamaOptions {
99 temperature: 0.7,
100 num_predict: config.tokens_max_output.unwrap_or(500) as i32,
101 },
102 };
103
104 let ollama_response: OllamaResponse = retry_async(|| async {
105 let url = format!("{}/api/generate", self.api_url);
106 let response = self
107 .client
108 .post(&url)
109 .json(&request)
110 .send()
111 .await
112 .context("Failed to connect to Ollama")?;
113
114 if !response.status().is_success() {
115 let error_text = response.text().await?;
116 return Err(anyhow::anyhow!("Ollama API error: {}", error_text));
117 }
118
119 let ollama_response: OllamaResponse = response
120 .json()
121 .await
122 .context("Failed to parse Ollama response")?;
123
124 Ok(ollama_response)
125 })
126 .await
127 .context("Failed to generate commit message from Ollama after retries")?;
128
129 Ok(ollama_response.response.trim().to_string())
130 }
131}
132
133pub struct OllamaProviderBuilder;
135
136impl super::registry::ProviderBuilder for OllamaProviderBuilder {
137 fn name(&self) -> &'static str {
138 "ollama"
139 }
140
141 fn category(&self) -> super::registry::ProviderCategory {
142 super::registry::ProviderCategory::Local
143 }
144
145 fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
146 Ok(Box::new(OllamaProvider::new(config)?))
147 }
148
149 fn requires_api_key(&self) -> bool {
150 false
151 }
152
153 fn default_model(&self) -> Option<&'static str> {
154 Some("llama3.1")
155 }
156}