ngdp_client/
fallback_client.rs1use ngdp_cache::{cached_ribbit_client::CachedRibbitClient, cached_tact_client::CachedTactClient};
8use ribbit_client::{Endpoint, Region};
9use std::fmt;
10use tact_client::error::Error as TactError;
11use thiserror::Error;
12use tracing::{debug, warn};
13
14#[derive(Error, Debug)]
16pub enum FallbackError {
17 #[error("Both Ribbit and TACT failed: Ribbit: {ribbit_error}, TACT: {tact_error}")]
19 BothFailed {
20 ribbit_error: String,
21 tact_error: String,
22 },
23 #[error("Failed to create clients: {0}")]
25 ClientCreation(String),
26}
27
28pub struct FallbackClient {
30 ribbit_client: CachedRibbitClient,
31 tact_client: CachedTactClient,
32 region: Region,
33 caching_enabled: bool,
34}
35
36impl FallbackClient {
37 pub async fn new(region: Region) -> Result<Self, FallbackError> {
39 let ribbit_client = CachedRibbitClient::new(region)
40 .await
41 .map_err(|e| FallbackError::ClientCreation(format!("Ribbit: {e}")))?;
42
43 let tact_region = match region {
45 Region::US => tact_client::Region::US,
46 Region::EU => tact_client::Region::EU,
47 Region::CN => tact_client::Region::CN,
48 Region::KR => tact_client::Region::KR,
49 Region::TW => tact_client::Region::TW,
50 Region::SG => {
51 tact_client::Region::US
53 }
54 };
55
56 let tact_client = CachedTactClient::new(tact_region, tact_client::ProtocolVersion::V2)
57 .await
58 .map_err(|e| FallbackError::ClientCreation(format!("TACT: {e}")))?;
59
60 Ok(Self {
61 ribbit_client,
62 tact_client,
63 region,
64 caching_enabled: true,
65 })
66 }
67
68 pub fn set_caching_enabled(&mut self, enabled: bool) {
70 self.caching_enabled = enabled;
71 self.ribbit_client.set_caching_enabled(enabled);
72 self.tact_client.set_caching_enabled(enabled);
73 }
74
75 pub async fn request(
79 &self,
80 endpoint: &Endpoint,
81 ) -> Result<ribbit_client::Response, FallbackError> {
82 let tact_endpoint = match endpoint {
84 Endpoint::Summary => {
85 return self.ribbit_request(endpoint).await;
87 }
88 Endpoint::ProductVersions(product) => format!("{product}/versions"),
89 Endpoint::ProductCdns(product) => format!("{product}/cdns"),
90 Endpoint::ProductBgdl(product) => format!("{product}/bgdl"),
91 Endpoint::Cert(_) | Endpoint::Ocsp(_) => {
92 return self.ribbit_request(endpoint).await;
94 }
95 Endpoint::Custom(path) => path.clone(),
96 };
97
98 match self.ribbit_client.request(endpoint).await {
100 Ok(response) => {
101 debug!("Successfully retrieved data from Ribbit for {:?}", endpoint);
102 Ok(response)
103 }
104 Err(ribbit_err) => {
105 warn!(
106 "Ribbit request failed for {:?}: {}, trying TACT fallback",
107 endpoint, ribbit_err
108 );
109
110 match self.tact_request(&tact_endpoint).await {
112 Ok(data) => {
113 debug!(
114 "Successfully retrieved data from TACT for {}",
115 tact_endpoint
116 );
117 Ok(ribbit_client::Response {
119 raw: data.as_bytes().to_vec(),
120 data: Some(data),
121 mime_parts: None,
122 })
123 }
124 Err(tact_err) => {
125 warn!(
126 "TACT request also failed for {}: {}",
127 tact_endpoint, tact_err
128 );
129 Err(FallbackError::BothFailed {
130 ribbit_error: ribbit_err.to_string(),
131 tact_error: tact_err.to_string(),
132 })
133 }
134 }
135 }
136 }
137 }
138
139 pub async fn request_typed<T: ribbit_client::TypedResponse>(
141 &self,
142 endpoint: &Endpoint,
143 ) -> Result<T, FallbackError> {
144 let response = self.request(endpoint).await?;
145 T::from_response(&response).map_err(|e| FallbackError::BothFailed {
146 ribbit_error: format!("Failed to parse response: {e}"),
147 tact_error: "Not attempted".to_string(),
148 })
149 }
150
151 async fn ribbit_request(
153 &self,
154 endpoint: &Endpoint,
155 ) -> Result<ribbit_client::Response, FallbackError> {
156 self.ribbit_client
157 .request(endpoint)
158 .await
159 .map_err(|e| FallbackError::BothFailed {
160 ribbit_error: e.to_string(),
161 tact_error: "Not applicable for this endpoint".to_string(),
162 })
163 }
164
165 async fn tact_request(&self, endpoint: &str) -> Result<String, Box<dyn std::error::Error>> {
167 let parts: Vec<&str> = endpoint.split('/').collect();
169 if parts.len() != 2 {
170 return Err(Box::new(TactError::InvalidManifest {
171 line: 0,
172 reason: format!("Invalid endpoint format: {endpoint}"),
173 }));
174 }
175
176 let product = parts[0];
177 let endpoint_type = parts[1];
178
179 let response = match endpoint_type {
181 "versions" => self.tact_client.get_versions(product).await?,
182 "cdns" => self.tact_client.get_cdns(product).await?,
183 "bgdl" => self.tact_client.get_bgdl(product).await?,
184 _ => {
185 return Err(Box::new(TactError::InvalidManifest {
186 line: 0,
187 reason: format!("Unknown endpoint type: {endpoint_type}"),
188 }));
189 }
190 };
191
192 Ok(response.text().await?)
193 }
194
195 pub async fn clear_expired(&self) -> Result<(), Box<dyn std::error::Error>> {
197 self.ribbit_client.clear_expired().await?;
198 self.tact_client.clear_expired().await?;
199 Ok(())
200 }
201
202 pub async fn clear_cache(&self) -> Result<(), Box<dyn std::error::Error>> {
204 self.ribbit_client.clear_cache().await?;
205 self.tact_client.clear_cache().await?;
206 Ok(())
207 }
208}
209
210impl fmt::Debug for FallbackClient {
211 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212 f.debug_struct("FallbackClient")
213 .field("region", &self.region)
214 .field("caching_enabled", &self.caching_enabled)
215 .finish()
216 }
217}