Skip to main content

resq_cli/commands/
commit.rs

1/*
2 * Copyright 2026 ResQ
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//! AI-powered commit message generation from staged diffs.
18
19use anyhow::{bail, Context, Result};
20use clap::Args;
21use crossterm::{
22    cursor,
23    event::{self, Event, KeyCode},
24    execute, style,
25    terminal::{self, ClearType},
26};
27use regex::Regex;
28use std::io::{self, IsTerminal, Write};
29use std::process::Command;
30
31/// Arguments for the `resq commit` command.
32#[derive(Args, Debug)]
33pub struct CommitArgs {
34    /// Number of candidate messages to generate
35    #[arg(short = 'n', long, default_value = "3")]
36    count: u8,
37
38    /// Max diff size in estimated tokens before truncation
39    #[arg(short = 'm', long, default_value = "2000")]
40    max_diff: usize,
41
42    /// Hint for commit scope (e.g., "auth", "ui")
43    #[arg(short, long)]
44    scope: Option<String>,
45
46    /// Auto-pick first candidate without prompting
47    #[arg(long)]
48    yes: bool,
49
50    /// Print message(s) but don't commit
51    #[arg(long)]
52    dry_run: bool,
53
54    /// Override AI provider (anthropic, openai, gemini)
55    #[arg(long)]
56    provider: Option<String>,
57
58    /// Override AI model
59    #[arg(long)]
60    model: Option<String>,
61
62    /// API request timeout in seconds
63    #[arg(long, default_value = "30")]
64    timeout: u64,
65
66    /// Show prompt, token count, and raw LLM response
67    #[arg(long)]
68    verbose: bool,
69}
70
71// -------------------------------------------------------------------------
72// Git helpers
73// -------------------------------------------------------------------------
74
75fn git(args: &[&str]) -> Result<String> {
76    let output = Command::new("git")
77        .args(args)
78        .output()
79        .context("Failed to run git")?;
80    if output.status.success() {
81        Ok(String::from_utf8_lossy(&output.stdout).to_string())
82    } else {
83        let stderr = String::from_utf8_lossy(&output.stderr);
84        bail!("git {}: {stderr}", args.join(" "));
85    }
86}
87
88fn get_staged_diff() -> Result<String> {
89    let stat = git(&["diff", "--staged", "--stat"])?;
90    if stat.trim().is_empty() {
91        bail!(
92            "No staged changes found.\n\
93             Stage files first with: git add <files>\n\
94             Then run: resq commit"
95        );
96    }
97    git(&["diff", "--staged"])
98}
99
100fn check_unstaged_warning() {
101    if let Ok(diff) = git(&["diff", "--stat"]) {
102        if !diff.trim().is_empty() {
103            eprintln!(
104                "\x1b[33mWarning:\x1b[0m You have unstaged changes. \
105                 Only staged changes will be included in the commit."
106            );
107        }
108    }
109}
110
111fn get_recent_commits() -> String {
112    git(&["log", "--oneline", "-10"]).unwrap_or_default()
113}
114
115// -------------------------------------------------------------------------
116// Prompt building + response parsing
117// -------------------------------------------------------------------------
118
119/// Conventional Commits regex (from templates/git-hooks/commit-msg).
120const CC_PATTERN: &str =
121    r"^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)(\(.+\))?(!)?: .+$";
122
123fn validate_conventional_commit(msg: &str) -> bool {
124    Regex::new(CC_PATTERN)
125        .map(|re| re.is_match(msg.lines().next().unwrap_or("")))
126        .unwrap_or(false)
127}
128
129fn build_prompt(
130    diff: &str,
131    recent_commits: &str,
132    scope: Option<&str>,
133    count: u8,
134) -> (String, String) {
135    let scope_hint = scope
136        .map(|s| format!("\nSuggested scope: {s}"))
137        .unwrap_or_default();
138
139    let system = format!(
140        "You are a commit message generator for a project that uses Conventional Commits.\n\
141         \n\
142         Format: <type>(<scope>): <description>\n\
143         \n\
144         Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert\n\
145         {scope_hint}\n\
146         \n\
147         Rules:\n\
148         - Subject line must be under 72 characters\n\
149         - Use imperative mood (\"add\" not \"added\")\n\
150         - Be specific about what changed and why\n\
151         - Match the style of recent commits shown below\n\
152         \n\
153         Recent commits for style reference:\n\
154         {recent_commits}\n\
155         \n\
156         Generate exactly {count} commit message candidates. \
157         Return them as a JSON array of strings. No markdown fences."
158    );
159
160    let user = format!("Staged diff:\n\n{diff}");
161    (system, user)
162}
163
164fn strip_code_fences(text: &str) -> &str {
165    let trimmed = text.trim();
166    if let Some(rest) = trimmed.strip_prefix("```json") {
167        rest.trim()
168            .strip_suffix("```")
169            .map_or_else(|| rest.trim(), str::trim)
170    } else if let Some(rest) = trimmed.strip_prefix("```") {
171        rest.trim()
172            .strip_suffix("```")
173            .map_or_else(|| rest.trim(), str::trim)
174    } else {
175        trimmed
176    }
177}
178
179fn parse_candidates(response: &str) -> Result<Vec<String>> {
180    let cleaned = strip_code_fences(response);
181    let candidates: Vec<String> =
182        serde_json::from_str(cleaned).context("Failed to parse LLM response as JSON array")?;
183    Ok(candidates)
184}
185
186// -------------------------------------------------------------------------
187// Interactive selector (crossterm-based)
188// -------------------------------------------------------------------------
189
190/// RAII guard to restore terminal raw mode on drop.
191struct RawModeGuard;
192
193impl RawModeGuard {
194    fn enable() -> Result<Self> {
195        terminal::enable_raw_mode()?;
196        Ok(Self)
197    }
198}
199
200impl Drop for RawModeGuard {
201    fn drop(&mut self) {
202        let _ = terminal::disable_raw_mode();
203    }
204}
205
206fn select_candidate(candidates: &[String]) -> Result<Option<usize>> {
207    if !io::stdout().is_terminal() {
208        bail!("Interactive selection requires a TTY. Use --yes or --dry-run in non-interactive contexts.");
209    }
210
211    let _guard = RawModeGuard::enable()?;
212    let mut stdout = io::stdout();
213    let mut selected: usize = 0;
214    let total = candidates.len();
215
216    // Initial render
217    render_selector(&mut stdout, candidates, selected)?;
218
219    loop {
220        if let Event::Key(key) = event::read()? {
221            match key.code {
222                KeyCode::Up | KeyCode::Char('k') if selected > 0 => selected -= 1,
223                KeyCode::Down | KeyCode::Char('j') if selected < total - 1 => selected += 1,
224                KeyCode::Enter => {
225                    clear_selector(&mut stdout, total)?;
226                    return Ok(Some(selected));
227                }
228                KeyCode::Esc | KeyCode::Char('q') => {
229                    clear_selector(&mut stdout, total)?;
230                    return Ok(None);
231                }
232                _ => continue,
233            }
234            // Move up and re-render
235            execute!(stdout, cursor::MoveUp((total + 2) as u16))?;
236            render_selector(&mut stdout, candidates, selected)?;
237        }
238    }
239}
240
241fn render_selector(stdout: &mut io::Stdout, candidates: &[String], selected: usize) -> Result<()> {
242    for (i, candidate) in candidates.iter().enumerate() {
243        if i == selected {
244            execute!(stdout, style::SetForegroundColor(style::Color::Cyan))?;
245            write!(stdout, "\r  > {}. {}\n", i + 1, candidate)?;
246            execute!(stdout, style::ResetColor)?;
247        } else {
248            write!(stdout, "\r    {}. {}\n", i + 1, candidate)?;
249        }
250    }
251    write!(
252        stdout,
253        "\r\n  \x1b[2m[↑/↓/j/k] Navigate  [Enter] Select  [Esc/q] Cancel\x1b[0m"
254    )?;
255    stdout.flush()?;
256    Ok(())
257}
258
259fn clear_selector(stdout: &mut io::Stdout, total: usize) -> Result<()> {
260    execute!(
261        stdout,
262        cursor::MoveUp((total + 2) as u16),
263        terminal::Clear(ClearType::FromCursorDown),
264    )?;
265    Ok(())
266}
267
268// -------------------------------------------------------------------------
269// Main entry point
270// -------------------------------------------------------------------------
271
272/// Run the commit command.
273pub async fn run(args: CommitArgs) -> Result<()> {
274    // 1. Check staged changes
275    let diff = get_staged_diff()?;
276    check_unstaged_warning();
277
278    // 2. Truncate diff to token budget
279    let truncated = resq_ai::truncate_to_budget(&diff, args.max_diff);
280    let recent = get_recent_commits();
281
282    // 3. Load AI config with CLI overrides
283    let mut config = resq_ai::load_config()?;
284    if let Some(ref p) = args.provider {
285        config.provider = match p.to_lowercase().as_str() {
286            "anthropic" => resq_ai::Provider::Anthropic,
287            "openai" => resq_ai::Provider::OpenAi,
288            "gemini" => resq_ai::Provider::Gemini,
289            other => bail!("Unknown provider: {other}. Use: anthropic, openai, gemini"),
290        };
291    }
292    if let Some(ref m) = args.model {
293        config.model = m.clone();
294    }
295    config.timeout_secs = args.timeout;
296
297    // 4. Build prompt
298    let (system, user_prompt) = build_prompt(truncated, &recent, args.scope.as_deref(), args.count);
299
300    if args.verbose {
301        eprintln!("--- System prompt ---\n{system}\n---");
302        eprintln!(
303            "Estimated tokens: {}",
304            resq_ai::estimate_tokens(&user_prompt)
305        );
306    }
307
308    // 5. Call LLM
309    eprintln!("Generating commit messages...");
310    let response = resq_ai::complete(&config, &system, &user_prompt).await?;
311
312    if args.verbose {
313        eprintln!("--- Raw response ---\n{response}\n---");
314    }
315
316    // 6. Parse + validate candidates
317    let candidates = parse_candidates(&response)?;
318    let valid: Vec<String> = candidates
319        .into_iter()
320        .filter(|c| validate_conventional_commit(c))
321        .collect();
322
323    if valid.is_empty() {
324        bail!("LLM returned no valid Conventional Commit messages. Try again or write manually.");
325    }
326
327    // 7. Select
328    if args.dry_run {
329        for (i, c) in valid.iter().enumerate() {
330            println!("{}. {c}", i + 1);
331        }
332        return Ok(());
333    }
334
335    let message = if args.yes {
336        valid.into_iter().next().unwrap()
337    } else {
338        let idx = select_candidate(&valid)?;
339        match idx {
340            Some(i) => valid.into_iter().nth(i).unwrap(),
341            None => {
342                eprintln!("Cancelled.");
343                return Ok(());
344            }
345        }
346    };
347
348    // 8. Commit
349    git(&["commit", "-m", &message])?;
350    eprintln!("Committed: {message}");
351    Ok(())
352}
353
354// -------------------------------------------------------------------------
355// Tests
356// -------------------------------------------------------------------------
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn strip_fences_json() {
364        let input = "```json\n[\"feat: add thing\"]\n```";
365        assert_eq!(strip_code_fences(input), "[\"feat: add thing\"]");
366    }
367
368    #[test]
369    fn strip_fences_plain() {
370        let input = "```\n[\"feat: add thing\"]\n```";
371        assert_eq!(strip_code_fences(input), "[\"feat: add thing\"]");
372    }
373
374    #[test]
375    fn strip_fences_none() {
376        let input = "[\"feat: add thing\"]";
377        assert_eq!(strip_code_fences(input), input);
378    }
379
380    #[test]
381    fn validate_cc_valid() {
382        assert!(validate_conventional_commit("feat: add new feature"));
383        assert!(validate_conventional_commit(
384            "fix(ui): correct button color"
385        ));
386        assert!(validate_conventional_commit("feat!: remove deprecated API"));
387        assert!(validate_conventional_commit(
388            "chore(deps): bump serde to 1.0.200"
389        ));
390    }
391
392    #[test]
393    fn validate_cc_invalid() {
394        assert!(!validate_conventional_commit("Add new feature"));
395        assert!(!validate_conventional_commit("FEAT: uppercase type"));
396        assert!(!validate_conventional_commit("feat:missing space"));
397        assert!(!validate_conventional_commit(""));
398    }
399
400    #[test]
401    fn parse_candidates_valid_json() {
402        let input = r#"["feat: add thing", "fix: repair thing"]"#;
403        let result = parse_candidates(input).unwrap();
404        assert_eq!(result.len(), 2);
405        assert_eq!(result[0], "feat: add thing");
406    }
407
408    #[test]
409    fn parse_candidates_with_fences() {
410        let input = "```json\n[\"feat: add thing\"]\n```";
411        let result = parse_candidates(input).unwrap();
412        assert_eq!(result.len(), 1);
413    }
414
415    #[test]
416    fn build_prompt_includes_scope() {
417        let (system, _user) = build_prompt("diff content", "recent commits", Some("auth"), 3);
418        assert!(system.contains("auth"));
419    }
420
421    #[test]
422    fn build_prompt_without_scope() {
423        let (system, _user) = build_prompt("diff content", "recent commits", None, 3);
424        assert!(!system.contains("Suggested scope"));
425    }
426
427    #[test]
428    fn build_prompt_includes_count() {
429        let (system, _user) = build_prompt("diff", "commits", None, 5);
430        assert!(system.contains("exactly 5"));
431    }
432}