use anyhow::{ensure, Context, Result};
use async_openai::types::{
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestUserMessageArgs,
ChatCompletionResponseMessage, CreateChatCompletionRequestArgs,
};
use clap::command;
use dirs::home_dir;
use inquire::{required, CustomType, Password, PasswordDisplayMode, Text};
use regex::Regex;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use spinoff::{spinners, Spinner};
use std::{path::Path, time::Duration};
use tokio::{
fs::{create_dir_all, read_to_string, write},
process::Command,
};
use which::which;
#[derive(Serialize, Deserialize)]
struct Config {
api_base_url: String, api_key: String, model_name: String, system_prompt: String, user_prompt: String, max_tokens: u16, request_timeout: u64, }
#[derive(Deserialize)]
struct CommitMessageCandidate {
message: ChatCompletionResponseMessage, }
#[derive(Deserialize)]
struct CommitMessageCandidates {
choices: Vec<CommitMessageCandidate>, }
async fn load_config_file(config_file: &Path) -> Result<Config> {
let config = if config_file.exists() {
toml::from_str(
&read_to_string(config_file)
.await
.context("Failed to read the configuration file")?,
)
.context("Failed to parse the configuration file")?
} else {
let api_base_url = Text::new("Enter API base URL:")
.with_default("https://api.together.xyz/v1")
.prompt()?;
let api_key = Password::new("Enter your API key:")
.with_display_toggle_enabled()
.with_display_mode(PasswordDisplayMode::Masked)
.with_validator(required!("API key is required"))
.without_confirmation()
.prompt()?;
let model_name = Text::new("Enter model name:")
.with_default("mistralai/Mixtral-8x7B-Instruct-v0.1")
.with_validator(required!("Model name is required"))
.prompt()?;
let system_prompt = Text::new("Enter system prompt:")
.with_default("You are required to write a meaningful commit message for the given code changes. The commit message must have the format: `type(scope): description`. The `type` must be one of the following: feat, fix, docs, style, refactor, perf, test, build, ci, chore, or revert. The `scope` indicates the area of the codebase that the changes affect. The `description` must be concise and written in a single sentence without a period at the end.")
.with_validator(required!("System prompt is required"))
.with_help_message("Press Enter to use the default system prompt")
.prompt()?;
let user_prompt = Text::new("Enter user prompt:")
.with_default("The output of the git diff command:\n```\n{}\n```")
.with_validator(required!("User prompt is required"))
.with_help_message("Press Enter to use the default user prompt")
.prompt()?;
let max_tokens = CustomType::<u16>::new("Enter max tokens of generated commit messages:")
.with_default(128)
.with_help_message("Press Enter to use the default max tokens")
.prompt()?;
let request_timeout = CustomType::<u64>::new("Enter request timeout (in seconds):")
.with_default(30)
.with_help_message("Press Enter to use the default request timeout")
.prompt()?;
let config = Config {
api_base_url: api_base_url.trim().to_string(),
api_key: api_key.trim().to_string(),
model_name: model_name.trim().to_string(),
system_prompt: system_prompt.trim().to_string(),
user_prompt: user_prompt.trim().to_string(),
max_tokens,
request_timeout,
};
create_dir_all(
config_file
.parent()
.context("Failed to retrieve the configuration directory")?,
)
.await
.context("Failed to create config directory")?;
write(
config_file,
toml::to_string(&config).context("Failed to serialize the configuration")?,
)
.await
.context("Failed to write config to file")?;
println!("Config file created successfully: {:?}", config_file);
config
};
Ok(config)
}
async fn run_git_command(args: &[&str]) -> Result<String> {
let res = Command::new("git")
.args(args)
.output()
.await
.context("Failed to execute the Git command")?;
ensure!(
res.status.success(),
"{}",
String::from_utf8_lossy(&res.stderr) );
Ok(String::from_utf8(res.stdout).context("Failed to decode the output of the Git command")?)
}
async fn generate_commit_message(
http_client: &Client,
config: &Config,
git_diffs: &str,
) -> Result<String> {
let payload = CreateChatCompletionRequestArgs::default()
.max_tokens(config.max_tokens)
.model(&config.model_name)
.messages([
ChatCompletionRequestSystemMessageArgs::default()
.content(&config.system_prompt)
.build()?
.into(),
ChatCompletionRequestUserMessageArgs::default()
.content(config.user_prompt.replace("{}", git_diffs))
.build()?
.into(),
])
.build()
.context("Failed to construct the request payload")?;
let response = http_client
.post(format!("{}/chat/completions", &config.api_base_url))
.bearer_auth(&config.api_key)
.json(&payload)
.send()
.await
.context("Failed to send the request to the Inference API provider")?
.error_for_status()?
.json::<CommitMessageCandidates>()
.await
.context("Failed to parse the response from the Inference API provider")?;
let commit_message = response
.choices
.first() .context("No commit messages generated")?
.message
.content
.as_ref()
.context("No commit messages generated")?;
let regex_matches = Regex::new(r"(?m)^\s*(?:`\s*(.+?)\s*`|(.+?))\s*$")?
.captures(&commit_message)
.context("Failed to post-process the generated commit message")?;
let commit_message = regex_matches
.get(1)
.or(regex_matches.get(2))
.context("Failed to post-process the generated commit message")?
.as_str()
.to_string();
Ok(commit_message)
}
#[tokio::main]
async fn main() -> Result<()> {
command!().get_matches();
which("git").context(
"Git not found, please install it first or check your PATH environment variable",
)?;
run_git_command(&["rev-parse", "--is-inside-work-tree"])
.await
.context("The current directory is not a Git repository")?;
let git_diffs = run_git_command(&[
"--no-pager",
"diff",
"--staged",
"--minimal",
"--no-color",
"--no-ext-diff",
"--",
":!*.lock", ])
.await?
.trim()
.to_string();
ensure!(!git_diffs.is_empty(), "No staged changes to commit");
let config_file = home_dir()
.context("Failed to retrieve the user's home directory")?
.join(".acm/config.toml");
let config = load_config_file(&config_file).await?;
let http_client = Client::builder()
.timeout(Duration::from_secs(config.request_timeout))
.build()?;
let mut spinner = Spinner::new(spinners::Dots, "Generating commit message", None);
let commit_message = generate_commit_message(&http_client, &config, &git_diffs).await;
spinner.stop_with_message("");
let commit_message = commit_message?;
let edited_commit_message = Text::new("Your generated commit message:")
.with_initial_value(&commit_message)
.with_validator(required!(
"Please provide a commit message to create a commit"
))
.with_help_message(
"Press Enter to create a new commit with the current message or ESC to cancel",
)
.prompt()?;
println!(
"{}",
&run_git_command(&["commit", "-m", edited_commit_message.trim()]).await?
);
Ok(())
}