commit_gpt/
lib.rs

1use std::error::Error;
2use std::process::Command;
3
4use async_openai::{
5    Client,
6    types::{ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs, Role},
7};
8use cli_clipboard::{ClipboardContext, ClipboardProvider};
9
10use crate::data::Model;
11
12pub mod data;
13
14fn get_diff() -> Result<String, Box<dyn Error>> {
15    let output = if cfg!(target_os = "windows") {
16        Command::new("cmd")
17            .args(["/C", "git diff --cached"])
18            .output()?
19    } else {
20        Command::new("sh")
21            .arg("-c")
22            .arg("git diff --cached")
23            .output()?
24    };
25
26    if !output.status.success() {
27        Err("Could not run git diff. Please make sure you are in a valid git repository.")?
28    }
29    Ok(String::from_utf8(output.stdout).unwrap())
30}
31
32async fn generate_commit_message(
33    diff: String,
34    prompt: String,
35    model: Model,
36    max_tokens: u16,
37) -> Result<String, Box<dyn Error>> {
38    let client = Client::new();
39
40    let prompt = format!("{} {}", prompt, diff);
41
42    check_prompt_length(&prompt, &model)?;
43
44    let request = CreateChatCompletionRequestArgs::default()
45        .max_tokens(max_tokens)
46        .model(model.model_str())
47        .messages([ChatCompletionRequestMessageArgs::default()
48            .role(Role::User)
49            .content(prompt)
50            .build()?])
51        .build()?;
52
53    let response = client.chat().create(request).await?;
54
55    let result = response.choices[0].message.content.clone().unwrap();
56
57    Ok(result)
58}
59
60/// Roughly calculates the token size of the prompt, taken from [OpenAI help page regarding tokens](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them).
61/// As stated in the the above link, > 1 token ~= 4 chars in English
62fn check_prompt_length(prompt: &String, model: &Model) -> Result<(), Box<dyn Error>> {
63    let max_tokens = model.max_tokens();
64    let prompt_token_size = prompt.len() / 4;
65
66    if prompt_token_size > max_tokens {
67        Err(format!("Prompt token size (prompt + diff length) {} is bigger than {} max token size of {}. Please try with a smaller diff.", prompt_token_size, model.model_str(), max_tokens))?
68    }
69
70    Ok(())
71}
72
73fn copy_message_to_clipboard(message: &String) -> Result<(), Box<dyn Error>> {
74    let mut ctx = ClipboardContext::new()?;
75    ctx.set_contents(message.to_owned())?;
76
77    Ok(())
78}
79
80pub async fn run(
81    prompt: String,
82    model: Model,
83    max_tokens: u16,
84) -> Result<Option<String>, Box<dyn Error>> {
85    let diff = get_diff()?;
86    if diff == "" {
87        Ok(None)
88    } else {
89        match generate_commit_message(diff, prompt, model, max_tokens).await {
90            Ok(x) => {
91                copy_message_to_clipboard(&x).unwrap_or_else(|err| {
92                    eprintln!("Could not copy commit message to clipboard: {err}")
93                });
94                Ok(Some(x))
95            }
96            Err(e) => Err(e),
97        }
98    }
99}