ngdp_cache/
cached_ribbit_client.rs

1//! Cached wrapper for RibbitClient
2//!
3//! This module provides a caching layer for RibbitClient that stores responses
4//! using the Blizzard MIME filename convention: command-argument(s)-sequencenumber.bmime
5//!
6//! # Example
7//!
8//! ```no_run
9//! use ngdp_cache::cached_ribbit_client::CachedRibbitClient;
10//! use ribbit_client::Region;
11//!
12//! # #[tokio::main]
13//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
14//! // Create a cached client
15//! let client = CachedRibbitClient::new(Region::US).await?;
16//!
17//! // Use it exactly like RibbitClient - caching is transparent
18//! let summary = client.get_summary().await?;
19//! println!("Found {} products", summary.products.len());
20//!
21//! // Subsequent calls use cache (5 minute TTL for regular endpoints)
22//! let summary2 = client.get_summary().await?;  // This will be from cache!
23//! # Ok(())
24//! # }
25//! ```
26
27use std::path::PathBuf;
28use std::time::Duration;
29use tracing::{debug, trace};
30
31use ribbit_client::{Endpoint, ProtocolVersion, Region, RibbitClient, TypedResponse};
32
33use crate::{Result, ensure_dir, get_cache_dir};
34
35/// Default TTL for certificate cache (30 days)
36const CERT_CACHE_TTL: Duration = Duration::from_secs(30 * 24 * 60 * 60);
37
38/// Default TTL for regular responses (5 minutes)
39const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
40
41/// A caching wrapper around RibbitClient for raw responses
42pub struct CachedRibbitClient {
43    /// The underlying RibbitClient
44    client: RibbitClient,
45    /// Base directory for cache
46    cache_dir: PathBuf,
47    /// Region for this client
48    region: Region,
49    /// Whether caching is enabled
50    enabled: bool,
51}
52
53impl CachedRibbitClient {
54    /// Create a new cached Ribbit client
55    pub async fn new(region: Region) -> Result<Self> {
56        let client = RibbitClient::new(region);
57        let cache_dir = get_cache_dir()?.join("ribbit");
58        ensure_dir(&cache_dir).await?;
59
60        debug!("Initialized cached Ribbit client for region {:?}", region);
61
62        Ok(Self {
63            client,
64            cache_dir,
65            region,
66            enabled: true,
67        })
68    }
69
70    /// Create a new cached client with custom cache directory
71    pub async fn with_cache_dir(region: Region, cache_dir: PathBuf) -> Result<Self> {
72        let client = RibbitClient::new(region);
73        ensure_dir(&cache_dir).await?;
74
75        Ok(Self {
76            client,
77            cache_dir,
78            region,
79            enabled: true,
80        })
81    }
82
83    /// Enable or disable caching
84    pub fn set_caching_enabled(&mut self, enabled: bool) {
85        self.enabled = enabled;
86    }
87
88    /// Generate cache filename following Blizzard convention:
89    /// command-argument(s)-sequencenumber.bmime
90    fn generate_cache_filename(&self, endpoint: &Endpoint, sequence_number: Option<u64>) -> String {
91        let (command, arguments) = match endpoint {
92            Endpoint::Summary => ("summary", "#".to_string()),
93            Endpoint::ProductVersions(product) => ("versions", product.clone()),
94            Endpoint::ProductCdns(product) => ("cdns", product.clone()),
95            Endpoint::ProductBgdl(product) => ("bgdl", product.clone()),
96            Endpoint::Cert(hash) => ("certs", hash.clone()),
97            Endpoint::Ocsp(hash) => ("ocsp", hash.clone()),
98            Endpoint::Custom(path) => {
99                // Try to extract command and argument from custom path
100                let parts: Vec<&str> = path.split('/').collect();
101                match parts.as_slice() {
102                    [cmd] => (*cmd, "#".to_string()),
103                    [cmd, arg] => (*cmd, arg.to_string()),
104                    [cmd, arg, ..] => (*cmd, arg.to_string()),
105                    _ => ("custom", path.replace('/', "_")),
106                }
107            }
108        };
109
110        let seq = sequence_number.unwrap_or(0);
111        format!("{command}-{arguments}-{seq}.bmime")
112    }
113
114    /// Get the cache path for an endpoint
115    fn get_cache_path(&self, endpoint: &Endpoint, sequence_number: Option<u64>) -> PathBuf {
116        let filename = self.generate_cache_filename(endpoint, sequence_number);
117        self.cache_dir.join(self.region.to_string()).join(filename)
118    }
119
120    /// Get the metadata path for an endpoint
121    fn get_metadata_path(&self, endpoint: &Endpoint, sequence_number: Option<u64>) -> PathBuf {
122        let mut path = self.get_cache_path(endpoint, sequence_number);
123        path.set_extension("meta");
124        path
125    }
126
127    /// Determine TTL based on endpoint type
128    fn get_ttl_for_endpoint(&self, endpoint: &Endpoint) -> Duration {
129        match endpoint {
130            Endpoint::Cert(_) | Endpoint::Ocsp(_) => CERT_CACHE_TTL,
131            _ => DEFAULT_CACHE_TTL,
132        }
133    }
134
135    /// Extract sequence number from raw response data
136    fn extract_sequence_number(&self, raw_data: &[u8]) -> Option<u64> {
137        let data_str = String::from_utf8_lossy(raw_data);
138
139        // Look for the sequence number in the format "## seqn = 12345"
140        for line in data_str.lines() {
141            if line.starts_with("## seqn = ") {
142                if let Some(seqn_str) = line.strip_prefix("## seqn = ") {
143                    if let Ok(seqn) = seqn_str.trim().parse::<u64>() {
144                        return Some(seqn);
145                    }
146                }
147            }
148        }
149
150        None
151    }
152
153    /// Find the most recent valid cached file for an endpoint
154    async fn find_cached_file(&self, endpoint: &Endpoint) -> Option<(PathBuf, u64)> {
155        if !self.enabled {
156            return None;
157        }
158
159        let region_dir = self.cache_dir.join(self.region.to_string());
160        if tokio::fs::metadata(&region_dir).await.is_err() {
161            return None;
162        }
163
164        // Generate pattern to match files for this endpoint
165        let base_filename = self.generate_cache_filename(endpoint, Some(0));
166        let prefix = base_filename.trim_end_matches("-0.bmime");
167
168        let ttl = self.get_ttl_for_endpoint(endpoint);
169        let now = std::time::SystemTime::now()
170            .duration_since(std::time::UNIX_EPOCH)
171            .unwrap()
172            .as_secs();
173
174        let mut best_file: Option<(PathBuf, u64)> = None;
175        let mut best_seqn: u64 = 0;
176
177        // Read directory and find matching files
178        if let Ok(mut entries) = tokio::fs::read_dir(&region_dir).await {
179            while let Some(entry) = entries.next_entry().await.ok()? {
180                let path = entry.path();
181                if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
182                    // Check if this file matches our endpoint pattern
183                    if filename.starts_with(&format!("{prefix}-")) && filename.ends_with(".bmime") {
184                        // Extract sequence number from filename
185                        if let Some(seqn_part) = filename
186                            .strip_prefix(&format!("{prefix}-"))
187                            .and_then(|s| s.strip_suffix(".bmime"))
188                        {
189                            if let Ok(seqn) = seqn_part.parse::<u64>() {
190                                // Check if this file is still valid
191                                let meta_path = path.with_extension("meta");
192                                if let Ok(metadata) = tokio::fs::read_to_string(&meta_path).await {
193                                    if let Ok(timestamp) = metadata.trim().parse::<u64>() {
194                                        if now.saturating_sub(timestamp) < ttl.as_secs()
195                                            && seqn > best_seqn
196                                        {
197                                            best_file = Some((path.clone(), seqn));
198                                            best_seqn = seqn;
199                                        }
200                                    }
201                                }
202                            }
203                        }
204                    }
205                }
206            }
207        }
208
209        best_file
210    }
211
212    /// Check if a cached response is still valid
213    async fn is_cache_valid(&self, endpoint: &Endpoint) -> bool {
214        self.find_cached_file(endpoint).await.is_some()
215    }
216
217    /// Write raw response to cache
218    async fn write_to_cache(&self, endpoint: &Endpoint, raw_data: &[u8]) -> Result<()> {
219        if !self.enabled {
220            return Ok(());
221        }
222
223        // Extract sequence number from the response data
224        let sequence_number = self.extract_sequence_number(raw_data);
225
226        let cache_path = self.get_cache_path(endpoint, sequence_number);
227        let meta_path = self.get_metadata_path(endpoint, sequence_number);
228
229        // Ensure parent directory exists
230        if let Some(parent) = cache_path.parent() {
231            ensure_dir(parent).await?;
232        }
233
234        // Write the raw response data
235        trace!(
236            "Writing {} bytes to cache: {:?}",
237            raw_data.len(),
238            cache_path
239        );
240        tokio::fs::write(&cache_path, raw_data).await?;
241
242        // Write timestamp metadata
243        let timestamp = std::time::SystemTime::now()
244            .duration_since(std::time::UNIX_EPOCH)
245            .unwrap()
246            .as_secs();
247        tokio::fs::write(&meta_path, timestamp.to_string()).await?;
248
249        Ok(())
250    }
251
252    /// Read response from cache
253    async fn read_from_cache(&self, endpoint: &Endpoint) -> Result<Vec<u8>> {
254        if let Some((cache_path, _seqn)) = self.find_cached_file(endpoint).await {
255            trace!("Reading from cache: {:?}", cache_path);
256            Ok(tokio::fs::read(&cache_path).await?)
257        } else {
258            Err(crate::Error::CacheEntryNotFound(format!(
259                "No valid cache for endpoint: {endpoint:?}"
260            )))
261        }
262    }
263
264    /// Make a request with caching
265    ///
266    /// This method caches the raw response and reconstructs the Response object
267    /// when serving from cache.
268    pub async fn request(&self, endpoint: &Endpoint) -> Result<ribbit_client::Response> {
269        // Check cache first
270        if self.enabled && self.is_cache_valid(endpoint).await {
271            debug!("Cache hit for endpoint: {:?}", endpoint);
272            if let Ok(cached_data) = self.read_from_cache(endpoint).await {
273                // Reconstruct Response based on protocol version
274                // We need to extract the data from the raw bytes
275                let response = match self.client.protocol_version() {
276                    ribbit_client::ProtocolVersion::V2 => {
277                        // V2 is simple - just raw data as string
278                        ribbit_client::Response {
279                            raw: cached_data.clone(),
280                            data: Some(String::from_utf8_lossy(&cached_data).to_string()),
281                            mime_parts: None,
282                        }
283                    }
284                    _ => {
285                        // V1 - try to extract data from MIME structure
286                        // Look for the main data content in the MIME message
287                        let data_str = String::from_utf8_lossy(&cached_data);
288                        let mut data_content = None;
289
290                        // Simple MIME parsing to extract the data part
291                        if let Some(boundary_start) = data_str.find("boundary=\"") {
292                            if let Some(boundary_end) = data_str[boundary_start + 10..].find('"') {
293                                let boundary = &data_str
294                                    [boundary_start + 10..boundary_start + 10 + boundary_end];
295                                let delimiter = format!("--{boundary}");
296
297                                // Find the data part (usually first part after content type)
298                                let parts: Vec<&str> = data_str.split(&delimiter).collect();
299                                for part in parts {
300                                    if part.contains("Content-Disposition:")
301                                        && !part.contains("Content-Type: application/cms")
302                                    {
303                                        // Extract the body after headers - try both \r\n\r\n and \n\n
304                                        let body_start = part
305                                            .find("\r\n\r\n")
306                                            .map(|pos| (pos, 4))
307                                            .or_else(|| part.find("\n\n").map(|pos| (pos, 2)));
308
309                                        if let Some((start, offset)) = body_start {
310                                            let body = &part[start + offset..];
311                                            // Remove any trailing boundary markers
312                                            if let Some(end) = body
313                                                .find(&format!("\r\n--{boundary}"))
314                                                .or_else(|| body.find(&format!("\n--{boundary}")))
315                                            {
316                                                data_content = Some(body[..end].trim().to_string());
317                                            } else {
318                                                data_content = Some(body.trim().to_string());
319                                            }
320                                            break;
321                                        }
322                                    }
323                                }
324                            }
325                        }
326
327                        ribbit_client::Response {
328                            raw: cached_data,
329                            data: data_content,
330                            mime_parts: None, // Cannot fully reconstruct
331                        }
332                    }
333                };
334                return Ok(response);
335            }
336        }
337
338        // Cache miss or error - make actual request
339        debug!(
340            "Cache miss for endpoint: {:?}, fetching from server",
341            endpoint
342        );
343
344        // For certificate and OCSP endpoints, we need to use V1 protocol
345        let response = match endpoint {
346            Endpoint::Cert(_) | Endpoint::Ocsp(_) => {
347                // Create a V1 client for these endpoints
348                let mut v1_client = self.client.clone();
349                v1_client = v1_client.with_protocol_version(ProtocolVersion::V1);
350                v1_client.request(endpoint).await?
351            }
352            _ => {
353                // Use the default client for other endpoints
354                self.client.request(endpoint).await?
355            }
356        };
357
358        // Cache the successful response
359        if let Err(e) = self.write_to_cache(endpoint, &response.raw).await {
360            debug!("Failed to write to cache: {}", e);
361        }
362
363        Ok(response)
364    }
365
366    /// Make a raw request with caching
367    ///
368    /// This is a convenience method that returns just the raw bytes.
369    pub async fn request_raw(&self, endpoint: &Endpoint) -> Result<Vec<u8>> {
370        // Check cache first
371        if self.enabled && self.is_cache_valid(endpoint).await {
372            debug!("Cache hit for raw endpoint: {:?}", endpoint);
373            if let Ok(cached_data) = self.read_from_cache(endpoint).await {
374                return Ok(cached_data);
375            }
376        }
377
378        // Cache miss or error - make actual request
379        debug!(
380            "Cache miss for raw endpoint: {:?}, fetching from server",
381            endpoint
382        );
383
384        // For certificate and OCSP endpoints, we need to use V1 protocol
385        let raw_data = match endpoint {
386            Endpoint::Cert(_) | Endpoint::Ocsp(_) => {
387                // Create a V1 client for these endpoints
388                let mut v1_client = self.client.clone();
389                v1_client = v1_client.with_protocol_version(ProtocolVersion::V1);
390                v1_client.request_raw(endpoint).await?
391            }
392            _ => {
393                // Use the default client for other endpoints
394                self.client.request_raw(endpoint).await?
395            }
396        };
397
398        // Cache the successful response
399        if let Err(e) = self.write_to_cache(endpoint, &raw_data).await {
400            debug!("Failed to write to cache: {}", e);
401        }
402
403        Ok(raw_data)
404    }
405
406    /// Request with automatic type parsing
407    ///
408    /// This method caches the raw response and parses it into the appropriate typed structure.
409    /// It's a drop-in replacement for RibbitClient::request_typed.
410    pub async fn request_typed<T: TypedResponse>(&self, endpoint: &Endpoint) -> Result<T> {
411        let response = self.request(endpoint).await?;
412        T::from_response(&response).map_err(|e| e.into())
413    }
414
415    /// Request product versions with typed response
416    ///
417    /// Convenience method with caching for requesting product version information.
418    pub async fn get_product_versions(
419        &self,
420        product: &str,
421    ) -> Result<ribbit_client::ProductVersionsResponse> {
422        self.request_typed(&Endpoint::ProductVersions(product.to_string()))
423            .await
424    }
425
426    /// Request product CDNs with typed response
427    ///
428    /// Convenience method with caching for requesting CDN server information.
429    pub async fn get_product_cdns(
430        &self,
431        product: &str,
432    ) -> Result<ribbit_client::ProductCdnsResponse> {
433        self.request_typed(&Endpoint::ProductCdns(product.to_string()))
434            .await
435    }
436
437    /// Request product background download config with typed response
438    ///
439    /// Convenience method with caching for requesting background download configuration.
440    pub async fn get_product_bgdl(
441        &self,
442        product: &str,
443    ) -> Result<ribbit_client::ProductBgdlResponse> {
444        self.request_typed(&Endpoint::ProductBgdl(product.to_string()))
445            .await
446    }
447
448    /// Request summary of all products with typed response
449    ///
450    /// Convenience method with caching for requesting the summary of all available products.
451    pub async fn get_summary(&self) -> Result<ribbit_client::SummaryResponse> {
452        self.request_typed(&Endpoint::Summary).await
453    }
454
455    /// Get the underlying RibbitClient
456    pub fn inner(&self) -> &RibbitClient {
457        &self.client
458    }
459
460    /// Get mutable access to the underlying RibbitClient
461    pub fn inner_mut(&mut self) -> &mut RibbitClient {
462        &mut self.client
463    }
464
465    /// Clear all cached responses
466    pub async fn clear_cache(&self) -> Result<()> {
467        debug!("Clearing all cached responses");
468
469        let region_dir = self.cache_dir.join(self.region.to_string());
470        if tokio::fs::metadata(&region_dir).await.is_ok() {
471            let mut entries = tokio::fs::read_dir(&region_dir).await?;
472            while let Some(entry) = entries.next_entry().await? {
473                let path = entry.path();
474                if path.extension().and_then(|s| s.to_str()) == Some("bmime")
475                    || path.extension().and_then(|s| s.to_str()) == Some("meta")
476                {
477                    tokio::fs::remove_file(&path).await?;
478                }
479            }
480        }
481
482        Ok(())
483    }
484
485    /// Clear expired cache entries
486    pub async fn clear_expired(&self) -> Result<()> {
487        debug!("Clearing expired cache entries");
488
489        let region_dir = self.cache_dir.join(self.region.to_string());
490        if tokio::fs::metadata(&region_dir).await.is_err() {
491            return Ok(());
492        }
493
494        let mut entries = tokio::fs::read_dir(&region_dir).await?;
495        while let Some(entry) = entries.next_entry().await? {
496            let path = entry.path();
497
498            if path.extension().and_then(|s| s.to_str()) == Some("bmime") {
499                // Check if this cache file is expired
500                let meta_path = path.with_extension("meta");
501
502                if let Ok(metadata) = tokio::fs::read_to_string(&meta_path).await {
503                    if let Ok(timestamp) = metadata.trim().parse::<u64>() {
504                        let now = std::time::SystemTime::now()
505                            .duration_since(std::time::UNIX_EPOCH)
506                            .unwrap()
507                            .as_secs();
508
509                        // Determine TTL based on filename
510                        let filename = path.file_name().unwrap().to_string_lossy();
511                        let ttl = if filename.starts_with("certs-") || filename.starts_with("ocsp-")
512                        {
513                            CERT_CACHE_TTL
514                        } else {
515                            DEFAULT_CACHE_TTL
516                        };
517
518                        if now.saturating_sub(timestamp) >= ttl.as_secs() {
519                            // Remove both files
520                            let _ = tokio::fs::remove_file(&path).await;
521                            let _ = tokio::fs::remove_file(&meta_path).await;
522                            trace!("Removed expired cache file: {:?}", path);
523                        }
524                    }
525                }
526            }
527        }
528
529        Ok(())
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_cache_filename_generation() {
539        tokio::runtime::Runtime::new().unwrap().block_on(async {
540            let client = CachedRibbitClient::new(Region::US).await.unwrap();
541
542            // Test various endpoints
543            assert_eq!(
544                client.generate_cache_filename(&Endpoint::Summary, None),
545                "summary-#-0.bmime"
546            );
547
548            assert_eq!(
549                client.generate_cache_filename(&Endpoint::ProductVersions("wow".to_string()), None),
550                "versions-wow-0.bmime"
551            );
552
553            assert_eq!(
554                client.generate_cache_filename(&Endpoint::Cert("abc123".to_string()), Some(12345)),
555                "certs-abc123-12345.bmime"
556            );
557
558            assert_eq!(
559                client.generate_cache_filename(
560                    &Endpoint::Custom("products/wow/config".to_string()),
561                    None
562                ),
563                "products-wow-0.bmime"
564            );
565        });
566    }
567
568    #[test]
569    fn test_ttl_selection() {
570        tokio::runtime::Runtime::new().unwrap().block_on(async {
571            let client = CachedRibbitClient::new(Region::US).await.unwrap();
572
573            // Regular endpoints get default TTL
574            assert_eq!(
575                client.get_ttl_for_endpoint(&Endpoint::Summary),
576                DEFAULT_CACHE_TTL
577            );
578
579            // Certificate endpoints get longer TTL
580            assert_eq!(
581                client.get_ttl_for_endpoint(&Endpoint::Cert("test".to_string())),
582                CERT_CACHE_TTL
583            );
584
585            assert_eq!(
586                client.get_ttl_for_endpoint(&Endpoint::Ocsp("test".to_string())),
587                CERT_CACHE_TTL
588            );
589        });
590    }
591
592    #[test]
593    fn test_api_methods_compile() {
594        // This test just verifies that all API methods compile correctly
595        // It doesn't actually run them to avoid network calls in tests
596        tokio::runtime::Runtime::new().unwrap().block_on(async {
597            let client = CachedRibbitClient::new(Region::US).await.unwrap();
598
599            // These would all compile and work in real usage:
600            // let _ = client.get_summary().await;
601            // let _ = client.get_product_versions("wow").await;
602            // let _ = client.get_product_cdns("wow").await;
603            // let _ = client.get_product_bgdl("wow").await;
604            // let _ = client.request(&Endpoint::Summary).await;
605            // let _ = client.request_raw(&Endpoint::Summary).await;
606            // let _ = client.request_typed::<SummaryResponse>(&Endpoint::Summary).await;
607
608            // Just verify the client was created
609            assert_eq!(client.inner().region(), Region::US);
610        });
611    }
612
613    #[test]
614    fn test_extract_sequence_number() {
615        tokio::runtime::Runtime::new().unwrap().block_on(async {
616            let client = CachedRibbitClient::new(Region::US).await.unwrap();
617
618            // Test BPSV format with sequence number
619            let data_with_seqn = b"Product!STRING:0|Seqn!DEC:4\n## seqn = 12345\nwow|67890";
620            assert_eq!(client.extract_sequence_number(data_with_seqn), Some(12345));
621
622            // Test MIME wrapped data
623            let mime_data = b"Subject: test\nFrom: Test/1.0\n\n--boundary\nContent-Disposition: test\n\nProduct!STRING:0\n## seqn = 67890\ndata\n--boundary--";
624            assert_eq!(client.extract_sequence_number(mime_data), Some(67890));
625
626            // Test data without sequence number
627            let data_no_seqn = b"Product!STRING:0|Seqn!DEC:4\nwow|12345";
628            assert_eq!(client.extract_sequence_number(data_no_seqn), None);
629        });
630    }
631}