use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tracing::{debug, error, info, instrument, warn};
use crate::providers::openai::OpenAIProvider;
use crate::providers::{GenerationRequest, LegacyLLMProvider};
use crate::utils::cache::CacheManager;
pub use crate::utils::config::OpenAIConfig;
use crate::utils::error::OpenCratesError;
use crate::utils::metrics::OpenCratesMetrics;
use std::future::Future;
use std::pin::Pin;
#[derive(Debug, Clone)]
pub struct OpenAIClientConfig {
pub api_key: String,
pub base_url: String,
pub model: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
pub timeout: Duration,
pub max_retries: u32,
}
impl Default for OpenAIClientConfig {
fn default() -> Self {
Self {
api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
base_url: "https://api.openai.com/v1".to_string(),
model: "gpt-4o".to_string(),
max_tokens: Some(2048),
temperature: Some(0.7),
timeout: Duration::from_secs(30),
max_retries: 3,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: MessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
impl ChatMessage {
#[must_use]
pub fn system(content: String) -> Self {
Self {
role: MessageRole::System,
content,
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[must_use]
pub fn user(content: String) -> Self {
Self {
role: MessageRole::User,
content,
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[must_use]
pub fn assistant(content: String) -> Self {
Self {
role: MessageRole::Assistant,
content,
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[must_use]
pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
self.tool_calls = Some(tool_calls);
self
}
}
#[derive(Serialize)]
pub struct FunctionTool {
pub name: String,
pub description: String,
pub parameters: Value,
#[serde(skip)]
pub function:
Arc<dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<Value>> + Send>> + Send + Sync>,
}
impl Clone for FunctionTool {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
description: self.description.clone(),
parameters: self.parameters.clone(),
function: self.function.clone(),
}
}
}
impl std::fmt::Debug for FunctionTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FunctionTool")
.field("name", &self.name)
.field("description", &self.description)
.field("parameters", &self.parameters)
.field("function", &"<function>")
.finish()
}
}
impl FunctionTool {
pub fn new<F, Fut>(name: &str, description: &str, parameters: Value, func: F) -> Self
where
F: Fn(Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Value>> + Send + 'static,
{
Self {
name: name.to_string(),
description: description.to_string(),
parameters,
function: Arc::new(move |args| Box::pin(func(args))),
}
}
pub fn call(&self, args: Value) -> Result<Value> {
let future = (self.function)(args);
tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(future))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Serialize)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<FunctionTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Usage {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: ChunkDelta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ChunkDelta {
pub role: Option<MessageRole>,
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[async_trait]
pub trait Function: Send + Sync {
fn name(&self) -> String;
fn description(&self) -> String;
fn parameters(&self) -> Value;
async fn call(&self, args: Value) -> Result<Value>;
}
pub struct CrateSearchFunction {
}
#[async_trait]
impl Function for CrateSearchFunction {
fn name(&self) -> String {
"search_crates".to_string()
}
fn description(&self) -> String {
"Search for Rust crates by name, description, or keywords".to_string()
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for crates"
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 10
}
},
"required": ["query"]
})
}
async fn call(&self, args: Value) -> Result<Value> {
let _query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| OpenCratesError::validation("Missing query parameter".to_string()))?;
let limit = args
.get("limit")
.and_then(serde_json::Value::as_u64)
.unwrap_or(10) as usize;
let results = vec![
serde_json::json!({
"name": "serde",
"version": "1.0.193",
"description": "A generic serialization/deserialization framework",
"downloads": 500_000_000,
"repository": "https://github.com/serde-rs/serde"
}),
serde_json::json!({
"name": "tokio",
"version": "1.35.1",
"description": "An event-driven, non-blocking I/O platform",
"downloads": 300_000_000,
"repository": "https://github.com/tokio-rs/tokio"
}),
];
let total_count = results.len();
let limited_results = results.into_iter().take(limit).collect::<Vec<_>>();
Ok(serde_json::json!({
"results": limited_results,
"total": total_count
}))
}
}
pub struct CrateRecommendationFunction {
}
#[async_trait]
impl Function for CrateRecommendationFunction {
fn name(&self) -> String {
"recommend_crates".to_string()
}
fn description(&self) -> String {
"Get personalized crate recommendations based on project requirements".to_string()
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {
"project_type": {
"type": "string",
"enum": ["web", "cli", "library", "game", "embedded"],
"description": "Type of Rust project"
},
"categories": {
"type": "array",
"items": {"type": "string"},
"description": "Relevant categories or use cases"
},
"existing_crates": {
"type": "array",
"items": {"type": "string"},
"description": "Crates already in use"
}
},
"required": ["project_type"]
})
}
async fn call(&self, args: Value) -> Result<Value> {
let project_type = args
.get("project_type")
.and_then(|v| v.as_str())
.ok_or_else(|| {
OpenCratesError::validation("Missing project_type parameter".to_string())
})?;
let recommendations = match project_type {
"web" => vec![
serde_json::json!({
"name": "axum",
"reason": "Modern, ergonomic web framework with excellent performance",
"confidence": 0.95
}),
serde_json::json!({
"name": "sqlx",
"reason": "Async SQL toolkit with compile-time checked queries",
"confidence": 0.90
}),
],
"cli" => vec![
serde_json::json!({
"name": "clap",
"reason": "Feature-rich command line argument parser",
"confidence": 0.98
}),
serde_json::json!({
"name": "colored",
"reason": "Simple library for colored terminal output",
"confidence": 0.85
}),
],
_ => vec![serde_json::json!({
"name": "anyhow",
"reason": "Flexible error handling library",
"confidence": 0.92
})],
};
Ok(serde_json::json!({
"recommendations": recommendations,
"project_type": project_type
}))
}
}
pub struct FunctionRegistry {
functions: Vec<FunctionTool>,
}
impl Default for FunctionRegistry {
fn default() -> Self {
Self::new()
}
}
impl FunctionRegistry {
#[must_use]
pub fn new() -> Self {
Self {
functions: Vec::new(),
}
}
pub fn register_function(&mut self, f: FunctionTool) {
self.functions.push(f);
}
#[must_use]
pub fn list_functions(&self) -> &[FunctionTool] {
&self.functions
}
}
#[must_use]
pub fn create_crate_search_function() -> FunctionTool {
FunctionTool::new(
"crate_search",
"Search crates by keyword",
serde_json::json!({}),
|_| Box::pin(async { Ok(serde_json::json!({})) }),
)
}
#[derive(Clone)]
pub struct OpenAIClient {
config: OpenAIClientConfig,
http_client: reqwest::Client,
metrics: Arc<OpenCratesMetrics>,
cache: Arc<CacheManager>,
function_registry: Arc<RwLock<FunctionRegistry>>,
}
impl OpenAIClient {
#[must_use]
pub fn new(
config: OpenAIClientConfig,
metrics: Arc<OpenCratesMetrics>,
cache: Arc<CacheManager>,
) -> Self {
let http_client = reqwest::Client::builder()
.timeout(config.timeout)
.build()
.expect("Failed to create HTTP client");
Self {
config,
http_client,
metrics,
cache,
function_registry: Arc::new(RwLock::new(FunctionRegistry::new())),
}
}
#[must_use]
pub fn cache(&self) -> &Arc<CacheManager> {
&self.cache
}
pub async fn register_function(&self, function: FunctionTool) {
let mut registry = self.function_registry.write().await;
registry.register_function(function);
}
#[instrument(skip(self, messages))]
pub async fn chat_completion(
&self,
messages: Vec<ChatMessage>,
use_tools: bool,
) -> Result<ChatCompletionResponse> {
let start = Instant::now();
let mut request = ChatCompletionRequest {
model: self.config.model.clone(),
messages,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
tools: None,
tool_choice: None,
stream: Some(false),
};
if use_tools {
let registry = self.function_registry.read().await;
request.tools = Some(registry.list_functions().to_vec());
request.tool_choice = Some("auto".to_string());
}
let response = self
.http_client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send chat completion request")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(
OpenCratesError::internal(format!("OpenAI API error: {error_text}")).into(),
);
}
let completion: ChatCompletionResponse = response
.json()
.await
.context("Failed to parse chat completion response")?;
let duration = start.elapsed();
self.metrics.record_api_request().await.unwrap_or(());
info!("Chat completion completed in {}ms", duration.as_millis());
Ok(completion)
}
#[instrument(skip(self, tool_calls))]
pub async fn handle_tool_calls(&self, tool_calls: Vec<ToolCall>) -> Result<Vec<ChatMessage>> {
let mut tool_messages = Vec::new();
let registry = self.function_registry.read().await;
for tool_call in tool_calls {
if tool_call.tool_type != "function" {
continue;
}
let function_name = &tool_call.function.name;
let function_args: serde_json::Value =
serde_json::from_str(&tool_call.function.arguments)
.context("Failed to parse function arguments")?;
debug!(
"Calling function: {} with args: {}",
function_name, function_args
);
let result = if let Some(function) = registry
.list_functions()
.iter()
.find(|f| f.name == *function_name)
{
match function.call(function_args) {
Ok(result) => result,
Err(e) => {
error!("Function {} failed: {}", function_name, e);
serde_json::json!({
"error": e.to_string()
})
}
}
} else {
warn!("Unknown function called: {}", function_name);
serde_json::json!({
"error": format!("Unknown function: {}", function_name)
})
};
tool_messages.push(ChatMessage {
role: MessageRole::Tool,
content: serde_json::to_string(&result)?,
name: Some(function_name.clone()),
tool_calls: None,
tool_call_id: Some(tool_call.id),
});
}
Ok(tool_messages)
}
#[instrument(skip(self, messages))]
pub async fn complete_conversation(&self, messages: Vec<ChatMessage>) -> Result<String> {
let mut conversation = messages;
let max_iterations = 5;
for iteration in 0..max_iterations {
debug!("Conversation iteration {}", iteration + 1);
let response = self.chat_completion(conversation.clone(), true).await?;
if let Some(choice) = response.choices.first() {
let assistant_message = choice.message.clone();
conversation.push(assistant_message.clone());
if let Some(tool_calls) = &assistant_message.tool_calls {
let tool_messages = self.handle_tool_calls(tool_calls.clone()).await?;
conversation.extend(tool_messages);
continue;
}
return Ok(assistant_message.content);
}
}
Err(OpenCratesError::internal("Max conversation iterations reached".to_string()).into())
}
#[instrument(skip(self, messages))]
pub async fn chat_completion_stream(
&self,
messages: Vec<ChatMessage>,
use_tools: bool,
) -> Result<impl futures::Stream<Item = Result<ChatCompletionChunk>>> {
let mut request = ChatCompletionRequest {
model: self.config.model.clone(),
messages,
max_tokens: self.config.max_tokens,
temperature: self.config.temperature,
tools: None,
tool_choice: None,
stream: Some(true),
};
if use_tools {
let registry = self.function_registry.read().await;
request.tools = Some(registry.list_functions().to_vec());
request.tool_choice = Some("auto".to_string());
}
let response = self
.http_client
.post(format!("{}/chat/completions", self.config.base_url))
.header("Authorization", format!("Bearer {}", self.config.api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.context("Failed to send streaming chat completion request")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(
OpenCratesError::internal(format!("OpenAI API error: {error_text}")).into(),
);
}
let response_text = response.text().await?;
let mut chunks = Vec::new();
for line in response_text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
break;
}
match serde_json::from_str::<ChatCompletionChunk>(data) {
Ok(parsed_chunk) => {
chunks.push(Ok(parsed_chunk));
}
Err(e) => {
chunks.push(Err(OpenCratesError::internal(format!(
"Failed to parse chunk: {e}"
))
.into()));
}
}
}
}
Ok(futures::stream::iter(chunks))
}
}
pub struct CrateAssistant {
openai_client: OpenAIClient,
cache: Arc<CacheManager>,
system_prompt: String,
}
#[derive(Clone)]
pub struct AgentManager {
openai_client: Arc<OpenAIClient>,
assistants: HashMap<String, Arc<CrateAssistant>>,
}
impl AgentManager {
pub async fn new(_openai_provider: OpenAIProvider) -> Result<Self> {
let config = OpenAIClientConfig::default();
let metrics = Arc::new(OpenCratesMetrics::new().await.unwrap());
let cache = Arc::new(CacheManager::new());
let openai_client = Arc::new(OpenAIClient::new(config, metrics, cache));
Ok(Self {
openai_client,
assistants: HashMap::new(),
})
}
pub async fn get_assistant(&self, name: &str) -> Option<Arc<CrateAssistant>> {
self.assistants.get(name).cloned()
}
pub async fn create_assistant(&mut self, name: String) -> Result<Arc<CrateAssistant>> {
let assistant = Arc::new(CrateAssistant::new(self.openai_client.clone()).await?);
self.assistants.insert(name.clone(), assistant.clone());
Ok(assistant)
}
pub async fn list_agents(&self) -> Vec<String> {
self.assistants.keys().cloned().collect()
}
pub async fn execute_agent(&self, agent_name: &str, task_description: &str) -> Result<String> {
if let Some(assistant) = self.get_assistant(agent_name).await {
assistant.assist(task_description.to_string()).await
} else {
Err(OpenCratesError::not_found("Agent", agent_name).into())
}
}
}
impl CrateAssistant {
pub async fn new(client: Arc<OpenAIClient>) -> Result<Self> {
let system_prompt = r"
You are an expert Rust developer and crate curator assistant for OpenCrates, a comprehensive Rust crate registry.
Your role is to help developers:
1. Find the best crates for their specific needs
2. Compare different crates and their trade-offs
3. Provide implementation guidance and best practices
4. Suggest compatible crates that work well together
5. Help with dependency management and version selection
You have access to functions that can search the crate registry and provide personalized recommendations.
Always provide accurate, helpful, and up-to-date information about Rust crates and ecosystem.
When recommending crates, consider:
- Maintenance status and community support
- Performance characteristics
- API ergonomics and ease of use
- Documentation quality
- License compatibility
- Security track record
".trim().to_string();
let cache = Arc::new(CacheManager::new());
Ok(Self {
openai_client: Arc::try_unwrap(client).unwrap_or_else(|arc| (*arc).clone()),
cache,
system_prompt,
})
}
#[must_use]
pub fn cache(&self) -> &Arc<CacheManager> {
&self.cache
}
#[instrument(skip(self))]
pub async fn assist(&self, user_query: String) -> Result<String> {
let messages = vec![
ChatMessage::system(self.system_prompt.clone()),
ChatMessage::user(user_query),
];
self.openai_client.complete_conversation(messages).await
}
#[instrument(skip(self))]
pub async fn recommend_crates(
&self,
project_type: String,
requirements: String,
) -> Result<String> {
let user_query = format!(
"I'm working on a {project_type} project with the following requirements: {requirements}. \
Can you recommend the best crates and explain why they're good choices?"
);
self.assist(user_query).await
}
}