1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::Arc,
4 time::Duration,
5};
6
7use dashmap::DashMap;
8use hive_router_config::{
9 override_subgraph_urls::UrlOrExpression, traffic_shaping::DurationOrExpression,
10 HiveRouterConfig,
11};
12use hive_router_internal::expressions::vrl::core::Value as VrlValue;
13use hive_router_internal::expressions::{CompileExpression, DurationOrProgram, ExecutableProgram};
14use hive_router_internal::{
15 expressions::vrl::compiler::Program as VrlProgram, telemetry::TelemetryContext,
16};
17use http::Uri;
18use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
19use hyper_util::{
20 client::legacy::{connect::HttpConnector, Client},
21 rt::{TokioExecutor, TokioTimer},
22};
23use tokio::sync::{OnceCell, Semaphore};
24
25use crate::{
26 execution::client_request_details::ClientRequestDetails,
27 executors::{
28 common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
29 dedupe::ABuildHasher,
30 error::SubgraphExecutorError,
31 http::{HTTPSubgraphExecutor, HttpClient, SubgraphHttpResponse},
32 },
33 hooks::on_subgraph_execute::{
34 OnSubgraphExecuteEndHookPayload, OnSubgraphExecuteStartHookPayload,
35 },
36 plugin_context::PluginRequestState,
37 plugin_trait::{EndControlFlow, StartControlFlow},
38 response::subgraph_response::SubgraphResponse,
39};
40
41type SubgraphName = String;
42type SubgraphEndpoint = String;
43type ExecutorsBySubgraphMap =
44 DashMap<SubgraphName, DashMap<SubgraphEndpoint, SubgraphExecutorBoxedArc>>;
45type StaticEndpointsBySubgraphMap = DashMap<SubgraphName, SubgraphEndpoint>;
46type ExpressionEndpointsBySubgraphMap = HashMap<SubgraphName, VrlProgram>;
47type TimeoutsBySubgraph = DashMap<SubgraphName, DurationOrProgram>;
48
49struct ResolvedSubgraphConfig<'a> {
50 client: Arc<HttpClient>,
51 timeout_config: &'a DurationOrExpression,
52 dedupe_enabled: bool,
53}
54
55pub type InflightRequestsMap =
56 Arc<DashMap<u64, Arc<OnceCell<(SubgraphHttpResponse, u64)>>, ABuildHasher>>;
57
58pub struct SubgraphExecutorMap {
59 executors_by_subgraph: ExecutorsBySubgraphMap,
60 static_endpoints_by_subgraph: StaticEndpointsBySubgraphMap,
63 expression_endpoints_by_subgraph: ExpressionEndpointsBySubgraphMap,
66 timeouts_by_subgraph: TimeoutsBySubgraph,
67 global_timeout: DurationOrProgram,
68 config: Arc<HiveRouterConfig>,
69 client: Arc<HttpClient>,
70 semaphores_by_origin: DashMap<String, Arc<Semaphore>>,
71 max_connections_per_host: usize,
72 in_flight_requests: InflightRequestsMap,
73 telemetry_context: Arc<TelemetryContext>,
74}
75
76fn build_https_executor() -> Result<HttpsConnector<HttpConnector>, SubgraphExecutorError> {
77 HttpsConnectorBuilder::new()
78 .with_native_roots()
79 .map_err(SubgraphExecutorError::NativeTlsCertificatesError)
80 .map(|b| b.https_or_http().enable_http1().enable_http2().build())
81}
82
83impl SubgraphExecutorMap {
84 pub fn new(
85 config: Arc<HiveRouterConfig>,
86 global_timeout: DurationOrProgram,
87 telemetry_context: Arc<TelemetryContext>,
88 ) -> Result<Self, SubgraphExecutorError> {
89 let client: HttpClient = Client::builder(TokioExecutor::new())
90 .pool_timer(TokioTimer::new())
91 .pool_idle_timeout(config.traffic_shaping.all.pool_idle_timeout)
92 .pool_max_idle_per_host(config.traffic_shaping.max_connections_per_host)
93 .build(build_https_executor()?);
94
95 let max_connections_per_host = config.traffic_shaping.max_connections_per_host;
96
97 Ok(SubgraphExecutorMap {
98 executors_by_subgraph: Default::default(),
99 static_endpoints_by_subgraph: Default::default(),
100 expression_endpoints_by_subgraph: Default::default(),
101 config,
102 client: Arc::new(client),
103 semaphores_by_origin: Default::default(),
104 max_connections_per_host,
105 in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())),
106 timeouts_by_subgraph: Default::default(),
107 global_timeout,
108 telemetry_context,
109 })
110 }
111
112 pub fn from_http_endpoint_map(
113 subgraph_endpoint_map: &HashMap<SubgraphName, String>,
114 config: Arc<HiveRouterConfig>,
115 telemetry_context: Arc<TelemetryContext>,
116 ) -> Result<Self, SubgraphExecutorError> {
117 let global_timeout = DurationOrProgram::compile(
118 &config.traffic_shaping.all.request_timeout,
119 None,
120 )
121 .map_err(|err| {
122 SubgraphExecutorError::RequestTimeoutExpressionBuild("all".to_string(), err.diagnostics)
123 })?;
124 let mut subgraph_executor_map =
125 SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?;
126
127 for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() {
128 let endpoint_config = config
129 .override_subgraph_urls
130 .get_subgraph_url(subgraph_name);
131
132 let endpoint_str = match endpoint_config {
133 Some(UrlOrExpression::Url(url)) => url.clone(),
134 Some(UrlOrExpression::Expression { expression }) => {
135 subgraph_executor_map
136 .register_endpoint_expression(subgraph_name, expression)?;
137 original_endpoint_str.clone()
138 }
139 None => original_endpoint_str.clone(),
140 };
141
142 subgraph_executor_map.register_static_endpoint(subgraph_name, &endpoint_str);
143 subgraph_executor_map.register_executor(subgraph_name, &endpoint_str)?;
144 subgraph_executor_map.register_subgraph_timeout(subgraph_name)?;
145 }
146
147 Ok(subgraph_executor_map)
148 }
149
150 pub async fn execute<'exec>(
151 &self,
152 subgraph_name: &'exec str,
153 mut execution_request: SubgraphExecutionRequest<'exec>,
154 client_request: &ClientRequestDetails<'exec>,
155 plugin_req_state: &'exec Option<PluginRequestState<'exec>>,
156 ) -> Result<SubgraphResponse<'exec>, SubgraphExecutorError> {
157 let mut executor = self.get_or_create_executor(subgraph_name, client_request)?;
158
159 let timeout = self
160 .timeouts_by_subgraph
161 .get(subgraph_name)
162 .map(|t| {
163 let global_timeout_duration =
164 resolve_timeout(&self.global_timeout, client_request, None)?;
165 resolve_timeout(t.value(), client_request, Some(global_timeout_duration))
166 })
167 .transpose()?;
168
169 let mut on_end_callbacks = vec![];
170
171 let mut execution_result: Option<SubgraphResponse<'exec>> = None;
172 if let Some(plugin_req_state) = plugin_req_state.as_ref() {
173 let mut start_payload = OnSubgraphExecuteStartHookPayload {
174 router_http_request: &plugin_req_state.router_http_request,
175 context: &plugin_req_state.context,
176 subgraph_name,
177 executor,
178 execution_request,
179 };
180 for plugin in plugin_req_state.plugins.as_ref() {
181 let result = plugin.on_subgraph_execute(start_payload).await;
182 start_payload = result.payload;
183 match result.control_flow {
184 StartControlFlow::Proceed => {
185 }
187 StartControlFlow::EndWithResponse(response) => {
188 execution_result = Some(response);
189 break;
190 }
191 StartControlFlow::OnEnd(callback) => {
192 on_end_callbacks.push(callback);
193 }
194 }
195 }
196 execution_request = start_payload.execution_request;
198 executor = start_payload.executor;
199 }
200
201 let mut execution_result = match execution_result {
202 Some(execution_result) => execution_result,
203 None => {
204 executor
205 .execute(execution_request, timeout, plugin_req_state)
206 .await?
207 }
208 };
209
210 if !on_end_callbacks.is_empty() {
211 if let Some(plugin_req_state) = plugin_req_state.as_ref() {
212 let mut end_payload = OnSubgraphExecuteEndHookPayload {
213 context: &plugin_req_state.context,
214 execution_result,
215 };
216
217 for callback in on_end_callbacks {
218 let result = callback(end_payload);
219 end_payload = result.payload;
220 match result.control_flow {
221 EndControlFlow::Proceed => {
222 }
224 EndControlFlow::EndWithResponse(response) => {
225 end_payload.execution_result = response;
226 }
227 }
228 }
229
230 execution_result = end_payload.execution_result;
232 }
233 }
234
235 Ok(execution_result)
236 }
237
238 fn get_or_create_executor(
242 &self,
243 subgraph_name: &str,
244 client_request: &ClientRequestDetails<'_>,
245 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
246 self.expression_endpoints_by_subgraph
247 .get(subgraph_name)
248 .map(|expression| {
249 self.get_or_create_executor_from_expression(
250 subgraph_name,
251 expression,
252 client_request,
253 )
254 })
255 .unwrap_or_else(|| {
256 self.get_executor_from_static_endpoint(subgraph_name)
257 .ok_or(SubgraphExecutorError::StaticEndpointNotFound)
258 })
259 }
260
261 fn get_or_create_executor_from_expression(
266 &self,
267 subgraph_name: &str,
268 expression: &VrlProgram,
269 client_request: &ClientRequestDetails<'_>,
270 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
271 let original_url_value = VrlValue::Bytes(
272 self.static_endpoints_by_subgraph
273 .get(subgraph_name)
274 .map(|endpoint| endpoint.value().clone())
275 .ok_or_else(|| SubgraphExecutorError::StaticEndpointNotFound)?
276 .into(),
277 );
278
279 let value = VrlValue::Object(BTreeMap::from([
280 ("request".into(), client_request.into()),
281 ("default".into(), original_url_value),
282 ]));
283
284 let endpoint_result = expression.execute(value).map_err(|err| {
286 SubgraphExecutorError::EndpointExpressionResolutionFailure(err.to_string())
287 })?;
288
289 let endpoint_str = match endpoint_result.as_str() {
290 Some(s) => Ok(s.to_string()),
291 None => Err(SubgraphExecutorError::EndpointExpressionWrongType),
292 }?;
293
294 if let Some(executor) = self.get_executor_from_endpoint(subgraph_name, &endpoint_str) {
296 return Ok(executor);
297 }
298
299 self.register_executor(subgraph_name, &endpoint_str)
301 }
302
303 fn get_executor_from_static_endpoint(
305 &self,
306 subgraph_name: &str,
307 ) -> Option<SubgraphExecutorBoxedArc> {
308 let endpoint_ref = self.static_endpoints_by_subgraph.get(subgraph_name)?;
309 let endpoint_str = endpoint_ref.value();
310 self.get_executor_from_endpoint(subgraph_name, endpoint_str)
311 }
312
313 #[inline]
315 fn get_executor_from_endpoint(
316 &self,
317 subgraph_name: &str,
318 endpoint_str: &str,
319 ) -> Option<SubgraphExecutorBoxedArc> {
320 self.executors_by_subgraph
321 .get(subgraph_name)
322 .and_then(|endpoints| endpoints.get(endpoint_str).map(|e| e.clone()))
323 }
324
325 fn register_endpoint_expression(
328 &mut self,
329 subgraph_name: &str,
330 expression: &str,
331 ) -> Result<(), SubgraphExecutorError> {
332 let program = expression.compile_expression(None).map_err(|err| {
333 SubgraphExecutorError::EndpointExpressionBuild(
334 subgraph_name.to_string(),
335 err.diagnostics,
336 )
337 })?;
338 self.expression_endpoints_by_subgraph
339 .insert(subgraph_name.to_string(), program);
340
341 Ok(())
342 }
343
344 fn register_static_endpoint(&self, subgraph_name: &str, endpoint_str: &str) {
348 self.static_endpoints_by_subgraph
349 .insert(subgraph_name.to_string(), endpoint_str.to_string());
350 }
351
352 fn register_executor(
355 &self,
356 subgraph_name: &str,
357 endpoint_str: &str,
358 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
359 let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
360 SubgraphExecutorError::EndpointParseFailure(endpoint_str.to_string(), e)
361 })?;
362
363 let origin = format!(
364 "{}://{}:{}",
365 endpoint_uri.scheme_str().unwrap_or("http"),
366 endpoint_uri.host().unwrap_or(""),
367 endpoint_uri.port_u16().unwrap_or_else(|| {
368 if endpoint_uri.scheme_str() == Some("https") {
369 443
370 } else {
371 80
372 }
373 })
374 );
375
376 let semaphore = self
377 .semaphores_by_origin
378 .entry(origin)
379 .or_insert_with(|| Arc::new(Semaphore::new(self.max_connections_per_host)))
380 .clone();
381
382 let subgraph_config = self.resolve_subgraph_config(subgraph_name)?;
383
384 let executor = HTTPSubgraphExecutor::new(
385 subgraph_name.to_string(),
386 endpoint_uri,
387 subgraph_config.client,
388 semaphore,
389 subgraph_config.dedupe_enabled,
390 self.in_flight_requests.clone(),
391 self.telemetry_context.clone(),
392 self.config.clone(),
393 );
394
395 let executor_arc = executor.to_boxed_arc();
396
397 self.executors_by_subgraph
398 .entry(subgraph_name.to_string())
399 .or_default()
400 .insert(endpoint_str.to_string(), executor_arc.clone());
401
402 Ok(executor_arc)
403 }
404
405 fn resolve_subgraph_config<'a>(
408 &'a self,
409 subgraph_name: &'a str,
410 ) -> Result<ResolvedSubgraphConfig<'a>, SubgraphExecutorError> {
411 let mut config = ResolvedSubgraphConfig {
412 client: self.client.clone(),
413 timeout_config: &self.config.traffic_shaping.all.request_timeout,
414 dedupe_enabled: self.config.traffic_shaping.all.dedupe_enabled,
415 };
416
417 let Some(subgraph_config) = self.config.traffic_shaping.subgraphs.get(subgraph_name) else {
418 return Ok(config);
419 };
420
421 if let Some(pool_idle_timeout) = subgraph_config.pool_idle_timeout {
423 if pool_idle_timeout != self.config.traffic_shaping.all.pool_idle_timeout {
425 config.client = Arc::new(
426 Client::builder(TokioExecutor::new())
427 .pool_timer(TokioTimer::new())
428 .pool_idle_timeout(pool_idle_timeout)
429 .pool_max_idle_per_host(self.max_connections_per_host)
430 .build(build_https_executor()?),
431 );
432 }
433 }
434
435 if let Some(dedupe_enabled) = subgraph_config.dedupe_enabled {
437 config.dedupe_enabled = dedupe_enabled;
438 }
439
440 if let Some(custom_timeout) = &subgraph_config.request_timeout {
441 config.timeout_config = custom_timeout;
442 }
443
444 Ok(config)
445 }
446
447 fn register_subgraph_timeout(&self, subgraph_name: &str) -> Result<(), SubgraphExecutorError> {
451 if self.timeouts_by_subgraph.contains_key(subgraph_name) {
453 return Ok(());
454 }
455
456 let timeout_config = self
458 .config
459 .traffic_shaping
460 .subgraphs
461 .get(subgraph_name)
462 .and_then(|s| s.request_timeout.as_ref())
463 .unwrap_or(&self.config.traffic_shaping.all.request_timeout);
464
465 let timeout_prog = DurationOrProgram::compile(timeout_config, None).map_err(|err| {
467 SubgraphExecutorError::RequestTimeoutExpressionBuild(
468 subgraph_name.to_string(),
469 err.diagnostics,
470 )
471 })?;
472
473 self.timeouts_by_subgraph
475 .insert(subgraph_name.to_string(), timeout_prog);
476
477 Ok(())
478 }
479}
480
481fn resolve_timeout(
484 duration_or_program: &DurationOrProgram,
485 client_request: &ClientRequestDetails<'_>,
486 default_timeout: Option<Duration>,
487) -> Result<Duration, SubgraphExecutorError> {
488 duration_or_program
489 .resolve(|| {
490 let mut context_map = BTreeMap::new();
491 context_map.insert("request".into(), client_request.into());
492
493 if let Some(default) = default_timeout {
494 context_map.insert(
495 "default".into(),
496 VrlValue::Integer(default.as_millis() as i64),
497 );
498 }
499
500 VrlValue::Object(context_map)
501 })
502 .map_err(|err| SubgraphExecutorError::TimeoutExpressionResolution(err.to_string()))
503}