ricecoder_providers/providers/
zen.rs

1//! Zen provider implementation
2//!
3//! Supports OpenCode's curated set of AI models via the Zen API.
4
5use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tracing::{debug, error, warn};
11
12use crate::error::ProviderError;
13use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
14use crate::provider::Provider;
15use crate::token_counter::TokenCounter;
16
17/// Zen provider implementation
18pub struct ZenProvider {
19    api_key: String,
20    client: Arc<Client>,
21    base_url: String,
22    token_counter: Arc<TokenCounter>,
23    models_cache: Arc<tokio::sync::Mutex<ModelCache>>,
24    health_check_cache: Arc<tokio::sync::Mutex<HealthCheckCache>>,
25}
26
27/// Cache for models with TTL
28struct ModelCache {
29    models: Option<Vec<ModelInfo>>,
30    cached_at: Option<SystemTime>,
31    ttl: Duration,
32}
33
34/// Cache for health check with TTL
35struct HealthCheckCache {
36    result: Option<bool>,
37    cached_at: Option<SystemTime>,
38    ttl: Duration,
39}
40
41impl ModelCache {
42    fn new() -> Self {
43        Self {
44            models: None,
45            cached_at: None,
46            ttl: Duration::from_secs(300), // 5 minutes
47        }
48    }
49
50    fn is_valid(&self) -> bool {
51        if let (Some(cached_at), Some(_)) = (self.cached_at, &self.models) {
52            if let Ok(elapsed) = cached_at.elapsed() {
53                return elapsed < self.ttl;
54            }
55        }
56        false
57    }
58
59    fn get(&self) -> Option<Vec<ModelInfo>> {
60        if self.is_valid() {
61            self.models.clone()
62        } else {
63            None
64        }
65    }
66
67    fn set(&mut self, models: Vec<ModelInfo>) {
68        self.models = Some(models);
69        self.cached_at = Some(SystemTime::now());
70    }
71
72    #[allow(dead_code)]
73    fn invalidate(&mut self) {
74        self.models = None;
75        self.cached_at = None;
76    }
77}
78
79impl HealthCheckCache {
80    fn new() -> Self {
81        Self {
82            result: None,
83            cached_at: None,
84            ttl: Duration::from_secs(60), // 1 minute
85        }
86    }
87
88    fn is_valid(&self) -> bool {
89        if let (Some(cached_at), Some(_)) = (self.cached_at, self.result) {
90            if let Ok(elapsed) = cached_at.elapsed() {
91                return elapsed < self.ttl;
92            }
93        }
94        false
95    }
96
97    fn get(&self) -> Option<bool> {
98        if self.is_valid() {
99            self.result
100        } else {
101            None
102        }
103    }
104
105    fn set(&mut self, result: bool) {
106        self.result = Some(result);
107        self.cached_at = Some(SystemTime::now());
108    }
109
110    #[allow(dead_code)]
111    fn invalidate(&mut self) {
112        self.result = None;
113        self.cached_at = None;
114    }
115}
116
117impl ZenProvider {
118    /// Create a new Zen provider instance
119    pub fn new(api_key: String) -> Result<Self, ProviderError> {
120        if api_key.is_empty() {
121            return Err(ProviderError::ConfigError(
122                "Zen API key is required".to_string(),
123            ));
124        }
125
126        Ok(Self {
127            api_key,
128            client: Arc::new(Client::new()),
129            base_url: "https://api.opencode.ai/v1".to_string(),
130            token_counter: Arc::new(TokenCounter::new()),
131            models_cache: Arc::new(tokio::sync::Mutex::new(ModelCache::new())),
132            health_check_cache: Arc::new(tokio::sync::Mutex::new(HealthCheckCache::new())),
133        })
134    }
135
136    /// Create a new Zen provider with a custom base URL
137    pub fn with_base_url(api_key: String, base_url: String) -> Result<Self, ProviderError> {
138        if api_key.is_empty() {
139            return Err(ProviderError::ConfigError(
140                "Zen API key is required".to_string(),
141            ));
142        }
143
144        Ok(Self {
145            api_key,
146            client: Arc::new(Client::new()),
147            base_url,
148            token_counter: Arc::new(TokenCounter::new()),
149            models_cache: Arc::new(tokio::sync::Mutex::new(ModelCache::new())),
150            health_check_cache: Arc::new(tokio::sync::Mutex::new(HealthCheckCache::new())),
151        })
152    }
153
154    /// Get the authorization header value (redacted for logging)
155    fn get_auth_header(&self) -> String {
156        format!("Bearer {}", self.api_key)
157    }
158
159    /// Convert Zen API response to our ChatResponse
160    fn convert_response(
161        response: ZenChatResponse,
162        model: String,
163    ) -> Result<ChatResponse, ProviderError> {
164        let content = response
165            .choices
166            .first()
167            .and_then(|c| c.message.as_ref())
168            .map(|m| m.content.clone())
169            .ok_or_else(|| ProviderError::ProviderError("No content in response".to_string()))?;
170
171        let finish_reason = match response
172            .choices
173            .first()
174            .and_then(|c| c.finish_reason.as_deref())
175        {
176            Some("stop") => FinishReason::Stop,
177            Some("length") => FinishReason::Length,
178            Some("error") => FinishReason::Error,
179            _ => FinishReason::Stop,
180        };
181
182        Ok(ChatResponse {
183            content,
184            model,
185            usage: TokenUsage {
186                prompt_tokens: response.usage.prompt_tokens,
187                completion_tokens: response.usage.completion_tokens,
188                total_tokens: response.usage.total_tokens,
189            },
190            finish_reason,
191        })
192    }
193
194    /// Fetch models from Zen API with retry logic
195    async fn fetch_models_from_api(&self) -> Result<Vec<ModelInfo>, ProviderError> {
196        let mut retries = 0;
197        let max_retries = 3;
198
199        loop {
200            debug!("Fetching models from Zen API (attempt {})", retries + 1);
201
202            let response = self
203                .client
204                .get(format!("{}/models", self.base_url))
205                .header("Authorization", self.get_auth_header())
206                .timeout(Duration::from_secs(30))
207                .send()
208                .await;
209
210            match response {
211                Ok(resp) => {
212                    let status = resp.status();
213                    if !status.is_success() {
214                        let error_text = resp.text().await.unwrap_or_default();
215                        error!("Zen API error ({}): {}", status, error_text);
216
217                        return match status.as_u16() {
218                            401 => Err(ProviderError::AuthError),
219                            429 => {
220                                if retries < max_retries {
221                                    let backoff = Duration::from_secs(2_u64.pow(retries as u32));
222                                    warn!("Rate limited, retrying after {:?}", backoff);
223                                    tokio::time::sleep(backoff).await;
224                                    retries += 1;
225                                    continue;
226                                }
227                                Err(ProviderError::RateLimited(60))
228                            }
229                            _ => Err(ProviderError::ProviderError(format!(
230                                "Zen API error: {}",
231                                status
232                            ))),
233                        };
234                    }
235
236                    let zen_response: ZenListModelsResponse = resp.json().await?;
237                    return Ok(zen_response
238                        .models
239                        .into_iter()
240                        .map(|m| ModelInfo {
241                            id: m.id,
242                            name: m.name,
243                            provider: "zen".to_string(),
244                            context_window: m.context_window,
245                            capabilities: m.capabilities,
246                            pricing: m.pricing.map(|p| crate::models::Pricing {
247                                input_per_1k_tokens: p.input_cost_per_1k,
248                                output_per_1k_tokens: p.output_cost_per_1k,
249                            }),
250                        })
251                        .collect());
252                }
253                Err(e) => {
254                    error!("Zen API request failed: {}", e);
255
256                    if retries < max_retries {
257                        let backoff = Duration::from_secs(2_u64.pow(retries as u32));
258                        warn!("Request failed, retrying after {:?}", backoff);
259                        tokio::time::sleep(backoff).await;
260                        retries += 1;
261                        continue;
262                    }
263
264                    return Err(ProviderError::from(e));
265                }
266            }
267        }
268    }
269
270    /// Count tokens using Zen API (with fallback to local approximation)
271    #[allow(dead_code)]
272    async fn count_tokens_from_api(
273        &self,
274        content: &str,
275        model: &str,
276    ) -> Result<usize, ProviderError> {
277        debug!(
278            "Counting tokens for model: {} (content length: {})",
279            model,
280            content.len()
281        );
282
283        let request = ZenTokenCountRequest {
284            model: model.to_string(),
285            content: content.to_string(),
286        };
287
288        let mut retries = 0;
289        let max_retries = 3;
290
291        loop {
292            let response = self
293                .client
294                .post(format!("{}/token/count", self.base_url))
295                .header("Authorization", self.get_auth_header())
296                .header("Content-Type", "application/json")
297                .json(&request)
298                .timeout(Duration::from_secs(30))
299                .send()
300                .await;
301
302            match response {
303                Ok(resp) => {
304                    let status = resp.status();
305                    if !status.is_success() {
306                        let error_text = resp.text().await.unwrap_or_default();
307                        error!("Zen token count API error ({}): {}", status, error_text);
308
309                        return match status.as_u16() {
310                            401 => Err(ProviderError::AuthError),
311                            429 => {
312                                if retries < max_retries {
313                                    let backoff = Duration::from_secs(2_u64.pow(retries as u32));
314                                    warn!("Rate limited, retrying after {:?}", backoff);
315                                    tokio::time::sleep(backoff).await;
316                                    retries += 1;
317                                    continue;
318                                }
319                                Err(ProviderError::RateLimited(60))
320                            }
321                            _ => {
322                                warn!("Token count API failed, using fallback approximation");
323                                return Ok(self.estimate_tokens(content));
324                            }
325                        };
326                    }
327
328                    let zen_response: ZenTokenCountResponse = resp.json().await?;
329                    debug!("Token count from API: {}", zen_response.token_count);
330                    return Ok(zen_response.token_count);
331                }
332                Err(e) => {
333                    error!("Zen token count API request failed: {}", e);
334
335                    if retries < max_retries {
336                        let backoff = Duration::from_secs(2_u64.pow(retries as u32));
337                        warn!("Request failed, retrying after {:?}", backoff);
338                        tokio::time::sleep(backoff).await;
339                        retries += 1;
340                        continue;
341                    }
342
343                    warn!("Token count API unavailable, using fallback approximation");
344                    return Ok(self.estimate_tokens(content));
345                }
346            }
347        }
348    }
349
350    /// Estimate tokens using simple approximation (fallback when API is unavailable)
351    #[allow(dead_code)]
352    fn estimate_tokens(&self, content: &str) -> usize {
353        // Approximation: 4 characters ≈ 1 token
354        content.len().div_ceil(4)
355    }
356}
357
358#[async_trait]
359impl Provider for ZenProvider {
360    fn id(&self) -> &str {
361        "zen"
362    }
363
364    fn name(&self) -> &str {
365        "OpenCode Zen"
366    }
367
368    fn models(&self) -> Vec<ModelInfo> {
369        // This is a blocking call, so we return a default set
370        // The async version is used in chat() and other async methods
371        vec![
372            ModelInfo {
373                id: "zen-gpt4".to_string(),
374                name: "Zen GPT-4".to_string(),
375                provider: "zen".to_string(),
376                context_window: 8192,
377                capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
378                pricing: Some(crate::models::Pricing {
379                    input_per_1k_tokens: 0.03,
380                    output_per_1k_tokens: 0.06,
381                }),
382            },
383            ModelInfo {
384                id: "zen-gpt4-turbo".to_string(),
385                name: "Zen GPT-4 Turbo".to_string(),
386                provider: "zen".to_string(),
387                context_window: 128000,
388                capabilities: vec![
389                    Capability::Chat,
390                    Capability::Code,
391                    Capability::Vision,
392                    Capability::Streaming,
393                ],
394                pricing: Some(crate::models::Pricing {
395                    input_per_1k_tokens: 0.01,
396                    output_per_1k_tokens: 0.03,
397                }),
398            },
399        ]
400    }
401
402    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
403        // Get models from cache or fetch from API
404        let models = {
405            let mut cache = self.models_cache.lock().await;
406            if let Some(models) = cache.get() {
407                debug!("Using cached models");
408                models
409            } else {
410                debug!("Cache miss, fetching models from API");
411                let models = self.fetch_models_from_api().await?;
412                cache.set(models.clone());
413                models
414            }
415        };
416
417        // Validate model
418        let model_id = &request.model;
419        if !models.iter().any(|m| m.id == *model_id) {
420            return Err(ProviderError::InvalidModel(model_id.clone()));
421        }
422
423        let zen_request = ZenChatRequest {
424            model: request.model.clone(),
425            messages: request
426                .messages
427                .iter()
428                .map(|m| ZenMessage {
429                    role: m.role.clone(),
430                    content: m.content.clone(),
431                })
432                .collect(),
433            temperature: request.temperature,
434            max_tokens: request.max_tokens,
435            stream: false,
436        };
437
438        debug!("Sending chat request to Zen for model: {}", request.model);
439
440        let mut retries = 0;
441        let max_retries = 3;
442
443        loop {
444            let response = self
445                .client
446                .post(format!("{}/chat/completions", self.base_url))
447                .header("Authorization", self.get_auth_header())
448                .header("Content-Type", "application/json")
449                .json(&zen_request)
450                .timeout(Duration::from_secs(30))
451                .send()
452                .await;
453
454            match response {
455                Ok(resp) => {
456                    let status = resp.status();
457                    if !status.is_success() {
458                        let error_text = resp.text().await.unwrap_or_default();
459                        error!("Zen API error ({}): {}", status, error_text);
460
461                        return match status.as_u16() {
462                            401 => Err(ProviderError::AuthError),
463                            429 => {
464                                if retries < max_retries {
465                                    let backoff = Duration::from_secs(2_u64.pow(retries as u32));
466                                    warn!("Rate limited, retrying after {:?}", backoff);
467                                    tokio::time::sleep(backoff).await;
468                                    retries += 1;
469                                    continue;
470                                }
471                                Err(ProviderError::RateLimited(60))
472                            }
473                            _ => Err(ProviderError::ProviderError(format!(
474                                "Zen API error: {}",
475                                status
476                            ))),
477                        };
478                    }
479
480                    let zen_response: ZenChatResponse = resp.json().await?;
481                    return Self::convert_response(zen_response, request.model);
482                }
483                Err(e) => {
484                    error!("Zen API request failed: {}", e);
485
486                    if retries < max_retries {
487                        let backoff = Duration::from_secs(2_u64.pow(retries as u32));
488                        warn!("Request failed, retrying after {:?}", backoff);
489                        tokio::time::sleep(backoff).await;
490                        retries += 1;
491                        continue;
492                    }
493
494                    return Err(ProviderError::from(e));
495                }
496            }
497        }
498    }
499
500    async fn chat_stream(
501        &self,
502        _request: ChatRequest,
503    ) -> Result<crate::provider::ChatStream, ProviderError> {
504        // Streaming support will be implemented in a future iteration
505        Err(ProviderError::ProviderError(
506            "Streaming not yet implemented for Zen".to_string(),
507        ))
508    }
509
510    fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
511        // Validate model
512        if !self.models().iter().any(|m| m.id == model) {
513            return Err(ProviderError::InvalidModel(model.to_string()));
514        }
515
516        // Use token counter with caching for performance
517        // In production, this would call the Zen API token counting endpoint
518        // For now, we use a local approximation as fallback
519        let tokens = self.token_counter.count_tokens_openai(content, model);
520        Ok(tokens)
521    }
522
523    async fn health_check(&self) -> Result<bool, ProviderError> {
524        debug!("Performing health check for Zen provider");
525
526        // Check cache first
527        {
528            let cache = self.health_check_cache.lock().await;
529            if let Some(result) = cache.get() {
530                debug!("Using cached health check result: {}", result);
531                return Ok(result);
532            }
533        }
534
535        // Try to list models as a health check
536        let response = self
537            .client
538            .get(format!("{}/models", self.base_url))
539            .header("Authorization", self.get_auth_header())
540            .timeout(Duration::from_secs(5))
541            .send()
542            .await;
543
544        let result = match response {
545            Ok(resp) => match resp.status().as_u16() {
546                200 => {
547                    debug!("Zen health check passed");
548                    true
549                }
550                401 => {
551                    error!("Zen health check failed: authentication error");
552                    return Err(ProviderError::AuthError);
553                }
554                _ => {
555                    warn!("Zen health check failed with status: {}", resp.status());
556                    false
557                }
558            },
559            Err(e) => {
560                warn!("Zen health check failed: {}", e);
561                false
562            }
563        };
564
565        // Cache the result
566        {
567            let mut cache = self.health_check_cache.lock().await;
568            cache.set(result);
569        }
570
571        Ok(result)
572    }
573}
574
575/// Zen API request format
576#[derive(Debug, Serialize)]
577struct ZenChatRequest {
578    model: String,
579    messages: Vec<ZenMessage>,
580    #[serde(skip_serializing_if = "Option::is_none")]
581    temperature: Option<f32>,
582    #[serde(skip_serializing_if = "Option::is_none")]
583    max_tokens: Option<usize>,
584    stream: bool,
585}
586
587/// Zen API message format
588#[derive(Debug, Serialize, Deserialize)]
589struct ZenMessage {
590    role: String,
591    content: String,
592}
593
594/// Zen API response format
595#[derive(Debug, Deserialize)]
596struct ZenChatResponse {
597    choices: Vec<ZenChoice>,
598    usage: ZenUsage,
599}
600
601/// Zen API choice format
602#[derive(Debug, Deserialize)]
603struct ZenChoice {
604    message: Option<ZenMessage>,
605    finish_reason: Option<String>,
606}
607
608/// Zen API usage format
609#[derive(Debug, Deserialize)]
610struct ZenUsage {
611    prompt_tokens: usize,
612    completion_tokens: usize,
613    total_tokens: usize,
614}
615
616/// Zen API models list response
617#[derive(Debug, Deserialize)]
618struct ZenListModelsResponse {
619    models: Vec<ZenModel>,
620}
621
622/// Zen API model info
623#[derive(Debug, Deserialize, Clone)]
624struct ZenModel {
625    id: String,
626    name: String,
627    context_window: usize,
628    capabilities: Vec<Capability>,
629    #[serde(skip_serializing_if = "Option::is_none")]
630    pricing: Option<ZenPricing>,
631}
632
633/// Zen API pricing info
634#[derive(Debug, Deserialize, Clone)]
635struct ZenPricing {
636    input_cost_per_1k: f64,
637    output_cost_per_1k: f64,
638}
639
640/// Zen API token count request
641#[derive(Debug, Serialize)]
642struct ZenTokenCountRequest {
643    model: String,
644    content: String,
645}
646
647/// Zen API token count response
648#[derive(Debug, Deserialize)]
649struct ZenTokenCountResponse {
650    token_count: usize,
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656
657    #[test]
658    fn test_zen_provider_creation() {
659        let provider = ZenProvider::new("test-key".to_string());
660        assert!(provider.is_ok());
661    }
662
663    #[test]
664    fn test_zen_provider_creation_empty_key() {
665        let provider = ZenProvider::new("".to_string());
666        assert!(provider.is_err());
667    }
668
669    #[test]
670    fn test_zen_provider_id() {
671        let provider = ZenProvider::new("test-key".to_string()).unwrap();
672        assert_eq!(provider.id(), "zen");
673    }
674
675    #[test]
676    fn test_zen_provider_name() {
677        let provider = ZenProvider::new("test-key".to_string()).unwrap();
678        assert_eq!(provider.name(), "OpenCode Zen");
679    }
680
681    #[test]
682    fn test_zen_models() {
683        let provider = ZenProvider::new("test-key".to_string()).unwrap();
684        let models = provider.models();
685        assert_eq!(models.len(), 2);
686        assert!(models.iter().any(|m| m.id == "zen-gpt4"));
687        assert!(models.iter().any(|m| m.id == "zen-gpt4-turbo"));
688    }
689
690    #[test]
691    fn test_token_counting() {
692        let provider = ZenProvider::new("test-key".to_string()).unwrap();
693        let tokens = provider.count_tokens("Hello, world!", "zen-gpt4").unwrap();
694        assert!(tokens > 0);
695    }
696
697    #[test]
698    fn test_token_counting_invalid_model() {
699        let provider = ZenProvider::new("test-key".to_string()).unwrap();
700        let result = provider.count_tokens("Hello, world!", "invalid-model");
701        assert!(result.is_err());
702    }
703
704    #[test]
705    fn test_model_cache_creation() {
706        let cache = ModelCache::new();
707        assert!(cache.get().is_none());
708    }
709
710    #[test]
711    fn test_model_cache_set_and_get() {
712        let mut cache = ModelCache::new();
713        let models = vec![ModelInfo {
714            id: "test".to_string(),
715            name: "Test".to_string(),
716            provider: "zen".to_string(),
717            context_window: 1000,
718            capabilities: vec![],
719            pricing: None,
720        }];
721        cache.set(models.clone());
722        let cached = cache.get();
723        assert!(cached.is_some());
724        let cached_models = cached.unwrap();
725        assert_eq!(cached_models.len(), 1);
726        assert_eq!(cached_models[0].id, "test");
727    }
728
729    #[test]
730    fn test_health_check_cache_creation() {
731        let cache = HealthCheckCache::new();
732        assert!(cache.get().is_none());
733    }
734
735    #[test]
736    fn test_health_check_cache_set_and_get() {
737        let mut cache = HealthCheckCache::new();
738        cache.set(true);
739        assert_eq!(cache.get(), Some(true));
740    }
741
742    #[test]
743    fn test_estimate_tokens() {
744        let provider = ZenProvider::new("test-key".to_string()).unwrap();
745        let tokens = provider.estimate_tokens("Hello, world!");
746        // "Hello, world!" is 13 characters, so (13 + 3) / 4 = 4 tokens
747        assert_eq!(tokens, 4);
748    }
749
750    #[test]
751    fn test_estimate_tokens_empty() {
752        let provider = ZenProvider::new("test-key".to_string()).unwrap();
753        let tokens = provider.estimate_tokens("");
754        assert_eq!(tokens, 0);
755    }
756
757    #[test]
758    fn test_estimate_tokens_single_char() {
759        let provider = ZenProvider::new("test-key".to_string()).unwrap();
760        let tokens = provider.estimate_tokens("a");
761        // (1 + 3) / 4 = 1 token
762        assert_eq!(tokens, 1);
763    }
764}