1use std::error::Error;
2use std::process::Command;
3
4use async_openai::{
5 Client,
6 types::{ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs, Role},
7};
8use cli_clipboard::{ClipboardContext, ClipboardProvider};
9
10use crate::data::Model;
11
12pub mod data;
13
14fn get_diff() -> Result<String, Box<dyn Error>> {
15 let output = if cfg!(target_os = "windows") {
16 Command::new("cmd")
17 .args(["/C", "git diff --cached"])
18 .output()?
19 } else {
20 Command::new("sh")
21 .arg("-c")
22 .arg("git diff --cached")
23 .output()?
24 };
25
26 if !output.status.success() {
27 Err("Could not run git diff. Please make sure you are in a valid git repository.")?
28 }
29 Ok(String::from_utf8(output.stdout).unwrap())
30}
31
32async fn generate_commit_message(
33 diff: String,
34 prompt: String,
35 model: Model,
36 max_tokens: u16,
37) -> Result<String, Box<dyn Error>> {
38 let client = Client::new();
39
40 let prompt = format!("{} {}", prompt, diff);
41
42 check_prompt_length(&prompt, &model)?;
43
44 let request = CreateChatCompletionRequestArgs::default()
45 .max_tokens(max_tokens)
46 .model(model.model_str())
47 .messages([ChatCompletionRequestMessageArgs::default()
48 .role(Role::User)
49 .content(prompt)
50 .build()?])
51 .build()?;
52
53 let response = client.chat().create(request).await?;
54
55 let result = response.choices[0].message.content.clone().unwrap();
56
57 Ok(result)
58}
59
60fn check_prompt_length(prompt: &String, model: &Model) -> Result<(), Box<dyn Error>> {
63 let max_tokens = model.max_tokens();
64 let prompt_token_size = prompt.len() / 4;
65
66 if prompt_token_size > max_tokens {
67 Err(format!("Prompt token size (prompt + diff length) {} is bigger than {} max token size of {}. Please try with a smaller diff.", prompt_token_size, model.model_str(), max_tokens))?
68 }
69
70 Ok(())
71}
72
73fn copy_message_to_clipboard(message: &String) -> Result<(), Box<dyn Error>> {
74 let mut ctx = ClipboardContext::new()?;
75 ctx.set_contents(message.to_owned())?;
76
77 Ok(())
78}
79
80pub async fn run(
81 prompt: String,
82 model: Model,
83 max_tokens: u16,
84) -> Result<Option<String>, Box<dyn Error>> {
85 let diff = get_diff()?;
86 if diff == "" {
87 Ok(None)
88 } else {
89 match generate_commit_message(diff, prompt, model, max_tokens).await {
90 Ok(x) => {
91 copy_message_to_clipboard(&x).unwrap_or_else(|err| {
92 eprintln!("Could not copy commit message to clipboard: {err}")
93 });
94 Ok(Some(x))
95 }
96 Err(e) => Err(e),
97 }
98 }
99}