1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::Arc,
4 time::Duration,
5};
6
7use dashmap::DashMap;
8use hive_router_config::{
9 override_subgraph_urls::UrlOrExpression, traffic_shaping::DurationOrExpression,
10 HiveRouterConfig,
11};
12use hive_router_internal::expressions::vrl::core::Value as VrlValue;
13use hive_router_internal::expressions::{CompileExpression, DurationOrProgram, ExecutableProgram};
14use hive_router_internal::{
15 expressions::vrl::compiler::Program as VrlProgram, telemetry::TelemetryContext,
16};
17use http::Uri;
18use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
19use hyper_util::{
20 client::legacy::{connect::HttpConnector, Client},
21 rt::{TokioExecutor, TokioTimer},
22};
23use tokio::sync::{OnceCell, Semaphore};
24
25use crate::{
26 execution::client_request_details::ClientRequestDetails,
27 executors::{
28 common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
29 dedupe::ABuildHasher,
30 error::SubgraphExecutorError,
31 http::{HTTPSubgraphExecutor, HttpClient, HttpResponse},
32 },
33 response::subgraph_response::SubgraphResponse,
34};
35
36type SubgraphName = String;
37type SubgraphEndpoint = String;
38type ExecutorsBySubgraphMap =
39 DashMap<SubgraphName, DashMap<SubgraphEndpoint, SubgraphExecutorBoxedArc>>;
40type StaticEndpointsBySubgraphMap = DashMap<SubgraphName, SubgraphEndpoint>;
41type ExpressionEndpointsBySubgraphMap = HashMap<SubgraphName, VrlProgram>;
42type TimeoutsBySubgraph = DashMap<SubgraphName, DurationOrProgram>;
43
44struct ResolvedSubgraphConfig<'a> {
45 client: Arc<HttpClient>,
46 timeout_config: &'a DurationOrExpression,
47 dedupe_enabled: bool,
48}
49
50pub type InflightRequestsMap = Arc<DashMap<u64, Arc<OnceCell<(HttpResponse, u64)>>, ABuildHasher>>;
51
52pub struct SubgraphExecutorMap {
53 executors_by_subgraph: ExecutorsBySubgraphMap,
54 static_endpoints_by_subgraph: StaticEndpointsBySubgraphMap,
57 expression_endpoints_by_subgraph: ExpressionEndpointsBySubgraphMap,
60 timeouts_by_subgraph: TimeoutsBySubgraph,
61 global_timeout: DurationOrProgram,
62 config: Arc<HiveRouterConfig>,
63 client: Arc<HttpClient>,
64 semaphores_by_origin: DashMap<String, Arc<Semaphore>>,
65 max_connections_per_host: usize,
66 in_flight_requests: InflightRequestsMap,
67 telemetry_context: Arc<TelemetryContext>,
68}
69
70fn build_https_executor() -> Result<HttpsConnector<HttpConnector>, SubgraphExecutorError> {
71 HttpsConnectorBuilder::new()
72 .with_native_roots()
73 .map_err(|e| SubgraphExecutorError::NativeTlsCertificatesError(e.to_string()))
74 .map(|b| b.https_or_http().enable_http1().enable_http2().build())
75}
76
77impl SubgraphExecutorMap {
78 pub fn new(
79 config: Arc<HiveRouterConfig>,
80 global_timeout: DurationOrProgram,
81 telemetry_context: Arc<TelemetryContext>,
82 ) -> Result<Self, SubgraphExecutorError> {
83 let client: HttpClient = Client::builder(TokioExecutor::new())
84 .pool_timer(TokioTimer::new())
85 .pool_idle_timeout(config.traffic_shaping.all.pool_idle_timeout)
86 .pool_max_idle_per_host(config.traffic_shaping.max_connections_per_host)
87 .build(build_https_executor()?);
88
89 let max_connections_per_host = config.traffic_shaping.max_connections_per_host;
90
91 Ok(SubgraphExecutorMap {
92 executors_by_subgraph: Default::default(),
93 static_endpoints_by_subgraph: Default::default(),
94 expression_endpoints_by_subgraph: Default::default(),
95 config,
96 client: Arc::new(client),
97 semaphores_by_origin: Default::default(),
98 max_connections_per_host,
99 in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())),
100 timeouts_by_subgraph: Default::default(),
101 global_timeout,
102 telemetry_context,
103 })
104 }
105
106 pub fn from_http_endpoint_map(
107 subgraph_endpoint_map: &HashMap<SubgraphName, String>,
108 config: Arc<HiveRouterConfig>,
109 telemetry_context: Arc<TelemetryContext>,
110 ) -> Result<Self, SubgraphExecutorError> {
111 let global_timeout = DurationOrProgram::compile(
112 &config.traffic_shaping.all.request_timeout,
113 None,
114 )
115 .map_err(|err| {
116 SubgraphExecutorError::RequestTimeoutExpressionBuild("all".to_string(), err.diagnostics)
117 })?;
118 let mut subgraph_executor_map =
119 SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?;
120
121 for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() {
122 let endpoint_config = config
123 .override_subgraph_urls
124 .get_subgraph_url(subgraph_name);
125
126 let endpoint_str = match endpoint_config {
127 Some(UrlOrExpression::Url(url)) => url.clone(),
128 Some(UrlOrExpression::Expression { expression }) => {
129 subgraph_executor_map
130 .register_endpoint_expression(subgraph_name, expression)?;
131 original_endpoint_str.clone()
132 }
133 None => original_endpoint_str.clone(),
134 };
135
136 subgraph_executor_map.register_static_endpoint(subgraph_name, &endpoint_str);
137 subgraph_executor_map.register_executor(subgraph_name, &endpoint_str)?;
138 subgraph_executor_map.register_subgraph_timeout(subgraph_name)?;
139 }
140
141 Ok(subgraph_executor_map)
142 }
143
144 pub async fn execute<'exec>(
145 &self,
146 subgraph_name: &str,
147 execution_request: SubgraphExecutionRequest<'exec>,
148 client_request: &ClientRequestDetails<'exec>,
149 ) -> Result<SubgraphResponse<'exec>, SubgraphExecutorError> {
150 let executor = self.get_or_create_executor(subgraph_name, client_request)?;
151
152 let timeout = self
153 .timeouts_by_subgraph
154 .get(subgraph_name)
155 .map(|t| {
156 let global_timeout_duration =
157 resolve_timeout(&self.global_timeout, client_request, None, "all")?;
158 resolve_timeout(
159 t.value(),
160 client_request,
161 Some(global_timeout_duration),
162 subgraph_name,
163 )
164 })
165 .transpose()?;
166
167 executor.execute(execution_request, timeout).await
168 }
169
170 fn get_or_create_executor(
174 &self,
175 subgraph_name: &str,
176 client_request: &ClientRequestDetails<'_>,
177 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
178 self.expression_endpoints_by_subgraph
179 .get(subgraph_name)
180 .map(|expression| {
181 self.get_or_create_executor_from_expression(
182 subgraph_name,
183 expression,
184 client_request,
185 )
186 })
187 .unwrap_or_else(|| {
188 self.get_executor_from_static_endpoint(subgraph_name)
189 .ok_or_else(|| {
190 SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
191 })
192 })
193 }
194
195 fn get_or_create_executor_from_expression(
200 &self,
201 subgraph_name: &str,
202 expression: &VrlProgram,
203 client_request: &ClientRequestDetails<'_>,
204 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
205 let original_url_value = VrlValue::Bytes(
206 self.static_endpoints_by_subgraph
207 .get(subgraph_name)
208 .map(|endpoint| endpoint.value().clone())
209 .ok_or_else(|| {
210 SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
211 })?
212 .into(),
213 );
214
215 let value = VrlValue::Object(BTreeMap::from([
216 ("request".into(), client_request.into()),
217 ("default".into(), original_url_value),
218 ]));
219
220 let endpoint_result = expression.execute(value).map_err(|err| {
222 SubgraphExecutorError::EndpointExpressionResolutionFailure(
223 subgraph_name.to_string(),
224 err.to_string(),
225 )
226 })?;
227
228 let endpoint_str = match endpoint_result.as_str() {
229 Some(s) => Ok(s.to_string()),
230 None => Err(SubgraphExecutorError::EndpointExpressionWrongType(
231 subgraph_name.to_string(),
232 )),
233 }?;
234
235 if let Some(executor) = self.get_executor_from_endpoint(subgraph_name, &endpoint_str) {
237 return Ok(executor);
238 }
239
240 self.register_executor(subgraph_name, &endpoint_str)
242 }
243
244 fn get_executor_from_static_endpoint(
246 &self,
247 subgraph_name: &str,
248 ) -> Option<SubgraphExecutorBoxedArc> {
249 let endpoint_ref = self.static_endpoints_by_subgraph.get(subgraph_name)?;
250 let endpoint_str = endpoint_ref.value();
251 self.get_executor_from_endpoint(subgraph_name, endpoint_str)
252 }
253
254 #[inline]
256 fn get_executor_from_endpoint(
257 &self,
258 subgraph_name: &str,
259 endpoint_str: &str,
260 ) -> Option<SubgraphExecutorBoxedArc> {
261 self.executors_by_subgraph
262 .get(subgraph_name)
263 .and_then(|endpoints| endpoints.get(endpoint_str).map(|e| e.clone()))
264 }
265
266 fn register_endpoint_expression(
269 &mut self,
270 subgraph_name: &str,
271 expression: &str,
272 ) -> Result<(), SubgraphExecutorError> {
273 let program = expression.compile_expression(None).map_err(|err| {
274 SubgraphExecutorError::EndpointExpressionBuild(
275 subgraph_name.to_string(),
276 err.diagnostics,
277 )
278 })?;
279 self.expression_endpoints_by_subgraph
280 .insert(subgraph_name.to_string(), program);
281
282 Ok(())
283 }
284
285 fn register_static_endpoint(&self, subgraph_name: &str, endpoint_str: &str) {
289 self.static_endpoints_by_subgraph
290 .insert(subgraph_name.to_string(), endpoint_str.to_string());
291 }
292
293 fn register_executor(
296 &self,
297 subgraph_name: &str,
298 endpoint_str: &str,
299 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
300 let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
301 SubgraphExecutorError::EndpointParseFailure(endpoint_str.to_string(), e.to_string())
302 })?;
303
304 let origin = format!(
305 "{}://{}:{}",
306 endpoint_uri.scheme_str().unwrap_or("http"),
307 endpoint_uri.host().unwrap_or(""),
308 endpoint_uri.port_u16().unwrap_or_else(|| {
309 if endpoint_uri.scheme_str() == Some("https") {
310 443
311 } else {
312 80
313 }
314 })
315 );
316
317 let semaphore = self
318 .semaphores_by_origin
319 .entry(origin)
320 .or_insert_with(|| Arc::new(Semaphore::new(self.max_connections_per_host)))
321 .clone();
322
323 let subgraph_config = self.resolve_subgraph_config(subgraph_name)?;
324
325 let executor = HTTPSubgraphExecutor::new(
326 subgraph_name.to_string(),
327 endpoint_uri,
328 subgraph_config.client,
329 semaphore,
330 subgraph_config.dedupe_enabled,
331 self.in_flight_requests.clone(),
332 self.telemetry_context.clone(),
333 );
334
335 let executor_arc = executor.to_boxed_arc();
336
337 self.executors_by_subgraph
338 .entry(subgraph_name.to_string())
339 .or_default()
340 .insert(endpoint_str.to_string(), executor_arc.clone());
341
342 Ok(executor_arc)
343 }
344
345 fn resolve_subgraph_config<'a>(
348 &'a self,
349 subgraph_name: &'a str,
350 ) -> Result<ResolvedSubgraphConfig<'a>, SubgraphExecutorError> {
351 let mut config = ResolvedSubgraphConfig {
352 client: self.client.clone(),
353 timeout_config: &self.config.traffic_shaping.all.request_timeout,
354 dedupe_enabled: self.config.traffic_shaping.all.dedupe_enabled,
355 };
356
357 let Some(subgraph_config) = self.config.traffic_shaping.subgraphs.get(subgraph_name) else {
358 return Ok(config);
359 };
360
361 if let Some(pool_idle_timeout) = subgraph_config.pool_idle_timeout {
363 if pool_idle_timeout != self.config.traffic_shaping.all.pool_idle_timeout {
365 config.client = Arc::new(
366 Client::builder(TokioExecutor::new())
367 .pool_timer(TokioTimer::new())
368 .pool_idle_timeout(pool_idle_timeout)
369 .pool_max_idle_per_host(self.max_connections_per_host)
370 .build(build_https_executor()?),
371 );
372 }
373 }
374
375 if let Some(dedupe_enabled) = subgraph_config.dedupe_enabled {
377 config.dedupe_enabled = dedupe_enabled;
378 }
379
380 if let Some(custom_timeout) = &subgraph_config.request_timeout {
381 config.timeout_config = custom_timeout;
382 }
383
384 Ok(config)
385 }
386
387 fn register_subgraph_timeout(&self, subgraph_name: &str) -> Result<(), SubgraphExecutorError> {
391 if self.timeouts_by_subgraph.contains_key(subgraph_name) {
393 return Ok(());
394 }
395
396 let timeout_config = self
398 .config
399 .traffic_shaping
400 .subgraphs
401 .get(subgraph_name)
402 .and_then(|s| s.request_timeout.as_ref())
403 .unwrap_or(&self.config.traffic_shaping.all.request_timeout);
404
405 let timeout_prog = DurationOrProgram::compile(timeout_config, None).map_err(|err| {
407 SubgraphExecutorError::RequestTimeoutExpressionBuild(
408 subgraph_name.to_string(),
409 err.diagnostics,
410 )
411 })?;
412
413 self.timeouts_by_subgraph
415 .insert(subgraph_name.to_string(), timeout_prog);
416
417 Ok(())
418 }
419}
420
421fn resolve_timeout(
424 duration_or_program: &DurationOrProgram,
425 client_request: &ClientRequestDetails<'_>,
426 default_timeout: Option<Duration>,
427 timeout_name: &str,
428) -> Result<Duration, SubgraphExecutorError> {
429 duration_or_program
430 .resolve(|| {
431 let mut context_map = BTreeMap::new();
432 context_map.insert("request".into(), client_request.into());
433
434 if let Some(default) = default_timeout {
435 context_map.insert(
436 "default".into(),
437 VrlValue::Integer(default.as_millis() as i64),
438 );
439 }
440
441 VrlValue::Object(context_map)
442 })
443 .map_err(|err| {
444 SubgraphExecutorError::TimeoutExpressionResolution(
445 timeout_name.to_string(),
446 err.to_string(),
447 )
448 })
449}