use crate::codex_vendored::apply_patch;
use crate::providers::{
model_types::{ModelProviderInfo, Prompt},
GenerationRequest, GenerationResponse, LLMProvider,
};
use crate::utils::config::CodexConfig;
use crate::utils::error::OpenCratesError;
use crate::utils::metrics::TokenUsage;
use crate::utils::openai_agents::Usage;
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::io::{self, Write};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::time::timeout;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OpenAIResponse {
id: Option<String>,
object: Option<String>,
created: Option<u64>,
model: String,
choices: Vec<Choice>,
usage: Option<OpenAIUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Choice {
index: u32,
message: Option<Message>,
delta: Option<Delta>,
finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Delta {
role: Option<String>,
content: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OpenAIUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct OpenAIRequest {
model: String,
messages: Vec<Message>,
max_tokens: Option<u32>,
temperature: Option<f32>,
stream: bool,
}
#[derive(Clone)]
pub struct CodexProvider {
config: CodexConfig,
client: reqwest::Client,
api_base: String,
}
impl CodexProvider {
pub async fn new(config: CodexConfig) -> Result<Self, OpenCratesError> {
info!("Initializing CodexProvider with model: {}", config.model);
let mut headers = HeaderMap::new();
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Some(ref api_key) = config.api_key {
let auth_value = HeaderValue::from_str(&format!("Bearer {api_key}"))
.map_err(|e| OpenCratesError::Config(format!("Invalid API key format: {e}")))?;
headers.insert(AUTHORIZATION, auth_value);
} else {
warn!("No API key provided - using environment variable fallback");
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
let auth_value =
HeaderValue::from_str(&format!("Bearer {api_key}")).map_err(|e| {
OpenCratesError::Config(format!(
"Invalid API key format from environment: {e}"
))
})?;
headers.insert(AUTHORIZATION, auth_value);
}
}
let client = reqwest::Client::builder()
.default_headers(headers)
.timeout(Duration::from_secs(120))
.build()
.map_err(|e| OpenCratesError::internal(format!("Failed to create HTTP client: {e}")))?;
let api_base = config.api_base.clone();
let provider = Self {
config,
client,
api_base,
};
if let Err(e) = provider.verify_connection().await {
warn!("Initial connection validation failed: {}", e);
}
Ok(provider)
}
pub async fn new_default() -> Result<Self, OpenCratesError> {
let mut config = CodexConfig::default();
if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
config.api_key = Some(api_key);
}
if let Ok(api_base) = std::env::var("OPENAI_API_BASE") {
config.api_base = api_base;
}
if let Ok(model) = std::env::var("OPENAI_MODEL") {
config.model = model;
}
Self::new(config).await
}
async fn get_chat_completion(
&self,
prompt_text: &str,
system_message: Option<&str>,
) -> Result<(String, TokenUsage), OpenCratesError> {
let start_time = Instant::now();
debug!(
"Starting chat completion for prompt: {:.100}...",
prompt_text
);
let mut messages = Vec::new();
if let Some(system_msg) = system_message {
messages.push(Message {
role: "system".to_string(),
content: system_msg.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: prompt_text.to_string(),
});
let request = OpenAIRequest {
model: self.config.model.clone(),
messages,
max_tokens: Some(self.config.max_tokens),
temperature: Some(self.config.temperature),
stream: false,
};
let url = format!("{}/chat/completions", self.api_base);
let response = timeout(
Duration::from_secs(120),
self.client.post(&url).json(&request).send(),
)
.await
.map_err(|e| OpenCratesError::Network(format!("Request timed out: {e}")))?
.map_err(|e| OpenCratesError::Network(format!("Failed to send request: {e}")))?;
let status = response.status();
let response_text = response
.text()
.await
.map_err(|e| OpenCratesError::Network(format!("Failed to read response body: {e}")))?;
if !status.is_success() {
error!("OpenAI API error: {} - {}", status, response_text);
return Err(OpenCratesError::external(format!(
"OpenAI API error: {status} - {response_text}"
)));
}
let openai_response: OpenAIResponse =
serde_json::from_str(&response_text).map_err(OpenCratesError::Serialization)?;
let content = openai_response
.choices
.first()
.and_then(|choice| choice.message.as_ref())
.map(|msg| msg.content.clone())
.unwrap_or_default();
let usage = openai_response.usage.map_or(
TokenUsage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
|u| TokenUsage {
prompt_tokens: u.prompt_tokens as usize,
completion_tokens: u.completion_tokens as usize,
total_tokens: u.total_tokens as usize,
},
);
let duration = start_time.elapsed();
info!(
"Chat completion completed in {:.2}s, tokens: {}",
duration.as_secs_f64(),
usage.total_tokens
);
Ok((content, usage))
}
pub fn apply_patch(
&self,
file_path: &std::path::Path,
patch_content: &str,
) -> Result<(), OpenCratesError> {
info!("Applying patch to file: {}", file_path.display());
debug!("Patch content: {}", patch_content);
let mut stdout = io::stdout();
let mut stderr = io::stderr();
apply_patch(patch_content, &mut stdout, &mut stderr).map_err(|e| {
OpenCratesError::internal(format!(
"Failed to apply patch to {}: {}",
file_path.display(),
e
))
})
}
pub async fn generate_code(
&self,
instruction: &str,
context: Option<&str>,
) -> Result<String, OpenCratesError> {
let system_message = "You are an expert Rust programmer. Generate clean, efficient, and well-documented code. Include appropriate error handling and follow Rust best practices.";
let prompt = if let Some(ctx) = context {
format!(
"Context:\n{ctx}\n\nInstruction:\n{instruction}\n\nGenerate the requested Rust code:"
)
} else {
format!("Instruction:\n{instruction}\n\nGenerate the requested Rust code:")
};
let (content, _usage) = self
.get_chat_completion(&prompt, Some(system_message))
.await?;
Ok(content)
}
pub async fn verify_connection(&self) -> Result<bool, OpenCratesError> {
debug!("Verifying OpenAI API connection");
match self
.get_chat_completion("Test", Some("Respond with 'OK'"))
.await
{
Ok((response, _)) => {
debug!("Connection verification successful: {}", response);
Ok(true)
}
Err(e) => {
warn!("Connection verification failed: {}", e);
Ok(false)
}
}
}
#[must_use]
pub fn get_model_info(&self) -> ModelProviderInfo {
ModelProviderInfo {
base_url: self.api_base.clone(),
api_key: self.config.api_key.clone(),
name: self.config.model.clone(),
provider: "OpenAI".to_string(),
max_tokens: self.config.max_tokens,
supports_streaming: true,
supports_functions: true,
context_window: match self.config.model.as_str() {
"gpt-4" => 8192,
"gpt-4-32k" => 32768,
"gpt-4-turbo" => 128_000,
"gpt-3.5-turbo" => 4096,
"gpt-3.5-turbo-16k" => 16384,
_ => 4096,
},
}
}
pub async fn analyze_code(
&self,
code: &str,
language: Option<&str>,
) -> Result<String, OpenCratesError> {
let lang = language.unwrap_or("rust");
let system_message = format!(
"You are an expert {lang} code reviewer. Analyze the provided code for:\n\
- Code quality and best practices\n\
- Potential bugs or issues\n\
- Performance improvements\n\
- Security considerations\n\
- Documentation suggestions\n\
Provide specific, actionable feedback."
);
let prompt = format!("Analyze this {lang} code:\n\n```{lang}\n{code}\n```");
let (analysis, _usage) = self
.get_chat_completion(&prompt, Some(&system_message))
.await?;
Ok(analysis)
}
pub async fn generate_documentation(
&self,
code: &str,
style: Option<&str>,
) -> Result<String, OpenCratesError> {
let doc_style = style.unwrap_or("rustdoc");
let system_message = format!(
"You are an expert technical writer specializing in {doc_style} documentation. \
Generate comprehensive, clear, and accurate documentation that includes:\n\
- Purpose and functionality overview\n\
- Parameter descriptions\n\
- Return value documentation\n\
- Usage examples\n\
- Error conditions\n\
- Performance considerations"
);
let prompt =
format!("Generate {doc_style} documentation for this code:\n\n```rust\n{code}\n```");
let (docs, _usage) = self
.get_chat_completion(&prompt, Some(&system_message))
.await?;
Ok(docs)
}
#[must_use]
pub fn name(&self) -> &'static str {
"codex"
}
pub async fn health_check(&self) -> Result<bool, OpenCratesError> {
self.verify_connection().await
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[async_trait]
impl LLMProvider for CodexProvider {
async fn generate(
&self,
request: &GenerationRequest,
) -> Result<GenerationResponse, OpenCratesError> {
let system_message = "You are an expert programmer. Generate high-quality code based on the user's requirements.";
let prompt = if let Some(context) = &request.context {
format!(
"Context:\n{}\n\nRequest:\n{}",
context,
request.prompt.as_ref().unwrap_or(&request.spec.description)
)
} else {
request
.prompt
.as_ref()
.unwrap_or(&request.spec.description)
.clone()
};
let (content, usage) = self
.get_chat_completion(&prompt, Some(system_message))
.await?;
let metrics = Usage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
};
Ok(GenerationResponse {
preview: content,
metrics,
finish_reason: Some("stop".to_string()),
})
}
async fn health_check(&self) -> Result<bool, OpenCratesError> {
self.verify_connection().await
}
fn name(&self) -> &'static str {
"codex"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}