1use anyhow::{bail, Context, Result};
20use clap::Args;
21use crossterm::{
22 cursor,
23 event::{self, Event, KeyCode},
24 execute, style,
25 terminal::{self, ClearType},
26};
27use regex::Regex;
28use std::io::{self, IsTerminal, Write};
29use std::process::Command;
30
31#[derive(Args, Debug)]
33pub struct CommitArgs {
34 #[arg(short = 'n', long, default_value = "3")]
36 count: u8,
37
38 #[arg(short = 'm', long, default_value = "2000")]
40 max_diff: usize,
41
42 #[arg(short, long)]
44 scope: Option<String>,
45
46 #[arg(long)]
48 yes: bool,
49
50 #[arg(long)]
52 dry_run: bool,
53
54 #[arg(long)]
56 provider: Option<String>,
57
58 #[arg(long)]
60 model: Option<String>,
61
62 #[arg(long, default_value = "30")]
64 timeout: u64,
65
66 #[arg(long)]
68 verbose: bool,
69}
70
71fn git(args: &[&str]) -> Result<String> {
76 let output = Command::new("git")
77 .args(args)
78 .output()
79 .context("Failed to run git")?;
80 if output.status.success() {
81 Ok(String::from_utf8_lossy(&output.stdout).to_string())
82 } else {
83 let stderr = String::from_utf8_lossy(&output.stderr);
84 bail!("git {}: {stderr}", args.join(" "));
85 }
86}
87
88fn get_staged_diff() -> Result<String> {
89 let stat = git(&["diff", "--staged", "--stat"])?;
90 if stat.trim().is_empty() {
91 bail!(
92 "No staged changes found.\n\
93 Stage files first with: git add <files>\n\
94 Then run: resq commit"
95 );
96 }
97 git(&["diff", "--staged"])
98}
99
100fn check_unstaged_warning() {
101 if let Ok(diff) = git(&["diff", "--stat"]) {
102 if !diff.trim().is_empty() {
103 eprintln!(
104 "\x1b[33mWarning:\x1b[0m You have unstaged changes. \
105 Only staged changes will be included in the commit."
106 );
107 }
108 }
109}
110
111fn get_recent_commits() -> String {
112 git(&["log", "--oneline", "-10"]).unwrap_or_default()
113}
114
115const CC_PATTERN: &str =
121 r"^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)(\(.+\))?(!)?: .+$";
122
123fn validate_conventional_commit(msg: &str) -> bool {
124 Regex::new(CC_PATTERN)
125 .map(|re| re.is_match(msg.lines().next().unwrap_or("")))
126 .unwrap_or(false)
127}
128
129fn build_prompt(
130 diff: &str,
131 recent_commits: &str,
132 scope: Option<&str>,
133 count: u8,
134) -> (String, String) {
135 let scope_hint = scope
136 .map(|s| format!("\nSuggested scope: {s}"))
137 .unwrap_or_default();
138
139 let system = format!(
140 "You are a commit message generator for a project that uses Conventional Commits.\n\
141 \n\
142 Format: <type>(<scope>): <description>\n\
143 \n\
144 Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert\n\
145 {scope_hint}\n\
146 \n\
147 Rules:\n\
148 - Subject line must be under 72 characters\n\
149 - Use imperative mood (\"add\" not \"added\")\n\
150 - Be specific about what changed and why\n\
151 - Match the style of recent commits shown below\n\
152 \n\
153 Recent commits for style reference:\n\
154 {recent_commits}\n\
155 \n\
156 Generate exactly {count} commit message candidates. \
157 Return them as a JSON array of strings. No markdown fences."
158 );
159
160 let user = format!("Staged diff:\n\n{diff}");
161 (system, user)
162}
163
164fn strip_code_fences(text: &str) -> &str {
165 let trimmed = text.trim();
166 if let Some(rest) = trimmed.strip_prefix("```json") {
167 rest.trim()
168 .strip_suffix("```")
169 .map_or_else(|| rest.trim(), str::trim)
170 } else if let Some(rest) = trimmed.strip_prefix("```") {
171 rest.trim()
172 .strip_suffix("```")
173 .map_or_else(|| rest.trim(), str::trim)
174 } else {
175 trimmed
176 }
177}
178
179fn parse_candidates(response: &str) -> Result<Vec<String>> {
180 let cleaned = strip_code_fences(response);
181 let candidates: Vec<String> =
182 serde_json::from_str(cleaned).context("Failed to parse LLM response as JSON array")?;
183 Ok(candidates)
184}
185
186struct RawModeGuard;
192
193impl RawModeGuard {
194 fn enable() -> Result<Self> {
195 terminal::enable_raw_mode()?;
196 Ok(Self)
197 }
198}
199
200impl Drop for RawModeGuard {
201 fn drop(&mut self) {
202 let _ = terminal::disable_raw_mode();
203 }
204}
205
206fn select_candidate(candidates: &[String]) -> Result<Option<usize>> {
207 if !io::stdout().is_terminal() {
208 bail!("Interactive selection requires a TTY. Use --yes or --dry-run in non-interactive contexts.");
209 }
210
211 let _guard = RawModeGuard::enable()?;
212 let mut stdout = io::stdout();
213 let mut selected: usize = 0;
214 let total = candidates.len();
215
216 render_selector(&mut stdout, candidates, selected)?;
218
219 loop {
220 if let Event::Key(key) = event::read()? {
221 match key.code {
222 KeyCode::Up | KeyCode::Char('k') if selected > 0 => selected -= 1,
223 KeyCode::Down | KeyCode::Char('j') if selected < total - 1 => selected += 1,
224 KeyCode::Enter => {
225 clear_selector(&mut stdout, total)?;
226 return Ok(Some(selected));
227 }
228 KeyCode::Esc | KeyCode::Char('q') => {
229 clear_selector(&mut stdout, total)?;
230 return Ok(None);
231 }
232 _ => continue,
233 }
234 execute!(stdout, cursor::MoveUp((total + 2) as u16))?;
236 render_selector(&mut stdout, candidates, selected)?;
237 }
238 }
239}
240
241fn render_selector(stdout: &mut io::Stdout, candidates: &[String], selected: usize) -> Result<()> {
242 for (i, candidate) in candidates.iter().enumerate() {
243 if i == selected {
244 execute!(stdout, style::SetForegroundColor(style::Color::Cyan))?;
245 write!(stdout, "\r > {}. {}\n", i + 1, candidate)?;
246 execute!(stdout, style::ResetColor)?;
247 } else {
248 write!(stdout, "\r {}. {}\n", i + 1, candidate)?;
249 }
250 }
251 write!(
252 stdout,
253 "\r\n \x1b[2m[↑/↓/j/k] Navigate [Enter] Select [Esc/q] Cancel\x1b[0m"
254 )?;
255 stdout.flush()?;
256 Ok(())
257}
258
259fn clear_selector(stdout: &mut io::Stdout, total: usize) -> Result<()> {
260 execute!(
261 stdout,
262 cursor::MoveUp((total + 2) as u16),
263 terminal::Clear(ClearType::FromCursorDown),
264 )?;
265 Ok(())
266}
267
268pub async fn run(args: CommitArgs) -> Result<()> {
274 let diff = get_staged_diff()?;
276 check_unstaged_warning();
277
278 let truncated = resq_ai::truncate_to_budget(&diff, args.max_diff);
280 let recent = get_recent_commits();
281
282 let mut config = resq_ai::load_config()?;
284 if let Some(ref p) = args.provider {
285 config.provider = match p.to_lowercase().as_str() {
286 "anthropic" => resq_ai::Provider::Anthropic,
287 "openai" => resq_ai::Provider::OpenAi,
288 "gemini" => resq_ai::Provider::Gemini,
289 other => bail!("Unknown provider: {other}. Use: anthropic, openai, gemini"),
290 };
291 }
292 if let Some(ref m) = args.model {
293 config.model = m.clone();
294 }
295 config.timeout_secs = args.timeout;
296
297 let (system, user_prompt) = build_prompt(truncated, &recent, args.scope.as_deref(), args.count);
299
300 if args.verbose {
301 eprintln!("--- System prompt ---\n{system}\n---");
302 eprintln!(
303 "Estimated tokens: {}",
304 resq_ai::estimate_tokens(&user_prompt)
305 );
306 }
307
308 eprintln!("Generating commit messages...");
310 let response = resq_ai::complete(&config, &system, &user_prompt).await?;
311
312 if args.verbose {
313 eprintln!("--- Raw response ---\n{response}\n---");
314 }
315
316 let candidates = parse_candidates(&response)?;
318 let valid: Vec<String> = candidates
319 .into_iter()
320 .filter(|c| validate_conventional_commit(c))
321 .collect();
322
323 if valid.is_empty() {
324 bail!("LLM returned no valid Conventional Commit messages. Try again or write manually.");
325 }
326
327 if args.dry_run {
329 for (i, c) in valid.iter().enumerate() {
330 println!("{}. {c}", i + 1);
331 }
332 return Ok(());
333 }
334
335 let message = if args.yes {
336 valid.into_iter().next().unwrap()
337 } else {
338 let idx = select_candidate(&valid)?;
339 match idx {
340 Some(i) => valid.into_iter().nth(i).unwrap(),
341 None => {
342 eprintln!("Cancelled.");
343 return Ok(());
344 }
345 }
346 };
347
348 git(&["commit", "-m", &message])?;
350 eprintln!("Committed: {message}");
351 Ok(())
352}
353
354#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn strip_fences_json() {
364 let input = "```json\n[\"feat: add thing\"]\n```";
365 assert_eq!(strip_code_fences(input), "[\"feat: add thing\"]");
366 }
367
368 #[test]
369 fn strip_fences_plain() {
370 let input = "```\n[\"feat: add thing\"]\n```";
371 assert_eq!(strip_code_fences(input), "[\"feat: add thing\"]");
372 }
373
374 #[test]
375 fn strip_fences_none() {
376 let input = "[\"feat: add thing\"]";
377 assert_eq!(strip_code_fences(input), input);
378 }
379
380 #[test]
381 fn validate_cc_valid() {
382 assert!(validate_conventional_commit("feat: add new feature"));
383 assert!(validate_conventional_commit(
384 "fix(ui): correct button color"
385 ));
386 assert!(validate_conventional_commit("feat!: remove deprecated API"));
387 assert!(validate_conventional_commit(
388 "chore(deps): bump serde to 1.0.200"
389 ));
390 }
391
392 #[test]
393 fn validate_cc_invalid() {
394 assert!(!validate_conventional_commit("Add new feature"));
395 assert!(!validate_conventional_commit("FEAT: uppercase type"));
396 assert!(!validate_conventional_commit("feat:missing space"));
397 assert!(!validate_conventional_commit(""));
398 }
399
400 #[test]
401 fn parse_candidates_valid_json() {
402 let input = r#"["feat: add thing", "fix: repair thing"]"#;
403 let result = parse_candidates(input).unwrap();
404 assert_eq!(result.len(), 2);
405 assert_eq!(result[0], "feat: add thing");
406 }
407
408 #[test]
409 fn parse_candidates_with_fences() {
410 let input = "```json\n[\"feat: add thing\"]\n```";
411 let result = parse_candidates(input).unwrap();
412 assert_eq!(result.len(), 1);
413 }
414
415 #[test]
416 fn build_prompt_includes_scope() {
417 let (system, _user) = build_prompt("diff content", "recent commits", Some("auth"), 3);
418 assert!(system.contains("auth"));
419 }
420
421 #[test]
422 fn build_prompt_without_scope() {
423 let (system, _user) = build_prompt("diff content", "recent commits", None, 3);
424 assert!(!system.contains("Suggested scope"));
425 }
426
427 #[test]
428 fn build_prompt_includes_count() {
429 let (system, _user) = build_prompt("diff", "commits", None, 5);
430 assert!(system.contains("exactly 5"));
431 }
432}