hive_router_plan_executor/executors/
map.rs1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use bytes::{BufMut, Bytes, BytesMut};
4use dashmap::DashMap;
5use hive_router_config::traffic_shaping::TrafficShapingExecutorConfig;
6use http::Uri;
7use hyper_util::{
8 client::legacy::Client,
9 rt::{TokioExecutor, TokioTimer},
10};
11use tokio::sync::{OnceCell, Semaphore};
12
13use crate::{
14 executors::{
15 common::{HttpExecutionRequest, SubgraphExecutor, SubgraphExecutorBoxedArc},
16 dedupe::{ABuildHasher, RequestFingerprint, SharedResponse},
17 error::SubgraphExecutorError,
18 http::HTTPSubgraphExecutor,
19 },
20 response::graphql_error::GraphQLError,
21};
22
23pub struct SubgraphExecutorMap {
24 inner: HashMap<String, SubgraphExecutorBoxedArc>,
25}
26
27impl Default for SubgraphExecutorMap {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33impl SubgraphExecutorMap {
34 pub fn new() -> Self {
35 SubgraphExecutorMap {
36 inner: HashMap::new(),
37 }
38 }
39
40 pub async fn execute<'a>(
41 &self,
42 subgraph_name: &str,
43 execution_request: HttpExecutionRequest<'a>,
44 ) -> Bytes {
45 match self.inner.get(subgraph_name) {
46 Some(executor) => executor.execute(execution_request).await,
47 None => {
48 let graphql_error: GraphQLError = format!(
49 "Subgraph executor not found for subgraph: {}",
50 subgraph_name
51 )
52 .into();
53 let errors = vec![graphql_error];
54 let errors_bytes = sonic_rs::to_vec(&errors).unwrap();
55 let mut buffer = BytesMut::new();
56 buffer.put_slice(b"{\"errors\":");
57 buffer.put_slice(&errors_bytes);
58 buffer.put_slice(b"}");
59 buffer.freeze()
60 }
61 }
62 }
63
64 pub fn insert_boxed_arc(&mut self, subgraph_name: String, boxed_arc: SubgraphExecutorBoxedArc) {
65 self.inner.insert(subgraph_name, boxed_arc);
66 }
67
68 pub fn from_http_endpoint_map(
69 subgraph_endpoint_map: HashMap<String, String>,
70 config: TrafficShapingExecutorConfig,
71 ) -> Result<Self, SubgraphExecutorError> {
72 let client = Client::builder(TokioExecutor::new())
73 .pool_timer(TokioTimer::new())
74 .pool_idle_timeout(Duration::from_secs(config.pool_idle_timeout_seconds))
75 .pool_max_idle_per_host(config.max_connections_per_host)
76 .build_http();
77
78 let client_arc = Arc::new(client);
79 let semaphores_by_origin: DashMap<String, Arc<Semaphore>> = DashMap::new();
80 let max_connections_per_host = config.max_connections_per_host;
81 let config_arc = Arc::new(config);
82 let in_flight_requests: Arc<
83 DashMap<RequestFingerprint, Arc<OnceCell<SharedResponse>>, ABuildHasher>,
84 > = Arc::new(DashMap::with_hasher(ABuildHasher::default()));
85
86 let executor_map = subgraph_endpoint_map
87 .into_iter()
88 .map(|(subgraph_name, endpoint_str)| {
89 let endpoint_uri = endpoint_str.parse::<Uri>().map_err(|e| {
90 SubgraphExecutorError::EndpointParseFailure(endpoint_str.clone(), e.to_string())
91 })?;
92
93 let origin = format!(
94 "{}://{}:{}",
95 endpoint_uri.scheme_str().unwrap_or("http"),
96 endpoint_uri.host().unwrap_or(""),
97 endpoint_uri.port_u16().unwrap_or_else(|| {
98 if endpoint_uri.scheme_str() == Some("https") {
99 443
100 } else {
101 80
102 }
103 })
104 );
105
106 let semaphore = semaphores_by_origin
107 .entry(origin)
108 .or_insert_with(|| Arc::new(Semaphore::new(max_connections_per_host)))
109 .clone();
110
111 let executor = HTTPSubgraphExecutor::new(
112 endpoint_uri,
113 client_arc.clone(),
114 semaphore,
115 config_arc.clone(),
116 in_flight_requests.clone(),
117 );
118
119 Ok((subgraph_name, executor.to_boxed_arc()))
120 })
121 .collect::<Result<HashMap<_, _>, SubgraphExecutorError>>()?;
122
123 Ok(SubgraphExecutorMap {
124 inner: executor_map,
125 })
126 }
127}