1use std::{
2 collections::{BTreeMap, HashMap},
3 sync::Arc,
4 time::Duration,
5};
6
7use bytes::{BufMut, BytesMut};
8use dashmap::DashMap;
9use hive_router_config::{
10 override_subgraph_urls::UrlOrExpression, traffic_shaping::DurationOrExpression,
11 HiveRouterConfig,
12};
13use http::Uri;
14use hyper_tls::HttpsConnector;
15use hyper_util::{
16 client::legacy::Client,
17 rt::{TokioExecutor, TokioTimer},
18};
19use tokio::sync::{OnceCell, Semaphore};
20use tracing::error;
21use vrl::compiler::Program as VrlProgram;
22use vrl::core::Value as VrlValue;
23
24use crate::{
25 execution::client_request_details::ClientRequestDetails,
26 executors::{
27 common::{
28 HttpExecutionResponse, SubgraphExecutionRequest, SubgraphExecutor,
29 SubgraphExecutorBoxedArc,
30 },
31 dedupe::{ABuildHasher, SharedResponse},
32 error::SubgraphExecutorError,
33 http::{HTTPSubgraphExecutor, HttpClient},
34 },
35 expressions::{CompileExpression, DurationOrProgram, ExecutableProgram},
36 response::graphql_error::GraphQLError,
37};
38
39type SubgraphName = String;
40type SubgraphEndpoint = String;
41type ExecutorsBySubgraphMap =
42 DashMap<SubgraphName, DashMap<SubgraphEndpoint, SubgraphExecutorBoxedArc>>;
43type StaticEndpointsBySubgraphMap = DashMap<SubgraphName, SubgraphEndpoint>;
44type ExpressionEndpointsBySubgraphMap = HashMap<SubgraphName, VrlProgram>;
45type TimeoutsBySubgraph = DashMap<SubgraphName, DurationOrProgram>;
46
47struct ResolvedSubgraphConfig<'a> {
48 client: Arc<HttpClient>,
49 timeout_config: &'a DurationOrExpression,
50 dedupe_enabled: bool,
51}
52
53pub struct SubgraphExecutorMap {
54 executors_by_subgraph: ExecutorsBySubgraphMap,
55 static_endpoints_by_subgraph: StaticEndpointsBySubgraphMap,
58 expression_endpoints_by_subgraph: ExpressionEndpointsBySubgraphMap,
61 timeouts_by_subgraph: TimeoutsBySubgraph,
62 global_timeout: DurationOrProgram,
63 config: Arc<HiveRouterConfig>,
64 client: Arc<HttpClient>,
65 semaphores_by_origin: DashMap<String, Arc<Semaphore>>,
66 max_connections_per_host: usize,
67 in_flight_requests: Arc<DashMap<u64, Arc<OnceCell<SharedResponse>>, ABuildHasher>>,
68}
69
70impl SubgraphExecutorMap {
71 pub fn new(config: Arc<HiveRouterConfig>, global_timeout: DurationOrProgram) -> Self {
72 let https = HttpsConnector::new();
73 let client: HttpClient = Client::builder(TokioExecutor::new())
74 .pool_timer(TokioTimer::new())
75 .pool_idle_timeout(config.traffic_shaping.all.pool_idle_timeout)
76 .pool_max_idle_per_host(config.traffic_shaping.max_connections_per_host)
77 .build(https);
78
79 let max_connections_per_host = config.traffic_shaping.max_connections_per_host;
80
81 SubgraphExecutorMap {
82 executors_by_subgraph: Default::default(),
83 static_endpoints_by_subgraph: Default::default(),
84 expression_endpoints_by_subgraph: Default::default(),
85 config,
86 client: Arc::new(client),
87 semaphores_by_origin: Default::default(),
88 max_connections_per_host,
89 in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())),
90 timeouts_by_subgraph: Default::default(),
91 global_timeout,
92 }
93 }
94
95 pub fn from_http_endpoint_map(
96 subgraph_endpoint_map: HashMap<SubgraphName, String>,
97 config: Arc<HiveRouterConfig>,
98 ) -> Result<Self, SubgraphExecutorError> {
99 let global_timeout = DurationOrProgram::compile(
100 &config.traffic_shaping.all.request_timeout,
101 None,
102 )
103 .map_err(|err| {
104 SubgraphExecutorError::RequestTimeoutExpressionBuild("all".to_string(), err.diagnostics)
105 })?;
106 let mut subgraph_executor_map = SubgraphExecutorMap::new(config.clone(), global_timeout);
107
108 for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.into_iter() {
109 let endpoint_config = config
110 .override_subgraph_urls
111 .get_subgraph_url(&subgraph_name);
112
113 let endpoint_str = match endpoint_config {
114 Some(UrlOrExpression::Url(url)) => url.clone(),
115 Some(UrlOrExpression::Expression { expression }) => {
116 subgraph_executor_map
117 .register_endpoint_expression(&subgraph_name, expression)?;
118 original_endpoint_str.clone()
119 }
120 None => original_endpoint_str.clone(),
121 };
122
123 subgraph_executor_map.register_static_endpoint(&subgraph_name, &endpoint_str);
124 subgraph_executor_map.register_executor(&subgraph_name, &endpoint_str)?;
125 subgraph_executor_map.register_subgraph_timeout(&subgraph_name)?;
126 }
127
128 Ok(subgraph_executor_map)
129 }
130
131 pub async fn execute<'a, 'req>(
132 &self,
133 subgraph_name: &str,
134 execution_request: SubgraphExecutionRequest<'a>,
135 client_request: &ClientRequestDetails<'a, 'req>,
136 ) -> HttpExecutionResponse {
137 match self.get_or_create_executor(subgraph_name, client_request) {
138 Ok(executor) => {
139 let timeout = self
140 .timeouts_by_subgraph
141 .get(subgraph_name)
142 .map(|t| {
143 let global_timeout_duration =
144 resolve_timeout(&self.global_timeout, client_request, None, "all")?;
145 resolve_timeout(
146 t.value(),
147 client_request,
148 Some(global_timeout_duration),
149 subgraph_name,
150 )
151 })
152 .transpose();
153
154 match timeout {
155 Ok(timeout) => executor.execute(execution_request, timeout).await,
156 Err(err) => {
157 error!(
158 "Failed to resolve timeout for subgraph '{}': {}",
159 subgraph_name, err,
160 );
161 self.internal_server_error_response(err.into(), subgraph_name)
162 }
163 }
164 }
165 Err(err) => {
166 error!(
167 "Subgraph executor error for subgraph '{}': {}",
168 subgraph_name, err,
169 );
170 self.internal_server_error_response(err.into(), subgraph_name)
171 }
172 }
173 }
174
175 fn internal_server_error_response(
176 &self,
177 graphql_error: GraphQLError,
178 subgraph_name: &str,
179 ) -> HttpExecutionResponse {
180 let errors = vec![graphql_error.add_subgraph_name(subgraph_name)];
181 let errors_bytes = sonic_rs::to_vec(&errors).unwrap();
182 let mut buffer = BytesMut::new();
183 buffer.put_slice(b"{\"errors\":");
184 buffer.put_slice(&errors_bytes);
185 buffer.put_slice(b"}");
186
187 HttpExecutionResponse {
188 body: buffer.freeze(),
189 headers: Default::default(),
190 }
191 }
192
193 fn get_or_create_executor(
197 &self,
198 subgraph_name: &str,
199 client_request: &ClientRequestDetails<'_, '_>,
200 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
201 self.expression_endpoints_by_subgraph
202 .get(subgraph_name)
203 .map(|expression| {
204 self.get_or_create_executor_from_expression(
205 subgraph_name,
206 expression,
207 client_request,
208 )
209 })
210 .unwrap_or_else(|| {
211 self.get_executor_from_static_endpoint(subgraph_name)
212 .ok_or_else(|| {
213 SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
214 })
215 })
216 }
217
218 fn get_or_create_executor_from_expression(
223 &self,
224 subgraph_name: &str,
225 expression: &VrlProgram,
226 client_request: &ClientRequestDetails<'_, '_>,
227 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
228 let original_url_value = VrlValue::Bytes(
229 self.static_endpoints_by_subgraph
230 .get(subgraph_name)
231 .map(|endpoint| endpoint.value().clone())
232 .ok_or_else(|| {
233 SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
234 })?
235 .into(),
236 );
237
238 let value = VrlValue::Object(BTreeMap::from([
239 ("request".into(), client_request.into()),
240 ("default".into(), original_url_value),
241 ]));
242
243 let endpoint_result = expression.execute(value).map_err(|err| {
245 SubgraphExecutorError::EndpointExpressionResolutionFailure(
246 subgraph_name.to_string(),
247 err.to_string(),
248 )
249 })?;
250
251 let endpoint_str = match endpoint_result.as_str() {
252 Some(s) => Ok(s.to_string()),
253 None => Err(SubgraphExecutorError::EndpointExpressionWrongType(
254 subgraph_name.to_string(),
255 )),
256 }?;
257
258 if let Some(executor) = self.get_executor_from_endpoint(subgraph_name, &endpoint_str) {
260 return Ok(executor);
261 }
262
263 self.register_executor(subgraph_name, &endpoint_str)
265 }
266
267 fn get_executor_from_static_endpoint(
269 &self,
270 subgraph_name: &str,
271 ) -> Option<SubgraphExecutorBoxedArc> {
272 let endpoint_ref = self.static_endpoints_by_subgraph.get(subgraph_name)?;
273 let endpoint_str = endpoint_ref.value();
274 self.get_executor_from_endpoint(subgraph_name, endpoint_str)
275 }
276
277 #[inline]
279 fn get_executor_from_endpoint(
280 &self,
281 subgraph_name: &str,
282 endpoint_str: &str,
283 ) -> Option<SubgraphExecutorBoxedArc> {
284 self.executors_by_subgraph
285 .get(subgraph_name)
286 .and_then(|endpoints| endpoints.get(endpoint_str).map(|e| e.clone()))
287 }
288
289 fn register_endpoint_expression(
292 &mut self,
293 subgraph_name: &str,
294 expression: &str,
295 ) -> Result<(), SubgraphExecutorError> {
296 let program = expression.compile_expression(None).map_err(|err| {
297 SubgraphExecutorError::EndpointExpressionBuild(
298 subgraph_name.to_string(),
299 err.diagnostics,
300 )
301 })?;
302 self.expression_endpoints_by_subgraph
303 .insert(subgraph_name.to_string(), program);
304
305 Ok(())
306 }
307
308 fn register_static_endpoint(&self, subgraph_name: &str, endpoint_str: &str) {
312 self.static_endpoints_by_subgraph
313 .insert(subgraph_name.to_string(), endpoint_str.to_string());
314 }
315
316 fn register_executor(
319 &self,
320 subgraph_name: &str,
321 endpoint_str: &str,
322 ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
323 let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
324 SubgraphExecutorError::EndpointParseFailure(endpoint_str.to_string(), e.to_string())
325 })?;
326
327 let origin = format!(
328 "{}://{}:{}",
329 endpoint_uri.scheme_str().unwrap_or("http"),
330 endpoint_uri.host().unwrap_or(""),
331 endpoint_uri.port_u16().unwrap_or_else(|| {
332 if endpoint_uri.scheme_str() == Some("https") {
333 443
334 } else {
335 80
336 }
337 })
338 );
339
340 let semaphore = self
341 .semaphores_by_origin
342 .entry(origin)
343 .or_insert_with(|| Arc::new(Semaphore::new(self.max_connections_per_host)))
344 .clone();
345
346 let subgraph_config = self.resolve_subgraph_config(subgraph_name);
347
348 let executor = HTTPSubgraphExecutor::new(
349 subgraph_name.to_string(),
350 endpoint_uri,
351 subgraph_config.client,
352 semaphore,
353 subgraph_config.dedupe_enabled,
354 self.in_flight_requests.clone(),
355 );
356
357 let executor_arc = executor.to_boxed_arc();
358
359 self.executors_by_subgraph
360 .entry(subgraph_name.to_string())
361 .or_default()
362 .insert(endpoint_str.to_string(), executor_arc.clone());
363
364 Ok(executor_arc)
365 }
366
367 fn resolve_subgraph_config<'a>(&'a self, subgraph_name: &'a str) -> ResolvedSubgraphConfig<'a> {
370 let mut config = ResolvedSubgraphConfig {
371 client: self.client.clone(),
372 timeout_config: &self.config.traffic_shaping.all.request_timeout,
373 dedupe_enabled: self.config.traffic_shaping.all.dedupe_enabled,
374 };
375
376 let Some(subgraph_config) = self.config.traffic_shaping.subgraphs.get(subgraph_name) else {
377 return config;
378 };
379
380 if let Some(pool_idle_timeout) = subgraph_config.pool_idle_timeout {
382 if pool_idle_timeout != self.config.traffic_shaping.all.pool_idle_timeout {
384 config.client = Arc::new(
385 Client::builder(TokioExecutor::new())
386 .pool_timer(TokioTimer::new())
387 .pool_idle_timeout(pool_idle_timeout)
388 .pool_max_idle_per_host(self.max_connections_per_host)
389 .build(HttpsConnector::new()),
390 );
391 }
392 }
393
394 if let Some(dedupe_enabled) = subgraph_config.dedupe_enabled {
396 config.dedupe_enabled = dedupe_enabled;
397 }
398
399 if let Some(custom_timeout) = &subgraph_config.request_timeout {
400 config.timeout_config = custom_timeout;
401 }
402
403 config
404 }
405
406 fn register_subgraph_timeout(&self, subgraph_name: &str) -> Result<(), SubgraphExecutorError> {
410 if self.timeouts_by_subgraph.contains_key(subgraph_name) {
412 return Ok(());
413 }
414
415 let timeout_config = self
417 .config
418 .traffic_shaping
419 .subgraphs
420 .get(subgraph_name)
421 .and_then(|s| s.request_timeout.as_ref())
422 .unwrap_or(&self.config.traffic_shaping.all.request_timeout);
423
424 let timeout_prog = DurationOrProgram::compile(timeout_config, None).map_err(|err| {
426 SubgraphExecutorError::RequestTimeoutExpressionBuild(
427 subgraph_name.to_string(),
428 err.diagnostics,
429 )
430 })?;
431
432 self.timeouts_by_subgraph
434 .insert(subgraph_name.to_string(), timeout_prog);
435
436 Ok(())
437 }
438}
439
440fn resolve_timeout(
443 duration_or_program: &DurationOrProgram,
444 client_request: &ClientRequestDetails<'_, '_>,
445 default_timeout: Option<Duration>,
446 timeout_name: &str,
447) -> Result<Duration, SubgraphExecutorError> {
448 let mut context_map = BTreeMap::new();
449 context_map.insert("request".into(), client_request.into());
450
451 if let Some(default) = default_timeout {
452 context_map.insert(
453 "default".into(),
454 VrlValue::Integer(default.as_millis() as i64),
455 );
456 }
457
458 let context = VrlValue::Object(context_map);
459
460 duration_or_program.resolve(context).map_err(|err| {
461 SubgraphExecutorError::TimeoutExpressionResolution(
462 timeout_name.to_string(),
463 err.to_string(),
464 )
465 })
466}