#![allow(dead_code)]
use super::{AIClient, AIConfig, AIResponse};
use anyhow::Result;
use futures_util::stream::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use tracing::debug;
pub struct OpenAIClient {
client: reqwest::Client,
api_key: String,
base_url: String,
model: String,
max_tokens: u32,
temperature: f32,
}
#[derive(Debug, Serialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
max_tokens: u32,
temperature: f32,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<Choice>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct Choice {
index: i32,
message: Option<Message>,
delta: Option<Delta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Delta {
content: Option<String>,
}
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
impl OpenAIClient {
pub fn new(config: &AIConfig) -> Result<Self> {
let base = std::env::var("OPENAI_BASE_URL")
.ok()
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
Self::with_base_url(config, &base)
}
pub fn with_base_url(config: &AIConfig, base_url: &str) -> Result<Self> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(60))
.build()?;
Ok(Self {
client,
api_key: config.api_key.clone(),
base_url: base_url.to_string(),
model: config.model.clone(),
max_tokens: config.max_tokens,
temperature: config.temperature,
})
}
fn build_headers(&self) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
let auth_header = format!("Bearer {}", self.api_key);
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&auth_header)?,
);
Ok(headers)
}
}
#[async_trait::async_trait]
impl AIClient for OpenAIClient {
async fn complete(
&self,
prompt: &str,
_context: Option<&super::AIContext>,
) -> Result<AIResponse> {
debug!("Sending completion request to OpenAI");
let url = format!("{}/chat/completions", self.base_url);
let messages = vec![
Message {
role: "system".to_string(),
content: "You are a helpful AI assistant for developers.".to_string(),
},
Message {
role: "user".to_string(),
content: prompt.to_string(),
},
];
let request = ChatCompletionRequest {
model: self.model.clone(),
messages,
max_tokens: self.max_tokens,
temperature: self.temperature,
stream: false,
};
let response = self
.client
.post(&url)
.headers(self.build_headers()?)
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error_text = response.text().await?;
anyhow::bail!("OpenAI API error: {}", error_text);
}
let completion: ChatCompletionResponse = response.json().await?;
let content = completion
.choices
.get(0)
.and_then(|c| c.message.as_ref())
.map(|m| m.content.clone())
.unwrap_or_default();
let tokens_used = completion.usage.as_ref().map(|u| u.total_tokens);
Ok(AIResponse {
content,
tokens_used,
model: completion.model,
finish_reason: completion.choices.get(0).and_then(|c| c.finish_reason.clone()),
cost_estimate: None,
})
}
fn model_info(&self) -> super::ModelConfig {
super::ModelConfig {
name: self.model.clone(),
context_window: 128_000,
supports_functions: true,
supports_vision: false,
cost_per_1k_input: 0.03,
cost_per_1k_output: 0.06,
}
}
async fn complete_stream(
&self,
prompt: &str,
_context: Option<&super::AIContext>,
) -> Result<tokio::sync::mpsc::Receiver<Result<String>>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let url = format!("{}/chat/completions", self.base_url);
let api_key = self.api_key.clone();
let model = self.model.clone();
let max_tokens = self.max_tokens;
let temperature = self.temperature;
let prompt = prompt.to_string();
let client = self.client.clone();
tokio::spawn(async move {
let messages = vec![
Message {
role: "system".to_string(),
content: "You are a helpful AI assistant for developers.".to_string(),
},
Message {
role: "user".to_string(),
content: prompt,
},
];
let request = ChatCompletionRequest {
model,
messages,
max_tokens,
temperature,
stream: true,
};
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", api_key)).unwrap(),
);
let response = match client
.post(&url)
.headers(headers)
.json(&request)
.send()
.await
{
Ok(r) => r,
Err(e) => {
let _ = tx.send(Err(e.into())).await;
return;
}
};
if !response.status().is_success() {
let error_text = match response.text().await {
Ok(t) => t,
Err(e) => format!("Failed to read error: {}", e),
};
let _ = tx.send(Err(anyhow::anyhow!("API error: {}", error_text))).await;
return;
}
let mut stream: futures_util::stream::BoxStream<'_, Result<bytes::Bytes, reqwest::Error>> = response.bytes_stream().boxed();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
return;
}
match serde_json::from_str::<ChatCompletionResponse>(data) {
Ok(resp) => {
if let Some(choice) = resp.choices.get(0) {
if let Some(delta) = &choice.delta {
if let Some(content) = &delta.content {
if tx.send(Ok(content.clone())).await.is_err() {
return;
}
}
}
}
}
Err(_) => continue,
}
}
}
}
Err(e) => {
let _ = tx.send(Err::<String, _>(anyhow::anyhow!("Stream error: {}", e))).await;
return;
}
}
}
});
Ok(rx)
}
fn name(&self) -> &str {
"OpenAI"
}
}