1use std::fmt::Display;
2use std::sync::RwLock;
3use std::time::Duration;
4use std::time::SystemTime;
5
6use reqwest::header::HeaderMap;
7use reqwest::header::HeaderValue;
8use reqwest::header::InvalidHeaderValue;
9use reqwest::header::IF_NONE_MATCH;
10use reqwest_middleware::ClientBuilder;
11use reqwest_middleware::ClientWithMiddleware;
12use reqwest_retry::policies::ExponentialBackoff;
13use reqwest_retry::RetryDecision;
14use reqwest_retry::RetryPolicy;
15use reqwest_retry::RetryTransientMiddleware;
16
17#[derive(Debug)]
18pub struct SupergraphFetcher<AsyncOrSync> {
19 client: SupergraphFetcherAsyncOrSyncClient,
20 endpoint: String,
21 etag: RwLock<Option<HeaderValue>>,
22 state: std::marker::PhantomData<AsyncOrSync>,
23}
24
25#[derive(Debug)]
26pub struct SupergraphFetcherAsyncState;
27#[derive(Debug)]
28pub struct SupergraphFetcherSyncState;
29
30#[derive(Debug)]
31enum SupergraphFetcherAsyncOrSyncClient {
32 Async {
33 reqwest_client: ClientWithMiddleware,
34 },
35 Sync {
36 reqwest_client: reqwest::blocking::Client,
37 retry_policy: ExponentialBackoff,
38 },
39}
40
41pub enum SupergraphFetcherError {
42 FetcherCreationError(reqwest::Error),
43 NetworkError(reqwest_middleware::Error),
44 NetworkResponseError(reqwest::Error),
45 Lock(String),
46 InvalidKey(InvalidHeaderValue),
47}
48
49impl Display for SupergraphFetcherError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 SupergraphFetcherError::FetcherCreationError(e) => {
53 write!(f, "Creating fetcher failed: {}", e)
54 }
55 SupergraphFetcherError::NetworkError(e) => write!(f, "Network error: {}", e),
56 SupergraphFetcherError::NetworkResponseError(e) => {
57 write!(f, "Network response error: {}", e)
58 }
59 SupergraphFetcherError::Lock(e) => write!(f, "Lock error: {}", e),
60 SupergraphFetcherError::InvalidKey(e) => write!(f, "Invalid CDN key: {}", e),
61 }
62 }
63}
64
65fn prepare_client_config(
66 mut endpoint: String,
67 key: &str,
68 retry_count: u32,
69) -> Result<(String, HeaderMap, ExponentialBackoff), SupergraphFetcherError> {
70 if !endpoint.ends_with("/supergraph") {
71 if endpoint.ends_with("/") {
72 endpoint.push_str("supergraph");
73 } else {
74 endpoint.push_str("/supergraph");
75 }
76 }
77
78 let mut headers = HeaderMap::new();
79 let mut cdn_key_header =
80 HeaderValue::from_str(key).map_err(SupergraphFetcherError::InvalidKey)?;
81 cdn_key_header.set_sensitive(true);
82 headers.insert("X-Hive-CDN-Key", cdn_key_header);
83
84 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retry_count);
85
86 Ok((endpoint, headers, retry_policy))
87}
88
89impl SupergraphFetcher<SupergraphFetcherSyncState> {
90 #[allow(clippy::too_many_arguments)]
91 pub fn try_new_sync(
92 endpoint: String,
93 key: &str,
94 user_agent: String,
95 connect_timeout: Duration,
96 request_timeout: Duration,
97 accept_invalid_certs: bool,
98 retry_count: u32,
99 ) -> Result<Self, SupergraphFetcherError> {
100 let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?;
101
102 Ok(Self {
103 client: SupergraphFetcherAsyncOrSyncClient::Sync {
104 reqwest_client: reqwest::blocking::Client::builder()
105 .danger_accept_invalid_certs(accept_invalid_certs)
106 .connect_timeout(connect_timeout)
107 .timeout(request_timeout)
108 .user_agent(user_agent)
109 .default_headers(headers)
110 .build()
111 .map_err(SupergraphFetcherError::FetcherCreationError)?,
112 retry_policy,
113 },
114 endpoint,
115 etag: RwLock::new(None),
116 state: std::marker::PhantomData,
117 })
118 }
119
120 pub fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
121 let request_start_time = SystemTime::now();
122 let mut n_past_retries = 0;
124 let (reqwest_client, retry_policy) = match &self.client {
125 SupergraphFetcherAsyncOrSyncClient::Sync {
126 reqwest_client,
127 retry_policy,
128 } => (reqwest_client, retry_policy),
129 _ => unreachable!(),
130 };
131 let resp = loop {
132 let mut req = reqwest_client.get(&self.endpoint);
133 let etag = self.get_latest_etag()?;
134 if let Some(etag) = etag {
135 req = req.header(IF_NONE_MATCH, etag);
136 }
137 let response = req.send();
138
139 match response {
140 Ok(resp) => break resp,
141 Err(e) => match retry_policy.should_retry(request_start_time, n_past_retries) {
142 RetryDecision::DoNotRetry => {
143 return Err(SupergraphFetcherError::NetworkError(
144 reqwest_middleware::Error::Reqwest(e),
145 ));
146 }
147 RetryDecision::Retry { execute_after } => {
148 n_past_retries += 1;
149 if let Ok(duration) = execute_after.elapsed() {
150 std::thread::sleep(duration);
151 }
152 }
153 },
154 }
155 };
156
157 if resp.status().as_u16() == 304 {
158 return Ok(None);
159 }
160
161 let etag = resp.headers().get("etag");
162 self.update_latest_etag(etag)?;
163
164 let text = resp
165 .text()
166 .map_err(SupergraphFetcherError::NetworkResponseError)?;
167
168 Ok(Some(text))
169 }
170}
171
172impl SupergraphFetcher<SupergraphFetcherAsyncState> {
173 #[allow(clippy::too_many_arguments)]
174 pub fn try_new_async(
175 endpoint: String,
176 key: &str,
177 user_agent: String,
178 connect_timeout: Duration,
179 request_timeout: Duration,
180 accept_invalid_certs: bool,
181 retry_count: u32,
182 ) -> Result<Self, SupergraphFetcherError> {
183 let (endpoint, headers, retry_policy) = prepare_client_config(endpoint, key, retry_count)?;
184
185 let reqwest_agent = reqwest::Client::builder()
186 .danger_accept_invalid_certs(accept_invalid_certs)
187 .connect_timeout(connect_timeout)
188 .timeout(request_timeout)
189 .default_headers(headers)
190 .user_agent(user_agent)
191 .build()
192 .map_err(SupergraphFetcherError::FetcherCreationError)?;
193 let reqwest_client = ClientBuilder::new(reqwest_agent)
194 .with(RetryTransientMiddleware::new_with_policy(retry_policy))
195 .build();
196
197 Ok(Self {
198 client: SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client },
199 endpoint,
200 etag: RwLock::new(None),
201 state: std::marker::PhantomData,
202 })
203 }
204 pub async fn fetch_supergraph(&self) -> Result<Option<String>, SupergraphFetcherError> {
205 let reqwest_client = match &self.client {
206 SupergraphFetcherAsyncOrSyncClient::Async { reqwest_client } => reqwest_client,
207 _ => unreachable!(),
208 };
209 let mut req = reqwest_client.get(&self.endpoint);
210 let etag = self.get_latest_etag()?;
211 if let Some(etag) = etag {
212 req = req.header(IF_NONE_MATCH, etag);
213 }
214
215 let resp = req
216 .send()
217 .await
218 .map_err(SupergraphFetcherError::NetworkError)?;
219
220 if resp.status().as_u16() == 304 {
221 return Ok(None);
222 }
223
224 let etag = resp.headers().get("etag");
225 self.update_latest_etag(etag)?;
226
227 let text = resp
228 .text()
229 .await
230 .map_err(SupergraphFetcherError::NetworkResponseError)?;
231
232 Ok(Some(text))
233 }
234}
235
236impl<AsyncOrSync> SupergraphFetcher<AsyncOrSync> {
237 fn get_latest_etag(&self) -> Result<Option<HeaderValue>, SupergraphFetcherError> {
238 let guard: std::sync::RwLockReadGuard<'_, Option<HeaderValue>> =
239 self.etag.try_read().map_err(|e| {
240 SupergraphFetcherError::Lock(format!("Failed to read the etag record: {:?}", e))
241 })?;
242
243 Ok(guard.clone())
244 }
245
246 fn update_latest_etag(&self, etag: Option<&HeaderValue>) -> Result<(), SupergraphFetcherError> {
247 let mut guard: std::sync::RwLockWriteGuard<'_, Option<HeaderValue>> =
248 self.etag.try_write().map_err(|e| {
249 SupergraphFetcherError::Lock(format!("Failed to update the etag record: {:?}", e))
250 })?;
251
252 if let Some(etag_value) = etag {
253 *guard = Some(etag_value.clone());
254 } else {
255 *guard = None;
256 }
257
258 Ok(())
259 }
260}