use std::sync::mpsc::Sender;
use futures::StreamExt;
use reqwest::Client;
use serde::Serialize;
use tokio_util::sync::CancellationToken;
use super::AiError;
use super::sse::{OpenAiEventParser, SseParser};
use crate::ai::ai_state::AiResponse;
const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
#[derive(Debug, Clone)]
pub struct AsyncOpenAiClient {
client: Client,
api_key: String,
model: String,
api_url: String,
}
impl AsyncOpenAiClient {
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
let api_url = Self::build_api_url(base_url);
Self {
client: Client::new(),
api_key,
model,
api_url,
}
}
fn build_api_url(base_url: Option<String>) -> String {
match base_url {
Some(url) => {
let url = url.trim_end_matches('/');
if url.ends_with("/chat/completions") {
url.to_string()
} else {
format!("{}/chat/completions", url)
}
}
None => OPENAI_API_URL.to_string(),
}
}
pub fn is_custom_endpoint(&self) -> bool {
!self.api_url.contains("api.openai.com")
}
fn build_request_body(&self, prompt: &str) -> Result<String, AiError> {
#[derive(Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Serialize)]
struct RequestBody {
model: String,
messages: Vec<Message>,
stream: bool,
}
let body = RequestBody {
model: self.model.clone(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}],
stream: true,
};
serde_json::to_string(&body).map_err(|e| AiError::Parse {
provider: "OpenAI".to_string(),
message: format!("Failed to serialize request body: {}", e),
})
}
pub async fn stream_with_cancel(
&self,
prompt: &str,
request_id: u64,
cancel_token: CancellationToken,
response_tx: Sender<AiResponse>,
) -> Result<(), AiError> {
if cancel_token.is_cancelled() {
return Err(AiError::Cancelled);
}
let body = self.build_request_body(prompt)?;
let response = self
.client
.post(&self.api_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.body(body)
.send()
.await
.map_err(|e| AiError::Network {
provider: "OpenAI".to_string(),
message: e.to_string(),
})?;
if !response.status().is_success() {
let code = response.status().as_u16();
let message = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(AiError::Api {
provider: "OpenAI".to_string(),
code,
message,
});
}
let mut stream = response.bytes_stream();
let mut sse_parser = SseParser::new(OpenAiEventParser);
loop {
tokio::select! {
biased;
_ = cancel_token.cancelled() => {
return Err(AiError::Cancelled);
}
chunk = stream.next() => {
match chunk {
Some(Ok(bytes)) => {
for text in sse_parser.parse_chunk(&bytes) {
if response_tx
.send(AiResponse::Chunk {
text,
request_id,
})
.is_err()
{
return Ok(());
}
}
}
Some(Err(e)) => {
return Err(AiError::Network {
provider: "OpenAI".to_string(),
message: e.to_string(),
});
}
None => {
break;
}
}
}
}
}
Ok(())
}
}
#[cfg(test)]
#[path = "async_openai_tests.rs"]
mod async_openai_tests;