use async_trait::async_trait;
use chrono::Utc;
use futures::{Stream, StreamExt};
use paladin_core::platform::container::content::{ContentItem, ContentType};
use paladin_core::platform::container::prompt::{PromptItem, PromptRole, PromptType};
use paladin_ports::output::llm_port::{
FinishReason, LlmError, LlmPort, LlmRequest, LlmResponse, ProviderCapabilities,
StreamingResponse, TokenUsage,
};
use rand::Rng;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
use std::pin::Pin;
use std::time::Duration;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct OpenAIConfig {
pub api_key: String,
pub base_url: String,
pub organization: Option<String>,
pub timeout_seconds: u64,
pub max_retries: u32,
}
impl OpenAIConfig {
pub fn from_env() -> Result<Self, String> {
let api_key = env::var("OPENAI_API_KEY")
.map_err(|_| "OPENAI_API_KEY environment variable not set")?;
let base_url =
env::var("OPENAI_BASE_URL").unwrap_or_else(|_| "https://api.openai.com/v1".to_string());
let organization = env::var("OPENAI_ORGANIZATION").ok();
let timeout_seconds = env::var("OPENAI_TIMEOUT_SECONDS")
.unwrap_or_else(|_| "300".to_string())
.parse()
.map_err(|_| "Invalid OPENAI_TIMEOUT_SECONDS value")?;
let max_retries = env::var("OPENAI_MAX_RETRIES")
.unwrap_or_else(|_| "3".to_string())
.parse()
.map_err(|_| "Invalid OPENAI_MAX_RETRIES value")?;
Ok(Self {
api_key,
base_url,
organization,
timeout_seconds,
max_retries,
})
}
pub fn new(api_key: String) -> Self {
Self {
api_key,
base_url: "https://api.openai.com/v1".to_string(),
organization: None,
timeout_seconds: 300,
max_retries: 3,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.api_key.is_empty() {
return Err("API key cannot be empty".to_string());
}
if self.base_url.is_empty() {
return Err("Base URL cannot be empty".to_string());
}
if !self.base_url.starts_with("http") {
return Err("Base URL must start with http or https".to_string());
}
Ok(())
}
}
#[derive(Debug, Serialize)]
struct OpenAIRequest {
model: String,
messages: Vec<OpenAIMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
stream: bool,
}
#[derive(Debug, Serialize, Deserialize)]
struct OpenAIMessage {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
#[allow(dead_code)]
id: String,
model: String,
choices: Vec<OpenAIChoice>,
usage: OpenAIUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
#[allow(dead_code)]
index: u32,
message: OpenAIMessage,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChunk {
#[allow(dead_code)]
id: String,
choices: Vec<OpenAIStreamChoice>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamChoice {
#[allow(dead_code)]
index: u32,
delta: OpenAIStreamDelta,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OpenAIStreamDelta {
#[allow(dead_code)]
role: Option<String>,
content: Option<String>,
}
pub struct OpenAIAdapter {
pub(crate) config: OpenAIConfig,
pub(crate) client: Client,
}
impl OpenAIAdapter {
pub fn new(config: OpenAIConfig) -> Result<Self, String> {
config.validate()?;
let client = Client::builder()
.timeout(Duration::from_secs(config.timeout_seconds))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
Ok(Self { config, client })
}
pub fn from_env() -> Result<Self, String> {
Self::new(OpenAIConfig::from_env()?)
}
fn convert_to_messages(
&self,
prompt: &PromptItem,
attachments: &[ContentItem],
) -> Result<Vec<OpenAIMessage>, LlmError> {
let mut messages = Vec::new();
match prompt.prompt_type() {
PromptType::System(system_prompt) => {
let mut content = system_prompt.instructions.clone();
if let Some(constraints) = &system_prompt.constraints
&& !constraints.is_empty()
{
content.push_str("\n\nConstraints:\n");
for constraint in constraints {
content.push_str(&format!("- {}\n", constraint));
}
}
messages.push(OpenAIMessage {
role: "system".to_string(),
content,
});
}
PromptType::User(user_prompt) => {
messages.push(OpenAIMessage {
role: "user".to_string(),
content: user_prompt.context.clone().unwrap_or_default(),
});
}
PromptType::Assistant(assistant_prompt) => {
let mut content = assistant_prompt.response.clone();
if let Some(reasoning) = &assistant_prompt.reasoning {
content.push_str(&format!("\n\nReasoning: {}", reasoning));
}
messages.push(OpenAIMessage {
role: "assistant".to_string(),
content,
});
}
PromptType::Text(text_prompt) => {
let role = match text_prompt.role {
PromptRole::System => "system",
PromptRole::User => "user",
PromptRole::Assistant => "assistant",
PromptRole::Function => "function",
};
messages.push(OpenAIMessage {
role: role.to_string(),
content: text_prompt.content.clone(),
});
}
PromptType::Function(function_prompt) => {
messages.push(OpenAIMessage {
role: "function".to_string(),
content: function_prompt.function_name.clone(),
});
}
}
for content in attachments {
if let Ok(content_text) = self.convert_content_to_text(content)
&& !content_text.is_empty()
{
messages.push(OpenAIMessage {
role: "user".to_string(),
content: format!("Content to analyze:\n{}", content_text),
});
}
}
Ok(messages)
}
fn convert_content_to_text(&self, content: &ContentItem) -> Result<String, LlmError> {
match content.content() {
ContentType::Text(text_content) => {
Ok(text_content.content.as_deref().unwrap_or("").to_string())
}
ContentType::Video(video_content) => Ok(format!(
"Video: {} (Duration: {}s)",
content.title().unwrap_or(&"Untitled".to_string()),
video_content.duration
)),
ContentType::Audio(audio_content) => Ok(format!(
"Audio: {} (Duration: {}s)",
content.title().unwrap_or(&"Untitled".to_string()),
audio_content.duration
)),
ContentType::Image(image_content) => Ok(format!(
"Image: {} ({}x{})",
content.title().unwrap_or(&"Untitled".to_string()),
image_content.resolution.0,
image_content.resolution.1
)),
}
}
fn convert_finish_reason(&self, reason: Option<String>) -> FinishReason {
match reason.as_deref() {
Some("stop") => FinishReason::Stop,
Some("length") => FinishReason::Length,
Some("content_filter") => FinishReason::ContentFilter,
Some("function_call") => FinishReason::FunctionCall,
Some(other) => FinishReason::Error(format!("Unknown: {}", other)),
None => FinishReason::Stop,
}
}
async fn make_request_with_retries(
&self,
request: &OpenAIRequest,
) -> Result<OpenAIResponse, LlmError> {
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
match self.make_single_request(request).await {
Ok(response) => return Ok(response),
Err(e) => {
last_error = Some(e.clone());
if matches!(e, LlmError::AuthenticationError(_)) {
return Err(e);
}
if attempt < self.config.max_retries {
let base_delay = Duration::from_secs(1);
let exponential_delay = base_delay * 2_u32.pow(attempt);
let max_delay = Duration::from_secs(10);
let delay = exponential_delay.min(max_delay);
let jitter_ms = {
let mut rng = rand::thread_rng();
rng.gen_range(0..=(delay.as_millis() / 5)) as u64
};
let total_delay = delay + Duration::from_millis(jitter_ms);
tokio::time::sleep(total_delay).await;
}
}
}
}
Err(last_error
.unwrap_or_else(|| LlmError::ProcessingError("Maximum retries exceeded".to_string())))
}
async fn make_single_request(
&self,
request: &OpenAIRequest,
) -> Result<OpenAIResponse, LlmError> {
let url = format!("{}/chat/completions", self.config.base_url);
let mut req = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json");
if let Some(org) = &self.config.organization {
req = req.header("OpenAI-Organization", org);
}
let response = req
.json(request)
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Request failed: {}", e)))?;
let status = response.status();
let response_text = response
.text()
.await
.map_err(|e| LlmError::ProcessingError(format!("Failed to read response: {}", e)))?;
if !status.is_success() {
return match status.as_u16() {
401 => Err(LlmError::AuthenticationError(
"Invalid OpenAI API key".to_string(),
)),
429 => Err(LlmError::RateLimitExceeded),
400 => {
if response_text.contains("maximum context length") {
Err(LlmError::TokenLimitExceeded)
} else {
Err(LlmError::InvalidPrompt(response_text))
}
}
500..=599 => Err(LlmError::ProcessingError(format!(
"OpenAI server error: {}",
response_text
))),
_ => Err(LlmError::ProcessingError(format!(
"HTTP {}: {}",
status, response_text
))),
};
}
serde_json::from_str::<OpenAIResponse>(&response_text)
.map_err(|e| LlmError::ProcessingError(format!("Failed to parse response: {}", e)))
}
async fn make_streaming_request(
&self,
request: &OpenAIRequest,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamingResponse, LlmError>> + Send>>, LlmError>
{
let url = format!("{}/chat/completions", self.config.base_url);
let mut req = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json");
if let Some(org) = &self.config.organization {
req = req.header("OpenAI-Organization", org);
}
let response = req
.json(request)
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(match status.as_u16() {
401 => LlmError::AuthenticationError("Invalid OpenAI API key".to_string()),
429 => LlmError::RateLimitExceeded,
400 => LlmError::InvalidPrompt(error_text),
_ => LlmError::ProcessingError(format!("HTTP {}: {}", status, error_text)),
});
}
let stream = response.bytes_stream().map(|chunk_result| {
chunk_result
.map_err(|e| LlmError::NetworkError(format!("Stream error: {}", e)))
.and_then(|chunk| {
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return Ok(StreamingResponse {
id: Uuid::new_v4(),
delta: String::new(),
finish_reason: Some(FinishReason::Stop),
});
}
match serde_json::from_str::<OpenAIStreamChunk>(data) {
Ok(chunk) => {
if let Some(choice) = chunk.choices.first() {
let delta =
choice.delta.content.clone().unwrap_or_default();
let finish_reason =
choice.finish_reason.as_ref().map(|r| {
match r.as_str() {
"stop" => FinishReason::Stop,
"length" => FinishReason::Length,
"content_filter" => FinishReason::ContentFilter,
"function_call" => FinishReason::FunctionCall,
other => FinishReason::Error(format!(
"Unknown: {}",
other
)),
}
});
return Ok(StreamingResponse {
id: Uuid::new_v4(),
delta,
finish_reason,
});
}
}
Err(e) => {
return Err(LlmError::ProcessingError(format!(
"Failed to parse stream chunk: {}",
e
)));
}
}
}
}
Ok(StreamingResponse {
id: Uuid::new_v4(),
delta: String::new(),
finish_reason: None,
})
})
});
Ok(Box::pin(stream))
}
}
#[async_trait]
impl LlmPort for OpenAIAdapter {
async fn generate(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
let messages = self.convert_to_messages(&request.prompt, &request.attachments)?;
let temperature = request
.prompt
.node
.node
.parameters
.temperature
.unwrap_or(0.7);
let max_tokens = request
.prompt
.node
.node
.parameters
.max_tokens
.unwrap_or(4096);
let openai_request = OpenAIRequest {
model: request.model.clone(),
messages,
temperature: Some(temperature),
max_tokens: Some(max_tokens),
top_p: Some(1.0),
stream: false,
};
let response = self.make_request_with_retries(&openai_request).await?;
if response.choices.is_empty() {
return Err(LlmError::ProcessingError(
"No choices in response".to_string(),
));
}
let choice = &response.choices[0];
let finish_reason = self.convert_finish_reason(choice.finish_reason.clone());
Ok(LlmResponse {
id: Uuid::new_v4(),
request_id: request.id,
model: response.model,
content: choice.message.content.clone(),
finish_reason,
usage: TokenUsage {
prompt_tokens: response.usage.prompt_tokens,
completion_tokens: response.usage.completion_tokens,
total_tokens: response.usage.total_tokens,
},
created_at: Utc::now(),
metadata: HashMap::new(),
function_call: None,
})
}
async fn generate_stream(
&self,
request: LlmRequest,
) -> Result<Box<dyn Stream<Item = Result<StreamingResponse, LlmError>> + Send>, LlmError> {
let messages = self.convert_to_messages(&request.prompt, &request.attachments)?;
let temperature = request
.prompt
.node
.node
.parameters
.temperature
.unwrap_or(0.7);
let max_tokens = request
.prompt
.node
.node
.parameters
.max_tokens
.unwrap_or(4096);
let openai_request = OpenAIRequest {
model: request.model.clone(),
messages,
temperature: Some(temperature),
max_tokens: Some(max_tokens),
top_p: Some(1.0),
stream: true,
};
let stream = self.make_streaming_request(&openai_request).await?;
Ok(Box::new(stream))
}
async fn validate_model(&self, model: &str) -> Result<bool, LlmError> {
let available_models = self.get_available_models().await?;
Ok(available_models.contains(&model.to_string()))
}
async fn get_available_models(&self) -> Result<Vec<String>, LlmError> {
let url = format!("{}/models", self.config.base_url);
let mut req = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key));
if let Some(org) = &self.config.organization {
req = req.header("OpenAI-Organization", org);
}
let response = req
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Failed to fetch models: {}", e)))?;
if !response.status().is_success() {
return Err(LlmError::ProcessingError(format!(
"HTTP {}",
response.status()
)));
}
let response_text = response
.text()
.await
.map_err(|e| LlmError::ProcessingError(format!("Failed to read response: {}", e)))?;
let models_response: serde_json::Value = serde_json::from_str(&response_text)
.map_err(|e| LlmError::ProcessingError(format!("Failed to parse response: {}", e)))?;
let models = models_response["data"]
.as_array()
.ok_or_else(|| LlmError::ProcessingError("Invalid models response format".to_string()))?
.iter()
.filter_map(|model| model["id"].as_str().map(String::from))
.collect();
Ok(models)
}
fn get_provider_name(&self) -> &'static str {
"openai"
}
fn get_capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
supports_streaming: true,
supports_tool_calling: true,
supports_function_calling: true,
supports_vision: true,
max_context_tokens: Some(128000),
supports_embeddings: true,
supports_system_messages: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_creation() {
let config = OpenAIConfig::new("test-key".to_string());
assert_eq!(config.api_key, "test-key");
assert_eq!(config.base_url, "https://api.openai.com/v1");
assert_eq!(config.timeout_seconds, 300);
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_config_validation() {
let valid_config = OpenAIConfig::new("test-key".to_string());
assert!(valid_config.validate().is_ok());
let invalid_config = OpenAIConfig {
api_key: String::new(),
base_url: "https://api.openai.com/v1".to_string(),
organization: None,
timeout_seconds: 300,
max_retries: 3,
};
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_adapter_creation() {
let config = OpenAIConfig::new("test-key".to_string());
let adapter = OpenAIAdapter::new(config);
assert!(adapter.is_ok());
}
#[test]
fn test_get_provider_name() {
let config = OpenAIConfig::new("test-key".to_string());
let adapter = OpenAIAdapter::new(config).unwrap();
assert_eq!(adapter.get_provider_name(), "openai");
}
#[test]
fn test_get_capabilities() {
let config = OpenAIConfig::new("test-key".to_string());
let adapter = OpenAIAdapter::new(config).unwrap();
let caps = adapter.get_capabilities();
assert!(caps.supports_streaming);
assert!(caps.supports_tool_calling);
assert!(caps.supports_vision);
assert_eq!(caps.max_context_tokens, Some(128000));
}
#[test]
fn test_config_with_organization() {
let mut config = OpenAIConfig::new("test-key".to_string());
config.organization = Some("org-123".to_string());
assert_eq!(config.organization, Some("org-123".to_string()));
}
#[test]
fn test_config_validation_empty_base_url() {
let config = OpenAIConfig {
api_key: "test-key".to_string(),
base_url: String::new(),
organization: None,
timeout_seconds: 300,
max_retries: 3,
};
assert!(config.validate().is_err());
}
}