use std::process::Command;
use clap::Parser;
use anyhow::{Result, anyhow};
use reqwest;
use serde_json::{json, Value};
use git2::{Repository, Status};
const DIFF_SIZE_THRESHOLD: usize = 1000;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long)]
model: Option<String>,
#[arg(default_value = None)]
files: Option<Vec<String>>,
}
async fn get_available_models() -> Result<Vec<String>> {
let output = Command::new("ollama")
.arg("list")
.output()?;
if !output.status.success() {
return Err(anyhow!("Failed to get model list from Ollama"));
}
let output_str = String::from_utf8(output.stdout)?;
let models: Vec<String> = output_str
.lines()
.skip(1) .filter_map(|line| line.split_whitespace().next())
.map(String::from)
.collect();
Ok(models)
}
fn select_model(preferred: Option<String>, available: &[String]) -> String {
if let Some(model) = preferred {
if available.contains(&model) {
return model;
}
eprintln!("Warning: Specified model '{}' not found. Falling back to available model.", model);
}
if available.contains(&String::from("qwen2.5:7b")) {
return String::from("qwen2.5:7b");
}
available.first()
.map(String::from)
.unwrap_or_else(|| String::from("neural-chat")) }
async fn ensure_ollama_running() -> Result<()> {
let client = reqwest::Client::new();
match client.get("http://localhost:11434/api/version").send().await {
Ok(_) => return Ok(()),
Err(_) => println!("Ollama is not running. Starting Ollama..."),
}
let _start = Command::new("ollama")
.arg("serve")
.spawn()?;
for _ in 0..30 {
if client.get("http://localhost:11434/api/version").send().await.is_ok() {
return Ok(());
}
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
Err(anyhow!("Failed to start Ollama after 30 seconds"))
}
async fn ensure_model_available(model: &str) -> Result<()> {
let available_models = get_available_models().await?;
if !available_models.contains(&model.to_string()) {
println!("Model {} not found. Pulling...", model);
let pull = Command::new("ollama")
.args(["pull", model])
.spawn()?
.wait()?;
if !pull.success() {
return Err(anyhow!("Failed to pull model {}", model));
}
}
Ok(())
}
fn stage_files(files: &Option<Vec<String>>) -> Result<()> {
match files {
Some(files) => {
for file in files {
if file == "." {
let status = Command::new("git")
.args(["add", "."])
.status()?;
if !status.success() {
return Err(anyhow!("Failed to stage all files"));
}
} else {
let status = Command::new("git")
.args(["add", file])
.status()?;
if !status.success() {
return Err(anyhow!("Failed to stage file: {}", file));
}
}
}
}
None => {} }
Ok(())
}
fn get_git_diff() -> Result<String> {
let output = Command::new("git")
.args(["diff", "--cached"])
.output()?;
if !output.status.success() {
return Err(anyhow!("Failed to get git diff"));
}
Ok(String::from_utf8(output.stdout)?)
}
fn clean_commit_message(message: &str) -> String {
if message.trim().starts_with("```") && message.trim().ends_with("```") {
message
.trim()
.replace("```", "")
.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty())
.collect::<Vec<&str>>()
.join("\n")
} else {
message.trim().to_string()
}
}
async fn summarize_diff(model: &str, diff: &str) -> Result<String> {
println!("Diff is large ({} characters). Summarizing diff...", diff.len());
let client = reqwest::Client::new();
let summary_prompt = format!(
"Summarize the following git diff in a concise and factual manner. Focus on key changes, file names, and types of modifications. \
Do not use code blocks, markdown, or extra commentary.\n\nDiff:\n{}",
diff
);
let response = client.post("http://localhost:11434/api/generate")
.json(&json!({
"model": model,
"prompt": summary_prompt,
"stream": false
}))
.send()
.await?;
let response_json: Value = response.json().await?;
let summary = response_json["response"]
.as_str()
.ok_or_else(|| anyhow!("Invalid summary response format"))?;
Ok(summary.trim().to_string())
}
async fn generate_commit_message(model: &str, diff: &str) -> Result<String> {
println!("Generating commit message for changes using {}...", model);
let effective_diff = if diff.len() > DIFF_SIZE_THRESHOLD {
summarize_diff(model, diff).await?
} else {
diff.to_string()
};
let client = reqwest::Client::new();
let prompt = format!(
"Generate a conventional commit message for the following git diff summary. \
IMPORTANT RULES:\n\
1. Use the format: <type>(<scope>): <description>\n\
2. Keep it concise and clear\n\
3. DO NOT use any markdown formatting\n\
4. DO NOT wrap the message in code blocks\n\
5. DO NOT add any extra formatting or annotations\n\
6. The message should be ready to use with 'git commit -m'\n\n\
Diff Summary:\n{}",
effective_diff
);
println!("Sending request to Ollama...");
let response = client.post("http://localhost:11434/api/generate")
.json(&json!({
"model": model,
"prompt": prompt,
"stream": false
}))
.send()
.await?;
println!("Got response from Ollama, parsing...");
let response_json: Value = response.json().await?;
let message = response_json["response"]
.as_str()
.ok_or_else(|| anyhow!("Invalid response format"))?;
Ok(clean_commit_message(message))
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
stage_files(&args.files)?;
ensure_ollama_running().await?;
let available_models = get_available_models().await?;
if available_models.is_empty() {
return Err(anyhow!("No models available in Ollama"));
}
let selected_model = select_model(args.model, &available_models);
println!("Using model: {}", selected_model);
ensure_model_available(&selected_model).await?;
let diff = get_git_diff()?;
if diff.is_empty() {
return Err(anyhow!("No changes to commit"));
}
let commit_message = generate_commit_message(&selected_model, &diff).await?;
let status = Command::new("git")
.args(["commit", "-m", &commit_message])
.status()?;
if status.success() {
println!("Successfully created commit: {}", commit_message);
Ok(())
} else {
Err(anyhow!("Failed to create commit"))
}
}