#![allow(dead_code)]
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tracing::{info, debug};
pub mod openai;
pub mod claude;
pub mod gemini;
pub mod litellm;
pub mod cost;
pub mod context;
pub mod prompts;
pub use openai::OpenAIClient;
pub use claude::ClaudeClient;
pub use gemini::GeminiClient;
pub use litellm::LiteLLMClient;
pub use cost::CostTracker;
pub use context::AIContext;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum AIProvider {
#[serde(rename = "openai")]
OpenAI,
#[serde(rename = "claude")]
Claude,
#[serde(rename = "gemini")]
Gemini,
#[serde(rename = "litellm")]
LiteLLM,
#[serde(rename = "openrouter")]
OpenRouter,
#[serde(rename = "azure")]
AzureOpenAI,
#[serde(rename = "cohere")]
Cohere,
#[serde(rename = "ollama")]
Ollama,
#[serde(rename = "groq")]
Groq,
}
impl Default for AIProvider {
fn default() -> Self {
AIProvider::OpenAI
}
}
impl std::fmt::Display for AIProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
AIProvider::OpenAI => write!(f, "OpenAI"),
AIProvider::Claude => write!(f, "Claude"),
AIProvider::Gemini => write!(f, "Gemini"),
AIProvider::LiteLLM => write!(f, "LiteLLM"),
AIProvider::OpenRouter => write!(f, "OpenRouter"),
AIProvider::AzureOpenAI => write!(f, "Azure OpenAI"),
AIProvider::Cohere => write!(f, "Cohere"),
AIProvider::Ollama => write!(f, "Ollama"),
AIProvider::Groq => write!(f, "Groq"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub name: String,
pub context_window: usize,
pub supports_functions: bool,
pub supports_vision: bool,
pub cost_per_1k_input: f64,
pub cost_per_1k_output: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AIConfig {
pub provider: AIProvider,
pub api_key: String,
pub api_base: Option<String>,
pub model: String,
pub max_tokens: u32,
pub temperature: f32,
pub system_prompt: Option<String>,
pub timeout_seconds: u64,
pub retry_attempts: u32,
}
impl Default for AIConfig {
fn default() -> Self {
Self {
provider: AIProvider::OpenAI,
api_key: String::new(),
api_base: None,
model: "gpt-4".to_string(),
max_tokens: 2000,
temperature: 0.7,
system_prompt: None,
timeout_seconds: 60,
retry_attempts: 3,
}
}
}
impl AIConfig {
pub fn from_env() -> Self {
let provider = match std::env::var("ISELF_AI_PROVIDER").as_deref() {
Ok("openai") => AIProvider::OpenAI,
Ok("claude") => AIProvider::Claude,
Ok("gemini") => AIProvider::Gemini,
Ok("litellm") => AIProvider::LiteLLM,
Ok("openrouter") => AIProvider::OpenRouter,
Ok("azure") => AIProvider::AzureOpenAI,
Ok("cohere") => AIProvider::Cohere,
Ok("ollama") => AIProvider::Ollama,
Ok("groq") => AIProvider::Groq,
_ => AIProvider::OpenAI,
};
let api_key = match provider {
AIProvider::OpenAI => std::env::var("OPENAI_API_KEY"),
AIProvider::Claude => std::env::var("ANTHROPIC_API_KEY"),
AIProvider::Gemini => std::env::var("GEMINI_API_KEY"),
AIProvider::LiteLLM => std::env::var("LITELLM_API_KEY"),
AIProvider::OpenRouter => std::env::var("OPENROUTER_API_KEY"),
AIProvider::AzureOpenAI => std::env::var("AZURE_OPENAI_API_KEY"),
AIProvider::Cohere => std::env::var("COHERE_API_KEY"),
AIProvider::Ollama => Ok(String::new()),
AIProvider::Groq => std::env::var("GROQ_API_KEY"),
}.unwrap_or_default();
let api_base = std::env::var("ISELF_AI_BASE_URL").ok();
let model = std::env::var("ISELF_AI_MODEL").unwrap_or_else(|_| match provider {
AIProvider::OpenAI => "gpt-4".to_string(),
AIProvider::Claude => "claude-3-sonnet-20240229".to_string(),
AIProvider::Gemini => "gemini-pro".to_string(),
AIProvider::LiteLLM => "gpt-4".to_string(),
AIProvider::OpenRouter => "openai/gpt-4".to_string(),
AIProvider::AzureOpenAI => "gpt-4".to_string(),
AIProvider::Cohere => "command-r".to_string(),
AIProvider::Ollama => "llama2".to_string(),
AIProvider::Groq => "llama3-70b-8192".to_string(),
});
Self {
provider,
api_key,
api_base,
model,
max_tokens: std::env::var("ISELF_AI_MAX_TOKENS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(2000),
temperature: std::env::var("ISELF_AI_TEMPERATURE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0.7),
system_prompt: std::env::var("ISELF_AI_SYSTEM_PROMPT").ok(),
timeout_seconds: 60,
retry_attempts: 3,
}
}
pub fn get_api_base(&self) -> String {
self.api_base.clone().unwrap_or_else(|| match self.provider {
AIProvider::OpenAI => "https://api.openai.com/v1".to_string(),
AIProvider::Claude => "https://api.anthropic.com/v1".to_string(),
AIProvider::Gemini => "https://generativelanguage.googleapis.com/v1".to_string(),
AIProvider::LiteLLM => "http://localhost:4000".to_string(),
AIProvider::OpenRouter => "https://openrouter.ai/api/v1".to_string(),
AIProvider::AzureOpenAI => format!(
"https://{}.openai.azure.com/openai/deployments/{}",
std::env::var("AZURE_OPENAI_ENDPOINT").unwrap_or_default(),
self.model
),
AIProvider::Cohere => "https://api.cohere.ai/v1".to_string(),
AIProvider::Ollama => "http://localhost:11434".to_string(),
AIProvider::Groq => "https://api.groq.com/openai/v1".to_string(),
})
}
}
#[async_trait::async_trait]
pub trait AIClient: Send + Sync {
async fn complete(&self, prompt: &str, context: Option<&AIContext>) -> Result<AIResponse>;
async fn complete_stream(
&self,
prompt: &str,
context: Option<&AIContext>,
) -> Result<tokio::sync::mpsc::Receiver<Result<String>>>;
fn name(&self) -> &str;
fn model_info(&self) -> ModelConfig;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AIResponse {
pub content: String,
pub tokens_used: Option<u32>,
pub model: String,
pub finish_reason: Option<String>,
pub cost_estimate: Option<f64>,
}
pub struct RAGQuery {
pub query: String,
pub context_chunks: Vec<String>,
pub developer_profile: Option<crate::storage::DeveloperProfile>,
}
pub fn build_rag_prompt(query: &RAGQuery) -> String {
let mut prompt = String::new();
if let Some(profile) = &query.developer_profile {
prompt.push_str("## Developer Profile\n");
prompt.push_str(&format!("Name: {}\n", profile.identity.name.as_deref().unwrap_or("Unknown")));
prompt.push_str(&format!("Primary Languages: {}\n",
profile.github.primary_languages.iter()
.take(5)
.map(|(l, _)| l.as_str())
.collect::<Vec<_>>()
.join(", ")
));
prompt.push('\n');
}
if !query.context_chunks.is_empty() {
prompt.push_str("## Relevant Code Context\n\n");
for (i, chunk) in query.context_chunks.iter().enumerate() {
prompt.push_str(&format!("### Code Snippet {}\n```\n{}\n```\n\n", i + 1, chunk));
}
}
prompt.push_str("## Question\n");
prompt.push_str(&query.query);
prompt.push_str("\n\n## Answer\n");
prompt
}
pub struct CodeAssistant {
client: Box<dyn AIClient>,
config: AIConfig,
cost_tracker: Option<CostTracker>,
}
impl CodeAssistant {
pub fn new(config: AIConfig) -> Result<Self> {
let client = AIClientFactory::create(&config)?;
Ok(Self {
client,
config,
cost_tracker: None,
})
}
pub fn with_cost_tracker(mut self, tracker: CostTracker) -> Self {
self.cost_tracker = Some(tracker);
self
}
pub async fn ask(&self, question: &str, context: Option<&AIContext>) -> Result<String> {
info!("Asking AI: {}", question);
let prompt = format!(
"You are an AI assistant that helps a developer understand their own code and development patterns. \
Answer based on the context provided. Be concise but informative.\n\nQuestion: {}\n\nAnswer:",
question
);
let response = self.client.complete(&prompt, context).await?;
if let Some(_tracker) = &self.cost_tracker {
debug!("Request cost: ${:.4}", response.cost_estimate.unwrap_or(0.0));
}
Ok(response.content)
}
pub async fn explain_code(&self, code: &str, language: &str) -> Result<String> {
let prompt = format!(
"Explain the following {} code in simple terms. Focus on what it does and why:\n\n```{}```\n\nExplanation:",
language, code
);
let response = self.client.complete(&prompt, None).await?;
Ok(response.content)
}
pub async fn suggest_improvements(&self, code: &str, language: &str) -> Result<Vec<String>> {
let prompt = format!(
"Review this {} code and suggest specific improvements. Focus on:\n\
1. Performance optimizations\n\
2. Code readability\n\
3. Best practices\n\
4. Potential bugs\n\n\
Code:\n```{}```\n\nSuggestions (one per line, numbered):",
language, code
);
let response = self.client.complete(&prompt, None).await?;
let suggestions: Vec<String> = response.content
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| line.trim().to_string())
.collect();
Ok(suggestions)
}
pub async fn generate_code(
&self,
description: &str,
language: &str,
context: Option<&AIContext>,
) -> Result<String> {
let prompt = format!(
"Generate {} code for the following description. \
Follow best practices and make it clean and well-documented.\n\n\
Description: {}\n\n\
Code:",
language, description
);
let response = self.client.complete(&prompt, context).await?;
Ok(response.content)
}
pub async fn ask_with_rag(&self, query: &RAGQuery) -> Result<String> {
let prompt = build_rag_prompt(query);
debug!("RAG prompt length: {} chars", prompt.len());
let response = self.client.complete(&prompt, None).await?;
Ok(response.content)
}
pub fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> f64 {
let model_info = self.client.model_info();
(input_tokens as f64 / 1000.0) * model_info.cost_per_1k_input +
(output_tokens as f64 / 1000.0) * model_info.cost_per_1k_output
}
}
pub struct AIClientFactory;
impl AIClientFactory {
pub fn create(config: &AIConfig) -> Result<Box<dyn AIClient>> {
match config.provider {
AIProvider::OpenAI => Ok(Box::new(OpenAIClient::new(config)?)),
AIProvider::Claude => Ok(Box::new(ClaudeClient::new(config)?)),
AIProvider::Gemini => Ok(Box::new(GeminiClient::new(config)?)),
AIProvider::LiteLLM => Ok(Box::new(LiteLLMClient::new(config)?)),
AIProvider::OpenRouter => Ok(Box::new(OpenAIClient::with_base_url(
config,
"https://openrouter.ai/api/v1"
)?)),
AIProvider::AzureOpenAI => Ok(Box::new(OpenAIClient::with_base_url(
config,
&config.get_api_base()
)?)),
AIProvider::Cohere => Ok(Box::new(OpenAIClient::with_base_url(
config,
"https://api.cohere.ai/v1"
)?)),
AIProvider::Ollama => Ok(Box::new(OpenAIClient::with_base_url(
config,
"http://localhost:11434/v1"
)?)),
AIProvider::Groq => Ok(Box::new(OpenAIClient::with_base_url(
config,
"https://api.groq.com/openai/v1"
)?)),
}
}
pub fn list_providers() -> Vec<(AIProvider, &'static str, Vec<&'static str>)> {
vec![
(AIProvider::OpenAI, "OpenAI", vec!["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"]),
(AIProvider::Claude, "Anthropic Claude", vec!["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"]),
(AIProvider::Gemini, "Google Gemini", vec!["gemini-pro", "gemini-pro-vision"]),
(AIProvider::LiteLLM, "LiteLLM Proxy", vec!["any model"]),
(AIProvider::OpenRouter, "OpenRouter", vec!["openai/gpt-4", "anthropic/claude-3", "google/gemini"]),
(AIProvider::AzureOpenAI, "Azure OpenAI", vec!["gpt-4", "gpt-35-turbo"]),
(AIProvider::Cohere, "Cohere", vec!["command-r", "command-r-plus"]),
(AIProvider::Ollama, "Ollama (Local)", vec!["llama2", "llama3", "mistral", "codellama"]),
(AIProvider::Groq, "Groq", vec!["llama3-70b", "mixtral-8x7b", "gemma-7b"]),
]
}
}