1use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatStreamEvent};
10use genai::Client;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::{mpsc, RwLock};
14
15pub const EOT_SIGNAL: &str = "<|EOT|>";
17
18#[derive(Default)]
20struct SharedState {
21 total_tokens_used: usize,
22 request_count: usize,
23}
24
25#[derive(Clone)]
30pub struct GenAIProvider {
31 client: Arc<Client>,
33 shared: Arc<RwLock<SharedState>>,
35}
36
37impl GenAIProvider {
38 pub fn new() -> Result<Self> {
40 let client = Client::default();
41 Ok(Self {
42 client: Arc::new(client),
43 shared: Arc::new(RwLock::new(SharedState::default())),
44 })
45 }
46
47 pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
49 if let (Some(provider), Some(key)) = (provider_type, api_key) {
51 let env_var = match provider {
52 "openai" => "OPENAI_API_KEY",
53 "anthropic" => "ANTHROPIC_API_KEY",
54 "gemini" => "GEMINI_API_KEY",
55 "groq" => "GROQ_API_KEY",
56 "cohere" => "COHERE_API_KEY",
57 "xai" => "XAI_API_KEY",
58 "deepseek" => "DEEPSEEK_API_KEY",
59 "ollama" => {
60 log::info!("Ollama provider detected - no API key required for local setup");
61 return Self::new();
62 }
63 _ => {
64 log::warn!("Unknown provider type for API key: {provider}");
65 return Self::new();
66 }
67 };
68
69 log::info!("Setting {env_var} environment variable for genai client");
70 std::env::set_var(env_var, key);
71 }
72
73 Self::new()
74 }
75
76 pub async fn get_total_tokens_used(&self) -> usize {
78 self.shared.read().await.total_tokens_used
79 }
80
81 pub async fn get_request_count(&self) -> usize {
83 self.shared.read().await.request_count
84 }
85
86 async fn increment_request(&self) {
88 let mut state = self.shared.write().await;
89 state.request_count += 1;
90 }
91
92 pub async fn add_tokens(&self, count: usize) {
94 let mut state = self.shared.write().await;
95 state.total_tokens_used += count;
96 }
97
98 pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
100 let adapter_kind = str_to_adapter_kind(provider)?;
101
102 let models = self
103 .client
104 .all_model_names(adapter_kind)
105 .await
106 .context(format!("Failed to get models for provider: {provider}"))?;
107
108 Ok(models)
109 }
110
111 pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<String> {
114 self.generate_response_with_retry(model, prompt, 3).await
115 }
116
117 pub async fn generate_response_with_retry(
119 &self,
120 model: &str,
121 prompt: &str,
122 max_retries: usize,
123 ) -> Result<String> {
124 self.increment_request().await;
125
126 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
127
128 log::debug!(
129 "Sending chat request to model: {model} with prompt length: {} chars",
130 prompt.len()
131 );
132
133 let start_time = Instant::now();
134 let mut last_error: Option<anyhow::Error> = None;
135 let mut retry_count = 0;
136
137 while retry_count <= max_retries {
138 if retry_count > 0 {
139 let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
141 log::warn!(
142 "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
143 retry_count,
144 max_retries,
145 model,
146 delay_secs,
147 last_error.as_ref().map(|e| e.to_string())
148 );
149 println!(
150 " ⏳ Rate limited, retrying in {}s (attempt {}/{})",
151 delay_secs, retry_count, max_retries
152 );
153 tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
154 }
155
156 match self.client.exec_chat(model, chat_req.clone(), None).await {
157 Ok(chat_res) => {
158 let content = chat_res
159 .first_text()
160 .context("No text content in response")?;
161 log::debug!(
162 "Received response with {} characters in {}ms",
163 content.len(),
164 start_time.elapsed().as_millis()
165 );
166
167 return Ok(content.to_string());
168 }
169 Err(e) => {
170 let err_str = e.to_string();
171
172 let is_retryable = err_str.contains("429")
174 || err_str.contains("rate limit")
175 || err_str.contains("Rate limit")
176 || err_str.contains("RESOURCE_EXHAUSTED")
177 || err_str.contains("500")
178 || err_str.contains("502")
179 || err_str.contains("503")
180 || err_str.contains("504")
181 || err_str.contains("timeout")
182 || err_str.contains("connection");
183
184 if is_retryable && retry_count < max_retries {
185 log::warn!("Retryable error for model {}: {}", model, err_str);
186 last_error = Some(anyhow::anyhow!("{}", err_str));
187 retry_count += 1;
188 continue;
189 } else {
190 return Err(anyhow::anyhow!(
191 "Failed to execute chat request for model {}: {}",
192 model,
193 err_str
194 ));
195 }
196 }
197 }
198 }
199
200 Err(last_error
202 .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
203 }
204
205 pub async fn generate_response_stream_to_channel(
207 &self,
208 model: &str,
209 prompt: &str,
210 tx: mpsc::UnboundedSender<String>,
211 ) -> Result<()> {
212 self.increment_request().await;
213
214 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
215
216 log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
217
218 let chat_res_stream = self
219 .client
220 .exec_chat_stream(model, chat_req, None)
221 .await
222 .context(format!(
223 "Failed to execute streaming chat request for model: {model}"
224 ))?;
225
226 let mut stream = chat_res_stream.stream;
227 let mut chunk_count = 0;
228 let mut total_content_length = 0;
229 let mut stream_ended_explicitly = false;
230 let start_time = Instant::now();
231
232 log::info!(
233 "=== STREAM START === Model: {}, Prompt length: {} chars",
234 model,
235 prompt.len()
236 );
237
238 while let Some(chunk_result) = stream.next().await {
239 let elapsed = start_time.elapsed();
240
241 match chunk_result {
242 Ok(ChatStreamEvent::Start) => {
243 log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
244 }
245 Ok(ChatStreamEvent::Chunk(chunk)) => {
246 chunk_count += 1;
247 total_content_length += chunk.content.len();
248
249 if chunk_count % 10 == 0 || chunk.content.len() > 100 {
250 log::info!(
251 "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
252 chunk_count,
253 chunk.content.len(),
254 total_content_length,
255 elapsed
256 );
257 }
258
259 if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
260 log::error!(
261 "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
262 );
263 break;
264 }
265 }
266 Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
267 log::info!(
268 "REASONING CHUNK: {} chars at {:?}",
269 chunk.content.len(),
270 elapsed
271 );
272 }
273 Ok(ChatStreamEvent::End(_)) => {
274 log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
275 stream_ended_explicitly = true;
276 break;
277 }
278 Ok(ChatStreamEvent::ToolCallChunk(_)) => {
279 log::debug!("Tool call chunk received (ignored)");
280 }
281 Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
282 log::debug!("Thought signature chunk received (ignored)");
283 }
284 Err(e) => {
285 log::error!(
286 "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
287 );
288 let error_msg = format!("Stream error: {e}");
289 let _ = tx.send(error_msg);
290 return Err(e.into());
291 }
292 }
293 }
294
295 let final_elapsed = start_time.elapsed();
296 if !stream_ended_explicitly {
297 log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
298 }
299
300 log::info!(
301 "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
302 );
303
304 self.add_tokens(total_content_length / 4).await; if tx.send(EOT_SIGNAL.to_string()).is_err() {
308 log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
309 return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
310 }
311
312 log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
313 Ok(())
314 }
315
316 pub async fn generate_response_with_history(
318 &self,
319 model: &str,
320 messages: Vec<ChatMessage>,
321 ) -> Result<String> {
322 self.increment_request().await;
323
324 let chat_req = ChatRequest::new(messages);
325
326 log::debug!("Sending chat request to model: {model} with conversation history");
327
328 let chat_res = self
329 .client
330 .exec_chat(model, chat_req, None)
331 .await
332 .context(format!("Failed to execute chat request for model: {model}"))?;
333
334 let content = chat_res
335 .first_text()
336 .context("No text content in response")?;
337
338 log::debug!("Received response with {} characters", content.len());
339 Ok(content.to_string())
340 }
341
342 pub async fn generate_response_with_options(
344 &self,
345 model: &str,
346 prompt: &str,
347 options: ChatOptions,
348 ) -> Result<String> {
349 self.increment_request().await;
350
351 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
352
353 log::debug!("Sending chat request to model: {model} with custom options");
354
355 let chat_res = self
356 .client
357 .exec_chat(model, chat_req, Some(&options))
358 .await
359 .context(format!("Failed to execute chat request for model: {model}"))?;
360
361 let content = chat_res
362 .first_text()
363 .context("No text content in response")?;
364
365 log::debug!("Received response with {} characters", content.len());
366 Ok(content.to_string())
367 }
368
369 pub fn get_supported_providers() -> Vec<&'static str> {
371 vec![
372 "openai",
373 "anthropic",
374 "gemini",
375 "groq",
376 "cohere",
377 "ollama",
378 "xai",
379 "deepseek",
380 ]
381 }
382
383 pub async fn get_available_providers(&self) -> Result<Vec<String>> {
385 Ok(Self::get_supported_providers()
386 .iter()
387 .map(|s| s.to_string())
388 .collect())
389 }
390
391 pub async fn test_model(&self, model: &str) -> Result<bool> {
393 match self.generate_response_simple(model, "Hello").await {
394 Ok(_) => {
395 log::info!("Model {model} is available and working");
396 Ok(true)
397 }
398 Err(e) => {
399 log::warn!("Model {model} test failed: {e}");
400 Ok(false)
401 }
402 }
403 }
404
405 pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
407 if self.test_model(model).await? {
408 return Ok(model.to_string());
409 }
410
411 if let Some(provider) = provider_type {
412 if let Ok(models) = self.get_available_models(provider).await {
413 if !models.is_empty() {
414 log::info!("Model {} not available, using {} instead", model, models[0]);
415 return Ok(models[0].clone());
416 }
417 }
418 }
419
420 log::warn!("Could not validate model {model}, proceeding anyway");
421 Ok(model.to_string())
422 }
423}
424
425fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
427 match provider.to_lowercase().as_str() {
428 "openai" => Ok(AdapterKind::OpenAI),
429 "anthropic" => Ok(AdapterKind::Anthropic),
430 "gemini" | "google" => Ok(AdapterKind::Gemini),
431 "groq" => Ok(AdapterKind::Groq),
432 "cohere" => Ok(AdapterKind::Cohere),
433 "ollama" => Ok(AdapterKind::Ollama),
434 "xai" => Ok(AdapterKind::Xai),
435 "deepseek" => Ok(AdapterKind::DeepSeek),
436 _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_str_to_adapter_kind() {
446 assert!(str_to_adapter_kind("openai").is_ok());
447 assert!(str_to_adapter_kind("anthropic").is_ok());
448 assert!(str_to_adapter_kind("gemini").is_ok());
449 assert!(str_to_adapter_kind("google").is_ok());
450 assert!(str_to_adapter_kind("groq").is_ok());
451 assert!(str_to_adapter_kind("cohere").is_ok());
452 assert!(str_to_adapter_kind("ollama").is_ok());
453 assert!(str_to_adapter_kind("xai").is_ok());
454 assert!(str_to_adapter_kind("deepseek").is_ok());
455 assert!(str_to_adapter_kind("invalid").is_err());
456 }
457
458 #[tokio::test]
459 async fn test_provider_creation() {
460 let provider = GenAIProvider::new();
461 assert!(provider.is_ok());
462 }
463
464 #[tokio::test]
465 async fn test_provider_is_clonable() {
466 let provider = GenAIProvider::new().unwrap();
467 let _clone1 = provider.clone();
468 let _clone2 = provider.clone();
469 }
471}