1use std::path::PathBuf;
28use std::time::Duration;
29use tracing::{debug, trace};
30
31use ribbit_client::{Endpoint, ProtocolVersion, Region, RibbitClient, TypedResponse};
32
33use crate::{Result, ensure_dir, get_cache_dir};
34
35const CERT_CACHE_TTL: Duration = Duration::from_secs(30 * 24 * 60 * 60);
37
38const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
40
41pub struct CachedRibbitClient {
43 client: RibbitClient,
45 cache_dir: PathBuf,
47 region: Region,
49 enabled: bool,
51}
52
53impl CachedRibbitClient {
54 pub async fn new(region: Region) -> Result<Self> {
56 let client = RibbitClient::new(region);
57 let cache_dir = get_cache_dir()?.join("ribbit");
58 ensure_dir(&cache_dir).await?;
59
60 debug!("Initialized cached Ribbit client for region {:?}", region);
61
62 Ok(Self {
63 client,
64 cache_dir,
65 region,
66 enabled: true,
67 })
68 }
69
70 pub async fn with_cache_dir(region: Region, cache_dir: PathBuf) -> Result<Self> {
72 let client = RibbitClient::new(region);
73 ensure_dir(&cache_dir).await?;
74
75 Ok(Self {
76 client,
77 cache_dir,
78 region,
79 enabled: true,
80 })
81 }
82
83 pub fn set_caching_enabled(&mut self, enabled: bool) {
85 self.enabled = enabled;
86 }
87
88 fn generate_cache_filename(&self, endpoint: &Endpoint, sequence_number: Option<u64>) -> String {
91 let (command, arguments) = match endpoint {
92 Endpoint::Summary => ("summary", "#".to_string()),
93 Endpoint::ProductVersions(product) => ("versions", product.clone()),
94 Endpoint::ProductCdns(product) => ("cdns", product.clone()),
95 Endpoint::ProductBgdl(product) => ("bgdl", product.clone()),
96 Endpoint::Cert(hash) => ("certs", hash.clone()),
97 Endpoint::Ocsp(hash) => ("ocsp", hash.clone()),
98 Endpoint::Custom(path) => {
99 let parts: Vec<&str> = path.split('/').collect();
101 match parts.as_slice() {
102 [cmd] => (*cmd, "#".to_string()),
103 [cmd, arg] => (*cmd, arg.to_string()),
104 [cmd, arg, ..] => (*cmd, arg.to_string()),
105 _ => ("custom", path.replace('/', "_")),
106 }
107 }
108 };
109
110 let seq = sequence_number.unwrap_or(0);
111 format!("{command}-{arguments}-{seq}.bmime")
112 }
113
114 fn get_cache_path(&self, endpoint: &Endpoint, sequence_number: Option<u64>) -> PathBuf {
116 let filename = self.generate_cache_filename(endpoint, sequence_number);
117 self.cache_dir.join(self.region.to_string()).join(filename)
118 }
119
120 fn get_metadata_path(&self, endpoint: &Endpoint, sequence_number: Option<u64>) -> PathBuf {
122 let mut path = self.get_cache_path(endpoint, sequence_number);
123 path.set_extension("meta");
124 path
125 }
126
127 fn get_ttl_for_endpoint(&self, endpoint: &Endpoint) -> Duration {
129 match endpoint {
130 Endpoint::Cert(_) | Endpoint::Ocsp(_) => CERT_CACHE_TTL,
131 _ => DEFAULT_CACHE_TTL,
132 }
133 }
134
135 fn extract_sequence_number(&self, raw_data: &[u8]) -> Option<u64> {
137 let data_str = String::from_utf8_lossy(raw_data);
138
139 for line in data_str.lines() {
141 if line.starts_with("## seqn = ") {
142 if let Some(seqn_str) = line.strip_prefix("## seqn = ") {
143 if let Ok(seqn) = seqn_str.trim().parse::<u64>() {
144 return Some(seqn);
145 }
146 }
147 }
148 }
149
150 None
151 }
152
153 async fn find_cached_file(&self, endpoint: &Endpoint) -> Option<(PathBuf, u64)> {
155 if !self.enabled {
156 return None;
157 }
158
159 let region_dir = self.cache_dir.join(self.region.to_string());
160 if tokio::fs::metadata(®ion_dir).await.is_err() {
161 return None;
162 }
163
164 let base_filename = self.generate_cache_filename(endpoint, Some(0));
166 let prefix = base_filename.trim_end_matches("-0.bmime");
167
168 let ttl = self.get_ttl_for_endpoint(endpoint);
169 let now = std::time::SystemTime::now()
170 .duration_since(std::time::UNIX_EPOCH)
171 .unwrap()
172 .as_secs();
173
174 let mut best_file: Option<(PathBuf, u64)> = None;
175 let mut best_seqn: u64 = 0;
176
177 if let Ok(mut entries) = tokio::fs::read_dir(®ion_dir).await {
179 while let Some(entry) = entries.next_entry().await.ok()? {
180 let path = entry.path();
181 if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
182 if filename.starts_with(&format!("{prefix}-")) && filename.ends_with(".bmime") {
184 if let Some(seqn_part) = filename
186 .strip_prefix(&format!("{prefix}-"))
187 .and_then(|s| s.strip_suffix(".bmime"))
188 {
189 if let Ok(seqn) = seqn_part.parse::<u64>() {
190 let meta_path = path.with_extension("meta");
192 if let Ok(metadata) = tokio::fs::read_to_string(&meta_path).await {
193 if let Ok(timestamp) = metadata.trim().parse::<u64>() {
194 if now.saturating_sub(timestamp) < ttl.as_secs()
195 && seqn > best_seqn
196 {
197 best_file = Some((path.clone(), seqn));
198 best_seqn = seqn;
199 }
200 }
201 }
202 }
203 }
204 }
205 }
206 }
207 }
208
209 best_file
210 }
211
212 async fn is_cache_valid(&self, endpoint: &Endpoint) -> bool {
214 self.find_cached_file(endpoint).await.is_some()
215 }
216
217 async fn write_to_cache(&self, endpoint: &Endpoint, raw_data: &[u8]) -> Result<()> {
219 if !self.enabled {
220 return Ok(());
221 }
222
223 let sequence_number = self.extract_sequence_number(raw_data);
225
226 let cache_path = self.get_cache_path(endpoint, sequence_number);
227 let meta_path = self.get_metadata_path(endpoint, sequence_number);
228
229 if let Some(parent) = cache_path.parent() {
231 ensure_dir(parent).await?;
232 }
233
234 trace!(
236 "Writing {} bytes to cache: {:?}",
237 raw_data.len(),
238 cache_path
239 );
240 tokio::fs::write(&cache_path, raw_data).await?;
241
242 let timestamp = std::time::SystemTime::now()
244 .duration_since(std::time::UNIX_EPOCH)
245 .unwrap()
246 .as_secs();
247 tokio::fs::write(&meta_path, timestamp.to_string()).await?;
248
249 Ok(())
250 }
251
252 async fn read_from_cache(&self, endpoint: &Endpoint) -> Result<Vec<u8>> {
254 if let Some((cache_path, _seqn)) = self.find_cached_file(endpoint).await {
255 trace!("Reading from cache: {:?}", cache_path);
256 Ok(tokio::fs::read(&cache_path).await?)
257 } else {
258 Err(crate::Error::CacheEntryNotFound(format!(
259 "No valid cache for endpoint: {endpoint:?}"
260 )))
261 }
262 }
263
264 pub async fn request(&self, endpoint: &Endpoint) -> Result<ribbit_client::Response> {
269 if self.enabled && self.is_cache_valid(endpoint).await {
271 debug!("Cache hit for endpoint: {:?}", endpoint);
272 if let Ok(cached_data) = self.read_from_cache(endpoint).await {
273 let response = match self.client.protocol_version() {
276 ribbit_client::ProtocolVersion::V2 => {
277 ribbit_client::Response {
279 raw: cached_data.clone(),
280 data: Some(String::from_utf8_lossy(&cached_data).to_string()),
281 mime_parts: None,
282 }
283 }
284 _ => {
285 let data_str = String::from_utf8_lossy(&cached_data);
288 let mut data_content = None;
289
290 if let Some(boundary_start) = data_str.find("boundary=\"") {
292 if let Some(boundary_end) = data_str[boundary_start + 10..].find('"') {
293 let boundary = &data_str
294 [boundary_start + 10..boundary_start + 10 + boundary_end];
295 let delimiter = format!("--{boundary}");
296
297 let parts: Vec<&str> = data_str.split(&delimiter).collect();
299 for part in parts {
300 if part.contains("Content-Disposition:")
301 && !part.contains("Content-Type: application/cms")
302 {
303 let body_start = part
305 .find("\r\n\r\n")
306 .map(|pos| (pos, 4))
307 .or_else(|| part.find("\n\n").map(|pos| (pos, 2)));
308
309 if let Some((start, offset)) = body_start {
310 let body = &part[start + offset..];
311 if let Some(end) = body
313 .find(&format!("\r\n--{boundary}"))
314 .or_else(|| body.find(&format!("\n--{boundary}")))
315 {
316 data_content = Some(body[..end].trim().to_string());
317 } else {
318 data_content = Some(body.trim().to_string());
319 }
320 break;
321 }
322 }
323 }
324 }
325 }
326
327 ribbit_client::Response {
328 raw: cached_data,
329 data: data_content,
330 mime_parts: None, }
332 }
333 };
334 return Ok(response);
335 }
336 }
337
338 debug!(
340 "Cache miss for endpoint: {:?}, fetching from server",
341 endpoint
342 );
343
344 let response = match endpoint {
346 Endpoint::Cert(_) | Endpoint::Ocsp(_) => {
347 let mut v1_client = self.client.clone();
349 v1_client = v1_client.with_protocol_version(ProtocolVersion::V1);
350 v1_client.request(endpoint).await?
351 }
352 _ => {
353 self.client.request(endpoint).await?
355 }
356 };
357
358 if let Err(e) = self.write_to_cache(endpoint, &response.raw).await {
360 debug!("Failed to write to cache: {}", e);
361 }
362
363 Ok(response)
364 }
365
366 pub async fn request_raw(&self, endpoint: &Endpoint) -> Result<Vec<u8>> {
370 if self.enabled && self.is_cache_valid(endpoint).await {
372 debug!("Cache hit for raw endpoint: {:?}", endpoint);
373 if let Ok(cached_data) = self.read_from_cache(endpoint).await {
374 return Ok(cached_data);
375 }
376 }
377
378 debug!(
380 "Cache miss for raw endpoint: {:?}, fetching from server",
381 endpoint
382 );
383
384 let raw_data = match endpoint {
386 Endpoint::Cert(_) | Endpoint::Ocsp(_) => {
387 let mut v1_client = self.client.clone();
389 v1_client = v1_client.with_protocol_version(ProtocolVersion::V1);
390 v1_client.request_raw(endpoint).await?
391 }
392 _ => {
393 self.client.request_raw(endpoint).await?
395 }
396 };
397
398 if let Err(e) = self.write_to_cache(endpoint, &raw_data).await {
400 debug!("Failed to write to cache: {}", e);
401 }
402
403 Ok(raw_data)
404 }
405
406 pub async fn request_typed<T: TypedResponse>(&self, endpoint: &Endpoint) -> Result<T> {
411 let response = self.request(endpoint).await?;
412 T::from_response(&response).map_err(|e| e.into())
413 }
414
415 pub async fn get_product_versions(
419 &self,
420 product: &str,
421 ) -> Result<ribbit_client::ProductVersionsResponse> {
422 self.request_typed(&Endpoint::ProductVersions(product.to_string()))
423 .await
424 }
425
426 pub async fn get_product_cdns(
430 &self,
431 product: &str,
432 ) -> Result<ribbit_client::ProductCdnsResponse> {
433 self.request_typed(&Endpoint::ProductCdns(product.to_string()))
434 .await
435 }
436
437 pub async fn get_product_bgdl(
441 &self,
442 product: &str,
443 ) -> Result<ribbit_client::ProductBgdlResponse> {
444 self.request_typed(&Endpoint::ProductBgdl(product.to_string()))
445 .await
446 }
447
448 pub async fn get_summary(&self) -> Result<ribbit_client::SummaryResponse> {
452 self.request_typed(&Endpoint::Summary).await
453 }
454
455 pub fn inner(&self) -> &RibbitClient {
457 &self.client
458 }
459
460 pub fn inner_mut(&mut self) -> &mut RibbitClient {
462 &mut self.client
463 }
464
465 pub async fn clear_cache(&self) -> Result<()> {
467 debug!("Clearing all cached responses");
468
469 let region_dir = self.cache_dir.join(self.region.to_string());
470 if tokio::fs::metadata(®ion_dir).await.is_ok() {
471 let mut entries = tokio::fs::read_dir(®ion_dir).await?;
472 while let Some(entry) = entries.next_entry().await? {
473 let path = entry.path();
474 if path.extension().and_then(|s| s.to_str()) == Some("bmime")
475 || path.extension().and_then(|s| s.to_str()) == Some("meta")
476 {
477 tokio::fs::remove_file(&path).await?;
478 }
479 }
480 }
481
482 Ok(())
483 }
484
485 pub async fn clear_expired(&self) -> Result<()> {
487 debug!("Clearing expired cache entries");
488
489 let region_dir = self.cache_dir.join(self.region.to_string());
490 if tokio::fs::metadata(®ion_dir).await.is_err() {
491 return Ok(());
492 }
493
494 let mut entries = tokio::fs::read_dir(®ion_dir).await?;
495 while let Some(entry) = entries.next_entry().await? {
496 let path = entry.path();
497
498 if path.extension().and_then(|s| s.to_str()) == Some("bmime") {
499 let meta_path = path.with_extension("meta");
501
502 if let Ok(metadata) = tokio::fs::read_to_string(&meta_path).await {
503 if let Ok(timestamp) = metadata.trim().parse::<u64>() {
504 let now = std::time::SystemTime::now()
505 .duration_since(std::time::UNIX_EPOCH)
506 .unwrap()
507 .as_secs();
508
509 let filename = path.file_name().unwrap().to_string_lossy();
511 let ttl = if filename.starts_with("certs-") || filename.starts_with("ocsp-")
512 {
513 CERT_CACHE_TTL
514 } else {
515 DEFAULT_CACHE_TTL
516 };
517
518 if now.saturating_sub(timestamp) >= ttl.as_secs() {
519 let _ = tokio::fs::remove_file(&path).await;
521 let _ = tokio::fs::remove_file(&meta_path).await;
522 trace!("Removed expired cache file: {:?}", path);
523 }
524 }
525 }
526 }
527 }
528
529 Ok(())
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_cache_filename_generation() {
539 tokio::runtime::Runtime::new().unwrap().block_on(async {
540 let client = CachedRibbitClient::new(Region::US).await.unwrap();
541
542 assert_eq!(
544 client.generate_cache_filename(&Endpoint::Summary, None),
545 "summary-#-0.bmime"
546 );
547
548 assert_eq!(
549 client.generate_cache_filename(&Endpoint::ProductVersions("wow".to_string()), None),
550 "versions-wow-0.bmime"
551 );
552
553 assert_eq!(
554 client.generate_cache_filename(&Endpoint::Cert("abc123".to_string()), Some(12345)),
555 "certs-abc123-12345.bmime"
556 );
557
558 assert_eq!(
559 client.generate_cache_filename(
560 &Endpoint::Custom("products/wow/config".to_string()),
561 None
562 ),
563 "products-wow-0.bmime"
564 );
565 });
566 }
567
568 #[test]
569 fn test_ttl_selection() {
570 tokio::runtime::Runtime::new().unwrap().block_on(async {
571 let client = CachedRibbitClient::new(Region::US).await.unwrap();
572
573 assert_eq!(
575 client.get_ttl_for_endpoint(&Endpoint::Summary),
576 DEFAULT_CACHE_TTL
577 );
578
579 assert_eq!(
581 client.get_ttl_for_endpoint(&Endpoint::Cert("test".to_string())),
582 CERT_CACHE_TTL
583 );
584
585 assert_eq!(
586 client.get_ttl_for_endpoint(&Endpoint::Ocsp("test".to_string())),
587 CERT_CACHE_TTL
588 );
589 });
590 }
591
592 #[test]
593 fn test_api_methods_compile() {
594 tokio::runtime::Runtime::new().unwrap().block_on(async {
597 let client = CachedRibbitClient::new(Region::US).await.unwrap();
598
599 assert_eq!(client.inner().region(), Region::US);
610 });
611 }
612
613 #[test]
614 fn test_extract_sequence_number() {
615 tokio::runtime::Runtime::new().unwrap().block_on(async {
616 let client = CachedRibbitClient::new(Region::US).await.unwrap();
617
618 let data_with_seqn = b"Product!STRING:0|Seqn!DEC:4\n## seqn = 12345\nwow|67890";
620 assert_eq!(client.extract_sequence_number(data_with_seqn), Some(12345));
621
622 let mime_data = b"Subject: test\nFrom: Test/1.0\n\n--boundary\nContent-Disposition: test\n\nProduct!STRING:0\n## seqn = 67890\ndata\n--boundary--";
624 assert_eq!(client.extract_sequence_number(mime_data), Some(67890));
625
626 let data_no_seqn = b"Product!STRING:0|Seqn!DEC:4\nwow|12345";
628 assert_eq!(client.extract_sequence_number(data_no_seqn), None);
629 });
630 }
631}