Skip to main content

hive_router_plan_executor/executors/
map.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    sync::Arc,
4    time::Duration,
5};
6
7use dashmap::DashMap;
8use hive_router_config::{
9    override_subgraph_urls::UrlOrExpression, traffic_shaping::DurationOrExpression,
10    HiveRouterConfig,
11};
12use hive_router_internal::expressions::vrl::core::Value as VrlValue;
13use hive_router_internal::expressions::{CompileExpression, DurationOrProgram, ExecutableProgram};
14use hive_router_internal::{
15    expressions::vrl::compiler::Program as VrlProgram, telemetry::TelemetryContext,
16};
17use http::Uri;
18use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
19use hyper_util::{
20    client::legacy::{connect::HttpConnector, Client},
21    rt::{TokioExecutor, TokioTimer},
22};
23use tokio::sync::{OnceCell, Semaphore};
24
25use crate::{
26    execution::client_request_details::ClientRequestDetails,
27    executors::{
28        common::{SubgraphExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
29        dedupe::ABuildHasher,
30        error::SubgraphExecutorError,
31        http::{HTTPSubgraphExecutor, HttpClient, HttpResponse},
32    },
33    response::subgraph_response::SubgraphResponse,
34};
35
36type SubgraphName = String;
37type SubgraphEndpoint = String;
38type ExecutorsBySubgraphMap =
39    DashMap<SubgraphName, DashMap<SubgraphEndpoint, SubgraphExecutorBoxedArc>>;
40type StaticEndpointsBySubgraphMap = DashMap<SubgraphName, SubgraphEndpoint>;
41type ExpressionEndpointsBySubgraphMap = HashMap<SubgraphName, VrlProgram>;
42type TimeoutsBySubgraph = DashMap<SubgraphName, DurationOrProgram>;
43
44struct ResolvedSubgraphConfig<'a> {
45    client: Arc<HttpClient>,
46    timeout_config: &'a DurationOrExpression,
47    dedupe_enabled: bool,
48}
49
50pub type InflightRequestsMap = Arc<DashMap<u64, Arc<OnceCell<(HttpResponse, u64)>>, ABuildHasher>>;
51
52pub struct SubgraphExecutorMap {
53    executors_by_subgraph: ExecutorsBySubgraphMap,
54    /// Mapping from subgraph name to static endpoint for quick lookup
55    /// based on subgraph SDL and static overrides from router's config.
56    static_endpoints_by_subgraph: StaticEndpointsBySubgraphMap,
57    /// Mapping from subgraph name to VRL expression program
58    /// Only contains subgraphs with expression-based endpoint overrides
59    expression_endpoints_by_subgraph: ExpressionEndpointsBySubgraphMap,
60    timeouts_by_subgraph: TimeoutsBySubgraph,
61    global_timeout: DurationOrProgram,
62    config: Arc<HiveRouterConfig>,
63    client: Arc<HttpClient>,
64    semaphores_by_origin: DashMap<String, Arc<Semaphore>>,
65    max_connections_per_host: usize,
66    in_flight_requests: InflightRequestsMap,
67    telemetry_context: Arc<TelemetryContext>,
68}
69
70fn build_https_executor() -> Result<HttpsConnector<HttpConnector>, SubgraphExecutorError> {
71    HttpsConnectorBuilder::new()
72        .with_native_roots()
73        .map_err(|e| SubgraphExecutorError::NativeTlsCertificatesError(e.to_string()))
74        .map(|b| b.https_or_http().enable_http1().enable_http2().build())
75}
76
77impl SubgraphExecutorMap {
78    pub fn new(
79        config: Arc<HiveRouterConfig>,
80        global_timeout: DurationOrProgram,
81        telemetry_context: Arc<TelemetryContext>,
82    ) -> Result<Self, SubgraphExecutorError> {
83        let client: HttpClient = Client::builder(TokioExecutor::new())
84            .pool_timer(TokioTimer::new())
85            .pool_idle_timeout(config.traffic_shaping.all.pool_idle_timeout)
86            .pool_max_idle_per_host(config.traffic_shaping.max_connections_per_host)
87            .build(build_https_executor()?);
88
89        let max_connections_per_host = config.traffic_shaping.max_connections_per_host;
90
91        Ok(SubgraphExecutorMap {
92            executors_by_subgraph: Default::default(),
93            static_endpoints_by_subgraph: Default::default(),
94            expression_endpoints_by_subgraph: Default::default(),
95            config,
96            client: Arc::new(client),
97            semaphores_by_origin: Default::default(),
98            max_connections_per_host,
99            in_flight_requests: Arc::new(DashMap::with_hasher(ABuildHasher::default())),
100            timeouts_by_subgraph: Default::default(),
101            global_timeout,
102            telemetry_context,
103        })
104    }
105
106    pub fn from_http_endpoint_map(
107        subgraph_endpoint_map: &HashMap<SubgraphName, String>,
108        config: Arc<HiveRouterConfig>,
109        telemetry_context: Arc<TelemetryContext>,
110    ) -> Result<Self, SubgraphExecutorError> {
111        let global_timeout = DurationOrProgram::compile(
112            &config.traffic_shaping.all.request_timeout,
113            None,
114        )
115        .map_err(|err| {
116            SubgraphExecutorError::RequestTimeoutExpressionBuild("all".to_string(), err.diagnostics)
117        })?;
118        let mut subgraph_executor_map =
119            SubgraphExecutorMap::new(config.clone(), global_timeout, telemetry_context)?;
120
121        for (subgraph_name, original_endpoint_str) in subgraph_endpoint_map.iter() {
122            let endpoint_config = config
123                .override_subgraph_urls
124                .get_subgraph_url(subgraph_name);
125
126            let endpoint_str = match endpoint_config {
127                Some(UrlOrExpression::Url(url)) => url.clone(),
128                Some(UrlOrExpression::Expression { expression }) => {
129                    subgraph_executor_map
130                        .register_endpoint_expression(subgraph_name, expression)?;
131                    original_endpoint_str.clone()
132                }
133                None => original_endpoint_str.clone(),
134            };
135
136            subgraph_executor_map.register_static_endpoint(subgraph_name, &endpoint_str);
137            subgraph_executor_map.register_executor(subgraph_name, &endpoint_str)?;
138            subgraph_executor_map.register_subgraph_timeout(subgraph_name)?;
139        }
140
141        Ok(subgraph_executor_map)
142    }
143
144    pub async fn execute<'exec>(
145        &self,
146        subgraph_name: &str,
147        execution_request: SubgraphExecutionRequest<'exec>,
148        client_request: &ClientRequestDetails<'exec>,
149    ) -> Result<SubgraphResponse<'exec>, SubgraphExecutorError> {
150        let executor = self.get_or_create_executor(subgraph_name, client_request)?;
151
152        let timeout = self
153            .timeouts_by_subgraph
154            .get(subgraph_name)
155            .map(|t| {
156                let global_timeout_duration =
157                    resolve_timeout(&self.global_timeout, client_request, None, "all")?;
158                resolve_timeout(
159                    t.value(),
160                    client_request,
161                    Some(global_timeout_duration),
162                    subgraph_name,
163                )
164            })
165            .transpose()?;
166
167        executor.execute(execution_request, timeout).await
168    }
169
170    /// Looks up a subgraph executor based on the subgraph name.
171    /// Looks for an expression first, falling back to a static endpoint.
172    /// If nothing is found, returns an error.
173    fn get_or_create_executor(
174        &self,
175        subgraph_name: &str,
176        client_request: &ClientRequestDetails<'_>,
177    ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
178        self.expression_endpoints_by_subgraph
179            .get(subgraph_name)
180            .map(|expression| {
181                self.get_or_create_executor_from_expression(
182                    subgraph_name,
183                    expression,
184                    client_request,
185                )
186            })
187            .unwrap_or_else(|| {
188                self.get_executor_from_static_endpoint(subgraph_name)
189                    .ok_or_else(|| {
190                        SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
191                    })
192            })
193    }
194
195    /// Looks up a subgraph executor,
196    /// or creates one if a VRL expression is defined for the subgraph.
197    /// The expression is resolved to get the endpoint URL,
198    /// and a new executor is created and stored for future requests.
199    fn get_or_create_executor_from_expression(
200        &self,
201        subgraph_name: &str,
202        expression: &VrlProgram,
203        client_request: &ClientRequestDetails<'_>,
204    ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
205        let original_url_value = VrlValue::Bytes(
206            self.static_endpoints_by_subgraph
207                .get(subgraph_name)
208                .map(|endpoint| endpoint.value().clone())
209                .ok_or_else(|| {
210                    SubgraphExecutorError::StaticEndpointNotFound(subgraph_name.to_string())
211                })?
212                .into(),
213        );
214
215        let value = VrlValue::Object(BTreeMap::from([
216            ("request".into(), client_request.into()),
217            ("default".into(), original_url_value),
218        ]));
219
220        // Resolve the expression to get an endpoint URL.
221        let endpoint_result = expression.execute(value).map_err(|err| {
222            SubgraphExecutorError::EndpointExpressionResolutionFailure(
223                subgraph_name.to_string(),
224                err.to_string(),
225            )
226        })?;
227
228        let endpoint_str = match endpoint_result.as_str() {
229            Some(s) => Ok(s.to_string()),
230            None => Err(SubgraphExecutorError::EndpointExpressionWrongType(
231                subgraph_name.to_string(),
232            )),
233        }?;
234
235        // Check if an executor for this endpoint already exists.
236        if let Some(executor) = self.get_executor_from_endpoint(subgraph_name, &endpoint_str) {
237            return Ok(executor);
238        }
239
240        // If not, create and register a new one.
241        self.register_executor(subgraph_name, &endpoint_str)
242    }
243
244    /// Looks up a subgraph executor based on a static endpoint URL.
245    fn get_executor_from_static_endpoint(
246        &self,
247        subgraph_name: &str,
248    ) -> Option<SubgraphExecutorBoxedArc> {
249        let endpoint_ref = self.static_endpoints_by_subgraph.get(subgraph_name)?;
250        let endpoint_str = endpoint_ref.value();
251        self.get_executor_from_endpoint(subgraph_name, endpoint_str)
252    }
253
254    /// Looks up a subgraph executor for a given endpoint URL.
255    #[inline]
256    fn get_executor_from_endpoint(
257        &self,
258        subgraph_name: &str,
259        endpoint_str: &str,
260    ) -> Option<SubgraphExecutorBoxedArc> {
261        self.executors_by_subgraph
262            .get(subgraph_name)
263            .and_then(|endpoints| endpoints.get(endpoint_str).map(|e| e.clone()))
264    }
265
266    /// Registers a new HTTP subgraph executor for the given subgraph name and endpoint URL.
267    /// It makes it availble for future requests.
268    fn register_endpoint_expression(
269        &mut self,
270        subgraph_name: &str,
271        expression: &str,
272    ) -> Result<(), SubgraphExecutorError> {
273        let program = expression.compile_expression(None).map_err(|err| {
274            SubgraphExecutorError::EndpointExpressionBuild(
275                subgraph_name.to_string(),
276                err.diagnostics,
277            )
278        })?;
279        self.expression_endpoints_by_subgraph
280            .insert(subgraph_name.to_string(), program);
281
282        Ok(())
283    }
284
285    /// Registers a static endpoint for the given subgraph name.
286    /// This is used for quick lookup when no expression is defined
287    /// or when resolving the expression (to have the original URL available there).
288    fn register_static_endpoint(&self, subgraph_name: &str, endpoint_str: &str) {
289        self.static_endpoints_by_subgraph
290            .insert(subgraph_name.to_string(), endpoint_str.to_string());
291    }
292
293    /// Registers a new HTTP subgraph executor for the given subgraph name and endpoint URL.
294    /// It makes it available for future requests.
295    fn register_executor(
296        &self,
297        subgraph_name: &str,
298        endpoint_str: &str,
299    ) -> Result<SubgraphExecutorBoxedArc, SubgraphExecutorError> {
300        let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
301            SubgraphExecutorError::EndpointParseFailure(endpoint_str.to_string(), e.to_string())
302        })?;
303
304        let origin = format!(
305            "{}://{}:{}",
306            endpoint_uri.scheme_str().unwrap_or("http"),
307            endpoint_uri.host().unwrap_or(""),
308            endpoint_uri.port_u16().unwrap_or_else(|| {
309                if endpoint_uri.scheme_str() == Some("https") {
310                    443
311                } else {
312                    80
313                }
314            })
315        );
316
317        let semaphore = self
318            .semaphores_by_origin
319            .entry(origin)
320            .or_insert_with(|| Arc::new(Semaphore::new(self.max_connections_per_host)))
321            .clone();
322
323        let subgraph_config = self.resolve_subgraph_config(subgraph_name)?;
324
325        let executor = HTTPSubgraphExecutor::new(
326            subgraph_name.to_string(),
327            endpoint_uri,
328            subgraph_config.client,
329            semaphore,
330            subgraph_config.dedupe_enabled,
331            self.in_flight_requests.clone(),
332            self.telemetry_context.clone(),
333        );
334
335        let executor_arc = executor.to_boxed_arc();
336
337        self.executors_by_subgraph
338            .entry(subgraph_name.to_string())
339            .or_default()
340            .insert(endpoint_str.to_string(), executor_arc.clone());
341
342        Ok(executor_arc)
343    }
344
345    /// Resolves traffic shaping configuration for a specific subgraph, applying subgraph-specific
346    /// overrides on top of global settings
347    fn resolve_subgraph_config<'a>(
348        &'a self,
349        subgraph_name: &'a str,
350    ) -> Result<ResolvedSubgraphConfig<'a>, SubgraphExecutorError> {
351        let mut config = ResolvedSubgraphConfig {
352            client: self.client.clone(),
353            timeout_config: &self.config.traffic_shaping.all.request_timeout,
354            dedupe_enabled: self.config.traffic_shaping.all.dedupe_enabled,
355        };
356
357        let Some(subgraph_config) = self.config.traffic_shaping.subgraphs.get(subgraph_name) else {
358            return Ok(config);
359        };
360
361        // Override client only if pool idle timeout is customized
362        if let Some(pool_idle_timeout) = subgraph_config.pool_idle_timeout {
363            // Only override if it's different from the global setting
364            if pool_idle_timeout != self.config.traffic_shaping.all.pool_idle_timeout {
365                config.client = Arc::new(
366                    Client::builder(TokioExecutor::new())
367                        .pool_timer(TokioTimer::new())
368                        .pool_idle_timeout(pool_idle_timeout)
369                        .pool_max_idle_per_host(self.max_connections_per_host)
370                        .build(build_https_executor()?),
371                );
372            }
373        }
374
375        // Apply other subgraph-specific overrides
376        if let Some(dedupe_enabled) = subgraph_config.dedupe_enabled {
377            config.dedupe_enabled = dedupe_enabled;
378        }
379
380        if let Some(custom_timeout) = &subgraph_config.request_timeout {
381            config.timeout_config = custom_timeout;
382        }
383
384        Ok(config)
385    }
386
387    /// Compiles and registers a timeout for a specific subgraph.
388    /// If the subgraph has a custom timeout configuration, it will be used.
389    /// Otherwise, the global timeout configuration will be used.
390    fn register_subgraph_timeout(&self, subgraph_name: &str) -> Result<(), SubgraphExecutorError> {
391        // Check if this subgraph already has a timeout registered
392        if self.timeouts_by_subgraph.contains_key(subgraph_name) {
393            return Ok(());
394        }
395
396        // Get the timeout configuration for this subgraph, or fall back to global
397        let timeout_config = self
398            .config
399            .traffic_shaping
400            .subgraphs
401            .get(subgraph_name)
402            .and_then(|s| s.request_timeout.as_ref())
403            .unwrap_or(&self.config.traffic_shaping.all.request_timeout);
404
405        // Compile the timeout configuration into a DurationOrProgram
406        let timeout_prog = DurationOrProgram::compile(timeout_config, None).map_err(|err| {
407            SubgraphExecutorError::RequestTimeoutExpressionBuild(
408                subgraph_name.to_string(),
409                err.diagnostics,
410            )
411        })?;
412
413        // Register the compiled timeout
414        self.timeouts_by_subgraph
415            .insert(subgraph_name.to_string(), timeout_prog);
416
417        Ok(())
418    }
419}
420
421/// Resolves a timeout DurationOrProgram to a concrete Duration.
422/// Optionally includes a default timeout value in the VRL context.
423fn resolve_timeout(
424    duration_or_program: &DurationOrProgram,
425    client_request: &ClientRequestDetails<'_>,
426    default_timeout: Option<Duration>,
427    timeout_name: &str,
428) -> Result<Duration, SubgraphExecutorError> {
429    duration_or_program
430        .resolve(|| {
431            let mut context_map = BTreeMap::new();
432            context_map.insert("request".into(), client_request.into());
433
434            if let Some(default) = default_timeout {
435                context_map.insert(
436                    "default".into(),
437                    VrlValue::Integer(default.as_millis() as i64),
438                );
439            }
440
441            VrlValue::Object(context_map)
442        })
443        .map_err(|err| {
444            SubgraphExecutorError::TimeoutExpressionResolution(
445                timeout_name.to_string(),
446                err.to_string(),
447            )
448        })
449}