use ngdp_cache::{cached_ribbit_client::CachedRibbitClient, cached_tact_client::CachedTactClient};
use ribbit_client::{Endpoint, Region};
use std::fmt;
use tact_client::error::Error as TactError;
use thiserror::Error;
use tracing::{debug, warn};
#[derive(Error, Debug)]
pub enum FallbackError {
#[error("Both Ribbit and TACT failed: Ribbit: {ribbit_error}, TACT: {tact_error}")]
BothFailed {
ribbit_error: String,
tact_error: String,
},
#[error("Failed to create clients: {0}")]
ClientCreation(String),
}
pub struct FallbackClient {
ribbit_client: CachedRibbitClient,
tact_client: CachedTactClient,
region: Region,
caching_enabled: bool,
}
impl FallbackClient {
pub async fn new(region: Region) -> Result<Self, FallbackError> {
let ribbit_client = CachedRibbitClient::new(region)
.await
.map_err(|e| FallbackError::ClientCreation(format!("Ribbit: {e}")))?;
let tact_region = match region {
Region::US => tact_client::Region::US,
Region::EU => tact_client::Region::EU,
Region::CN => tact_client::Region::CN,
Region::KR => tact_client::Region::KR,
Region::TW => tact_client::Region::TW,
Region::SG => {
tact_client::Region::US
}
};
let tact_client = CachedTactClient::new(tact_region, tact_client::ProtocolVersion::V2)
.await
.map_err(|e| FallbackError::ClientCreation(format!("TACT: {e}")))?;
Ok(Self {
ribbit_client,
tact_client,
region,
caching_enabled: true,
})
}
pub fn set_caching_enabled(&mut self, enabled: bool) {
self.caching_enabled = enabled;
self.ribbit_client.set_caching_enabled(enabled);
self.tact_client.set_caching_enabled(enabled);
}
pub async fn request(
&self,
endpoint: &Endpoint,
) -> Result<ribbit_client::Response, FallbackError> {
let tact_endpoint = match endpoint {
Endpoint::Summary => {
return self.ribbit_request(endpoint).await;
}
Endpoint::ProductVersions(product) => format!("{product}/versions"),
Endpoint::ProductCdns(product) => format!("{product}/cdns"),
Endpoint::ProductBgdl(product) => format!("{product}/bgdl"),
Endpoint::Cert(_) | Endpoint::Ocsp(_) => {
return self.ribbit_request(endpoint).await;
}
Endpoint::Custom(path) => path.clone(),
};
match self.ribbit_client.request(endpoint).await {
Ok(response) => {
debug!("Successfully retrieved data from Ribbit for {:?}", endpoint);
Ok(response)
}
Err(ribbit_err) => {
warn!(
"Ribbit request failed for {:?}: {}, trying TACT fallback",
endpoint, ribbit_err
);
match self.tact_request(&tact_endpoint).await {
Ok(data) => {
debug!(
"Successfully retrieved data from TACT for {}",
tact_endpoint
);
Ok(ribbit_client::Response {
raw: data.as_bytes().to_vec(),
data: Some(data),
mime_parts: None,
})
}
Err(tact_err) => {
warn!(
"TACT request also failed for {}: {}",
tact_endpoint, tact_err
);
Err(FallbackError::BothFailed {
ribbit_error: ribbit_err.to_string(),
tact_error: tact_err.to_string(),
})
}
}
}
}
}
pub async fn request_typed<T: ribbit_client::TypedResponse>(
&self,
endpoint: &Endpoint,
) -> Result<T, FallbackError> {
let response = self.request(endpoint).await?;
T::from_response(&response).map_err(|e| FallbackError::BothFailed {
ribbit_error: format!("Failed to parse response: {e}"),
tact_error: "Not attempted".to_string(),
})
}
async fn ribbit_request(
&self,
endpoint: &Endpoint,
) -> Result<ribbit_client::Response, FallbackError> {
self.ribbit_client
.request(endpoint)
.await
.map_err(|e| FallbackError::BothFailed {
ribbit_error: e.to_string(),
tact_error: "Not applicable for this endpoint".to_string(),
})
}
async fn tact_request(&self, endpoint: &str) -> Result<String, Box<dyn std::error::Error>> {
let parts: Vec<&str> = endpoint.split('/').collect();
if parts.len() != 2 {
return Err(Box::new(TactError::InvalidManifest {
line: 0,
reason: format!("Invalid endpoint format: {endpoint}"),
}));
}
let product = parts[0];
let endpoint_type = parts[1];
let response = match endpoint_type {
"versions" => self.tact_client.get_versions(product).await?,
"cdns" => self.tact_client.get_cdns(product).await?,
"bgdl" => self.tact_client.get_bgdl(product).await?,
_ => {
return Err(Box::new(TactError::InvalidManifest {
line: 0,
reason: format!("Unknown endpoint type: {endpoint_type}"),
}));
}
};
Ok(response.text().await?)
}
pub async fn clear_expired(&self) -> Result<(), Box<dyn std::error::Error>> {
self.ribbit_client.clear_expired().await?;
self.tact_client.clear_expired().await?;
Ok(())
}
pub async fn clear_cache(&self) -> Result<(), Box<dyn std::error::Error>> {
self.ribbit_client.clear_cache().await?;
self.tact_client.clear_cache().await?;
Ok(())
}
}
impl fmt::Debug for FallbackClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FallbackClient")
.field("region", &self.region)
.field("caching_enabled", &self.caching_enabled)
.finish()
}
}