1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::Arc,
4 time::Duration,
5};
6
7use dashmap::DashMap;
8use hive_router_config::{
9 override_subgraph_urls::UrlOrExpression, traffic_shaping::DurationOrExpression,
10 HiveRouterConfig,
11};
12use hive_router_internal::expressions::vrl::compiler::Program as VrlProgram;
13use hive_router_internal::expressions::vrl::core::Value as VrlValue;
14use hive_router_internal::expressions::{CompileExpression, DurationOrProgram, ExecutableProgram};
15use http::Uri;
16use hyper_tls::HttpsConnector;
17use hyper_util::{
18 client::legacy::Client,
19 rt::{TokioExecutor, TokioTimer},
20};
21use tokio::sync::{OnceCell, Semaphore};
22use tracing::error;
23
24use crate::{
25 execution::client_request_details::ClientRequestDetails,
26 executors::{
27 common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
28 dedupe::ABuildHasher,
29 error::SubgraphExecutorError,
30 http::{HTTPSubgraphExecutor, HttpClient, HttpResponse},
31 },
32 response::{graphql_error::GraphQLError, subgraph_response::SubgraphResponse},
33};
34
35type SubgraphName = String;
36type SubgraphEndpoint = String;
37type ExecutorsBySubgraphMap =
38 DashMap<SubgraphName, DashMap<SubgraphEndpoint, SubgraphExecutorBoxedArc>>;
39type StaticEndpointsBySubgraphMap = DashMap<SubgraphName, SubgraphEndpoint>;
40type ExpressionEndpointsBySubgraphMap = HashMap<SubgraphName, VrlProgram>;
41type TimeoutsBySubgraph = DashMap<SubgraphName, DurationOrProgram>;
42
43struct ResolvedSubgraphConfig<'a> {
44 client: Arc<HttpClient>,
45 timeout_config: &'a DurationOrExpression,
46 dedupe_enabled: bool,
47}
48
49pub type InflightRequestsMap = Arc<DashMap<u64, Arc<OnceCell<HttpResponse>>, ABuildHasher>>;
50
51pub struct SubgraphExecutorMap {
52 executors_by_subgraph: ExecutorsBySubgraphMap,
53 static_endpoints_by_subgraph: StaticEndpointsBySubgraphMap,
56 expression_endpoints_by_subgraph: ExpressionEndpointsBySubgraphMap,
59 timeouts_by_subgraph: TimeoutsBySubgraph,
60 global_timeout: DurationOrProgram,
61 config: Arc<HiveRouterConfig>,
62 client: Arc<HttpClient>,
63 semaphores_by_origin: DashMap<String, Arc<Semaphore>>,
64 max_connections_per_host: usize,
65 in_flight_requests: InflightRequestsMap,
66}
67
68impl SubgraphExecutorMap {
69 pub fn new(config: Arc<HiveRouterConfig>, global_timeout: DurationOrProgram) -> Self {
70 let https = HttpsConnector::new();
71 let client: HttpClient = Client::builder(TokioExecutor::new())
72 .pool_timer(TokioTimer::new())
73 .pool_idle_timeout(config.traffic_shaping.all.pool_idle_timeout)
74 .pool_max_idle_per_host(config.traffic_shaping.max_connections_per_host)
75 .build(https);
76
77 let max_connections_per_host = config.traffic_shaping.max_connections_per_host;
78
79 SubgraphExecutorMap {
80 executors_by_subgraph: Default::default(),
81 static_endpoints_by_subgraph: Default::default(),
82 expression_endpoints_by_subgraph: Default::default(),
83 config,
84 client: Arc::new(client),
85 semaphores_by_origin: Default::default(),
86 max_connections_per_host,
87 in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())),
88 timeouts_by_subgraph: Default::default(),
89 global_timeout,
90 }
91 }
92
93 pub fn from_http_endpoint_map(
94 subgraph_endpoint_map: &HashMap<SubgraphName, String>,
95 config: Arc<HiveRouterConfig>,
96 ) -> Result<Self, SubgraphExecutorError> {
97 let global_timeout = DurationOrProgram::compile(
98 &config.traffic_shaping.all.request_timeout,
99 None,
100 )
101 .map_err(|err| {
102 SubgraphExecutorError::RequestTimeoutExpressionBuild("all".to_string(), err.diagnostics)
103 })?;
104 let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), global_timeout);
105
106 for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() {
107 let endpoint_config = config
108 .override_subgraph_urls
109 .get_subgraph_url(subgraph_name);
110
111 let endpoint_str = match endpoint_config {
112 Some(UrlOrExpression::Url(url)) => url.clone(),
113 Some(UrlOrExpression::Expression { expression }) => {
114 subgraph_executor_map
115 .register_endpoint_expression(subgraph_name, expression)?;
116 original_endpoint_str.clone()
117 }
118 None => original_endpoint_str.clone(),
119 };
120
121 subgraph_executor_map.register_static_endpoint(subgraph_name, &endpoint_str);
122 subgraph_executor_map.register_executor(subgraph_name, &endpoint_str)?;
123 subgraph_executor_map.register_subgraph_timeout(subgraph_name)?;
124 }
125
126 Ok(subgraph_executor_map)
127 }
128
129 pub async fn execute<'exec>(
130 &self,
131 subgraph_name: &'exec str,
132 execution_request: SubgraphExecutionRequest<'exec>,
133 client_request: &ClientRequestDetails<'exec>,
134 ) -> SubgraphResponse<'exec> {
135 match self.get_or_create_executor(subgraph_name, client_request) {
136 Ok(executor) => {
137 let timeout = self
138 .timeouts_by_subgraph
139 .get(subgraph_name)
140 .map(|t| {
141 let global_timeout_duration =
142 resolve_timeout(&self.global_timeout, client_request, None, "all")?;
143 resolve_timeout(
144 t.value(),
145 client_request,
146 Some(global_timeout_duration),
147 subgraph_name,
148 )
149 })
150 .transpose();
151
152 match timeout {
153 Ok(timeout) => executor.execute(execution_request, timeout).await,
154 Err(err) => {
155 error!(
156 "Failed to resolve timeout for subgraph '{}': {}",
157 subgraph_name, err,
158 );
159 self.internal_server_error_response(err.into(), subgraph_name)
160 }
161 }
162 }
163 Err(err) => {
164 error!(
165 "Subgraph executor error for subgraph '{}': {}",
166 subgraph_name, err,
167 );
168 self.internal_server_error_response(err.into(), subgraph_name)
169 }
170 }
171 }
172
173 fn internal_server_error_response<'exec>(
174 &self,
175 graphql_error: GraphQLError,
176 subgraph_name: &'exec str,
177 ) -> SubgraphResponse<'exec> {
178 let error_with_subgraph_name = graphql_error.add_subgraph_name(subgraph_name);
179 SubgraphResponse {
180 errors: Some(vec![error_with_subgraph_name]),
181 ..Default::default()
182 }
183 }
184
185 fn get_or_create_executor(
189 &self,
190 subgraph_name: &str,
191 client_request: &ClientRequestDetails<'_>,
192 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
193 self.expression_endpoints_by_subgraph
194 .get(subgraph_name)
195 .map(|expression| {
196 self.get_or_create_executor_from_expression(
197 subgraph_name,
198 expression,
199 client_request,
200 )
201 })
202 .unwrap_or_else(|| {
203 self.get_executor_from_static_endpoint(subgraph_name)
204 .ok_or_else(|| {
205 SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
206 })
207 })
208 }
209
210 fn get_or_create_executor_from_expression(
215 &self,
216 subgraph_name: &str,
217 expression: &VrlProgram,
218 client_request: &ClientRequestDetails<'_>,
219 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
220 let original_url_value = VrlValue::Bytes(
221 self.static_endpoints_by_subgraph
222 .get(subgraph_name)
223 .map(|endpoint| endpoint.value().clone())
224 .ok_or_else(|| {
225 SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
226 })?
227 .into(),
228 );
229
230 let value = VrlValue::Object(BTreeMap::from([
231 ("request".into(), client_request.into()),
232 ("default".into(), original_url_value),
233 ]));
234
235 let endpoint_result = expression.execute(value).map_err(|err| {
237 SubgraphExecutorError::EndpointExpressionResolutionFailure(
238 subgraph_name.to_string(),
239 err.to_string(),
240 )
241 })?;
242
243 let endpoint_str = match endpoint_result.as_str() {
244 Some(s) => Ok(s.to_string()),
245 None => Err(SubgraphExecutorError::EndpointExpressionWrongType(
246 subgraph_name.to_string(),
247 )),
248 }?;
249
250 if let Some(executor) = self.get_executor_from_endpoint(subgraph_name, &endpoint_str) {
252 return Ok(executor);
253 }
254
255 self.register_executor(subgraph_name, &endpoint_str)
257 }
258
259 fn get_executor_from_static_endpoint(
261 &self,
262 subgraph_name: &str,
263 ) -> Option<SubgraphExecutorBoxedArc> {
264 let endpoint_ref = self.static_endpoints_by_subgraph.get(subgraph_name)?;
265 let endpoint_str = endpoint_ref.value();
266 self.get_executor_from_endpoint(subgraph_name, endpoint_str)
267 }
268
269 #[inline]
271 fn get_executor_from_endpoint(
272 &self,
273 subgraph_name: &str,
274 endpoint_str: &str,
275 ) -> Option<SubgraphExecutorBoxedArc> {
276 self.executors_by_subgraph
277 .get(subgraph_name)
278 .and_then(|endpoints| endpoints.get(endpoint_str).map(|e| e.clone()))
279 }
280
281 fn register_endpoint_expression(
284 &mut self,
285 subgraph_name: &str,
286 expression: &str,
287 ) -> Result<(), SubgraphExecutorError> {
288 let program = expression.compile_expression(None).map_err(|err| {
289 SubgraphExecutorError::EndpointExpressionBuild(
290 subgraph_name.to_string(),
291 err.diagnostics,
292 )
293 })?;
294 self.expression_endpoints_by_subgraph
295 .insert(subgraph_name.to_string(), program);
296
297 Ok(())
298 }
299
300 fn register_static_endpoint(&self, subgraph_name: &str, endpoint_str: &str) {
304 self.static_endpoints_by_subgraph
305 .insert(subgraph_name.to_string(), endpoint_str.to_string());
306 }
307
308 fn register_executor(
311 &self,
312 subgraph_name: &str,
313 endpoint_str: &str,
314 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
315 let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
316 SubgraphExecutorError::EndpointParseFailure(endpoint_str.to_string(), e.to_string())
317 })?;
318
319 let origin = format!(
320 "{}://{}:{}",
321 endpoint_uri.scheme_str().unwrap_or("http"),
322 endpoint_uri.host().unwrap_or(""),
323 endpoint_uri.port_u16().unwrap_or_else(|| {
324 if endpoint_uri.scheme_str() == Some("https") {
325 443
326 } else {
327 80
328 }
329 })
330 );
331
332 let semaphore = self
333 .semaphores_by_origin
334 .entry(origin)
335 .or_insert_with(|| Arc::new(Semaphore::new(self.max_connections_per_host)))
336 .clone();
337
338 let subgraph_config = self.resolve_subgraph_config(subgraph_name);
339
340 let executor = HTTPSubgraphExecutor::new(
341 subgraph_name.to_string(),
342 endpoint_uri,
343 subgraph_config.client,
344 semaphore,
345 subgraph_config.dedupe_enabled,
346 self.in_flight_requests.clone(),
347 );
348
349 let executor_arc = executor.to_boxed_arc();
350
351 self.executors_by_subgraph
352 .entry(subgraph_name.to_string())
353 .or_default()
354 .insert(endpoint_str.to_string(), executor_arc.clone());
355
356 Ok(executor_arc)
357 }
358
359 fn resolve_subgraph_config<'a>(&'a self, subgraph_name: &'a str) -> ResolvedSubgraphConfig<'a> {
362 let mut config = ResolvedSubgraphConfig {
363 client: self.client.clone(),
364 timeout_config: &self.config.traffic_shaping.all.request_timeout,
365 dedupe_enabled: self.config.traffic_shaping.all.dedupe_enabled,
366 };
367
368 let Some(subgraph_config) = self.config.traffic_shaping.subgraphs.get(subgraph_name) else {
369 return config;
370 };
371
372 if let Some(pool_idle_timeout) = subgraph_config.pool_idle_timeout {
374 if pool_idle_timeout != self.config.traffic_shaping.all.pool_idle_timeout {
376 config.client = Arc::new(
377 Client::builder(TokioExecutor::new())
378 .pool_timer(TokioTimer::new())
379 .pool_idle_timeout(pool_idle_timeout)
380 .pool_max_idle_per_host(self.max_connections_per_host)
381 .build(HttpsConnector::new()),
382 );
383 }
384 }
385
386 if let Some(dedupe_enabled) = subgraph_config.dedupe_enabled {
388 config.dedupe_enabled = dedupe_enabled;
389 }
390
391 if let Some(custom_timeout) = &subgraph_config.request_timeout {
392 config.timeout_config = custom_timeout;
393 }
394
395 config
396 }
397
398 fn register_subgraph_timeout(&self, subgraph_name: &str) -> Result<(), SubgraphExecutorError> {
402 if self.timeouts_by_subgraph.contains_key(subgraph_name) {
404 return Ok(());
405 }
406
407 let timeout_config = self
409 .config
410 .traffic_shaping
411 .subgraphs
412 .get(subgraph_name)
413 .and_then(|s| s.request_timeout.as_ref())
414 .unwrap_or(&self.config.traffic_shaping.all.request_timeout);
415
416 let timeout_prog = DurationOrProgram::compile(timeout_config, None).map_err(|err| {
418 SubgraphExecutorError::RequestTimeoutExpressionBuild(
419 subgraph_name.to_string(),
420 err.diagnostics,
421 )
422 })?;
423
424 self.timeouts_by_subgraph
426 .insert(subgraph_name.to_string(), timeout_prog);
427
428 Ok(())
429 }
430}
431
432fn resolve_timeout(
435 duration_or_program: &DurationOrProgram,
436 client_request: &ClientRequestDetails<'_>,
437 default_timeout: Option<Duration>,
438 timeout_name: &str,
439) -> Result<Duration, SubgraphExecutorError> {
440 duration_or_program
441 .resolve(|| {
442 let mut context_map = BTreeMap::new();
443 context_map.insert("request".into(), client_request.into());
444
445 if let Some(default) = default_timeout {
446 context_map.insert(
447 "default".into(),
448 VrlValue::Integer(default.as_millis() as i64),
449 );
450 }
451
452 VrlValue::Object(context_map)
453 })
454 .map_err(|err| {
455 SubgraphExecutorError::TimeoutExpressionResolution(
456 timeout_name.to_string(),
457 err.to_string(),
458 )
459 })
460}