hive_router_plan_executor/executors/
map.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use bytes::{BufMut, Bytes, BytesMut};
4use dashmap::DashMap;
5use hive_router_config::traffic_shaping::TrafficShapingExecutorConfig;
6use http::Uri;
7use hyper_util::{
8    client::legacy::Client,
9    rt::{TokioExecutor, TokioTimer},
10};
11use tokio::sync::{OnceCell, Semaphore};
12
13use crate::{
14    executors::{
15        common::{HttpExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
16        dedupe::{ABuildHasher, RequestFingerprint, SharedResponse},
17        error::SubgraphExecutorError,
18        http::HTTPSubgraphExecutor,
19    },
20    response::graphql_error::GraphQLError,
21};
22
23pub struct SubgraphExecutorMap {
24    inner: HashMap<String, SubgraphExecutorBoxedArc>,
25}
26
27impl Default for SubgraphExecutorMap {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl SubgraphExecutorMap {
34    pub fn new() -> Self {
35        SubgraphExecutorMap {
36            inner: HashMap::new(),
37        }
38    }
39
40    pub async fn execute<'a>(
41        &self,
42        subgraph_name: &str,
43        execution_request: HttpExecutionRequest<'a>,
44    ) -> Bytes {
45        match self.inner.get(subgraph_name) {
46            Some(executor) => executor.execute(execution_request).await,
47            None => {
48                let graphql_error: GraphQLError = format!(
49                    "Subgraph executor not found for subgraph: {}",
50                    subgraph_name
51                )
52                .into();
53                let errors = vec![graphql_error];
54                let errors_bytes = sonic_rs::to_vec(&errors).unwrap();
55                let mut buffer = BytesMut::new();
56                buffer.put_slice(b"{\"errors\":");
57                buffer.put_slice(&errors_bytes);
58                buffer.put_slice(b"}");
59                buffer.freeze()
60            }
61        }
62    }
63
64    pub fn insert_boxed_arc(&mut self, subgraph_name: String, boxed_arc: SubgraphExecutorBoxedArc) {
65        self.inner.insert(subgraph_name, boxed_arc);
66    }
67
68    pub fn from_http_endpoint_map(
69        subgraph_endpoint_map: HashMap<String, String>,
70        config: TrafficShapingExecutorConfig,
71    ) -> Result<Self, SubgraphExecutorError> {
72        let client = Client::builder(TokioExecutor::new())
73            .pool_timer(TokioTimer::new())
74            .pool_idle_timeout(Duration::from_secs(config.pool_idle_timeout_seconds))
75            .pool_max_idle_per_host(config.max_connections_per_host)
76            .build_http();
77
78        let client_arc = Arc::new(client);
79        let semaphores_by_origin: DashMap<String, Arc<Semaphore>> = DashMap::new();
80        let max_connections_per_host = config.max_connections_per_host;
81        let config_arc = Arc::new(config);
82        let in_flight_requests: Arc<
83            DashMap<RequestFingerprint, Arc<OnceCell<SharedResponse>>, ABuildHasher>,
84        > = Arc::new(DashMap::with_hasher(ABuildHasher::default()));
85
86        let executor_map = subgraph_endpoint_map
87            .into_iter()
88            .map(|(subgraph_name, endpoint_str)| {
89                let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
90                    SubgraphExecutorError::EndpointParseFailure(endpoint_str.clone(), e.to_string())
91                })?;
92
93                let origin = format!(
94                    "{}://{}:{}",
95                    endpoint_uri.scheme_str().unwrap_or("http"),
96                    endpoint_uri.host().unwrap_or(""),
97                    endpoint_uri.port_u16().unwrap_or_else(|| {
98                        if endpoint_uri.scheme_str() == Some("https") {
99                            443
100                        } else {
101                            80
102                        }
103                    })
104                );
105
106                let semaphore = semaphores_by_origin
107                    .entry(origin)
108                    .or_insert_with(|| Arc::new(Semaphore::new(max_connections_per_host)))
109                    .clone();
110
111                let executor = HTTPSubgraphExecutor::new(
112                    endpoint_uri,
113                    client_arc.clone(),
114                    semaphore,
115                    config_arc.clone(),
116                    in_flight_requests.clone(),
117                );
118
119                Ok((subgraph_name, executor.to_boxed_arc()))
120            })
121            .collect::<Result<HashMap<_, _>, SubgraphExecutorError>>()?;
122
123        Ok(SubgraphExecutorMap {
124            inner: executor_map,
125        })
126    }
127}