use anyhow::{Context, Result};
use futures::StreamExt;
use genai::adapter::AdapterKind;
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatStreamEvent};
use genai::Client;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{mpsc, RwLock};
pub const EOT_SIGNAL: &str = "<|EOT|>";
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub text: String,
pub tokens_in: Option<i32>,
pub tokens_out: Option<i32>,
}
#[derive(Default)]
struct SharedState {
total_tokens_used: usize,
request_count: usize,
}
#[derive(Clone)]
pub struct GenAIProvider {
client: Arc<Client>,
shared: Arc<RwLock<SharedState>>,
}
impl GenAIProvider {
pub fn new() -> Result<Self> {
let client = Client::default();
Ok(Self {
client: Arc::new(client),
shared: Arc::new(RwLock::new(SharedState::default())),
})
}
pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
if let (Some(provider), Some(key)) = (provider_type, api_key) {
let env_var = match provider {
"openai" => "OPENAI_API_KEY",
"anthropic" => "ANTHROPIC_API_KEY",
"gemini" => "GEMINI_API_KEY",
"groq" => "GROQ_API_KEY",
"cohere" => "COHERE_API_KEY",
"xai" => "XAI_API_KEY",
"deepseek" => "DEEPSEEK_API_KEY",
"ollama" => {
log::info!("Ollama provider detected - no API key required for local setup");
return Self::new();
}
_ => {
log::warn!("Unknown provider type for API key: {provider}");
return Self::new();
}
};
log::info!("Setting {env_var} environment variable for genai client");
std::env::set_var(env_var, key);
}
Self::new()
}
pub async fn get_total_tokens_used(&self) -> usize {
self.shared.read().await.total_tokens_used
}
pub async fn get_request_count(&self) -> usize {
self.shared.read().await.request_count
}
async fn increment_request(&self) {
let mut state = self.shared.write().await;
state.request_count += 1;
}
pub async fn add_tokens(&self, count: usize) {
let mut state = self.shared.write().await;
state.total_tokens_used += count;
}
pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
let adapter_kind = str_to_adapter_kind(provider)?;
let models = self
.client
.all_model_names(adapter_kind)
.await
.context(format!("Failed to get models for provider: {provider}"))?;
Ok(models)
}
pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
self.generate_response_with_retry(model, prompt, 3).await
}
pub async fn generate_response_with_retry(
&self,
model: &str,
prompt: &str,
max_retries: usize,
) -> Result<LlmResponse> {
self.increment_request().await;
let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
log::debug!(
"Sending chat request to model: {model} with prompt length: {} chars",
prompt.len()
);
let start_time = Instant::now();
let mut last_error: Option<anyhow::Error> = None;
let mut retry_count = 0;
while retry_count <= max_retries {
if retry_count > 0 {
let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
log::warn!(
"Retry {}/{} for model {} after {}s delay (previous error: {:?})",
retry_count,
max_retries,
model,
delay_secs,
last_error.as_ref().map(|e| e.to_string())
);
println!(
" ⏳ Rate limited, retrying in {}s (attempt {}/{})",
delay_secs, retry_count, max_retries
);
tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
}
match self.client.exec_chat(model, chat_req.clone(), None).await {
Ok(chat_res) => {
let tokens_in = chat_res.usage.prompt_tokens;
let tokens_out = chat_res.usage.completion_tokens;
let content = chat_res
.first_text()
.context("No text content in response")?;
log::debug!(
"Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
content.len(),
start_time.elapsed().as_millis(),
tokens_in,
tokens_out,
);
let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
if total > 0 {
self.add_tokens(total as usize).await;
}
return Ok(LlmResponse {
text: content.to_string(),
tokens_in,
tokens_out,
});
}
Err(e) => {
let err_str = e.to_string();
let is_retryable = err_str.contains("429")
|| err_str.contains("rate limit")
|| err_str.contains("Rate limit")
|| err_str.contains("RESOURCE_EXHAUSTED")
|| err_str.contains("500")
|| err_str.contains("502")
|| err_str.contains("503")
|| err_str.contains("504")
|| err_str.contains("timeout")
|| err_str.contains("connection");
if is_retryable && retry_count < max_retries {
log::warn!("Retryable error for model {}: {}", model, err_str);
last_error = Some(anyhow::anyhow!("{}", err_str));
retry_count += 1;
continue;
} else {
return Err(anyhow::anyhow!(
"Failed to execute chat request for model {}: {}",
model,
err_str
));
}
}
}
}
Err(last_error
.unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
}
pub async fn generate_response_stream_to_channel(
&self,
model: &str,
prompt: &str,
tx: mpsc::UnboundedSender<String>,
) -> Result<()> {
self.increment_request().await;
let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
let chat_res_stream = self
.client
.exec_chat_stream(model, chat_req, None)
.await
.context(format!(
"Failed to execute streaming chat request for model: {model}"
))?;
let mut stream = chat_res_stream.stream;
let mut chunk_count = 0;
let mut total_content_length = 0;
let mut stream_ended_explicitly = false;
let start_time = Instant::now();
log::info!(
"=== STREAM START === Model: {}, Prompt length: {} chars",
model,
prompt.len()
);
while let Some(chunk_result) = stream.next().await {
let elapsed = start_time.elapsed();
match chunk_result {
Ok(ChatStreamEvent::Start) => {
log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
}
Ok(ChatStreamEvent::Chunk(chunk)) => {
chunk_count += 1;
total_content_length += chunk.content.len();
if chunk_count % 10 == 0 || chunk.content.len() > 100 {
log::info!(
"CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
chunk_count,
chunk.content.len(),
total_content_length,
elapsed
);
}
if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
log::error!(
"!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
);
break;
}
}
Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
log::info!(
"REASONING CHUNK: {} chars at {:?}",
chunk.content.len(),
elapsed
);
}
Ok(ChatStreamEvent::End(_)) => {
log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
stream_ended_explicitly = true;
break;
}
Ok(ChatStreamEvent::ToolCallChunk(_)) => {
log::debug!("Tool call chunk received (ignored)");
}
Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
log::debug!("Thought signature chunk received (ignored)");
}
Err(e) => {
log::error!(
"!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
);
let error_msg = format!("Stream error: {e}");
let _ = tx.send(error_msg);
return Err(e.into());
}
}
}
let final_elapsed = start_time.elapsed();
if !stream_ended_explicitly {
log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
}
log::info!(
"=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
);
self.add_tokens(total_content_length / 4).await;
if tx.send(EOT_SIGNAL.to_string()).is_err() {
log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
}
log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
Ok(())
}
pub async fn generate_response_with_history(
&self,
model: &str,
messages: Vec<ChatMessage>,
) -> Result<String> {
self.increment_request().await;
let chat_req = ChatRequest::new(messages);
log::debug!("Sending chat request to model: {model} with conversation history");
let chat_res = self
.client
.exec_chat(model, chat_req, None)
.await
.context(format!("Failed to execute chat request for model: {model}"))?;
let content = chat_res
.first_text()
.context("No text content in response")?;
log::debug!("Received response with {} characters", content.len());
Ok(content.to_string())
}
pub async fn generate_response_with_options(
&self,
model: &str,
prompt: &str,
options: ChatOptions,
) -> Result<String> {
self.increment_request().await;
let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
log::debug!("Sending chat request to model: {model} with custom options");
let chat_res = self
.client
.exec_chat(model, chat_req, Some(&options))
.await
.context(format!("Failed to execute chat request for model: {model}"))?;
let content = chat_res
.first_text()
.context("No text content in response")?;
log::debug!("Received response with {} characters", content.len());
Ok(content.to_string())
}
pub fn get_supported_providers() -> Vec<&'static str> {
vec![
"openai",
"anthropic",
"gemini",
"groq",
"cohere",
"ollama",
"xai",
"deepseek",
]
}
pub async fn get_available_providers(&self) -> Result<Vec<String>> {
Ok(Self::get_supported_providers()
.iter()
.map(|s| s.to_string())
.collect())
}
pub async fn test_model(&self, model: &str) -> Result<bool> {
match self.generate_response_simple(model, "Hello").await {
Ok(_) => {
log::info!("Model {model} is available and working");
Ok(true)
}
Err(e) => {
log::warn!("Model {model} test failed: {e}");
Ok(false)
}
}
}
pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
if self.test_model(model).await? {
return Ok(model.to_string());
}
if let Some(provider) = provider_type {
if let Ok(models) = self.get_available_models(provider).await {
if !models.is_empty() {
log::info!("Model {} not available, using {} instead", model, models[0]);
return Ok(models[0].clone());
}
}
}
log::warn!("Could not validate model {model}, proceeding anyway");
Ok(model.to_string())
}
}
fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
match provider.to_lowercase().as_str() {
"openai" => Ok(AdapterKind::OpenAI),
"anthropic" => Ok(AdapterKind::Anthropic),
"gemini" | "google" => Ok(AdapterKind::Gemini),
"groq" => Ok(AdapterKind::Groq),
"cohere" => Ok(AdapterKind::Cohere),
"ollama" => Ok(AdapterKind::Ollama),
"xai" => Ok(AdapterKind::Xai),
"deepseek" => Ok(AdapterKind::DeepSeek),
_ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_str_to_adapter_kind() {
assert!(str_to_adapter_kind("openai").is_ok());
assert!(str_to_adapter_kind("anthropic").is_ok());
assert!(str_to_adapter_kind("gemini").is_ok());
assert!(str_to_adapter_kind("google").is_ok());
assert!(str_to_adapter_kind("groq").is_ok());
assert!(str_to_adapter_kind("cohere").is_ok());
assert!(str_to_adapter_kind("ollama").is_ok());
assert!(str_to_adapter_kind("xai").is_ok());
assert!(str_to_adapter_kind("deepseek").is_ok());
assert!(str_to_adapter_kind("invalid").is_err());
}
#[tokio::test]
async fn test_provider_creation() {
let provider = GenAIProvider::new();
assert!(provider.is_ok());
}
#[tokio::test]
async fn test_provider_is_clonable() {
let provider = GenAIProvider::new().unwrap();
let _clone1 = provider.clone();
let _clone2 = provider.clone();
}
}