use crate::utils::error::OpenCratesError;
use async_openai::{
config::OpenAIConfig as AsyncOpenAIConfig,
types::{
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestSystemMessageContent,
ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageArgs,
ChatCompletionRequestUserMessageContent, ChatCompletionResponseStream,
CreateChatCompletionRequestArgs,
},
Client,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fs;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument, warn};
use super::{
model_types::{ModelProviderInfo, Prompt},
GenerationRequest, GenerationResponse, LLMProvider, LegacyLLMProvider,
};
use crate::utils::config::OpenAIConfig;
use crate::utils::metrics::{ProviderMetrics, TokenUsage};
use crate::utils::openai_agents::Usage;
use crate::utils::project::ProjectAnalysis;
use crate::utils::templates::CrateSpec;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIProviderConfig {
pub api_key: String,
pub model: String,
pub max_tokens: u16,
pub temperature: f32,
pub timeout_seconds: u64,
}
impl Default for OpenAIProviderConfig {
fn default() -> Self {
Self {
api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
model: "gpt-4".to_string(),
max_tokens: 2048,
temperature: 0.7,
timeout_seconds: 60,
}
}
}
#[derive(Debug, Clone)]
pub struct OpenAIProvider {
client: Client<AsyncOpenAIConfig>,
model: String,
config: Arc<RwLock<OpenAIProviderConfig>>,
}
impl OpenAIProvider {
pub async fn new() -> Result<Self, OpenCratesError> {
let api_key = std::env::var("OPENAI_API_KEY").map_err(|_| {
OpenCratesError::Config("OPENAI_API_KEY not found in environment".to_string())
})?;
let client_config = AsyncOpenAIConfig::new().with_api_key(api_key.clone());
let client = Client::with_config(client_config);
let provider_config = OpenAIProviderConfig {
api_key,
model: "gpt-4".to_string(),
max_tokens: 2048,
temperature: 0.7,
timeout_seconds: 60,
};
Ok(Self {
client,
model: provider_config.model.clone(),
config: Arc::new(RwLock::new(provider_config)),
})
}
pub async fn new_with_config(config: &OpenAIConfig) -> Result<Self, OpenCratesError> {
let api_key = config
.api_key
.clone()
.unwrap_or_else(|| "test-key".to_string());
let client_config = AsyncOpenAIConfig::new().with_api_key(api_key.clone());
let client = Client::with_config(client_config);
let provider_config = OpenAIProviderConfig {
api_key,
model: config.model.clone(),
max_tokens: config.max_tokens as u16,
temperature: config.temperature,
timeout_seconds: 60,
};
Ok(Self {
client,
model: provider_config.model.clone(),
config: Arc::new(RwLock::new(provider_config)),
})
}
pub async fn stream_generate(
&self,
request: &GenerationRequest,
) -> Result<ChatCompletionResponseStream, OpenCratesError> {
let chat_request = CreateChatCompletionRequestArgs::default()
.model(request.model.as_deref().unwrap_or(&self.model))
.messages(vec![
ChatCompletionRequestSystemMessageArgs::default()
.content("You are an expert Rust developer and crate creator.")
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?
.into(),
ChatCompletionRequestUserMessageArgs::default()
.content(ChatCompletionRequestUserMessageContent::Text(
request
.prompt
.as_ref()
.unwrap_or(&request.spec.description)
.clone(),
))
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?
.into(),
])
.stream(true)
.max_tokens(request.max_tokens.map_or(4096u32, |v| v))
.temperature(request.temperature.unwrap_or(0.7))
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
self.client
.chat()
.create_stream(chat_request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))
}
pub async fn suggest_optimizations(
&self,
analysis: &ProjectAnalysis,
) -> Result<Vec<String>, OpenCratesError> {
let prompt = format!(
"Analyze this Rust project and suggest optimizations:\n\
Name: {}\n\
Dependencies: {:?}\n\
Metrics: {:?}\n\
\
Please provide 5-10 specific optimization suggestions for performance, \
security, and maintainability.",
analysis.name, analysis.dependencies, analysis.metrics
);
let request = GenerationRequest {
spec: CrateSpec::default(),
prompt: Some(prompt),
max_tokens: Some(1024),
model: Some(self.model.clone()),
temperature: Some(0.3),
context: None,
};
let response = <Self as LLMProvider>::generate(self, &request).await?;
let suggestions: Vec<String> = response
.preview
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| line.trim().to_string())
.collect();
Ok(suggestions)
}
pub async fn generate_crate_context(
&self,
name: &str,
description: &str,
features: &[String],
) -> Result<crate::stages::CrateContext, OpenCratesError> {
let mut context = crate::stages::CrateContext::new(description, None);
context.crate_name = name.to_string();
for feature in features {
context.add_feature(feature.clone());
}
context.set_metadata("author".to_string(), "Generated by OpenCrates".to_string());
context.set_metadata("license".to_string(), "MIT OR Apache-2.0".to_string());
Ok(context)
}
pub async fn chat(&self, model: &str, prompt: &str) -> Result<String, OpenCratesError> {
let messages = vec![
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
"You are an expert Rust developer assistant. Provide helpful, \
accurate, and concise responses about Rust programming, crate development, \
and best practices."
.to_string(),
),
name: None,
}),
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(prompt.to_string()),
name: None,
}),
];
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(messages)
.max_tokens(2048u32)
.temperature(0.7)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.unwrap_or_else(|| "No response generated".to_string());
Ok(content)
}
fn create_context(
&self,
spec: &CrateSpec,
_metadata: crate::core::CrateMetadata,
) -> crate::stages::CrateContext {
let mut context = crate::stages::CrateContext::new(&spec.description, None);
context.crate_name = spec.name.clone();
context.version = spec.version.clone();
for (dep_name, dep_version) in &spec.dependencies {
context.add_dependency(format!("{dep_name} = \"{dep_version}\""));
}
for feature in &spec.features {
context.add_feature(feature.clone());
}
context.set_metadata(
"author".to_string(),
spec.authors.first().cloned().unwrap_or_default(),
);
context.set_metadata(
"license".to_string(),
spec.license.clone().unwrap_or_default(),
);
context.set_metadata(
"homepage".to_string(),
spec.homepage.clone().unwrap_or_default(),
);
context.set_metadata(
"repository".to_string(),
spec.repository.clone().unwrap_or_default(),
);
context
}
pub async fn generate_crate(
&self,
spec: &CrateSpec,
) -> Result<crate::stages::CrateContext, OpenCratesError> {
let config = self.config.read().await;
let request = GenerationRequest {
spec: spec.clone(),
prompt: Some(format!("Generate a Rust crate: {}", spec.description)),
max_tokens: Some(u32::from(config.max_tokens)),
model: Some(config.model.clone()),
temperature: Some(config.temperature),
context: None,
};
let _ = <Self as LLMProvider>::generate(self, &request).await?;
let metadata = crate::core::CrateMetadata {
name: spec.name.clone(),
description: spec.description.clone(),
version: spec.version.clone(),
authors: spec.authors.clone(),
license: spec.license.clone(),
crate_type: spec.crate_type,
dependencies: spec.dependencies.clone(),
dev_dependencies: spec.dev_dependencies.clone(),
features: spec.features.clone(),
keywords: spec.keywords.clone(),
categories: spec.categories.clone(),
repository: spec.repository.clone(),
homepage: spec.homepage.clone(),
documentation: spec.documentation.clone(),
readme: spec.readme.clone(),
rust_version: spec.rust_version.clone(),
edition: spec.edition.clone(),
publish: spec.publish,
author: spec.author.clone(),
template: None,
};
let context = self.create_context(spec, metadata);
Ok(context)
}
pub async fn verify_connection(&self) -> Result<bool, OpenCratesError> {
let config = self.config.read().await;
let test_request = GenerationRequest {
spec: crate::utils::templates::CrateSpec::default(),
prompt: Some("Test connection".to_string()),
max_tokens: Some(10),
model: Some(config.model.clone()),
temperature: Some(0.5),
context: None,
};
match <Self as LLMProvider>::generate(self, &test_request).await {
Ok(_) => Ok(true),
Err(e) => {
warn!("OpenAI connection failed: {e}");
Ok(false)
}
}
}
pub fn apply_patch(
&self,
file: &std::path::Path,
patch_str: &str,
) -> Result<(), OpenCratesError> {
info!("Applying patch to {:?}", file);
let original_content = fs::read_to_string(file).map_err(OpenCratesError::Io)?;
let patch = diffy::Patch::from_str(patch_str)
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let patched_content = diffy::apply(&original_content, &patch)
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
fs::write(file, patched_content).map_err(OpenCratesError::Io)?;
info!("Successfully applied patch to {:?}", file);
Ok(())
}
}
#[async_trait]
impl LLMProvider for OpenAIProvider {
#[instrument(skip(self, request))]
async fn generate(
&self,
request: &GenerationRequest,
) -> Result<GenerationResponse, OpenCratesError> {
let _start_time = std::time::Instant::now();
let config = self.config.read().await;
if config.api_key == "test-key" {
let metrics = Usage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
};
return Ok(GenerationResponse {
preview: format!("Mock response for: {}", request.spec.description),
metrics,
finish_reason: Some("stop".to_string()),
});
}
let prompt_text = request.prompt.as_ref().unwrap_or(&request.spec.description);
let messages = vec![
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text("You are an expert Rust developer and crate creator. You generate high-quality, idiomatic Rust code with comprehensive documentation, tests, and following best practices. Always ensure memory safety, proper error handling, and performance optimization.".to_string()),
name: None,
}),
ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage {
content: ChatCompletionRequestUserMessageContent::Text(prompt_text.clone()),
name: None,
}),
];
let max_tokens = request
.max_tokens
.map_or(config.max_tokens, |v| v.try_into().unwrap());
let request_builder = CreateChatCompletionRequestArgs::default()
.model(self.model.as_str())
.messages(messages)
.max_tokens(max_tokens)
.temperature(config.temperature)
.build()
.map_err(|e| OpenCratesError::internal(e.to_string()))?;
let response = self
.client
.chat()
.create(request_builder)
.await
.map_err(|e| OpenCratesError::external(e.to_string()))?;
let choice = response
.choices
.first()
.ok_or_else(|| OpenCratesError::external("No response from OpenAI"))?;
let preview = choice.message.content.clone().unwrap_or_default();
debug!(
"Generated {} tokens from OpenAI",
response.usage.as_ref().map_or(0, |u| u.total_tokens)
);
let usage = response.usage.as_ref().map(|u| TokenUsage {
prompt_tokens: u.prompt_tokens as usize,
completion_tokens: u.completion_tokens as usize,
total_tokens: u.total_tokens as usize,
});
let metrics = Usage {
prompt_tokens: usage.as_ref().map_or(0, |u| u.prompt_tokens),
completion_tokens: usage.as_ref().map_or(0, |u| u.completion_tokens),
total_tokens: usage.as_ref().map_or(0, |u| u.total_tokens),
};
Ok(GenerationResponse {
preview,
metrics,
finish_reason: choice.finish_reason.as_ref().map(|r| format!("{r:?}")),
})
}
async fn health_check(&self) -> Result<bool, OpenCratesError> {
let test_request = GenerationRequest {
spec: CrateSpec::default(),
prompt: Some("health check".to_string()),
max_tokens: Some(5),
model: Some(self.model.clone()),
temperature: Some(0.0),
context: None,
};
match <Self as LLMProvider>::generate(self, &test_request).await {
Ok(_) => Ok(true),
Err(e) => {
error!("Health check failed: {}", e);
Ok(false)
}
}
}
fn name(&self) -> &'static str {
"openai"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[async_trait]
impl LegacyLLMProvider for OpenAIProvider {
async fn generate(
&self,
request: GenerationRequest,
) -> Result<GenerationResponse, OpenCratesError> {
<Self as LLMProvider>::generate(self, &request).await
}
async fn set_model(&self, model: &str) -> Result<(), OpenCratesError> {
let mut config = self.config.write().await;
config.model = model.to_string();
info!("Model changed to: {}", model);
Ok(())
}
async fn validate_api_key(&self) -> Result<bool, OpenCratesError> {
let test_request = GenerationRequest {
spec: crate::utils::templates::CrateSpec::default(),
prompt: Some("Say 'test' and nothing else.".to_string()),
max_tokens: Some(5),
temperature: Some(0.0),
model: Some(self.model.clone()),
context: None,
};
match <Self as LegacyLLMProvider>::generate(self, test_request).await {
Ok(_) => Ok(true),
Err(e) => {
error!("API key validation failed: {}", e);
Ok(false)
}
}
}
}