1use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatStreamEvent};
10use genai::Client;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::{mpsc, RwLock};
14
15pub const EOT_SIGNAL: &str = "<|EOT|>";
17
18#[derive(Default)]
20struct SharedState {
21 total_tokens_used: usize,
22 request_count: usize,
23}
24
25#[derive(Clone)]
30pub struct GenAIProvider {
31 client: Arc<Client>,
33 shared: Arc<RwLock<SharedState>>,
35}
36
37impl GenAIProvider {
38 pub fn new() -> Result<Self> {
40 let client = Client::default();
41 Ok(Self {
42 client: Arc::new(client),
43 shared: Arc::new(RwLock::new(SharedState::default())),
44 })
45 }
46
47 pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
49 if let (Some(provider), Some(key)) = (provider_type, api_key) {
51 let env_var = match provider {
52 "openai" => "OPENAI_API_KEY",
53 "anthropic" => "ANTHROPIC_API_KEY",
54 "gemini" => "GEMINI_API_KEY",
55 "groq" => "GROQ_API_KEY",
56 "cohere" => "COHERE_API_KEY",
57 "xai" => "XAI_API_KEY",
58 "deepseek" => "DEEPSEEK_API_KEY",
59 "ollama" => {
60 log::info!("Ollama provider detected - no API key required for local setup");
61 return Self::new();
62 }
63 _ => {
64 log::warn!("Unknown provider type for API key: {provider}");
65 return Self::new();
66 }
67 };
68
69 log::info!("Setting {env_var} environment variable for genai client");
70 std::env::set_var(env_var, key);
71 }
72
73 Self::new()
74 }
75
76 pub async fn get_total_tokens_used(&self) -> usize {
78 self.shared.read().await.total_tokens_used
79 }
80
81 pub async fn get_request_count(&self) -> usize {
83 self.shared.read().await.request_count
84 }
85
86 async fn increment_request(&self) {
88 let mut state = self.shared.write().await;
89 state.request_count += 1;
90 }
91
92 pub async fn add_tokens(&self, count: usize) {
94 let mut state = self.shared.write().await;
95 state.total_tokens_used += count;
96 }
97
98 pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
100 let adapter_kind = str_to_adapter_kind(provider)?;
101
102 let models = self
103 .client
104 .all_model_names(adapter_kind)
105 .await
106 .context(format!("Failed to get models for provider: {provider}"))?;
107
108 Ok(models)
109 }
110
111 pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<String> {
114 self.generate_response_with_retry(model, prompt, 3).await
115 }
116
117 pub async fn generate_response_with_retry(
119 &self,
120 model: &str,
121 prompt: &str,
122 max_retries: usize,
123 ) -> Result<String> {
124 self.increment_request().await;
125
126 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
127
128 log::debug!(
129 "Sending chat request to model: {model} with prompt length: {} chars",
130 prompt.len()
131 );
132
133 let mut last_error: Option<anyhow::Error> = None;
134 let mut retry_count = 0;
135
136 while retry_count <= max_retries {
137 if retry_count > 0 {
138 let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
140 log::warn!(
141 "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
142 retry_count,
143 max_retries,
144 model,
145 delay_secs,
146 last_error.as_ref().map(|e| e.to_string())
147 );
148 println!(
149 " ⏳ Rate limited, retrying in {}s (attempt {}/{})",
150 delay_secs, retry_count, max_retries
151 );
152 tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
153 }
154
155 match self.client.exec_chat(model, chat_req.clone(), None).await {
156 Ok(chat_res) => {
157 let content = chat_res
158 .first_text()
159 .context("No text content in response")?;
160 log::debug!("Received response with {} characters", content.len());
161 return Ok(content.to_string());
162 }
163 Err(e) => {
164 let err_str = e.to_string();
165
166 let is_retryable = err_str.contains("429")
168 || err_str.contains("rate limit")
169 || err_str.contains("Rate limit")
170 || err_str.contains("RESOURCE_EXHAUSTED")
171 || err_str.contains("500")
172 || err_str.contains("502")
173 || err_str.contains("503")
174 || err_str.contains("504")
175 || err_str.contains("timeout")
176 || err_str.contains("connection");
177
178 if is_retryable && retry_count < max_retries {
179 log::warn!("Retryable error for model {}: {}", model, err_str);
180 last_error = Some(anyhow::anyhow!("{}", err_str));
181 retry_count += 1;
182 continue;
183 } else {
184 return Err(anyhow::anyhow!(
185 "Failed to execute chat request for model {}: {}",
186 model,
187 err_str
188 ));
189 }
190 }
191 }
192 }
193
194 Err(last_error
196 .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
197 }
198
199 pub async fn generate_response_stream_to_channel(
201 &self,
202 model: &str,
203 prompt: &str,
204 tx: mpsc::UnboundedSender<String>,
205 ) -> Result<()> {
206 self.increment_request().await;
207
208 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
209
210 log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
211
212 let chat_res_stream = self
213 .client
214 .exec_chat_stream(model, chat_req, None)
215 .await
216 .context(format!(
217 "Failed to execute streaming chat request for model: {model}"
218 ))?;
219
220 let mut stream = chat_res_stream.stream;
221 let mut chunk_count = 0;
222 let mut total_content_length = 0;
223 let mut stream_ended_explicitly = false;
224 let start_time = Instant::now();
225
226 log::info!(
227 "=== STREAM START === Model: {}, Prompt length: {} chars",
228 model,
229 prompt.len()
230 );
231
232 while let Some(chunk_result) = stream.next().await {
233 let elapsed = start_time.elapsed();
234
235 match chunk_result {
236 Ok(ChatStreamEvent::Start) => {
237 log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
238 }
239 Ok(ChatStreamEvent::Chunk(chunk)) => {
240 chunk_count += 1;
241 total_content_length += chunk.content.len();
242
243 if chunk_count % 10 == 0 || chunk.content.len() > 100 {
244 log::info!(
245 "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
246 chunk_count,
247 chunk.content.len(),
248 total_content_length,
249 elapsed
250 );
251 }
252
253 if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
254 log::error!(
255 "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
256 );
257 break;
258 }
259 }
260 Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
261 log::info!(
262 "REASONING CHUNK: {} chars at {:?}",
263 chunk.content.len(),
264 elapsed
265 );
266 }
267 Ok(ChatStreamEvent::End(_)) => {
268 log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
269 stream_ended_explicitly = true;
270 break;
271 }
272 Ok(ChatStreamEvent::ToolCallChunk(_)) => {
273 log::debug!("Tool call chunk received (ignored)");
274 }
275 Err(e) => {
276 log::error!(
277 "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
278 );
279 let error_msg = format!("Stream error: {e}");
280 let _ = tx.send(error_msg);
281 return Err(e.into());
282 }
283 }
284 }
285
286 let final_elapsed = start_time.elapsed();
287 if !stream_ended_explicitly {
288 log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
289 }
290
291 log::info!(
292 "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
293 );
294
295 self.add_tokens(total_content_length / 4).await; if tx.send(EOT_SIGNAL.to_string()).is_err() {
299 log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
300 return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
301 }
302
303 log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
304 Ok(())
305 }
306
307 pub async fn generate_response_with_history(
309 &self,
310 model: &str,
311 messages: Vec<ChatMessage>,
312 ) -> Result<String> {
313 self.increment_request().await;
314
315 let chat_req = ChatRequest::new(messages);
316
317 log::debug!("Sending chat request to model: {model} with conversation history");
318
319 let chat_res = self
320 .client
321 .exec_chat(model, chat_req, None)
322 .await
323 .context(format!("Failed to execute chat request for model: {model}"))?;
324
325 let content = chat_res
326 .first_text()
327 .context("No text content in response")?;
328
329 log::debug!("Received response with {} characters", content.len());
330 Ok(content.to_string())
331 }
332
333 pub async fn generate_response_with_options(
335 &self,
336 model: &str,
337 prompt: &str,
338 options: ChatOptions,
339 ) -> Result<String> {
340 self.increment_request().await;
341
342 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
343
344 log::debug!("Sending chat request to model: {model} with custom options");
345
346 let chat_res = self
347 .client
348 .exec_chat(model, chat_req, Some(&options))
349 .await
350 .context(format!("Failed to execute chat request for model: {model}"))?;
351
352 let content = chat_res
353 .first_text()
354 .context("No text content in response")?;
355
356 log::debug!("Received response with {} characters", content.len());
357 Ok(content.to_string())
358 }
359
360 pub fn get_supported_providers() -> Vec<&'static str> {
362 vec![
363 "openai",
364 "anthropic",
365 "gemini",
366 "groq",
367 "cohere",
368 "ollama",
369 "xai",
370 "deepseek",
371 ]
372 }
373
374 pub async fn get_available_providers(&self) -> Result<Vec<String>> {
376 Ok(Self::get_supported_providers()
377 .iter()
378 .map(|s| s.to_string())
379 .collect())
380 }
381
382 pub async fn test_model(&self, model: &str) -> Result<bool> {
384 match self.generate_response_simple(model, "Hello").await {
385 Ok(_) => {
386 log::info!("Model {model} is available and working");
387 Ok(true)
388 }
389 Err(e) => {
390 log::warn!("Model {model} test failed: {e}");
391 Ok(false)
392 }
393 }
394 }
395
396 pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
398 if self.test_model(model).await? {
399 return Ok(model.to_string());
400 }
401
402 if let Some(provider) = provider_type {
403 if let Ok(models) = self.get_available_models(provider).await {
404 if !models.is_empty() {
405 log::info!("Model {} not available, using {} instead", model, models[0]);
406 return Ok(models[0].clone());
407 }
408 }
409 }
410
411 log::warn!("Could not validate model {model}, proceeding anyway");
412 Ok(model.to_string())
413 }
414}
415
416fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
418 match provider.to_lowercase().as_str() {
419 "openai" => Ok(AdapterKind::OpenAI),
420 "anthropic" => Ok(AdapterKind::Anthropic),
421 "gemini" | "google" => Ok(AdapterKind::Gemini),
422 "groq" => Ok(AdapterKind::Groq),
423 "cohere" => Ok(AdapterKind::Cohere),
424 "ollama" => Ok(AdapterKind::Ollama),
425 "xai" => Ok(AdapterKind::Xai),
426 "deepseek" => Ok(AdapterKind::DeepSeek),
427 _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_str_to_adapter_kind() {
437 assert!(str_to_adapter_kind("openai").is_ok());
438 assert!(str_to_adapter_kind("anthropic").is_ok());
439 assert!(str_to_adapter_kind("gemini").is_ok());
440 assert!(str_to_adapter_kind("google").is_ok());
441 assert!(str_to_adapter_kind("groq").is_ok());
442 assert!(str_to_adapter_kind("cohere").is_ok());
443 assert!(str_to_adapter_kind("ollama").is_ok());
444 assert!(str_to_adapter_kind("xai").is_ok());
445 assert!(str_to_adapter_kind("deepseek").is_ok());
446 assert!(str_to_adapter_kind("invalid").is_err());
447 }
448
449 #[tokio::test]
450 async fn test_provider_creation() {
451 let provider = GenAIProvider::new();
452 assert!(provider.is_ok());
453 }
454
455 #[tokio::test]
456 async fn test_provider_is_clonable() {
457 let provider = GenAIProvider::new().unwrap();
458 let _clone1 = provider.clone();
459 let _clone2 = provider.clone();
460 }
462}