use crate::built_info;
use crate::utils::env;
use askama::Template;
use async_openai::config::OPENAI_API_BASE;
use async_openai::error::OpenAIError;
use async_openai::{
Client,
config::OpenAIConfig,
types::chat::{ChatCompletionRequestMessage, CreateChatCompletionRequestArgs},
};
use log::trace;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{ClientBuilder, Proxy};
use std::error::Error;
use std::time::Duration;
use tracing::{debug, warn};
#[derive(Template)]
#[template(path = "user.txt")]
struct PromptTemplate<'a> {
logs: &'a str,
diffs: &'a str,
}
pub struct OpenAI {
client: Client<OpenAIConfig>,
}
impl OpenAI {
pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
let ai_config = OpenAIConfig::new()
.with_api_key(env::get("OPENAI_API_TOKEN", ""))
.with_api_base(env::get("OPENAI_API_BASE", OPENAI_API_BASE))
.with_org_id(built_info::PKG_NAME);
let mut http_client_builder = Self::create_http_client_builder();
if let Some(proxy_addr) = Self::get_proxy_config() {
trace!("Using proxy: {proxy_addr}");
if let Ok(proxy) = Proxy::all(&proxy_addr) {
http_client_builder = http_client_builder.proxy(proxy);
}
}
if let Some(timeout) = Self::get_timeout_config() {
trace!("Setting request timeout to: {}s", timeout.as_secs());
http_client_builder = http_client_builder.timeout(timeout);
} else {
trace!("No request timeout configured; using HTTP client default");
}
let http_client = http_client_builder.build().map_err(|e| {
Box::new(std::io::Error::other(format!(
"failed to build HTTP client: {e}"
))) as Box<dyn std::error::Error>
})?;
let client = Client::with_config(ai_config).with_http_client(http_client);
Ok(Self { client })
}
#[inline]
fn create_http_client_builder() -> ClientBuilder {
let mut headers = HeaderMap::new();
headers.insert(
"HTTP-Referer",
HeaderValue::from_static(built_info::PKG_HOMEPAGE),
);
headers.insert("X-Title", HeaderValue::from_static(built_info::PKG_NAME));
headers.insert("X-Client-Type", HeaderValue::from_static("CLI"));
ClientBuilder::new()
.user_agent(format!(
"{} ({})",
built_info::PKG_NAME,
built_info::PKG_DESCRIPTION
))
.default_headers(headers)
}
#[inline]
fn get_proxy_config() -> Option<String> {
let proxy_addr = env::get("OPENAI_API_PROXY", "");
(!proxy_addr.is_empty()).then_some(proxy_addr)
}
#[inline]
fn get_timeout_config() -> Option<Duration> {
for key in ["OPENAI_API_TIMEOUT", "OPENAI_REQUEST_TIMEOUT"] {
let raw = env::get(key, "");
if raw.is_empty() {
continue;
}
match raw.trim().parse::<u64>() {
Ok(0) => {
trace!("{key}=0 interpreted as no timeout");
return None;
}
Ok(secs) => return Some(Duration::from_secs(secs)),
Err(e) => {
warn!("ignoring invalid {key}={raw:?}: {e}");
}
}
}
None
}
pub async fn check_model(&self, model_name: &str) -> Result<(), Box<dyn Error>> {
let list = self.client.models().list().await?;
debug!(
"Available models: {:?}",
list.data.iter().map(|m| &m.id).collect::<Vec<_>>()
);
if list.data.iter().any(|model| model.id == model_name) {
debug!("OpenAI API is reachable and model {model_name} is available");
Ok(())
} else {
Err(format!("Model {model_name} not found").into())
}
}
pub async fn chat(
&self,
model_name: &str,
message: Vec<ChatCompletionRequestMessage>,
) -> Result<String, OpenAIError> {
let request = CreateChatCompletionRequestArgs::default()
.model(model_name)
.messages(message)
.build()?;
trace!("✨ Using model: {}", model_name);
let response = self.client.chat().create(request).await?;
let result: Vec<String> = response
.choices
.iter()
.filter_map(|choice| choice.message.content.as_ref().map(ToString::to_string))
.collect();
if let Some(usage) = response.usage {
debug!(
"usage: completion_tokens: {}, prompt_tokens: {}, total_tokens: {}",
usage.completion_tokens, usage.prompt_tokens, usage.total_tokens
);
}
Ok(result.join("\n"))
}
pub fn prompt(logs: &[String], diff: &[String]) -> Result<String, Box<dyn Error>> {
let template = PromptTemplate {
logs: &logs.join("\n"),
diffs: &diff.join("\n"),
};
Ok(template.render()?)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::git::repository::Repository;
use tracing::error;
fn setup_repo() -> Result<Repository, Box<dyn Error>> {
let repo_path = std::env::var("TEST_REPO_PATH")
.map_err(|_| "TEST_REPO_PATH environment variable not set")?;
if repo_path.is_empty() {
return Err("Please specify the repository path".into());
}
Repository::new(&repo_path)
}
#[test]
fn test_prompt() {
let repo = setup_repo();
if repo.is_err() {
error!("Please specify the repository path");
return;
}
assert!(repo.is_ok());
let repo = repo.unwrap();
let diffs = repo.get_diff();
assert!(diffs.is_ok());
let logs = repo.get_logs(5);
assert!(logs.is_ok());
let diff_content = diffs.unwrap();
assert!(!diff_content.is_empty());
let logs_content = logs.unwrap();
assert!(!logs_content.is_empty());
let result = OpenAI::prompt(&logs_content, &diff_content).unwrap();
assert!(!result.is_empty());
}
#[test]
fn timeout_config_handles_unset_zero_and_invalid() {
unsafe {
std::env::remove_var("OPENAI_API_TIMEOUT");
std::env::remove_var("OPENAI_REQUEST_TIMEOUT");
}
assert!(OpenAI::get_timeout_config().is_none(), "unset => None");
unsafe { std::env::set_var("OPENAI_API_TIMEOUT", "0") };
assert!(
OpenAI::get_timeout_config().is_none(),
"0 should disable the timeout, not apply 0s"
);
unsafe { std::env::set_var("OPENAI_API_TIMEOUT", "not-a-number") };
assert!(
OpenAI::get_timeout_config().is_none(),
"invalid value => None (default)"
);
unsafe { std::env::set_var("OPENAI_API_TIMEOUT", " 45 ") };
assert_eq!(
OpenAI::get_timeout_config(),
Some(Duration::from_secs(45)),
"valid seconds value should be honored"
);
unsafe { std::env::remove_var("OPENAI_API_TIMEOUT") };
}
}