git_commit_sage/
ai.rs

1use serde::{Deserialize, Serialize};
2use crate::{Error, Result, AiConfig, is_conventional_commit};
3use reqwest::StatusCode;
4use std::{time::Duration};
5
6const API_URL: &str = "https://api.together.xyz/v1/chat/completions";
7const MAX_RETRIES: u32 = 3;
8const INITIAL_RETRY_DELAY_MS: u64 = 1000;
9
10#[derive(Debug)]
11struct CommitContext {
12    commit_type: String,
13    file_types: Vec<String>,
14    new_files: Vec<String>,
15    modified_files: Vec<String>,
16    total_additions: usize,
17    total_deletions: usize,
18}
19
20impl CommitContext {
21    fn from_diff(diff: &str) -> Self {
22        let mut context = CommitContext {
23            commit_type: String::new(),
24            file_types: Vec::new(),
25            new_files: Vec::new(),
26            modified_files: Vec::new(),
27            total_additions: 0,
28            total_deletions: 0,
29        };
30
31        let mut current_file = String::new();
32        for line in diff.lines() {
33            if line.starts_with("diff --git") {
34                current_file = line.split(' ').last().unwrap_or("").trim_start_matches('b').to_string();
35                if let Some(ext) = current_file.split('.').last() {
36                    context.file_types.push(ext.to_string());
37                }
38            } else if line.starts_with("new file") {
39                context.new_files.push(current_file.clone());
40            } else if line.starts_with("modified") {
41                context.modified_files.push(current_file.clone());
42            } else if line.starts_with('+') && !line.starts_with("+++") {
43                context.total_additions += 1;
44            } else if line.starts_with('-') && !line.starts_with("---") {
45                context.total_deletions += 1;
46            }
47        }
48
49        // Determine commit type based on context
50        context.commit_type = if context.new_files.iter().any(|f| f.contains("Cargo.toml")) 
51            && context.new_files.len() > 5 {
52            "initial project setup".to_string()
53        } else if context.file_types.iter().any(|t| t == "md" || t == "txt") 
54            && context.file_types.len() == 1 {
55            "documentation change".to_string()
56        } else if context.new_files.iter().any(|f| f.contains("test") || f.contains("spec")) {
57            "test addition".to_string()
58        } else if context.total_additions > 100 || context.new_files.len() > 5 {
59            "large feature implementation".to_string()
60        } else if context.total_deletions > context.total_additions * 2 {
61            "major refactoring".to_string()
62        } else {
63            "standard change".to_string()
64        };
65
66        context
67    }
68
69    fn get_suggested_type(&self) -> &'static str {
70        match self.commit_type.as_str() {
71            "initial project setup" => "feat",
72            "documentation change" => "docs",
73            "test addition" => "test",
74            "large feature implementation" => "feat",
75            "major refactoring" => "refactor",
76            _ => "feat"
77        }
78    }
79
80    fn to_prompt_context(&self) -> String {
81        format!(
82            "{} (suggested type: {}) with {} new files and {} modified files. \
83            Changes include {} additions and {} deletions across file types: {}",
84            self.commit_type,
85            self.get_suggested_type(),
86            self.new_files.len(),
87            self.modified_files.len(),
88            self.total_additions,
89            self.total_deletions,
90            self.file_types.join(", ")
91        )
92    }
93}
94
95#[derive(Debug, Serialize, Clone)]
96struct ChatMessage {
97    role: String,
98    content: String,
99}
100
101#[derive(Debug, Serialize, Clone)]
102struct ChatRequest {
103    model: String,
104    messages: Vec<ChatMessage>,
105    temperature: f32,
106    max_tokens: u32,
107    stop: Vec<String>,
108}
109
110#[derive(Debug, Deserialize)]
111struct ChatResponse {
112    choices: Vec<ChatChoice>,
113}
114
115#[derive(Debug, Deserialize)]
116struct ChatChoice {
117    message: ChatResponseMessage,
118}
119
120#[derive(Debug, Deserialize)]
121struct ChatResponseMessage {
122    content: String,
123}
124
125pub struct AiClient {
126    client: reqwest::Client,
127    api_key: String,
128    config: AiConfig,
129}
130
131impl AiClient {
132    pub fn new(api_key: String, config: AiConfig) -> Self {
133        Self {
134            client: reqwest::Client::new(),
135            api_key,
136            config,
137        }
138    }
139
140    pub async fn generate_commit_message(&self, diff: &str) -> Result<String> {
141        let context = CommitContext::from_diff(diff);
142        
143        let request = ChatRequest {
144            model: self.config.model.clone(),
145            messages: vec![
146                ChatMessage {
147                    role: "system".to_string(),
148                    content: self.config.system_prompt.clone(),
149                },
150                ChatMessage {
151                    role: "user".to_string(),
152                    content: self.config.user_prompt_template
153                        .replace("{}", &context.to_prompt_context())
154                        .replace("{}", diff),
155                },
156            ],
157            temperature: self.config.temperature,
158            max_tokens: self.config.max_tokens,
159            stop: self.config.stop_sequences.clone(),
160        };
161
162        let mut last_error = None;
163        for retry in 0..MAX_RETRIES {
164            if retry > 0 {
165                tokio::time::sleep(Duration::from_millis(
166                    INITIAL_RETRY_DELAY_MS * (2_u64.pow(retry - 1))
167                )).await;
168            }
169
170            match self.try_generate_message(&request).await {
171                Ok(message) => {
172                    // Pre-validate the message
173                    if !is_conventional_commit(&message) {
174                        continue; // Try again if format is invalid
175                    }
176                    // Validate the type matches the context
177                    let msg_type = message.split(':').next().unwrap_or("")
178                        .split('(').next().unwrap_or("");
179                    if msg_type == context.get_suggested_type() {
180                        return Ok(message);
181                    }
182                    // If we get here, the message is valid but doesn't match context
183                    // Try again with a lower temperature
184                    if retry < MAX_RETRIES - 1 {
185                        let mut new_request = request.clone();
186                        new_request.temperature *= 0.8;
187                        if let Ok(new_message) = self.try_generate_message(&new_request).await {
188                            if is_conventional_commit(&new_message) {
189                                return Ok(new_message);
190                            }
191                        }
192                    }
193                    return Ok(message); // Use the original message if retries fail
194                },
195                Err(e) => {
196                    if let Error::Request(ref req_err) = e {
197                        if let Some(status) = req_err.status() {
198                            if status == StatusCode::SERVICE_UNAVAILABLE 
199                               || status == StatusCode::TOO_MANY_REQUESTS {
200                                last_error = Some(e);
201                                continue;
202                            }
203                        }
204                    }
205                    return Err(e);
206                }
207            }
208        }
209
210        Err(last_error.unwrap_or_else(|| Error::CommitMessageGeneration(
211            "Maximum retries exceeded".to_string()
212        )))
213    }
214
215    async fn try_generate_message(&self, request: &ChatRequest) -> Result<String> {
216        let response = self
217            .client
218            .post(API_URL)
219            .header("Authorization", format!("Bearer {}", self.api_key))
220            .json(request)
221            .send()
222            .await?
223            .error_for_status()?
224            .json::<ChatResponse>()
225            .await?;
226
227        response
228            .choices
229            .first()
230            .map(|choice| choice.message.content.trim().to_string())
231            .ok_or_else(|| Error::CommitMessageGeneration("No response from API".to_string()))
232    }
233}