use super::LlmClient;
use super::stream::read_stream_to_string;
use crate::{FileChange};
use crate::git::{PrItem, PrSummaryMode};
use anyhow::{anyhow, Context, Result};
use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
use std::io::BufReader;
use std::time::Duration;
use super::prompt_builder;
#[derive(Serialize)]
struct ChatRequest {
model: String,
messages: Vec<ChatMessage>,
stream: bool,
}
#[derive(Serialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Deserialize)]
struct ChatResponse {
choices: Vec<ChatChoice>,
usage: Option<ChatUsage>,
}
#[derive(Deserialize)]
struct ChatChoice {
message: ChatMessageResponse,
}
#[derive(Deserialize)]
struct ChatMessageResponse {
content: String,
}
#[derive(Deserialize)]
struct ChatUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
#[derive(Deserialize)]
struct StreamResponse {
choices: Vec<StreamChoice>,
}
#[derive(Deserialize)]
struct StreamChoice {
delta: StreamDelta,
}
#[derive(Deserialize)]
struct StreamDelta {
content: Option<String>,
}
pub struct OpenAiClient {
client: Client,
api_key: String,
model: String,
api_base_url: String,
stream: bool,
}
impl OpenAiClient {
pub fn new(api_key: String, model: String, api_base_url: String, stream: bool) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(90))
.build()
.expect("failed to build HTTP client");
OpenAiClient {
client,
api_key,
model,
api_base_url: api_base_url.trim_end_matches('/').to_string(),
stream,
}
}
fn chat_url(&self) -> String {
if self.api_base_url.ends_with("/v1") {
format!("{}/chat/completions", self.api_base_url)
} else {
format!("{}/v1/chat/completions", self.api_base_url)
}
}
fn call_chat(&self, req: &ChatRequest) -> Result<String> {
if req.stream {
return self.call_chat_streaming(req);
}
let url = self.chat_url();
log::info!("Calling OpenAI model {:?}", &req.model);
let resp = self
.client
.post(url)
.bearer_auth(&self.api_key)
.json(req)
.send()
.context("failed to send request to OpenAI")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().unwrap_or_default();
return Err(anyhow!(
"OpenAI API error: HTTP {} - {}",
status.as_u16(),
text
));
}
let chat_resp: ChatResponse = resp.json().context("failed to parse OpenAI response")?;
let content = chat_resp
.choices
.first()
.map(|c| c.message.content.clone())
.ok_or_else(|| anyhow!("no choices returned from OpenAI"))?;
if let Some(usage) = &chat_resp.usage {
log::warn!("Token usage: prompt={}, completion={}, total={}",
usage.prompt_tokens, usage.completion_tokens, usage.total_tokens
);
}
Ok(content)
}
fn call_chat_streaming(&self, req: &ChatRequest) -> Result<String> {
let url = self.chat_url();
log::info!("Streaming OpenAI model {:?}", &req.model);
let resp = self
.client
.post(url)
.bearer_auth(&self.api_key)
.json(req)
.send()
.context("failed to send streaming request to OpenAI")?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().unwrap_or_default();
return Err(anyhow!(
"OpenAI API error: HTTP {} - {}",
status.as_u16(),
text
));
}
let reader = BufReader::new(resp);
read_stream_to_string(reader, parse_stream_line)
}
}
fn parse_stream_line(line: &str) -> Result<Option<String>> {
let line = line.trim_start();
if !line.starts_with("data:") {
return Ok(None);
}
let data = line.trim_start_matches("data:").trim();
if data == "[DONE]" {
return Ok(None);
}
let chunk: StreamResponse =
serde_json::from_str(data).context("failed to parse OpenAI streaming chunk")?;
let content = chunk
.choices
.first()
.and_then(|c| c.delta.content.clone());
Ok(content)
}
impl LlmClient for OpenAiClient {
fn summarize_file(
&self,
branch: &str,
file: &FileChange,
file_index: usize,
total_files: usize,
ticket_summary: Option<&str>,
) -> Result<String> {
let prompts = prompt_builder::file_summary_prompt(branch, file, file_index, total_files, ticket_summary);
log::info!(
"Per-file summarize prompt for {} ({:?}) [truncated]:\n{}",
file.path,
file.category,
truncate(&prompts.user, 1000)
);
log::debug!(
"Per-file summarize prompt for {} ({:?}) [full]:\n--- SYSTEM ---\n{}\n--- USER ---\n{}",
file.path,
file.category,
prompts.system,
prompts.user
);
let req = ChatRequest {
model: self.model.clone(),
messages: vec![
ChatMessage {
role: "system".into(),
content: prompts.system,
},
ChatMessage {
role: "user".into(),
content: prompts.user,
},
],
stream: false,
};
let content = self.call_chat(&req)?;
Ok(content)
}
fn generate_commit_message(
&self,
branch: &str,
files: &[FileChange],
ticket_summary: Option<&str>,
) -> Result<String> {
let prompts = prompt_builder::commit_message_prompt(branch, files, ticket_summary);
log::info!(
"Final commit-message prompt [truncated]:\n{}",
truncate(&prompts.user, 1000)
);
log::debug!(
"Final commit-message prompt [full]:\n--- SYSTEM ---\n{}\n--- USER ---\n{}",
prompts.system,
prompts.user
);
let req = ChatRequest {
model: self.model.clone(),
messages: vec![
ChatMessage {
role: "system".into(),
content: prompts.system,
},
ChatMessage {
role: "user".into(),
content: prompts.user,
},
],
stream: self.stream,
};
let content = self.call_chat(&req)?;
Ok(content)
}
fn generate_pr_message(
&self,
base_branch: &str,
from_branch: &str,
mode: PrSummaryMode,
items: &[PrItem],
ticket_summary: Option<&str>,
) -> Result<String> {
let prompts =
prompt_builder::pr_message_prompt(base_branch, from_branch, mode, items, ticket_summary);
log::info!(
"PR description prompt [truncated]:\n{}",
truncate(&prompts.user, 1000)
);
log::debug!(
"PR description prompt [full]:\n--- SYSTEM ---\n{}\n--- USER ---\n{}",
prompts.system,
prompts.user
);
let req = ChatRequest {
model: self.model.clone(),
messages: vec![
ChatMessage {
role: "system".into(),
content: prompts.system,
},
ChatMessage {
role: "user".into(),
content: prompts.user,
},
],
stream: self.stream,
};
let content = self.call_chat(&req)?;
Ok(content)
}
}
fn truncate(s: &str, max_len: usize) -> String {
if s.len() <= max_len {
s.to_string()
} else {
format!("{}...\n[truncated {} chars]", &s[..max_len], s.len() - max_len)
}
}