use std::time::Duration;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use secrecy::{ExposeSecret, SecretString};
use crate::config::Config;
use crate::error::{Error, Result};
use super::MAX_RESPONSE_BYTES;
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub struct OpenAiProvider {
client: Client,
base_url: String,
model: String,
api_key: SecretString,
temperature: f32,
max_tokens: u32,
}
#[derive(Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f32,
max_tokens: u32,
stream: bool,
}
#[derive(Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Deserialize)]
struct ChatChunk {
choices: Vec<ChunkChoice>,
}
#[derive(Deserialize)]
struct ChunkChoice {
delta: Delta,
finish_reason: Option<String>,
}
#[derive(Deserialize)]
struct Delta {
content: Option<String>,
}
impl OpenAiProvider {
pub fn new(config: &Config) -> Result<Self> {
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_secs))
.build()
.map_err(|e| Error::Provider {
provider: "openai".into(),
message: format!("failed to build HTTP client: {e}"),
})?;
Ok(Self {
client,
base_url: config
.openai_base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string())
.trim_end_matches('/')
.to_string(),
model: config.model.clone(),
api_key: config.api_key.clone().unwrap_or_default(),
temperature: config.temperature,
max_tokens: config.num_predict,
})
}
pub async fn verify_connection(&self) -> Result<()> {
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.header(
"Authorization",
format!("Bearer {}", self.api_key.expose_secret()),
)
.send()
.await
.map_err(|e| Error::Provider {
provider: "openai".into(),
message: e.without_url().to_string(),
})?;
if response.status() == reqwest::StatusCode::UNAUTHORIZED {
return Err(Error::Provider {
provider: "openai".into(),
message: "invalid API key".into(),
});
}
Ok(())
}
pub async fn generate(
&self,
prompt: &str,
system_prompt: &str,
token_tx: mpsc::Sender<String>,
cancel: CancellationToken,
) -> Result<String> {
let url = format!("{}/chat/completions", self.base_url);
let response = self
.client
.post(&url)
.header(
"Authorization",
format!("Bearer {}", self.api_key.expose_secret()),
)
.json(&ChatRequest {
model: self.model.clone(),
messages: vec![
Message {
role: "system".into(),
content: system_prompt.into(),
},
Message {
role: "user".into(),
content: prompt.to_string(),
},
],
temperature: self.temperature,
max_tokens: self.max_tokens,
stream: true,
})
.send()
.await
.map_err(|e| {
if e.is_timeout() {
Error::Provider {
provider: "openai".into(),
message: "request timed out".into(),
}
} else {
Error::Provider {
provider: "openai".into(),
message: e.without_url().to_string(),
}
}
})?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|e| format!("(failed to read body: {e})"));
return Err(Error::Provider {
provider: "openai".into(),
message: format!("HTTP {status}: {body}"),
});
}
let mut stream = response.bytes_stream();
let mut full_response = String::new();
let mut line_buffer = String::new();
loop {
tokio::select! {
_ = cancel.cancelled() => {
return Err(Error::Cancelled);
}
chunk = stream.next() => {
let Some(chunk) = chunk else { break };
let chunk = chunk.map_err(|e| Error::Provider {
provider: "openai".into(),
message: e.without_url().to_string(),
})?;
line_buffer.push_str(&String::from_utf8_lossy(&chunk));
if line_buffer.len() > MAX_RESPONSE_BYTES {
return Err(Error::Provider {
provider: "openai".into(),
message: "line buffer exceeded 1 MB limit".into(),
});
}
while let Some(newline_pos) = line_buffer.find('\n') {
let result = {
let line = line_buffer[..newline_pos].trim();
if line.is_empty() || line == "data: [DONE]" {
None
} else if let Some(data) = line.strip_prefix("data: ") {
serde_json::from_str::<ChatChunk>(data).ok()
} else {
None
}
};
line_buffer.drain(..=newline_pos);
if let Some(chunk) = result {
for choice in &chunk.choices {
if let Some(ref content) = choice.delta.content {
let _ = token_tx.send(content.clone()).await;
full_response.push_str(content);
}
if full_response.len() > MAX_RESPONSE_BYTES {
return Err(Error::Provider {
provider: "openai".into(),
message: "response exceeded 1 MB limit".into(),
});
}
if choice.finish_reason.is_some() {
return Ok(full_response.trim().to_string());
}
}
}
}
}
}
}
if !line_buffer.is_empty() {
let line = line_buffer.trim();
if !line.is_empty()
&& line != "data: [DONE]"
&& let Some(data) = line.strip_prefix("data: ")
&& let Ok(chunk) = serde_json::from_str::<ChatChunk>(data)
{
for choice in &chunk.choices {
if let Some(ref content) = choice.delta.content {
full_response.push_str(content);
}
}
}
}
Ok(full_response.trim().to_string())
}
pub fn name(&self) -> &str {
"openai"
}
}