hive_router_plan_executor/executors/
map.rs1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use bytes::{BufMut, Bytes, BytesMut};
4use dashmap::DashMap;
5use hive_router_config::traffic_shaping::TrafficShapingExecutorConfig;
6use http::Uri;
7use hyper_tls::HttpsConnector;
8use hyper_util::{
9 client::legacy::Client,
10 rt::{TokioExecutor, TokioTimer},
11};
12use tokio::sync::{OnceCell, Semaphore};
13
14use crate::{
15 executors::{
16 common::{HttpExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
17 dedupe::{ABuildHasher, RequestFingerprint, SharedResponse},
18 error::SubgraphExecutorError,
19 http::HTTPSubgraphExecutor,
20 },
21 response::graphql_error::GraphQLError,
22};
23
24pub struct SubgraphExecutorMap {
25 inner: HashMap<String, SubgraphExecutorBoxedArc>,
26}
27
28impl Default for SubgraphExecutorMap {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl SubgraphExecutorMap {
35 pub fn new() -> Self {
36 SubgraphExecutorMap {
37 inner: HashMap::new(),
38 }
39 }
40
41 pub async fn execute<'a>(
42 &self,
43 subgraph_name: &str,
44 execution_request: HttpExecutionRequest<'a>,
45 ) -> Bytes {
46 match self.inner.get(subgraph_name) {
47 Some(executor) => executor.execute(execution_request).await,
48 None => {
49 let graphql_error: GraphQLError = format!(
50 "Subgraph executor not found for subgraph: {}",
51 subgraph_name
52 )
53 .into();
54 let errors = vec![graphql_error];
55 let errors_bytes = sonic_rs::to_vec(&errors).unwrap();
56 let mut buffer = BytesMut::new();
57 buffer.put_slice(b"{\"errors\":");
58 buffer.put_slice(&errors_bytes);
59 buffer.put_slice(b"}");
60 buffer.freeze()
61 }
62 }
63 }
64
65 pub fn insert_boxed_arc(&mut self, subgraph_name: String, boxed_arc: SubgraphExecutorBoxedArc) {
66 self.inner.insert(subgraph_name, boxed_arc);
67 }
68
69 pub fn from_http_endpoint_map(
70 subgraph_endpoint_map: HashMap<String, String>,
71 config: TrafficShapingExecutorConfig,
72 ) -> Result<Self, SubgraphExecutorError> {
73 let https = HttpsConnector::new();
74 let client = Client::builder(TokioExecutor::new())
75 .pool_timer(TokioTimer::new())
76 .pool_idle_timeout(Duration::from_secs(config.pool_idle_timeout_seconds))
77 .pool_max_idle_per_host(config.max_connections_per_host)
78 .build(https);
79
80 let client_arc = Arc::new(client);
81 let semaphores_by_origin: DashMap<String, Arc<Semaphore>> = DashMap::new();
82 let max_connections_per_host = config.max_connections_per_host;
83 let config_arc = Arc::new(config);
84 let in_flight_requests: Arc<
85 DashMap<RequestFingerprint, Arc<OnceCell<SharedResponse>>, ABuildHasher>,
86 > = Arc::new(DashMap::with_hasher(ABuildHasher::default()));
87
88 let executor_map = subgraph_endpoint_map
89 .into_iter()
90 .map(|(subgraph_name, endpoint_str)| {
91 let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
92 SubgraphExecutorError::EndpointParseFailure(endpoint_str.clone(), e.to_string())
93 })?;
94
95 let origin = format!(
96 "{}://{}:{}",
97 endpoint_uri.scheme_str().unwrap_or("http"),
98 endpoint_uri.host().unwrap_or(""),
99 endpoint_uri.port_u16().unwrap_or_else(|| {
100 if endpoint_uri.scheme_str() == Some("https") {
101 443
102 } else {
103 80
104 }
105 })
106 );
107
108 let semaphore = semaphores_by_origin
109 .entry(origin)
110 .or_insert_with(|| Arc::new(Semaphore::new(max_connections_per_host)))
111 .clone();
112
113 let executor = HTTPSubgraphExecutor::new(
114 endpoint_uri,
115 client_arc.clone(),
116 semaphore,
117 config_arc.clone(),
118 in_flight_requests.clone(),
119 );
120
121 Ok((subgraph_name, executor.to_boxed_arc()))
122 })
123 .collect::<Result<HashMap<_, _>, SubgraphExecutorError>>()?;
124
125 Ok(SubgraphExecutorMap {
126 inner: executor_map,
127 })
128 }
129}