Skip to main content

rusty_commit/providers/
mlx.rs

1//! MLX Provider - Apple's ML framework for Apple Silicon
2//!
3//! MLX is Apple's machine learning framework optimized for Apple Silicon.
4//! This provider connects to an MLX HTTP server running locally.
5//!
6//! Setup:
7//! 1. Install mlx-lm: `pip install mlx-lm`
8//! 2. Start server: `python -m mlx_lm.server --model mlx-community/Llama-3.2-3B-Instruct-4bit`
9//! 3. Configure rco: `rco config set RCO_AI_PROVIDER=mlx RCO_API_URL=http://localhost:8080`
10
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15
16use super::{build_prompt, AIProvider};
17use crate::config::Config;
18use crate::utils::retry::retry_async;
19
20pub struct MlxProvider {
21    client: Client,
22    api_url: String,
23    model: String,
24}
25
26#[derive(Serialize)]
27struct MlxRequest {
28    model: String,
29    messages: Vec<MlxMessage>,
30    max_tokens: i32,
31    temperature: f32,
32    stream: bool,
33}
34
35#[derive(Serialize, Deserialize, Clone)]
36struct MlxMessage {
37    role: String,
38    content: String,
39}
40
41#[derive(Deserialize)]
42struct MlxResponse {
43    choices: Vec<MlxChoice>,
44}
45
46#[derive(Deserialize)]
47struct MlxChoice {
48    message: MlxMessage,
49}
50
51impl MlxProvider {
52    pub fn new(config: &Config) -> Result<Self> {
53        let client = Client::new();
54        let api_url = config
55            .api_url
56            .as_deref()
57            .unwrap_or("http://localhost:8080")
58            .to_string();
59        let model = config.model.as_deref().unwrap_or("default").to_string();
60
61        Ok(Self {
62            client,
63            api_url,
64            model,
65        })
66    }
67
68    /// Create provider from account configuration
69    #[allow(dead_code)]
70    pub fn from_account(
71        account: &crate::config::accounts::AccountConfig,
72        _api_key: &str,
73        config: &Config,
74    ) -> Result<Self> {
75        let client = Client::new();
76        let api_url = account
77            .api_url
78            .as_deref()
79            .or(config.api_url.as_deref())
80            .unwrap_or("http://localhost:8080")
81            .to_string();
82        let model = account
83            .model
84            .as_deref()
85            .or(config.model.as_deref())
86            .unwrap_or("default")
87            .to_string();
88
89        Ok(Self {
90            client,
91            api_url,
92            model,
93        })
94    }
95}
96
97#[async_trait]
98impl AIProvider for MlxProvider {
99    async fn generate_commit_message(
100        &self,
101        diff: &str,
102        context: Option<&str>,
103        full_gitmoji: bool,
104        config: &Config,
105    ) -> Result<String> {
106        let prompt = build_prompt(diff, context, config, full_gitmoji);
107
108        // MLX uses OpenAI-compatible chat format
109        let messages = vec![
110            MlxMessage {
111                role: "system".to_string(),
112                content: "You are an expert at writing clear, concise git commit messages."
113                    .to_string(),
114            },
115            MlxMessage {
116                role: "user".to_string(),
117                content: prompt,
118            },
119        ];
120
121        let request = MlxRequest {
122            model: self.model.clone(),
123            messages,
124            max_tokens: config.tokens_max_output.unwrap_or(500) as i32,
125            temperature: 0.7,
126            stream: false,
127        };
128
129        let mlx_response: MlxResponse = retry_async(|| async {
130            let url = format!("{}/v1/chat/completions", self.api_url);
131            let response = self
132                .client
133                .post(&url)
134                .json(&request)
135                .send()
136                .await
137                .context("Failed to connect to MLX server")?;
138
139            if !response.status().is_success() {
140                let error_text = response.text().await?;
141                return Err(anyhow::anyhow!("MLX API error: {}", error_text));
142            }
143
144            let mlx_response: MlxResponse = response
145                .json()
146                .await
147                .context("Failed to parse MLX response")?;
148
149            Ok(mlx_response)
150        })
151        .await
152        .context("Failed to generate commit message from MLX after retries")?;
153
154        let message = mlx_response
155            .choices
156            .first()
157            .map(|choice| choice.message.content.trim().to_string())
158            .context("MLX returned an empty response")?;
159
160        Ok(message)
161    }
162}
163
164/// ProviderBuilder for MLX
165pub struct MlxProviderBuilder;
166
167impl super::registry::ProviderBuilder for MlxProviderBuilder {
168    fn name(&self) -> &'static str {
169        "mlx"
170    }
171
172    fn aliases(&self) -> Vec<&'static str> {
173        vec!["mlx-lm", "apple-mlx"]
174    }
175
176    fn category(&self) -> super::registry::ProviderCategory {
177        super::registry::ProviderCategory::Local
178    }
179
180    fn create(&self, config: &Config) -> Result<Box<dyn super::AIProvider>> {
181        Ok(Box::new(MlxProvider::new(config)?))
182    }
183
184    fn requires_api_key(&self) -> bool {
185        false
186    }
187
188    fn default_model(&self) -> Option<&'static str> {
189        Some("mlx-community/Llama-3.2-3B-Instruct-4bit")
190    }
191}