use crate::utils::error::OpenCratesError;
use async_openai::{
config::{AzureConfig, OpenAIConfig as AsyncOpenAIConfig},
types::{
AudioResponseFormat, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestUserMessage, ChatCompletionToolArgs, CreateChatCompletionRequestArgs,
CreateEmbeddingRequestArgs, CreateImageRequestArgs, CreateSpeechRequestArgs,
CreateTranscriptionRequestArgs, FinishReason, FunctionObjectArgs, ImageModel,
ImageResponseFormat, ImageSize, ResponseFormat, ResponseFormatJsonSchema, SpeechModel,
Voice,
},
Client,
};
use async_trait::async_trait;
use futures_util::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument, warn};
use super::{GenerationRequest, GenerationResponse, LLMProvider};
use crate::utils::config::OpenAIConfig;
use crate::utils::openai_agents::Usage;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnhancedOpenAIConfig {
pub api_key: String,
pub chat_model: String,
pub embedding_model: String,
pub image_model: String,
pub audio_model: String,
pub speech_model: String,
pub max_tokens: u16,
pub temperature: f32,
pub timeout_seconds: u64,
pub base_url: Option<String>,
pub organization: Option<String>,
}
impl Default for EnhancedOpenAIConfig {
fn default() -> Self {
Self {
api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
chat_model: "gpt-4o".to_string(),
embedding_model: "text-embedding-3-large".to_string(),
image_model: "dall-e-3".to_string(),
audio_model: "whisper-1".to_string(),
speech_model: "tts-1-hd".to_string(),
max_tokens: 4096,
temperature: 0.7,
timeout_seconds: 60,
base_url: None,
organization: None,
}
}
}
#[derive(Debug, Clone)]
pub struct EnhancedOpenAIProvider {
client: Client<AsyncOpenAIConfig>,
config: Arc<RwLock<EnhancedOpenAIConfig>>,
usage_stats: Arc<RwLock<UsageStats>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UsageStats {
pub total_requests: u64,
pub total_tokens: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub average_response_time_ms: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub input: Vec<String>,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub embeddings: Vec<Vec<f32>>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationRequest {
pub prompt: String,
pub n: Option<u8>,
pub size: Option<String>,
pub quality: Option<String>,
pub response_format: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ImageGenerationResponse {
pub images: Vec<GeneratedImage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedImage {
pub url: Option<String>,
pub b64_json: Option<String>,
pub revised_prompt: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionRequest {
pub file_path: String,
pub model: Option<String>,
pub language: Option<String>,
pub response_format: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeechRequest {
pub input: String,
pub model: Option<String>,
pub voice: Option<String>,
}
impl EnhancedOpenAIProvider {
pub async fn new() -> Result<Self, OpenCratesError> {
let config = EnhancedOpenAIConfig::default();
Self::new_with_config(config).await
}
pub async fn new_with_config(config: EnhancedOpenAIConfig) -> Result<Self, OpenCratesError> {
let mut client_config = AsyncOpenAIConfig::new().with_api_key(&config.api_key);
if let Some(ref base_url) = config.base_url {
client_config = client_config.with_api_base(base_url);
}
if let Some(ref org) = config.organization {
client_config = client_config.with_org_id(org);
}
let client = Client::with_config(client_config);
Ok(Self {
client,
config: Arc::new(RwLock::new(config)),
usage_stats: Arc::new(RwLock::new(UsageStats::default())),
})
}
pub async fn from_opencrates_config(config: &OpenAIConfig) -> Result<Self, OpenCratesError> {
let enhanced_config = EnhancedOpenAIConfig {
api_key: config
.api_key
.clone()
.unwrap_or_else(|| "test-key".to_string()),
chat_model: config.model.clone(),
embedding_model: "text-embedding-3-large".to_string(),
image_model: "dall-e-3".to_string(),
audio_model: "whisper-1".to_string(),
speech_model: "tts-1-hd".to_string(),
max_tokens: config.max_tokens as u16,
temperature: config.temperature,
timeout_seconds: 60,
base_url: config.base_url.clone(),
organization: config.organization.clone(),
};
Self::new_with_config(enhanced_config).await
}
#[instrument(skip(self, messages))]
pub async fn chat_completion(
&self,
messages: Vec<ChatCompletionRequestMessage>,
model: Option<String>,
max_tokens: Option<u16>,
temperature: Option<f32>,
stream: bool,
) -> Result<GenerationResponse, OpenCratesError> {
let config = self.config.read().await;
let model = model.unwrap_or_else(|| config.chat_model.clone());
let request = CreateChatCompletionRequestArgs::default()
.model(&model)
.messages(messages)
.max_tokens(max_tokens.unwrap_or(config.max_tokens))
.temperature(temperature.unwrap_or(config.temperature))
.stream(stream)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
if stream {
let mut stream = self
.client
.chat()
.create_stream(request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let mut content = String::new();
let mut final_response: Option<
async_openai::types::CreateChatCompletionStreamResponse,
> = None;
while let Some(result) = stream.next().await {
match result {
Ok(response) => {
for choice in &response.choices {
if let Some(delta_content) = &choice.delta.content {
content.push_str(delta_content);
}
}
final_response = Some(response);
}
Err(e) => {
error!("Stream error: {}", e);
break;
}
}
}
let usage = final_response.and_then(|r| r.usage).unwrap_or(
async_openai::types::CompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
prompt_tokens_details: None,
},
);
Ok(GenerationResponse {
preview: content,
metrics: Usage {
prompt_tokens: usage.prompt_tokens as usize,
completion_tokens: usage.completion_tokens as usize,
total_tokens: usage.total_tokens as usize,
},
finish_reason: Some("stop".to_string()),
})
} else {
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let choice = response.choices.first().unwrap();
let content = choice.message.content.clone().unwrap_or_default();
let usage = response
.usage
.unwrap_or(async_openai::types::CompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
prompt_tokens_details: None,
});
Ok(GenerationResponse {
preview: content,
metrics: Usage {
prompt_tokens: usage.prompt_tokens as usize,
completion_tokens: usage.completion_tokens as usize,
total_tokens: usage.total_tokens as usize,
},
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{r:?}")),
})
}
}
#[instrument(skip(self, request))]
pub async fn generate_embeddings(
&self,
request: EmbeddingRequest,
) -> Result<EmbeddingResponse, OpenCratesError> {
let config = self.config.read().await;
let model = request
.model
.unwrap_or_else(|| config.embedding_model.clone());
let response = self
.client
.embeddings()
.create(
CreateEmbeddingRequestArgs::default()
.model(model)
.input(request.input)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?,
)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let embeddings = response.data.into_iter().map(|d| d.embedding).collect();
let usage = Usage {
prompt_tokens: response.usage.prompt_tokens as usize,
completion_tokens: 0,
total_tokens: response.usage.total_tokens as usize,
};
Ok(EmbeddingResponse { embeddings, usage })
}
#[instrument(skip(self, request))]
pub async fn generate_image(
&self,
request: ImageGenerationRequest,
) -> Result<ImageGenerationResponse, OpenCratesError> {
let size = request.size.as_deref().unwrap_or("1024x1024");
let image_size = match size {
"256x256" => ImageSize::S256x256,
"512x512" => ImageSize::S512x512,
"1024x1024" => ImageSize::S1024x1024,
"1792x1024" => ImageSize::S1792x1024,
"1024x1792" => ImageSize::S1024x1792,
_ => ImageSize::S1024x1024,
};
let response_format = match request.response_format.as_deref() {
Some("b64_json") => ImageResponseFormat::B64Json,
_ => ImageResponseFormat::Url,
};
let image_request = CreateImageRequestArgs::default()
.prompt(&request.prompt)
.model(ImageModel::DallE3)
.n(request.n.unwrap_or(1))
.size(image_size)
.response_format(response_format)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.images()
.create(image_request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let images = response
.data
.into_iter()
.map(|image| GeneratedImage {
url: match &*image {
async_openai::types::Image::Url { url, .. } => Some(url.clone()),
_ => None,
},
b64_json: match &*image {
async_openai::types::Image::B64Json { b64_json, .. } => {
Some(b64_json.to_string())
}
_ => None,
},
revised_prompt: match &*image {
async_openai::types::Image::Url { revised_prompt, .. }
| async_openai::types::Image::B64Json { revised_prompt, .. } => {
revised_prompt.clone()
}
},
})
.collect();
Ok(ImageGenerationResponse { images })
}
#[instrument(skip(self, request))]
pub async fn transcribe_audio(
&self,
request: AudioTranscriptionRequest,
) -> Result<AudioTranscriptionResponse, OpenCratesError> {
let config = self.config.read().await;
let model = request.model.unwrap_or_else(|| config.audio_model.clone());
let response_format =
request
.response_format
.map_or(AudioResponseFormat::Json, |f| match f.as_str() {
"text" => AudioResponseFormat::Text,
"verbose_json" => AudioResponseFormat::VerboseJson,
"srt" => AudioResponseFormat::Srt,
"vtt" => AudioResponseFormat::Vtt,
_ => AudioResponseFormat::Json,
});
let audio_data = std::fs::read(&request.file_path).map_err(OpenCratesError::Io)?;
let filename = std::path::Path::new(&request.file_path)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.mp3");
let transcription_request = CreateTranscriptionRequestArgs::default()
.file(async_openai::types::AudioInput::from_vec_u8(
filename.to_string(),
audio_data,
))
.model(model)
.response_format(response_format)
.language(request.language.unwrap_or_default())
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.audio()
.transcribe(transcription_request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
Ok(AudioTranscriptionResponse {
text: response.text,
})
}
#[instrument(skip(self, request))]
pub async fn generate_speech(
&self,
request: SpeechRequest,
) -> Result<Vec<u8>, OpenCratesError> {
let model = match request.model.as_deref() {
Some("tts-1") => SpeechModel::Tts1,
Some("tts-1-hd") => SpeechModel::Tts1Hd,
_ => SpeechModel::Tts1Hd,
};
let voice = match request.voice.as_deref() {
Some("alloy") => Voice::Alloy,
Some("echo") => Voice::Echo,
Some("fable") => Voice::Fable,
Some("onyx") => Voice::Onyx,
Some("nova") => Voice::Nova,
Some("shimmer") => Voice::Shimmer,
_ => Voice::Alloy,
};
let speech_request = CreateSpeechRequestArgs::default()
.input(&request.input)
.model(model)
.voice(voice)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.audio()
.speech(speech_request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
Ok(response.bytes.to_vec())
}
#[instrument(skip(self, messages, schema))]
pub async fn structured_completion(
&self,
messages: Vec<ChatCompletionRequestMessage>,
schema: Value,
schema_name: String,
) -> Result<Value, OpenCratesError> {
let config = self.config.read().await;
let response_format = ResponseFormat::JsonSchema {
json_schema: ResponseFormatJsonSchema {
name: schema_name,
description: None,
schema: Some(schema),
strict: Some(true),
},
};
let request = CreateChatCompletionRequestArgs::default()
.model(&config.chat_model)
.messages(messages)
.max_tokens(config.max_tokens)
.temperature(config.temperature)
.response_format(response_format)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let choice = response
.choices
.first()
.ok_or_else(|| OpenCratesError::external("No choices in response"))?;
let content = choice
.message
.content
.as_ref()
.ok_or_else(|| OpenCratesError::external("No content in response"))?;
Ok(serde_json::from_str(content)?)
}
#[instrument(skip(self, messages, tools))]
pub async fn function_calling(
&self,
messages: Vec<ChatCompletionRequestMessage>,
tools: Vec<Value>,
) -> Result<GenerationResponse, OpenCratesError> {
let config = self.config.read().await;
let request = CreateChatCompletionRequestArgs::default()
.model(&config.chat_model)
.messages(messages)
.tools(
tools
.into_iter()
.map(|t| {
let function = FunctionObjectArgs::default()
.name(t.get("name").and_then(|v| v.as_str()).unwrap_or("unknown"))
.description(
t.get("description").and_then(|v| v.as_str()).unwrap_or(""),
)
.parameters(t.get("parameters").cloned().unwrap_or(json!({})))
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
ChatCompletionToolArgs::default()
.function(function)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))
})
.collect::<Result<Vec<_>, _>>()?,
)
.tool_choice(async_openai::types::ChatCompletionToolChoiceOption::Auto)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let choice = response.choices.first().unwrap();
let content = choice.message.content.clone().unwrap_or_default();
let usage = response
.usage
.unwrap_or(async_openai::types::CompletionUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
completion_tokens_details: None,
prompt_tokens_details: None,
});
let metrics = Usage {
prompt_tokens: usage.prompt_tokens as usize,
completion_tokens: usage.completion_tokens as usize,
total_tokens: usage.total_tokens as usize,
};
Ok(GenerationResponse {
preview: content,
metrics,
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{r:?}")),
})
}
pub async fn get_usage_stats(&self) -> UsageStats {
self.usage_stats.read().await.clone()
}
async fn update_stats(&self, tokens: u64, success: bool, response_time_ms: u64) {
let mut stats = self.usage_stats.write().await;
stats.total_requests += 1;
stats.total_tokens += tokens;
if success {
stats.successful_requests += 1;
} else {
stats.failed_requests += 1;
}
let total_response_time =
stats.average_response_time_ms * (stats.total_requests - 1) as f64;
stats.average_response_time_ms =
(total_response_time + response_time_ms as f64) / stats.total_requests as f64;
}
}
#[async_trait]
impl LLMProvider for EnhancedOpenAIProvider {
async fn generate(
&self,
request: &GenerationRequest,
) -> Result<GenerationResponse, OpenCratesError> {
let prompt_text = request.prompt.as_ref().unwrap_or(&request.spec.description);
let messages = vec![
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: async_openai::types::ChatCompletionRequestSystemMessageContent::Text(
"You are an expert Rust developer and crate creator.".to_string(),
),
name: None,
}),
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: async_openai::types::ChatCompletionRequestUserMessageContent::Text(
prompt_text.clone(),
),
name: None,
}),
];
self.chat_completion(
messages,
request.model.clone(),
request.max_tokens.map(|t| t.try_into().unwrap_or(4096)),
request.temperature,
false,
)
.await
}
async fn health_check(&self) -> Result<bool, OpenCratesError> {
match self.client.models().list().await {
Ok(_) => Ok(true),
Err(e) => {
error!("Enhanced OpenAI health check failed: {}", e);
Ok(false)
}
}
}
fn name(&self) -> &'static str {
"enhanced_openai"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}