1use anyhow::{Context, Result};
7use futures::StreamExt;
8use genai::adapter::AdapterKind;
9use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatStreamEvent};
10use genai::Client;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::{mpsc, RwLock};
14
15pub const EOT_SIGNAL: &str = "<|EOT|>";
17
18#[derive(Debug, Clone)]
20pub struct LlmResponse {
21 pub text: String,
22 pub tokens_in: Option<i32>,
23 pub tokens_out: Option<i32>,
24}
25
26#[derive(Default)]
28struct SharedState {
29 total_tokens_used: usize,
30 request_count: usize,
31}
32
33#[derive(Clone)]
38pub struct GenAIProvider {
39 client: Arc<Client>,
41 shared: Arc<RwLock<SharedState>>,
43}
44
45impl GenAIProvider {
46 pub fn new() -> Result<Self> {
48 let client = Client::default();
49 Ok(Self {
50 client: Arc::new(client),
51 shared: Arc::new(RwLock::new(SharedState::default())),
52 })
53 }
54
55 pub fn new_with_config(provider_type: Option<&str>, api_key: Option<&str>) -> Result<Self> {
57 if let (Some(provider), Some(key)) = (provider_type, api_key) {
59 let env_var = match provider {
60 "openai" => "OPENAI_API_KEY",
61 "anthropic" => "ANTHROPIC_API_KEY",
62 "gemini" => "GEMINI_API_KEY",
63 "groq" => "GROQ_API_KEY",
64 "cohere" => "COHERE_API_KEY",
65 "xai" => "XAI_API_KEY",
66 "deepseek" => "DEEPSEEK_API_KEY",
67 "ollama" => {
68 log::info!("Ollama provider detected - no API key required for local setup");
69 return Self::new();
70 }
71 _ => {
72 log::warn!("Unknown provider type for API key: {provider}");
73 return Self::new();
74 }
75 };
76
77 log::info!("Setting {env_var} environment variable for genai client");
78 std::env::set_var(env_var, key);
79 }
80
81 Self::new()
82 }
83
84 pub async fn get_total_tokens_used(&self) -> usize {
86 self.shared.read().await.total_tokens_used
87 }
88
89 pub async fn get_request_count(&self) -> usize {
91 self.shared.read().await.request_count
92 }
93
94 async fn increment_request(&self) {
96 let mut state = self.shared.write().await;
97 state.request_count += 1;
98 }
99
100 pub async fn add_tokens(&self, count: usize) {
102 let mut state = self.shared.write().await;
103 state.total_tokens_used += count;
104 }
105
106 pub async fn get_available_models(&self, provider: &str) -> Result<Vec<String>> {
108 let adapter_kind = str_to_adapter_kind(provider)?;
109
110 let models = self
111 .client
112 .all_model_names(adapter_kind)
113 .await
114 .context(format!("Failed to get models for provider: {provider}"))?;
115
116 Ok(models)
117 }
118
119 pub async fn generate_response_simple(&self, model: &str, prompt: &str) -> Result<LlmResponse> {
122 self.generate_response_with_retry(model, prompt, 3).await
123 }
124
125 pub async fn generate_response_with_retry(
127 &self,
128 model: &str,
129 prompt: &str,
130 max_retries: usize,
131 ) -> Result<LlmResponse> {
132 self.increment_request().await;
133
134 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
135
136 log::debug!(
137 "Sending chat request to model: {model} with prompt length: {} chars",
138 prompt.len()
139 );
140
141 let start_time = Instant::now();
142 let mut last_error: Option<anyhow::Error> = None;
143 let mut retry_count = 0;
144
145 while retry_count <= max_retries {
146 if retry_count > 0 {
147 let delay_secs = std::cmp::min(1u64 << (retry_count - 1), 16);
149 log::warn!(
150 "Retry {}/{} for model {} after {}s delay (previous error: {:?})",
151 retry_count,
152 max_retries,
153 model,
154 delay_secs,
155 last_error.as_ref().map(|e| e.to_string())
156 );
157 println!(
158 " ⏳ Rate limited, retrying in {}s (attempt {}/{})",
159 delay_secs, retry_count, max_retries
160 );
161 tokio::time::sleep(tokio::time::Duration::from_secs(delay_secs)).await;
162 }
163
164 match self.client.exec_chat(model, chat_req.clone(), None).await {
165 Ok(chat_res) => {
166 let tokens_in = chat_res.usage.prompt_tokens;
167 let tokens_out = chat_res.usage.completion_tokens;
168 let content = chat_res
169 .first_text()
170 .context("No text content in response")?;
171 log::debug!(
172 "Received response with {} characters in {}ms (tokens: in={:?}, out={:?})",
173 content.len(),
174 start_time.elapsed().as_millis(),
175 tokens_in,
176 tokens_out,
177 );
178
179 let total = tokens_in.unwrap_or(0) + tokens_out.unwrap_or(0);
181 if total > 0 {
182 self.add_tokens(total as usize).await;
183 }
184
185 return Ok(LlmResponse {
186 text: content.to_string(),
187 tokens_in,
188 tokens_out,
189 });
190 }
191 Err(e) => {
192 let err_str = e.to_string();
193
194 let is_retryable = err_str.contains("429")
196 || err_str.contains("rate limit")
197 || err_str.contains("Rate limit")
198 || err_str.contains("RESOURCE_EXHAUSTED")
199 || err_str.contains("500")
200 || err_str.contains("502")
201 || err_str.contains("503")
202 || err_str.contains("504")
203 || err_str.contains("timeout")
204 || err_str.contains("connection");
205
206 if is_retryable && retry_count < max_retries {
207 log::warn!("Retryable error for model {}: {}", model, err_str);
208 last_error = Some(anyhow::anyhow!("{}", err_str));
209 retry_count += 1;
210 continue;
211 } else {
212 return Err(anyhow::anyhow!(
213 "Failed to execute chat request for model {}: {}",
214 model,
215 err_str
216 ));
217 }
218 }
219 }
220 }
221
222 Err(last_error
224 .unwrap_or_else(|| anyhow::anyhow!("Unknown error after {} retries", max_retries)))
225 }
226
227 pub async fn generate_response_stream_to_channel(
229 &self,
230 model: &str,
231 prompt: &str,
232 tx: mpsc::UnboundedSender<String>,
233 ) -> Result<()> {
234 self.increment_request().await;
235
236 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
237
238 log::debug!("Sending streaming chat request to model: {model} with prompt: {prompt}");
239
240 let chat_res_stream = self
241 .client
242 .exec_chat_stream(model, chat_req, None)
243 .await
244 .context(format!(
245 "Failed to execute streaming chat request for model: {model}"
246 ))?;
247
248 let mut stream = chat_res_stream.stream;
249 let mut chunk_count = 0;
250 let mut total_content_length = 0;
251 let mut stream_ended_explicitly = false;
252 let start_time = Instant::now();
253
254 log::info!(
255 "=== STREAM START === Model: {}, Prompt length: {} chars",
256 model,
257 prompt.len()
258 );
259
260 while let Some(chunk_result) = stream.next().await {
261 let elapsed = start_time.elapsed();
262
263 match chunk_result {
264 Ok(ChatStreamEvent::Start) => {
265 log::info!(">>> STREAM STARTED for model: {model} at {elapsed:?}");
266 }
267 Ok(ChatStreamEvent::Chunk(chunk)) => {
268 chunk_count += 1;
269 total_content_length += chunk.content.len();
270
271 if chunk_count % 10 == 0 || chunk.content.len() > 100 {
272 log::info!(
273 "CHUNK #{}: {} chars, total: {} chars, elapsed: {:?}",
274 chunk_count,
275 chunk.content.len(),
276 total_content_length,
277 elapsed
278 );
279 }
280
281 if !chunk.content.is_empty() && tx.send(chunk.content.clone()).is_err() {
282 log::error!(
283 "!!! CHANNEL SEND FAILED for chunk #{chunk_count} - STOPPING STREAM !!!"
284 );
285 break;
286 }
287 }
288 Ok(ChatStreamEvent::ReasoningChunk(chunk)) => {
289 log::info!(
290 "REASONING CHUNK: {} chars at {:?}",
291 chunk.content.len(),
292 elapsed
293 );
294 }
295 Ok(ChatStreamEvent::End(_)) => {
296 log::info!(">>> STREAM ENDED EXPLICITLY for model: {model} after {chunk_count} chunks, {total_content_length} chars, {elapsed:?} elapsed");
297 stream_ended_explicitly = true;
298 break;
299 }
300 Ok(ChatStreamEvent::ToolCallChunk(_)) => {
301 log::debug!("Tool call chunk received (ignored)");
302 }
303 Ok(ChatStreamEvent::ThoughtSignatureChunk(_)) => {
304 log::debug!("Thought signature chunk received (ignored)");
305 }
306 Err(e) => {
307 log::error!(
308 "!!! STREAM ERROR after {chunk_count} chunks at {elapsed:?}: {e} !!!"
309 );
310 let error_msg = format!("Stream error: {e}");
311 let _ = tx.send(error_msg);
312 return Err(e.into());
313 }
314 }
315 }
316
317 let final_elapsed = start_time.elapsed();
318 if !stream_ended_explicitly {
319 log::warn!("!!! STREAM ENDED IMPLICITLY (exhausted) for model: {model} after {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed !!!");
320 }
321
322 log::info!(
323 "=== STREAM COMPLETE === Model: {model}, Final: {chunk_count} chunks, {total_content_length} chars, {final_elapsed:?} elapsed"
324 );
325
326 self.add_tokens(total_content_length / 4).await; if tx.send(EOT_SIGNAL.to_string()).is_err() {
330 log::error!("!!! FAILED TO SEND EOT SIGNAL - channel may be closed !!!");
331 return Err(anyhow::anyhow!("Channel closed during EOT signal send"));
332 }
333
334 log::info!(">>> EOT SIGNAL SENT for model: {model} <<<");
335 Ok(())
336 }
337
338 pub async fn generate_response_with_history(
340 &self,
341 model: &str,
342 messages: Vec<ChatMessage>,
343 ) -> Result<String> {
344 self.increment_request().await;
345
346 let chat_req = ChatRequest::new(messages);
347
348 log::debug!("Sending chat request to model: {model} with conversation history");
349
350 let chat_res = self
351 .client
352 .exec_chat(model, chat_req, None)
353 .await
354 .context(format!("Failed to execute chat request for model: {model}"))?;
355
356 let content = chat_res
357 .first_text()
358 .context("No text content in response")?;
359
360 log::debug!("Received response with {} characters", content.len());
361 Ok(content.to_string())
362 }
363
364 pub async fn generate_response_with_options(
366 &self,
367 model: &str,
368 prompt: &str,
369 options: ChatOptions,
370 ) -> Result<String> {
371 self.increment_request().await;
372
373 let chat_req = ChatRequest::default().append_message(ChatMessage::user(prompt));
374
375 log::debug!("Sending chat request to model: {model} with custom options");
376
377 let chat_res = self
378 .client
379 .exec_chat(model, chat_req, Some(&options))
380 .await
381 .context(format!("Failed to execute chat request for model: {model}"))?;
382
383 let content = chat_res
384 .first_text()
385 .context("No text content in response")?;
386
387 log::debug!("Received response with {} characters", content.len());
388 Ok(content.to_string())
389 }
390
391 pub fn get_supported_providers() -> Vec<&'static str> {
393 vec![
394 "openai",
395 "anthropic",
396 "gemini",
397 "groq",
398 "cohere",
399 "ollama",
400 "xai",
401 "deepseek",
402 ]
403 }
404
405 pub async fn get_available_providers(&self) -> Result<Vec<String>> {
407 Ok(Self::get_supported_providers()
408 .iter()
409 .map(|s| s.to_string())
410 .collect())
411 }
412
413 pub async fn test_model(&self, model: &str) -> Result<bool> {
415 match self.generate_response_simple(model, "Hello").await {
416 Ok(_) => {
417 log::info!("Model {model} is available and working");
418 Ok(true)
419 }
420 Err(e) => {
421 log::warn!("Model {model} test failed: {e}");
422 Ok(false)
423 }
424 }
425 }
426
427 pub async fn validate_model(&self, model: &str, provider_type: Option<&str>) -> Result<String> {
429 if self.test_model(model).await? {
430 return Ok(model.to_string());
431 }
432
433 if let Some(provider) = provider_type {
434 if let Ok(models) = self.get_available_models(provider).await {
435 if !models.is_empty() {
436 log::info!("Model {} not available, using {} instead", model, models[0]);
437 return Ok(models[0].clone());
438 }
439 }
440 }
441
442 log::warn!("Could not validate model {model}, proceeding anyway");
443 Ok(model.to_string())
444 }
445}
446
447fn str_to_adapter_kind(provider: &str) -> Result<AdapterKind> {
449 match provider.to_lowercase().as_str() {
450 "openai" => Ok(AdapterKind::OpenAI),
451 "anthropic" => Ok(AdapterKind::Anthropic),
452 "gemini" | "google" => Ok(AdapterKind::Gemini),
453 "groq" => Ok(AdapterKind::Groq),
454 "cohere" => Ok(AdapterKind::Cohere),
455 "ollama" => Ok(AdapterKind::Ollama),
456 "xai" => Ok(AdapterKind::Xai),
457 "deepseek" => Ok(AdapterKind::DeepSeek),
458 _ => Err(anyhow::anyhow!("Unsupported provider: {}", provider)),
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_str_to_adapter_kind() {
468 assert!(str_to_adapter_kind("openai").is_ok());
469 assert!(str_to_adapter_kind("anthropic").is_ok());
470 assert!(str_to_adapter_kind("gemini").is_ok());
471 assert!(str_to_adapter_kind("google").is_ok());
472 assert!(str_to_adapter_kind("groq").is_ok());
473 assert!(str_to_adapter_kind("cohere").is_ok());
474 assert!(str_to_adapter_kind("ollama").is_ok());
475 assert!(str_to_adapter_kind("xai").is_ok());
476 assert!(str_to_adapter_kind("deepseek").is_ok());
477 assert!(str_to_adapter_kind("invalid").is_err());
478 }
479
480 #[tokio::test]
481 async fn test_provider_creation() {
482 let provider = GenAIProvider::new();
483 assert!(provider.is_ok());
484 }
485
486 #[tokio::test]
487 async fn test_provider_is_clonable() {
488 let provider = GenAIProvider::new().unwrap();
489 let _clone1 = provider.clone();
490 let _clone2 = provider.clone();
491 }
493}