1use async_trait::async_trait;
6use reqwest::Client;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tracing::{debug, error, warn};
11
12use crate::error::ProviderError;
13use crate::models::{Capability, ChatRequest, ChatResponse, FinishReason, ModelInfo, TokenUsage};
14use crate::provider::Provider;
15use crate::token_counter::TokenCounter;
16
17pub struct ZenProvider {
19 api_key: String,
20 client: Arc<Client>,
21 base_url: String,
22 token_counter: Arc<TokenCounter>,
23 models_cache: Arc<tokio::sync::Mutex<ModelCache>>,
24 health_check_cache: Arc<tokio::sync::Mutex<HealthCheckCache>>,
25}
26
27struct ModelCache {
29 models: Option<Vec<ModelInfo>>,
30 cached_at: Option<SystemTime>,
31 ttl: Duration,
32}
33
34struct HealthCheckCache {
36 result: Option<bool>,
37 cached_at: Option<SystemTime>,
38 ttl: Duration,
39}
40
41impl ModelCache {
42 fn new() -> Self {
43 Self {
44 models: None,
45 cached_at: None,
46 ttl: Duration::from_secs(300), }
48 }
49
50 fn is_valid(&self) -> bool {
51 if let (Some(cached_at), Some(_)) = (self.cached_at, &self.models) {
52 if let Ok(elapsed) = cached_at.elapsed() {
53 return elapsed < self.ttl;
54 }
55 }
56 false
57 }
58
59 fn get(&self) -> Option<Vec<ModelInfo>> {
60 if self.is_valid() {
61 self.models.clone()
62 } else {
63 None
64 }
65 }
66
67 fn set(&mut self, models: Vec<ModelInfo>) {
68 self.models = Some(models);
69 self.cached_at = Some(SystemTime::now());
70 }
71
72 #[allow(dead_code)]
73 fn invalidate(&mut self) {
74 self.models = None;
75 self.cached_at = None;
76 }
77}
78
79impl HealthCheckCache {
80 fn new() -> Self {
81 Self {
82 result: None,
83 cached_at: None,
84 ttl: Duration::from_secs(60), }
86 }
87
88 fn is_valid(&self) -> bool {
89 if let (Some(cached_at), Some(_)) = (self.cached_at, self.result) {
90 if let Ok(elapsed) = cached_at.elapsed() {
91 return elapsed < self.ttl;
92 }
93 }
94 false
95 }
96
97 fn get(&self) -> Option<bool> {
98 if self.is_valid() {
99 self.result
100 } else {
101 None
102 }
103 }
104
105 fn set(&mut self, result: bool) {
106 self.result = Some(result);
107 self.cached_at = Some(SystemTime::now());
108 }
109
110 #[allow(dead_code)]
111 fn invalidate(&mut self) {
112 self.result = None;
113 self.cached_at = None;
114 }
115}
116
117impl ZenProvider {
118 pub fn new(api_key: String) -> Result<Self, ProviderError> {
120 if api_key.is_empty() {
121 return Err(ProviderError::ConfigError(
122 "Zen API key is required".to_string(),
123 ));
124 }
125
126 Ok(Self {
127 api_key,
128 client: Arc::new(Client::new()),
129 base_url: "https://api.opencode.ai/v1".to_string(),
130 token_counter: Arc::new(TokenCounter::new()),
131 models_cache: Arc::new(tokio::sync::Mutex::new(ModelCache::new())),
132 health_check_cache: Arc::new(tokio::sync::Mutex::new(HealthCheckCache::new())),
133 })
134 }
135
136 pub fn with_base_url(api_key: String, base_url: String) -> Result<Self, ProviderError> {
138 if api_key.is_empty() {
139 return Err(ProviderError::ConfigError(
140 "Zen API key is required".to_string(),
141 ));
142 }
143
144 Ok(Self {
145 api_key,
146 client: Arc::new(Client::new()),
147 base_url,
148 token_counter: Arc::new(TokenCounter::new()),
149 models_cache: Arc::new(tokio::sync::Mutex::new(ModelCache::new())),
150 health_check_cache: Arc::new(tokio::sync::Mutex::new(HealthCheckCache::new())),
151 })
152 }
153
154 fn get_auth_header(&self) -> String {
156 format!("Bearer {}", self.api_key)
157 }
158
159 fn convert_response(
161 response: ZenChatResponse,
162 model: String,
163 ) -> Result<ChatResponse, ProviderError> {
164 let content = response
165 .choices
166 .first()
167 .and_then(|c| c.message.as_ref())
168 .map(|m| m.content.clone())
169 .ok_or_else(|| ProviderError::ProviderError("No content in response".to_string()))?;
170
171 let finish_reason = match response
172 .choices
173 .first()
174 .and_then(|c| c.finish_reason.as_deref())
175 {
176 Some("stop") => FinishReason::Stop,
177 Some("length") => FinishReason::Length,
178 Some("error") => FinishReason::Error,
179 _ => FinishReason::Stop,
180 };
181
182 Ok(ChatResponse {
183 content,
184 model,
185 usage: TokenUsage {
186 prompt_tokens: response.usage.prompt_tokens,
187 completion_tokens: response.usage.completion_tokens,
188 total_tokens: response.usage.total_tokens,
189 },
190 finish_reason,
191 })
192 }
193
194 async fn fetch_models_from_api(&self) -> Result<Vec<ModelInfo>, ProviderError> {
196 let mut retries = 0;
197 let max_retries = 3;
198
199 loop {
200 debug!("Fetching models from Zen API (attempt {})", retries + 1);
201
202 let response = self
203 .client
204 .get(format!("{}/models", self.base_url))
205 .header("Authorization", self.get_auth_header())
206 .timeout(Duration::from_secs(30))
207 .send()
208 .await;
209
210 match response {
211 Ok(resp) => {
212 let status = resp.status();
213 if !status.is_success() {
214 let error_text = resp.text().await.unwrap_or_default();
215 error!("Zen API error ({}): {}", status, error_text);
216
217 return match status.as_u16() {
218 401 => Err(ProviderError::AuthError),
219 429 => {
220 if retries < max_retries {
221 let backoff = Duration::from_secs(2_u64.pow(retries as u32));
222 warn!("Rate limited, retrying after {:?}", backoff);
223 tokio::time::sleep(backoff).await;
224 retries += 1;
225 continue;
226 }
227 Err(ProviderError::RateLimited(60))
228 }
229 _ => Err(ProviderError::ProviderError(format!(
230 "Zen API error: {}",
231 status
232 ))),
233 };
234 }
235
236 let zen_response: ZenListModelsResponse = resp.json().await?;
237 return Ok(zen_response
238 .models
239 .into_iter()
240 .map(|m| ModelInfo {
241 id: m.id,
242 name: m.name,
243 provider: "zen".to_string(),
244 context_window: m.context_window,
245 capabilities: m.capabilities,
246 pricing: m.pricing.map(|p| crate::models::Pricing {
247 input_per_1k_tokens: p.input_cost_per_1k,
248 output_per_1k_tokens: p.output_cost_per_1k,
249 }),
250 })
251 .collect());
252 }
253 Err(e) => {
254 error!("Zen API request failed: {}", e);
255
256 if retries < max_retries {
257 let backoff = Duration::from_secs(2_u64.pow(retries as u32));
258 warn!("Request failed, retrying after {:?}", backoff);
259 tokio::time::sleep(backoff).await;
260 retries += 1;
261 continue;
262 }
263
264 return Err(ProviderError::from(e));
265 }
266 }
267 }
268 }
269
270 #[allow(dead_code)]
272 async fn count_tokens_from_api(
273 &self,
274 content: &str,
275 model: &str,
276 ) -> Result<usize, ProviderError> {
277 debug!(
278 "Counting tokens for model: {} (content length: {})",
279 model,
280 content.len()
281 );
282
283 let request = ZenTokenCountRequest {
284 model: model.to_string(),
285 content: content.to_string(),
286 };
287
288 let mut retries = 0;
289 let max_retries = 3;
290
291 loop {
292 let response = self
293 .client
294 .post(format!("{}/token/count", self.base_url))
295 .header("Authorization", self.get_auth_header())
296 .header("Content-Type", "application/json")
297 .json(&request)
298 .timeout(Duration::from_secs(30))
299 .send()
300 .await;
301
302 match response {
303 Ok(resp) => {
304 let status = resp.status();
305 if !status.is_success() {
306 let error_text = resp.text().await.unwrap_or_default();
307 error!("Zen token count API error ({}): {}", status, error_text);
308
309 return match status.as_u16() {
310 401 => Err(ProviderError::AuthError),
311 429 => {
312 if retries < max_retries {
313 let backoff = Duration::from_secs(2_u64.pow(retries as u32));
314 warn!("Rate limited, retrying after {:?}", backoff);
315 tokio::time::sleep(backoff).await;
316 retries += 1;
317 continue;
318 }
319 Err(ProviderError::RateLimited(60))
320 }
321 _ => {
322 warn!("Token count API failed, using fallback approximation");
323 return Ok(self.estimate_tokens(content));
324 }
325 };
326 }
327
328 let zen_response: ZenTokenCountResponse = resp.json().await?;
329 debug!("Token count from API: {}", zen_response.token_count);
330 return Ok(zen_response.token_count);
331 }
332 Err(e) => {
333 error!("Zen token count API request failed: {}", e);
334
335 if retries < max_retries {
336 let backoff = Duration::from_secs(2_u64.pow(retries as u32));
337 warn!("Request failed, retrying after {:?}", backoff);
338 tokio::time::sleep(backoff).await;
339 retries += 1;
340 continue;
341 }
342
343 warn!("Token count API unavailable, using fallback approximation");
344 return Ok(self.estimate_tokens(content));
345 }
346 }
347 }
348 }
349
350 #[allow(dead_code)]
352 fn estimate_tokens(&self, content: &str) -> usize {
353 content.len().div_ceil(4)
355 }
356}
357
358#[async_trait]
359impl Provider for ZenProvider {
360 fn id(&self) -> &str {
361 "zen"
362 }
363
364 fn name(&self) -> &str {
365 "OpenCode Zen"
366 }
367
368 fn models(&self) -> Vec<ModelInfo> {
369 vec![
372 ModelInfo {
373 id: "zen-gpt4".to_string(),
374 name: "Zen GPT-4".to_string(),
375 provider: "zen".to_string(),
376 context_window: 8192,
377 capabilities: vec![Capability::Chat, Capability::Code, Capability::Streaming],
378 pricing: Some(crate::models::Pricing {
379 input_per_1k_tokens: 0.03,
380 output_per_1k_tokens: 0.06,
381 }),
382 },
383 ModelInfo {
384 id: "zen-gpt4-turbo".to_string(),
385 name: "Zen GPT-4 Turbo".to_string(),
386 provider: "zen".to_string(),
387 context_window: 128000,
388 capabilities: vec![
389 Capability::Chat,
390 Capability::Code,
391 Capability::Vision,
392 Capability::Streaming,
393 ],
394 pricing: Some(crate::models::Pricing {
395 input_per_1k_tokens: 0.01,
396 output_per_1k_tokens: 0.03,
397 }),
398 },
399 ]
400 }
401
402 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ProviderError> {
403 let models = {
405 let mut cache = self.models_cache.lock().await;
406 if let Some(models) = cache.get() {
407 debug!("Using cached models");
408 models
409 } else {
410 debug!("Cache miss, fetching models from API");
411 let models = self.fetch_models_from_api().await?;
412 cache.set(models.clone());
413 models
414 }
415 };
416
417 let model_id = &request.model;
419 if !models.iter().any(|m| m.id == *model_id) {
420 return Err(ProviderError::InvalidModel(model_id.clone()));
421 }
422
423 let zen_request = ZenChatRequest {
424 model: request.model.clone(),
425 messages: request
426 .messages
427 .iter()
428 .map(|m| ZenMessage {
429 role: m.role.clone(),
430 content: m.content.clone(),
431 })
432 .collect(),
433 temperature: request.temperature,
434 max_tokens: request.max_tokens,
435 stream: false,
436 };
437
438 debug!("Sending chat request to Zen for model: {}", request.model);
439
440 let mut retries = 0;
441 let max_retries = 3;
442
443 loop {
444 let response = self
445 .client
446 .post(format!("{}/chat/completions", self.base_url))
447 .header("Authorization", self.get_auth_header())
448 .header("Content-Type", "application/json")
449 .json(&zen_request)
450 .timeout(Duration::from_secs(30))
451 .send()
452 .await;
453
454 match response {
455 Ok(resp) => {
456 let status = resp.status();
457 if !status.is_success() {
458 let error_text = resp.text().await.unwrap_or_default();
459 error!("Zen API error ({}): {}", status, error_text);
460
461 return match status.as_u16() {
462 401 => Err(ProviderError::AuthError),
463 429 => {
464 if retries < max_retries {
465 let backoff = Duration::from_secs(2_u64.pow(retries as u32));
466 warn!("Rate limited, retrying after {:?}", backoff);
467 tokio::time::sleep(backoff).await;
468 retries += 1;
469 continue;
470 }
471 Err(ProviderError::RateLimited(60))
472 }
473 _ => Err(ProviderError::ProviderError(format!(
474 "Zen API error: {}",
475 status
476 ))),
477 };
478 }
479
480 let zen_response: ZenChatResponse = resp.json().await?;
481 return Self::convert_response(zen_response, request.model);
482 }
483 Err(e) => {
484 error!("Zen API request failed: {}", e);
485
486 if retries < max_retries {
487 let backoff = Duration::from_secs(2_u64.pow(retries as u32));
488 warn!("Request failed, retrying after {:?}", backoff);
489 tokio::time::sleep(backoff).await;
490 retries += 1;
491 continue;
492 }
493
494 return Err(ProviderError::from(e));
495 }
496 }
497 }
498 }
499
500 async fn chat_stream(
501 &self,
502 _request: ChatRequest,
503 ) -> Result<crate::provider::ChatStream, ProviderError> {
504 Err(ProviderError::ProviderError(
506 "Streaming not yet implemented for Zen".to_string(),
507 ))
508 }
509
510 fn count_tokens(&self, content: &str, model: &str) -> Result<usize, ProviderError> {
511 if !self.models().iter().any(|m| m.id == model) {
513 return Err(ProviderError::InvalidModel(model.to_string()));
514 }
515
516 let tokens = self.token_counter.count_tokens_openai(content, model);
520 Ok(tokens)
521 }
522
523 async fn health_check(&self) -> Result<bool, ProviderError> {
524 debug!("Performing health check for Zen provider");
525
526 {
528 let cache = self.health_check_cache.lock().await;
529 if let Some(result) = cache.get() {
530 debug!("Using cached health check result: {}", result);
531 return Ok(result);
532 }
533 }
534
535 let response = self
537 .client
538 .get(format!("{}/models", self.base_url))
539 .header("Authorization", self.get_auth_header())
540 .timeout(Duration::from_secs(5))
541 .send()
542 .await;
543
544 let result = match response {
545 Ok(resp) => match resp.status().as_u16() {
546 200 => {
547 debug!("Zen health check passed");
548 true
549 }
550 401 => {
551 error!("Zen health check failed: authentication error");
552 return Err(ProviderError::AuthError);
553 }
554 _ => {
555 warn!("Zen health check failed with status: {}", resp.status());
556 false
557 }
558 },
559 Err(e) => {
560 warn!("Zen health check failed: {}", e);
561 false
562 }
563 };
564
565 {
567 let mut cache = self.health_check_cache.lock().await;
568 cache.set(result);
569 }
570
571 Ok(result)
572 }
573}
574
575#[derive(Debug, Serialize)]
577struct ZenChatRequest {
578 model: String,
579 messages: Vec<ZenMessage>,
580 #[serde(skip_serializing_if = "Option::is_none")]
581 temperature: Option<f32>,
582 #[serde(skip_serializing_if = "Option::is_none")]
583 max_tokens: Option<usize>,
584 stream: bool,
585}
586
587#[derive(Debug, Serialize, Deserialize)]
589struct ZenMessage {
590 role: String,
591 content: String,
592}
593
594#[derive(Debug, Deserialize)]
596struct ZenChatResponse {
597 choices: Vec<ZenChoice>,
598 usage: ZenUsage,
599}
600
601#[derive(Debug, Deserialize)]
603struct ZenChoice {
604 message: Option<ZenMessage>,
605 finish_reason: Option<String>,
606}
607
608#[derive(Debug, Deserialize)]
610struct ZenUsage {
611 prompt_tokens: usize,
612 completion_tokens: usize,
613 total_tokens: usize,
614}
615
616#[derive(Debug, Deserialize)]
618struct ZenListModelsResponse {
619 models: Vec<ZenModel>,
620}
621
622#[derive(Debug, Deserialize, Clone)]
624struct ZenModel {
625 id: String,
626 name: String,
627 context_window: usize,
628 capabilities: Vec<Capability>,
629 #[serde(skip_serializing_if = "Option::is_none")]
630 pricing: Option<ZenPricing>,
631}
632
633#[derive(Debug, Deserialize, Clone)]
635struct ZenPricing {
636 input_cost_per_1k: f64,
637 output_cost_per_1k: f64,
638}
639
640#[derive(Debug, Serialize)]
642struct ZenTokenCountRequest {
643 model: String,
644 content: String,
645}
646
647#[derive(Debug, Deserialize)]
649struct ZenTokenCountResponse {
650 token_count: usize,
651}
652
653#[cfg(test)]
654mod tests {
655 use super::*;
656
657 #[test]
658 fn test_zen_provider_creation() {
659 let provider = ZenProvider::new("test-key".to_string());
660 assert!(provider.is_ok());
661 }
662
663 #[test]
664 fn test_zen_provider_creation_empty_key() {
665 let provider = ZenProvider::new("".to_string());
666 assert!(provider.is_err());
667 }
668
669 #[test]
670 fn test_zen_provider_id() {
671 let provider = ZenProvider::new("test-key".to_string()).unwrap();
672 assert_eq!(provider.id(), "zen");
673 }
674
675 #[test]
676 fn test_zen_provider_name() {
677 let provider = ZenProvider::new("test-key".to_string()).unwrap();
678 assert_eq!(provider.name(), "OpenCode Zen");
679 }
680
681 #[test]
682 fn test_zen_models() {
683 let provider = ZenProvider::new("test-key".to_string()).unwrap();
684 let models = provider.models();
685 assert_eq!(models.len(), 2);
686 assert!(models.iter().any(|m| m.id == "zen-gpt4"));
687 assert!(models.iter().any(|m| m.id == "zen-gpt4-turbo"));
688 }
689
690 #[test]
691 fn test_token_counting() {
692 let provider = ZenProvider::new("test-key".to_string()).unwrap();
693 let tokens = provider.count_tokens("Hello, world!", "zen-gpt4").unwrap();
694 assert!(tokens > 0);
695 }
696
697 #[test]
698 fn test_token_counting_invalid_model() {
699 let provider = ZenProvider::new("test-key".to_string()).unwrap();
700 let result = provider.count_tokens("Hello, world!", "invalid-model");
701 assert!(result.is_err());
702 }
703
704 #[test]
705 fn test_model_cache_creation() {
706 let cache = ModelCache::new();
707 assert!(cache.get().is_none());
708 }
709
710 #[test]
711 fn test_model_cache_set_and_get() {
712 let mut cache = ModelCache::new();
713 let models = vec![ModelInfo {
714 id: "test".to_string(),
715 name: "Test".to_string(),
716 provider: "zen".to_string(),
717 context_window: 1000,
718 capabilities: vec![],
719 pricing: None,
720 }];
721 cache.set(models.clone());
722 let cached = cache.get();
723 assert!(cached.is_some());
724 let cached_models = cached.unwrap();
725 assert_eq!(cached_models.len(), 1);
726 assert_eq!(cached_models[0].id, "test");
727 }
728
729 #[test]
730 fn test_health_check_cache_creation() {
731 let cache = HealthCheckCache::new();
732 assert!(cache.get().is_none());
733 }
734
735 #[test]
736 fn test_health_check_cache_set_and_get() {
737 let mut cache = HealthCheckCache::new();
738 cache.set(true);
739 assert_eq!(cache.get(), Some(true));
740 }
741
742 #[test]
743 fn test_estimate_tokens() {
744 let provider = ZenProvider::new("test-key".to_string()).unwrap();
745 let tokens = provider.estimate_tokens("Hello, world!");
746 assert_eq!(tokens, 4);
748 }
749
750 #[test]
751 fn test_estimate_tokens_empty() {
752 let provider = ZenProvider::new("test-key".to_string()).unwrap();
753 let tokens = provider.estimate_tokens("");
754 assert_eq!(tokens, 0);
755 }
756
757 #[test]
758 fn test_estimate_tokens_single_char() {
759 let provider = ZenProvider::new("test-key".to_string()).unwrap();
760 let tokens = provider.estimate_tokens("a");
761 assert_eq!(tokens, 1);
763 }
764}