use crate::built_info;
use crate::utils::env;
use askama::Template;
use async_openai::config::OPENAI_API_BASE;
use async_openai::error::OpenAIError;
use async_openai::{
Client,
config::OpenAIConfig,
types::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs},
};
use log::trace;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{ClientBuilder, Proxy};
use std::error::Error;
use std::time::Duration;
use tracing::debug;
#[derive(Template)]
#[template(path = "user.txt")]
struct PromptTemplate<'a> {
logs: &'a str,
diffs: &'a str,
}
pub struct OpenAI {
client: Client<OpenAIConfig>,
}
impl Default for OpenAI {
fn default() -> Self {
Self::new()
}
}
impl OpenAI {
pub fn new() -> Self {
let ai_config: OpenAIConfig = OpenAIConfig::new()
.with_api_key(env::get("OPENAI_API_TOKEN", ""))
.with_api_base(env::get("OPENAI_API_BASE", OPENAI_API_BASE))
.with_org_id(built_info::PKG_NAME);
let mut http_client_builder = ClientBuilder::new()
.user_agent(format!(
"{} ({})",
built_info::PKG_NAME,
built_info::PKG_DESCRIPTION
))
.default_headers({
let mut headers = HeaderMap::new();
headers.insert(
"HTTP-Referer",
HeaderValue::from_static(built_info::PKG_HOMEPAGE),
);
headers.insert("X-Title", HeaderValue::from_static(built_info::PKG_NAME));
headers.insert("X-Client-Type", HeaderValue::from_static("CLI"));
headers
});
let proxy_addr = env::get("OPENAI_API_PROXY", "");
if !proxy_addr.is_empty() {
trace!("Using proxy: {proxy_addr}");
http_client_builder = http_client_builder.proxy(Proxy::all(proxy_addr).unwrap());
}
let request_timeout = env::get("OPENAI_REQUEST_TIMEOUT", "");
if !request_timeout.is_empty()
&& let Ok(timeout) = request_timeout.parse::<u64>()
{
trace!("Setting request timeout to: {request_timeout}ms");
http_client_builder = http_client_builder.timeout(Duration::from_millis(timeout));
}
let http_client = http_client_builder.build().unwrap();
let client = Client::with_config(ai_config).with_http_client(http_client);
OpenAI { client }
}
pub async fn check_model(&self, model_name: &str) -> Result<(), Box<dyn Error>> {
match self.client.models().list().await {
Ok(list) => {
debug!(
"Available models: {:?}",
list.data.iter().map(|m| &m.id).collect::<Vec<_>>()
);
if list.data.iter().any(|model| model.id == model_name) {
debug!("OpenAI API is reachable and model {model_name} is available");
Ok(())
} else {
Err(format!("Model {model_name} not found").into())
}
}
Err(e) => Err(e.into()),
}
}
pub async fn chat(
&self,
model_name: &str,
message: Vec<ChatCompletionRequestMessage>,
) -> Result<String, OpenAIError> {
let request = CreateChatCompletionRequestArgs::default()
.model(model_name)
.messages(message)
.build()?;
trace!("✨ Using model: {}", model_name);
let response = match self.client.chat().create(request).await {
Ok(s) => s,
Err(e) => return Err(e),
};
let mut result = vec![];
response.choices.iter().for_each(|choice| {
result.push(choice.message.content.as_ref().unwrap().to_string());
});
if let Option::Some(usage) = response.usage {
debug!(
"usage: completion_tokens: {}, prompt_tokens: {}, total_tokens: {}",
usage.completion_tokens, usage.prompt_tokens, usage.total_tokens
);
}
Ok(result.join("\n"))
}
pub fn prompt(logs: &[String], diff: &[String]) -> Result<String, Box<dyn Error>> {
let template = PromptTemplate {
logs: &logs.join("\n"),
diffs: &diff.join("\n"),
};
match template.render() {
Ok(content) => Ok(content),
Err(e) => Err(Box::new(e)),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::git::repository::Repository;
use tracing::error;
fn setup_repo() -> Result<Repository, Box<dyn Error>> {
let repo_path = std::env::var("TEST_REPO_PATH")
.map_err(|_| "TEST_REPO_PATH environment variable not set")?;
if repo_path.is_empty() {
return Err("Please specify the repository path".into());
}
Repository::new(&repo_path)
}
#[test]
fn test_prompt() {
let repo = setup_repo();
if repo.is_err() {
error!("Please specify the repository path");
return;
}
assert!(repo.is_ok());
let repo = repo.unwrap();
let diffs = repo.get_diff();
assert!(diffs.is_ok());
let logs = repo.get_logs(5);
assert!(logs.is_ok());
let diff_content = diffs.unwrap();
assert!(!diff_content.is_empty());
let logs_content = logs.unwrap();
assert!(!logs_content.is_empty());
let result = OpenAI::prompt(&logs_content, &diff_content).unwrap();
assert!(!result.is_empty());
}
}