1use std::path::PathBuf;
37use std::time::Duration;
38use tracing::{debug, trace};
39
40use tact_client::{CdnEntry, HttpClient, ProtocolVersion, Region, VersionEntry};
41
42use crate::{Result, ensure_dir, get_cache_dir};
43
44const VERSIONS_CACHE_TTL: Duration = Duration::from_secs(5 * 60);
46
47const CDN_CACHE_TTL: Duration = Duration::from_secs(30 * 60);
49
50#[derive(Debug, serde::Serialize, serde::Deserialize)]
52struct CacheMetadata {
53 timestamp: u64,
55 ttl_seconds: u64,
57 region: String,
59 protocol: String,
61 product: String,
63 endpoint: String,
65 sequence: Option<u64>,
67 response_size: usize,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq)]
73enum TactEndpoint {
74 Versions,
75 Cdns,
76 Bgdl,
77}
78
79impl TactEndpoint {
80 fn as_str(&self) -> &'static str {
81 match self {
82 Self::Versions => "versions",
83 Self::Cdns => "cdns",
84 Self::Bgdl => "bgdl",
85 }
86 }
87
88 fn ttl(&self) -> Duration {
89 match self {
90 Self::Versions => VERSIONS_CACHE_TTL,
91 Self::Cdns | Self::Bgdl => CDN_CACHE_TTL,
92 }
93 }
94}
95
96pub struct CachedTactClient {
98 client: HttpClient,
100 cache_dir: PathBuf,
102 enabled: bool,
104}
105
106impl CachedTactClient {
107 pub async fn new(region: Region, protocol: ProtocolVersion) -> Result<Self> {
109 let client = HttpClient::new(region, protocol)?;
110 let cache_dir = get_cache_dir()?.join("tact");
111 ensure_dir(&cache_dir).await?;
112
113 debug!(
114 "Initialized cached TACT client for region {:?}, protocol {:?}",
115 region, protocol
116 );
117
118 Ok(Self {
119 client,
120 cache_dir,
121 enabled: true,
122 })
123 }
124
125 pub async fn with_cache_dir(
127 region: Region,
128 protocol: ProtocolVersion,
129 cache_dir: PathBuf,
130 ) -> Result<Self> {
131 let client = HttpClient::new(region, protocol)?;
132 ensure_dir(&cache_dir).await?;
133
134 Ok(Self {
135 client,
136 cache_dir,
137 enabled: true,
138 })
139 }
140
141 pub async fn with_client(client: HttpClient) -> Result<Self> {
143 let cache_dir = get_cache_dir()?.join("tact");
144 ensure_dir(&cache_dir).await?;
145
146 Ok(Self {
147 client,
148 cache_dir,
149 enabled: true,
150 })
151 }
152
153 pub fn set_caching_enabled(&mut self, enabled: bool) {
155 self.enabled = enabled;
156 }
157
158 fn get_cache_path(
160 &self,
161 product: &str,
162 endpoint: TactEndpoint,
163 sequence: Option<u64>,
164 ) -> PathBuf {
165 let region = self.client.region().to_string();
166 let protocol = match self.client.version() {
167 ProtocolVersion::V1 => "v1",
168 ProtocolVersion::V2 => "v2",
169 };
170
171 let seq = sequence.unwrap_or(0);
172 let filename = format!("{}-{}.bpsv", endpoint.as_str(), seq);
173
174 self.cache_dir
175 .join(region)
176 .join(protocol)
177 .join(product)
178 .join(filename)
179 }
180
181 fn get_metadata_path(
183 &self,
184 product: &str,
185 endpoint: TactEndpoint,
186 sequence: Option<u64>,
187 ) -> PathBuf {
188 let mut path = self.get_cache_path(product, endpoint, sequence);
189 path.set_extension("meta");
190 path
191 }
192
193 fn extract_sequence_number(&self, data: &str) -> Option<u64> {
195 for line in data.lines() {
197 if line.starts_with("## seqn = ") {
198 if let Some(seqn_str) = line.strip_prefix("## seqn = ") {
199 if let Ok(seqn) = seqn_str.trim().parse::<u64>() {
200 return Some(seqn);
201 }
202 }
203 }
204 }
205 None
206 }
207
208 async fn find_cached_file(
210 &self,
211 product: &str,
212 endpoint: TactEndpoint,
213 ) -> Option<(PathBuf, u64)> {
214 if !self.enabled {
215 return None;
216 }
217
218 let region = self.client.region().to_string();
219 let protocol = match self.client.version() {
220 ProtocolVersion::V1 => "v1",
221 ProtocolVersion::V2 => "v2",
222 };
223
224 let cache_subdir = self.cache_dir.join(®ion).join(protocol).join(product);
225 if tokio::fs::metadata(&cache_subdir).await.is_err() {
226 return None;
227 }
228
229 let prefix = format!("{}-", endpoint.as_str());
230 let ttl = endpoint.ttl();
231 let now = std::time::SystemTime::now()
232 .duration_since(std::time::UNIX_EPOCH)
233 .unwrap()
234 .as_secs();
235
236 let mut best_file: Option<(PathBuf, u64)> = None;
237 let mut best_seqn: u64 = 0;
238
239 if let Ok(mut entries) = tokio::fs::read_dir(&cache_subdir).await {
241 while let Some(entry) = entries.next_entry().await.ok()? {
242 let path = entry.path();
243 if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
244 if filename.starts_with(&prefix) && filename.ends_with(".bpsv") {
246 if let Some(seqn_part) = filename
248 .strip_prefix(&prefix)
249 .and_then(|s| s.strip_suffix(".bpsv"))
250 {
251 if let Ok(seqn) = seqn_part.parse::<u64>() {
252 let meta_path = path.with_extension("meta");
254 if let Ok(metadata_str) =
255 tokio::fs::read_to_string(&meta_path).await
256 {
257 if let Ok(metadata) =
258 serde_json::from_str::<CacheMetadata>(&metadata_str)
259 {
260 if now.saturating_sub(metadata.timestamp) < ttl.as_secs()
261 && seqn > best_seqn
262 {
263 best_file = Some((path.clone(), seqn));
264 best_seqn = seqn;
265 }
266 }
267 }
268 }
269 }
270 }
271 }
272 }
273 }
274
275 best_file
276 }
277
278 async fn write_to_cache(
280 &self,
281 product: &str,
282 endpoint: TactEndpoint,
283 data: &str,
284 ) -> Result<()> {
285 if !self.enabled {
286 return Ok(());
287 }
288
289 let sequence = self.extract_sequence_number(data);
291
292 let cache_path = self.get_cache_path(product, endpoint, sequence);
293 let meta_path = self.get_metadata_path(product, endpoint, sequence);
294
295 if let Some(parent) = cache_path.parent() {
297 ensure_dir(parent).await?;
298 }
299
300 trace!(
302 "Writing {} bytes to TACT cache: {:?}",
303 data.len(),
304 cache_path
305 );
306 tokio::fs::write(&cache_path, data).await?;
307
308 let metadata = CacheMetadata {
310 timestamp: std::time::SystemTime::now()
311 .duration_since(std::time::UNIX_EPOCH)
312 .unwrap()
313 .as_secs(),
314 ttl_seconds: endpoint.ttl().as_secs(),
315 region: self.client.region().to_string(),
316 protocol: match self.client.version() {
317 ProtocolVersion::V1 => "v1".to_string(),
318 ProtocolVersion::V2 => "v2".to_string(),
319 },
320 product: product.to_string(),
321 endpoint: endpoint.as_str().to_string(),
322 sequence,
323 response_size: data.len(),
324 };
325
326 let metadata_json = serde_json::to_string_pretty(&metadata)?;
327 tokio::fs::write(&meta_path, metadata_json).await?;
328
329 Ok(())
330 }
331
332 async fn read_from_cache(&self, product: &str, endpoint: TactEndpoint) -> Result<String> {
334 if let Some((cache_path, _seqn)) = self.find_cached_file(product, endpoint).await {
335 trace!("Reading from TACT cache: {:?}", cache_path);
336 Ok(tokio::fs::read_to_string(&cache_path).await?)
337 } else {
338 Err(crate::Error::CacheEntryNotFound(format!(
339 "No valid cache for {}/{}/{}",
340 self.client.region(),
341 product,
342 endpoint.as_str()
343 )))
344 }
345 }
346
347 pub async fn get_versions(&self, product: &str) -> Result<reqwest::Response> {
349 Ok(self.client.get_versions(product).await?)
352 }
353
354 pub async fn get_versions_parsed(&self, product: &str) -> Result<Vec<VersionEntry>> {
356 let endpoint = TactEndpoint::Versions;
357
358 if self.enabled {
360 if let Ok(cached_data) = self.read_from_cache(product, endpoint).await {
361 debug!("Cache hit for TACT {}/{}", product, endpoint.as_str());
362 return Ok(tact_client::parse_versions(&cached_data)?);
364 }
365 }
366
367 debug!(
369 "Cache miss for TACT {}/{}, fetching from server",
370 product,
371 endpoint.as_str()
372 );
373 let response = self.client.get_versions(product).await?;
374 let text = response.text().await?;
375
376 if let Err(e) = self.write_to_cache(product, endpoint, &text).await {
378 debug!("Failed to write to TACT cache: {}", e);
379 }
380
381 Ok(tact_client::parse_versions(&text)?)
383 }
384
385 pub async fn get_cdns(&self, product: &str) -> Result<reqwest::Response> {
395 Ok(self.client.get_cdns(product).await?)
397 }
398
399 pub async fn get_cdns_parsed(&self, product: &str) -> Result<Vec<CdnEntry>> {
404 let endpoint = TactEndpoint::Cdns;
405
406 if self.enabled {
408 if let Ok(cached_data) = self.read_from_cache(product, endpoint).await {
409 debug!("Cache hit for TACT {}/{}", product, endpoint.as_str());
410 return Ok(tact_client::parse_cdns(&cached_data)?);
412 }
413 }
414
415 debug!(
417 "Cache miss for TACT {}/{}, fetching from server",
418 product,
419 endpoint.as_str()
420 );
421 let response = self.client.get_cdns(product).await?;
422 let text = response.text().await?;
423
424 if let Err(e) = self.write_to_cache(product, endpoint, &text).await {
426 debug!("Failed to write to TACT cache: {}", e);
427 }
428
429 Ok(tact_client::parse_cdns(&text)?)
431 }
432
433 pub async fn get_bgdl(&self, product: &str) -> Result<reqwest::Response> {
435 Ok(self.client.get_bgdl(product).await?)
437 }
438
439 pub async fn get_bgdl_parsed(
441 &self,
442 product: &str,
443 ) -> Result<Vec<tact_client::response_types::BgdlEntry>> {
444 let endpoint = TactEndpoint::Bgdl;
445
446 if self.enabled {
448 if let Ok(cached_data) = self.read_from_cache(product, endpoint).await {
449 debug!("Cache hit for TACT {}/{}", product, endpoint.as_str());
450 return Ok(tact_client::response_types::parse_bgdl(&cached_data)?);
452 }
453 }
454
455 debug!(
457 "Cache miss for TACT {}/{}, fetching from server",
458 product,
459 endpoint.as_str()
460 );
461 let response = self.client.get_bgdl(product).await?;
462 let text = response.text().await?;
463
464 if let Err(e) = self.write_to_cache(product, endpoint, &text).await {
466 debug!("Failed to write to TACT cache: {}", e);
467 }
468
469 Ok(tact_client::response_types::parse_bgdl(&text)?)
471 }
472
473 pub async fn get(&self, path: &str) -> Result<reqwest::Response> {
475 Ok(self.client.get(path).await?)
477 }
478
479 pub async fn download_file(
486 &self,
487 cdn_host: &str,
488 path: &str,
489 hash: &str,
490 ) -> Result<reqwest::Response> {
491 Ok(self.client.download_file(cdn_host, path, hash).await?)
492 }
493
494 pub fn inner(&self) -> &HttpClient {
496 &self.client
497 }
498
499 pub fn inner_mut(&mut self) -> &mut HttpClient {
501 &mut self.client
502 }
503
504 pub async fn clear_cache(&self) -> Result<()> {
506 debug!("Clearing all cached TACT responses");
507
508 let region = self.client.region().to_string();
509 let protocol = match self.client.version() {
510 ProtocolVersion::V1 => "v1",
511 ProtocolVersion::V2 => "v2",
512 };
513
514 let cache_subdir = self.cache_dir.join(region).join(protocol);
515 if tokio::fs::metadata(&cache_subdir).await.is_ok() {
516 clear_directory_recursively(&cache_subdir).await?;
517 }
518
519 Ok(())
520 }
521
522 pub async fn clear_expired(&self) -> Result<()> {
524 debug!("Clearing expired TACT cache entries");
525
526 let region = self.client.region().to_string();
527 let protocol = match self.client.version() {
528 ProtocolVersion::V1 => "v1",
529 ProtocolVersion::V2 => "v2",
530 };
531
532 let cache_subdir = self.cache_dir.join(region).join(protocol);
533 if tokio::fs::metadata(&cache_subdir).await.is_ok() {
534 clear_expired_in_directory(&cache_subdir).await?;
535 }
536
537 Ok(())
538 }
539}
540
541fn clear_directory_recursively(
543 dir: &PathBuf,
544) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + '_>> {
545 Box::pin(async move {
546 let mut entries = tokio::fs::read_dir(dir).await?;
547 while let Some(entry) = entries.next_entry().await? {
548 let path = entry.path();
549 if let Ok(metadata) = tokio::fs::metadata(&path).await {
550 if metadata.is_dir() {
551 clear_directory_recursively(&path).await?;
552 } else {
553 tokio::fs::remove_file(&path).await?;
554 }
555 }
556 }
557 Ok(())
558 })
559}
560
561fn clear_expired_in_directory(
563 dir: &PathBuf,
564) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + '_>> {
565 Box::pin(async move {
566 let mut entries = tokio::fs::read_dir(dir).await?;
567 let now = std::time::SystemTime::now()
568 .duration_since(std::time::UNIX_EPOCH)
569 .unwrap()
570 .as_secs();
571
572 while let Some(entry) = entries.next_entry().await? {
573 let path = entry.path();
574
575 if path.is_dir() {
576 clear_expired_in_directory(&path).await?;
577 } else if path.extension().and_then(|s| s.to_str()) == Some("meta") {
578 if let Ok(metadata_str) = tokio::fs::read_to_string(&path).await {
580 if let Ok(metadata) = serde_json::from_str::<CacheMetadata>(&metadata_str) {
581 if now.saturating_sub(metadata.timestamp) >= metadata.ttl_seconds {
582 let data_path = path.with_extension("bpsv");
584 let _ = tokio::fs::remove_file(&data_path).await;
585 let _ = tokio::fs::remove_file(&path).await;
586 trace!("Removed expired TACT cache entry: {:?}", data_path);
587 }
588 }
589 }
590 }
591 }
592
593 Ok(())
594 })
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600
601 #[test]
602 fn test_endpoint_properties() {
603 assert_eq!(TactEndpoint::Versions.as_str(), "versions");
604 assert_eq!(TactEndpoint::Cdns.as_str(), "cdns");
605 assert_eq!(TactEndpoint::Bgdl.as_str(), "bgdl");
606
607 assert_eq!(TactEndpoint::Versions.ttl(), VERSIONS_CACHE_TTL);
608 assert_eq!(TactEndpoint::Cdns.ttl(), CDN_CACHE_TTL);
609 assert_eq!(TactEndpoint::Bgdl.ttl(), CDN_CACHE_TTL);
610 }
611
612 #[test]
613 fn test_sequence_number_extraction() {
614 tokio::runtime::Runtime::new().unwrap().block_on(async {
615 let client = CachedTactClient::new(Region::US, ProtocolVersion::V1)
616 .await
617 .unwrap();
618
619 let data_with_seqn = "Product!STRING:0|Seqn!DEC:4\n## seqn = 3020098\nwow|12345";
621 assert_eq!(
622 client.extract_sequence_number(data_with_seqn),
623 Some(3020098)
624 );
625
626 let data_no_seqn = "Product!STRING:0|Seqn!DEC:4\nwow|12345";
628 assert_eq!(client.extract_sequence_number(data_no_seqn), None);
629
630 let data_bad_seqn = "## seqn = not_a_number\nwow|12345";
632 assert_eq!(client.extract_sequence_number(data_bad_seqn), None);
633 });
634 }
635
636 #[test]
637 fn test_cache_path_generation() {
638 tokio::runtime::Runtime::new().unwrap().block_on(async {
639 let client = CachedTactClient::new(Region::US, ProtocolVersion::V1)
640 .await
641 .unwrap();
642
643 let path = client.get_cache_path("wow", TactEndpoint::Versions, Some(12345));
644 assert!(path.ends_with("us/v1/wow/versions-12345.bpsv"));
645
646 let path_no_seq = client.get_cache_path("d3", TactEndpoint::Cdns, None);
647 assert!(path_no_seq.ends_with("us/v1/d3/cdns-0.bpsv"));
648 });
649 }
650
651 #[test]
652 fn test_api_methods_compile() {
653 tokio::runtime::Runtime::new().unwrap().block_on(async {
655 let client = CachedTactClient::new(Region::EU, ProtocolVersion::V2)
656 .await
657 .unwrap();
658
659 assert_eq!(client.inner().region(), Region::EU);
668 assert_eq!(client.inner().version(), ProtocolVersion::V2);
669 });
670 }
671}