1use serde::{Deserialize, Serialize};
2use crate::{Error, Result, AiConfig, is_conventional_commit};
3use reqwest::StatusCode;
4use std::{time::Duration};
5
6const API_URL: &str = "https://api.together.xyz/v1/chat/completions";
7const MAX_RETRIES: u32 = 3;
8const INITIAL_RETRY_DELAY_MS: u64 = 1000;
9
10#[derive(Debug)]
11struct CommitContext {
12 commit_type: String,
13 file_types: Vec<String>,
14 new_files: Vec<String>,
15 modified_files: Vec<String>,
16 total_additions: usize,
17 total_deletions: usize,
18}
19
20impl CommitContext {
21 fn from_diff(diff: &str) -> Self {
22 let mut context = CommitContext {
23 commit_type: String::new(),
24 file_types: Vec::new(),
25 new_files: Vec::new(),
26 modified_files: Vec::new(),
27 total_additions: 0,
28 total_deletions: 0,
29 };
30
31 let mut current_file = String::new();
32 for line in diff.lines() {
33 if line.starts_with("diff --git") {
34 current_file = line.split(' ').last().unwrap_or("").trim_start_matches('b').to_string();
35 if let Some(ext) = current_file.split('.').last() {
36 context.file_types.push(ext.to_string());
37 }
38 } else if line.starts_with("new file") {
39 context.new_files.push(current_file.clone());
40 } else if line.starts_with("modified") {
41 context.modified_files.push(current_file.clone());
42 } else if line.starts_with('+') && !line.starts_with("+++") {
43 context.total_additions += 1;
44 } else if line.starts_with('-') && !line.starts_with("---") {
45 context.total_deletions += 1;
46 }
47 }
48
49 context.commit_type = if context.new_files.iter().any(|f| f.contains("Cargo.toml"))
51 && context.new_files.len() > 5 {
52 "initial project setup".to_string()
53 } else if context.file_types.iter().any(|t| t == "md" || t == "txt")
54 && context.file_types.len() == 1 {
55 "documentation change".to_string()
56 } else if context.new_files.iter().any(|f| f.contains("test") || f.contains("spec")) {
57 "test addition".to_string()
58 } else if context.total_additions > 100 || context.new_files.len() > 5 {
59 "large feature implementation".to_string()
60 } else if context.total_deletions > context.total_additions * 2 {
61 "major refactoring".to_string()
62 } else {
63 "standard change".to_string()
64 };
65
66 context
67 }
68
69 fn get_suggested_type(&self) -> &'static str {
70 match self.commit_type.as_str() {
71 "initial project setup" => "feat",
72 "documentation change" => "docs",
73 "test addition" => "test",
74 "large feature implementation" => "feat",
75 "major refactoring" => "refactor",
76 _ => "feat"
77 }
78 }
79
80 fn to_prompt_context(&self) -> String {
81 format!(
82 "{} (suggested type: {}) with {} new files and {} modified files. \
83 Changes include {} additions and {} deletions across file types: {}",
84 self.commit_type,
85 self.get_suggested_type(),
86 self.new_files.len(),
87 self.modified_files.len(),
88 self.total_additions,
89 self.total_deletions,
90 self.file_types.join(", ")
91 )
92 }
93}
94
95#[derive(Debug, Serialize, Clone)]
96struct ChatMessage {
97 role: String,
98 content: String,
99}
100
101#[derive(Debug, Serialize, Clone)]
102struct ChatRequest {
103 model: String,
104 messages: Vec<ChatMessage>,
105 temperature: f32,
106 max_tokens: u32,
107 stop: Vec<String>,
108}
109
110#[derive(Debug, Deserialize)]
111struct ChatResponse {
112 choices: Vec<ChatChoice>,
113}
114
115#[derive(Debug, Deserialize)]
116struct ChatChoice {
117 message: ChatResponseMessage,
118}
119
120#[derive(Debug, Deserialize)]
121struct ChatResponseMessage {
122 content: String,
123}
124
125pub struct AiClient {
126 client: reqwest::Client,
127 api_key: String,
128 config: AiConfig,
129}
130
131impl AiClient {
132 pub fn new(api_key: String, config: AiConfig) -> Self {
133 Self {
134 client: reqwest::Client::new(),
135 api_key,
136 config,
137 }
138 }
139
140 pub async fn generate_commit_message(&self, diff: &str) -> Result<String> {
141 let context = CommitContext::from_diff(diff);
142
143 let request = ChatRequest {
144 model: self.config.model.clone(),
145 messages: vec![
146 ChatMessage {
147 role: "system".to_string(),
148 content: self.config.system_prompt.clone(),
149 },
150 ChatMessage {
151 role: "user".to_string(),
152 content: self.config.user_prompt_template
153 .replace("{}", &context.to_prompt_context())
154 .replace("{}", diff),
155 },
156 ],
157 temperature: self.config.temperature,
158 max_tokens: self.config.max_tokens,
159 stop: self.config.stop_sequences.clone(),
160 };
161
162 let mut last_error = None;
163 for retry in 0..MAX_RETRIES {
164 if retry > 0 {
165 tokio::time::sleep(Duration::from_millis(
166 INITIAL_RETRY_DELAY_MS * (2_u64.pow(retry - 1))
167 )).await;
168 }
169
170 match self.try_generate_message(&request).await {
171 Ok(message) => {
172 if !is_conventional_commit(&message) {
174 continue; }
176 let msg_type = message.split(':').next().unwrap_or("")
178 .split('(').next().unwrap_or("");
179 if msg_type == context.get_suggested_type() {
180 return Ok(message);
181 }
182 if retry < MAX_RETRIES - 1 {
185 let mut new_request = request.clone();
186 new_request.temperature *= 0.8;
187 if let Ok(new_message) = self.try_generate_message(&new_request).await {
188 if is_conventional_commit(&new_message) {
189 return Ok(new_message);
190 }
191 }
192 }
193 return Ok(message); },
195 Err(e) => {
196 if let Error::Request(ref req_err) = e {
197 if let Some(status) = req_err.status() {
198 if status == StatusCode::SERVICE_UNAVAILABLE
199 || status == StatusCode::TOO_MANY_REQUESTS {
200 last_error = Some(e);
201 continue;
202 }
203 }
204 }
205 return Err(e);
206 }
207 }
208 }
209
210 Err(last_error.unwrap_or_else(|| Error::CommitMessageGeneration(
211 "Maximum retries exceeded".to_string()
212 )))
213 }
214
215 async fn try_generate_message(&self, request: &ChatRequest) -> Result<String> {
216 let response = self
217 .client
218 .post(API_URL)
219 .header("Authorization", format!("Bearer {}", self.api_key))
220 .json(request)
221 .send()
222 .await?
223 .error_for_status()?
224 .json::<ChatResponse>()
225 .await?;
226
227 response
228 .choices
229 .first()
230 .map(|choice| choice.message.content.trim().to_string())
231 .ok_or_else(|| Error::CommitMessageGeneration("No response from API".to_string()))
232 }
233}