use http::Method;
use ntex::channel::oneshot;
use ntex::http::{header::HeaderName, header::HeaderValue, HeaderMap};
use ntex::router::Path;
use ntex::service::{fn_factory_with_config, fn_service, fn_shutdown, Service};
use ntex::web::{self, ws, Error, HttpRequest, HttpResponse};
use ntex::{chain, rt};
use sonic_rs::{JsonContainerTrait, JsonValueTrait, Value};
use std::cell::RefCell;
use std::collections::HashMap;
use std::io;
use std::rc::Rc;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::mpsc;
use tracing::{debug, error, trace, warn, Instrument};
use hive_router_internal::telemetry::traces::spans::graphql::GraphQLOperationSpan;
use hive_router_plan_executor::executors::graphql_transport_ws::{
ClientMessage, CloseCode, ConnectionInitPayload, ServerMessage, WS_SUBPROTOCOL,
};
use hive_router_plan_executor::executors::websocket_common::{
handshake_timeout, heartbeat, parse_frame_to_text, FrameNotParsedToText, WsState,
};
use hive_router_plan_executor::hooks::on_graphql_params::GraphQLParams;
use hive_router_plan_executor::plugin_context::{
PluginContext, PluginRequestState, RouterHttpRequest,
};
use hive_router_plan_executor::response::graphql_error::{GraphQLError, GraphQLErrorExtensions};
use hive_router_query_planner::state::supergraph_state::OperationKind;
use crate::jwt::errors::JwtError;
use crate::pipeline::active_subscriptions::SubscriptionEvent;
use crate::pipeline::error::PipelineError;
use crate::pipeline::execute_planned_request;
use crate::pipeline::header::{ResponseMode, SingleContentType, StreamContentType};
use crate::pipeline::{
hash_graphql_extensions, hash_graphql_variables, inbound_request_fingerprint,
normalize::normalize_request_with_cache, parser::parse_operation_with_cache, usage_reporting,
validation::validate_operation_with_cache,
};
use crate::schema_state::SchemaState;
use crate::shared_state::{RouterSharedState, SharedRouterResponse};
type WsStateRef = Rc<RefCell<WsState<tokio::sync::mpsc::Sender<()>>>>;
pub async fn ws_index(
req: HttpRequest,
schema_state: web::types::State<Arc<SchemaState>>,
shared_state: web::types::State<Arc<RouterSharedState>>,
) -> Result<HttpResponse, Error> {
let schema_state = schema_state.get_ref().clone();
let shared_state = shared_state.get_ref().clone();
let accepted_subprotocol = ws::subprotocols(&req)
.find(|p| *p == WS_SUBPROTOCOL)
.map(|_| WS_SUBPROTOCOL);
let plugin_context = req.extensions().get::<Arc<PluginContext>>().cloned();
ws::start(
req,
accepted_subprotocol,
fn_factory_with_config(move |sink: ws::WsSink| {
let schema_state = schema_state.clone();
let shared_state = shared_state.clone();
let plugin_context = plugin_context.clone();
async move {
ws_service(
accepted_subprotocol.is_some(),
sink,
schema_state,
shared_state,
plugin_context,
)
.await
}
}),
)
.await
}
async fn ws_service(
has_accepted_subprotocol: bool,
sink: ws::WsSink,
schema_state: Arc<SchemaState>,
shared_state: Arc<RouterSharedState>,
plugin_context: Option<Arc<PluginContext>>,
) -> Result<impl Service<ws::Frame, Response = Option<ws::Message>, Error = io::Error>, web::Error>
{
if !has_accepted_subprotocol {
debug!("WebSocket connection rejecting due to unacceptable subprotocol");
let _ = sink.send(CloseCode::SubprotocolNotAcceptable.into()).await;
} else {
debug!("WebSocket connection accepted");
}
let ws_uri: Rc<http::Uri> = Rc::new(
shared_state
.router_config
.websocket_path()
.expect("websocket path must exist because the websocket handler wouldn't have been mounted otherwise")
.parse()
.unwrap_or_else(|_| http::Uri::from_static("/graphql")),
);
let ws_path: Rc<Path<http::Uri>> = Rc::new(Path::new((*ws_uri).clone()));
let (heartbeat_tx, heartbeat_rx) = oneshot::channel();
let (acknowledged_tx, acknowledged_rx) = oneshot::channel();
let state: WsStateRef = Rc::new(RefCell::new(WsState::new(acknowledged_tx)));
rt::spawn(heartbeat(state.clone(), sink.clone(), heartbeat_rx));
rt::spawn(handshake_timeout(
state.clone(),
sink.clone(),
acknowledged_rx,
CloseCode::ConnectionInitTimeout,
));
let state_for_service = state.clone();
let service = fn_service(move |frame| {
let sink = sink.clone();
let state = state_for_service.clone();
let schema_state = schema_state.clone();
let shared_state = shared_state.clone();
let plugin_context = plugin_context.clone();
let ws_uri = ws_uri.clone();
let ws_path = ws_path.clone();
async move {
match parse_frame_to_text(frame, &state) {
Ok(text) => Ok(handle_text_frame(
text,
sink,
state,
&schema_state,
&shared_state,
plugin_context,
&ws_uri,
&ws_path,
)
.await),
Err(FrameNotParsedToText::Message(msg)) => Ok(Some(msg)),
Err(FrameNotParsedToText::Closed) => {
Ok(None)
}
Err(FrameNotParsedToText::None) => Ok(None),
}
}
});
let on_shutdown = fn_shutdown(async move || {
let _ = heartbeat_tx.send(());
if let Some(tx) = state.borrow_mut().acknowledged_tx.take() {
let _ = tx.send(());
}
state.borrow_mut().subscriptions.clear();
});
Ok(chain(service).and_then(on_shutdown))
}
struct SubscriptionGuard {
state: WsStateRef,
id: String,
}
impl Drop for SubscriptionGuard {
fn drop(&mut self) {
self.state.borrow_mut().subscriptions.remove(&self.id);
trace!(id = %self.id, "Subscription removed from active subscriptions");
}
}
lazy_static::lazy_static! {
static ref WS_URI_PATH: http::Uri = http::Uri::from_static("/graphql");
}
#[allow(clippy::too_many_arguments)]
async fn handle_text_frame(
text: String,
sink: ws::WsSink,
state: WsStateRef,
schema_state: &Arc<SchemaState>,
shared_state: &Arc<RouterSharedState>,
plugin_context: Option<Arc<PluginContext>>,
ws_uri: &http::Uri,
ws_path: &Path<http::Uri>,
) -> Option<ws::Message> {
let client_msg: ClientMessage = match sonic_rs::from_str(&text) {
Ok(msg) => msg,
Err(e) => {
error!("Failed to parse client message to JSON: {}", e);
return Some(CloseCode::BadRequest("Invalid message received from client").into());
}
};
trace!("type" = client_msg.as_ref(), "Received client message");
match client_msg {
ClientMessage::ConnectionInit { payload } => {
if state.borrow().handshake_received {
return Some(CloseCode::TooManyInitialisationRequests.into());
}
state.borrow_mut().handshake_received = true;
state.borrow_mut().init_payload = payload;
state.borrow_mut().complete_handshake();
let _ = sink.send(ServerMessage::ack()).await;
debug!("Connection acknowledged");
let header_map =
parse_headers_from_connection_init_payload(state.borrow().init_payload.as_ref());
if !header_map.is_empty() {
trace!("Connection init message contains headers in the payload");
} else {
trace!("Connection init message does not contain headers in the payload");
}
None
}
ClientMessage::Subscribe { id, payload } => {
if let Some(msg) = state.borrow().check_acknowledged() {
return Some(msg);
}
if state.borrow().subscriptions.contains_key(&id) {
return Some(CloseCode::SubscriberAlreadyExists(id).into());
}
let started_at = Instant::now();
let operation_span = GraphQLOperationSpan::new();
let span_clone = operation_span.clone();
let result = async {
let maybe_supergraph = schema_state.current_supergraph();
let supergraph = match maybe_supergraph.as_ref() {
Some(supergraph) => supergraph,
None => {
warn!(
"No supergraph available yet, unable to process client subscribe message"
);
return Some(ServerMessage::error(
&id,
&[GraphQLError::from_message_and_extensions(
"No supergraph available yet".to_string(),
GraphQLErrorExtensions::new_from_code("SERVICE_UNAVAILABLE"),
)],
));
}
};
let config = &shared_state.router_config.websocket;
let connection_init_headers = if config.headers.accepts_connection_headers() {
let state_borrow = state.borrow();
parse_headers_from_connection_init_payload(state_borrow.init_payload.as_ref())
} else {
HeaderMap::new()
};
let extensions_headers = if config.headers.accepts_operation_headers() {
parse_headers_from_extensions(payload.extensions.as_ref())
} else {
HeaderMap::new()
};
let mut headers = connection_init_headers;
for (key, value) in extensions_headers.iter() {
headers.insert(key.clone(), value.clone());
}
if config.headers.persist {
if let Some(ref mut init_payload) = state.borrow_mut().init_payload {
for (key, value) in headers.iter() {
if let Ok(val_str) = value.to_str() {
init_payload
.fields
.insert(key.to_string(), Value::from(val_str));
}
}
}
}
let payload = GraphQLParams {
query: Some(payload.query),
operation_name: payload.operation_name,
variables: payload.variables.unwrap_or_default(),
extensions: payload.extensions,
};
let plugin_req_state = if let (Some(plugins), Some(ref plugin_context)) = (
shared_state.plugins.as_ref(),
plugin_context,
) {
Some(PluginRequestState {
plugins: plugins.clone(),
router_http_request: RouterHttpRequest {
uri: ws_uri,
method: &Method::POST,
version: http::Version::HTTP_11,
headers: &headers,
path: ws_uri.path(),
query_string: ws_uri.query().unwrap_or(""),
match_info: ws_path,
},
context: plugin_context.clone(),
})
} else {
None
};
let client_name = headers
.get(
&shared_state
.router_config
.telemetry
.client_identification
.name_header,
)
.and_then(|v| v.to_str().ok());
let client_version = headers
.get(
&shared_state
.router_config
.telemetry
.client_identification
.version_header,
)
.and_then(|v| v.to_str().ok());
let parser_result =
match parse_operation_with_cache(shared_state, &payload, &plugin_req_state).await {
Ok(result) => result,
Err(err) => return Some(err.into_server_message(&id)),
};
let parser_payload = match parser_result {
crate::pipeline::parser::ParseResult::Payload(payload) => payload,
crate::pipeline::parser::ParseResult::EarlyResponse(_) => {
return Some(ServerMessage::error(
&id,
&[GraphQLError::from_message_and_code(
"Unexpected early response during parse",
"INTERNAL_SERVER_ERROR",
)],
));
}
};
operation_span.record_details(
&parser_payload.minified_document,
(&parser_payload).into(),
client_name,
client_version,
&parser_payload.hive_operation_hash,
);
match validate_operation_with_cache(
supergraph,
schema_state,
shared_state,
&parser_payload,
&plugin_req_state,
)
.await
{
Ok(Some(_)) => {
return Some(ServerMessage::error(
&id,
&[GraphQLError::from_message_and_code(
"Unexpected early response during validation",
"INTERNAL_SERVER_ERROR",
)],
));
}
Ok(None) => {}
Err(err) => return Some(err.into_server_message(&id)),
}
let normalize_payload = match normalize_request_with_cache(
supergraph,
schema_state,
&payload,
&parser_payload,
)
.await
{
Ok(payload) => payload,
Err(err) => return Some(err.into_server_message(&id)),
};
let is_subscription = matches!(
normalize_payload.operation_for_plan.operation_kind,
Some(OperationKind::Subscription)
);
if is_subscription && !shared_state.router_config.subscriptions.enabled {
return Some(PipelineError::SubscriptionsNotSupported.into_server_message(&id));
}
let request_dedupe_enabled =
shared_state.router_config.traffic_shaping.router.dedupe.enabled;
let fingerprint = if request_dedupe_enabled
&& matches!(
normalize_payload.operation_for_plan.operation_kind,
None | Some(OperationKind::Query) | Some(OperationKind::Subscription)
) {
let variables_hash = hash_graphql_variables(&payload.variables);
let extensions_hash = payload
.extensions
.as_ref()
.map_or(0, hash_graphql_extensions);
Some(inbound_request_fingerprint(
&Method::POST,
ws_uri.path(),
&headers,
&shared_state.in_flight_requests_header_policy,
supergraph.schema_checksum(),
normalize_payload.normalized_operation_hash,
variables_hash,
extensions_hash,
))
} else {
None
};
let response_mode = ResponseMode::Dual(
SingleContentType::default(),
StreamContentType::default(),
);
let exec = |guard| execute_planned_request(
&Method::POST,
ws_uri,
&headers,
payload,
&normalize_payload,
supergraph,
shared_state,
schema_state,
operation_span,
plugin_req_state,
&response_mode,
guard,
);
let shared_response = if let Some(fp) = fingerprint {
let result = if is_subscription {
shared_state
.in_flight_requests
.claim(fp)
.get_or_try_init_with_guard(|guard| exec(Some(guard)))
.await
} else {
shared_state
.in_flight_requests
.claim(fp)
.get_or_try_init(|| exec(None))
.await
};
let (shared_response, _role) = match result {
Ok(result) => result,
Err(PipelineError::JwtError(err)) => {
let _ = sink.send(err.clone().into_server_message(&id)).await;
return Some(err.into_close_message());
},
Err(err) => return Some(err.into_server_message(&id)),
};
Arc::unwrap_or_clone(shared_response)
} else {
match exec(None).await {
Ok(result) => result,
Err(PipelineError::JwtError(err)) => {
let _ = sink.send(err.clone().into_server_message(&id)).await;
return Some(err.into_close_message());
},
Err(err) => return Some(err.into_server_message(&id)),
}
};
if let Some(hive_usage_agent) = &shared_state.hive_usage_agent {
usage_reporting::collect_usage_report(
supergraph.supergraph_schema.clone(),
started_at.elapsed(),
client_name,
client_version,
normalize_payload.operation_for_plan.name.as_deref(),
&parser_payload.minified_document,
hive_usage_agent,
shared_state
.router_config
.telemetry
.hive
.as_ref()
.map(|c| &c.usage_reporting)
.expect("Expected Usage Reporting options to be present when Hive Usage Agent is initialized"),
shared_response.error_count(),
)
.await;
}
match shared_response {
SharedRouterResponse::Single(response) => {
let _ = sink.send(ServerMessage::next(&id, &response.body)).await;
Some(ServerMessage::complete(&id))
}
SharedRouterResponse::Stream(response) => {
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
state
.borrow_mut()
.subscriptions
.insert(id.clone(), cancel_tx);
let guard = SubscriptionGuard {
state: state.clone(),
id: id.clone(),
};
let mut receiver = response
.receiver
.unwrap_or_else(|| response.body.subscribe());
trace!(id = %id, "Subscription started");
let sink = sink.clone();
let id_for_loop = id.clone();
rt::spawn(async move {
let _guard = guard;
let mut cancelled = false;
loop {
tokio::select! {
maybe_item = receiver.recv() => {
match maybe_item {
Ok(SubscriptionEvent::Raw(data)) => {
let _ = sink.send(ServerMessage::next(&id_for_loop, &data)).await;
}
Ok(SubscriptionEvent::Error(errors)) => {
let _ = sink.send(ServerMessage::error(&id_for_loop, &errors)).await;
break;
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
trace!(id = %id_for_loop, lagged = n, "broadcast receiver lagged, skipping missed messages");
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
_ = cancel_rx.recv() => {
cancelled = true;
break;
}
}
}
if cancelled {
trace!(id = %id_for_loop, "Subscription cancelled");
} else {
trace!(id = %id_for_loop, "Subscription completed");
let _ = sink.send(ServerMessage::complete(&id_for_loop)).await;
}
});
None
}
}
}
.instrument(span_clone)
.await;
result
}
ClientMessage::Complete { id } => {
if let Some(msg) = state.borrow().check_acknowledged() {
return Some(msg);
}
if let Some(cancel_tx) = state.borrow_mut().subscriptions.remove(&id) {
trace!(id = %id, "Client requested subscription cancellation");
let _ = cancel_tx.try_send(());
}
None
}
ClientMessage::Ping {} => {
Some(ServerMessage::pong())
}
ClientMessage::Pong {} => None,
}
}
fn parse_headers_from_object(headers_obj: &sonic_rs::Object) -> HeaderMap {
let mut header_map = HeaderMap::new();
for (key, value) in headers_obj.iter() {
let Ok(name) = HeaderName::try_from(key as &str) else {
continue;
};
let header_value = if let Some(s) = value.as_str() {
HeaderValue::try_from(s).ok()
} else if let Some(b) = value.as_bool() {
Some(HeaderValue::from_static(if b { "true" } else { "false" }))
} else if let Some(n) = value.as_i64() {
Some(HeaderValue::from(n))
} else if let Some(n) = value.as_u64() {
Some(HeaderValue::from(n))
} else if let Some(f) = value.as_f64() {
HeaderValue::try_from(f.to_string()).ok()
} else {
None };
if let Some(val) = header_value {
header_map.insert(name, val);
}
}
header_map
}
fn parse_headers_from_connection_init_payload(
payload: Option<&ConnectionInitPayload>,
) -> HeaderMap {
let mut header_map = HeaderMap::new();
if let Some(payload) = payload {
if let Some(headers_prop) = payload.fields.get("headers") {
if let Some(headers_obj) = headers_prop.as_object() {
header_map = parse_headers_from_object(headers_obj);
return header_map;
}
}
let mut obj = sonic_rs::Object::new();
for (k, v) in payload.fields.iter() {
obj.insert(k, v.clone());
}
header_map = parse_headers_from_object(&obj);
}
header_map
}
fn parse_headers_from_extensions(extensions: Option<&HashMap<String, Value>>) -> HeaderMap {
let mut header_map = HeaderMap::new();
if let Some(ext) = extensions {
if let Some(headers_value) = ext.get("headers") {
if let Some(headers_obj) = headers_value.as_object() {
header_map = parse_headers_from_object(headers_obj);
}
}
}
header_map
}
impl PipelineError {
fn into_server_message(self, id: &str) -> ws::Message {
let code = self.graphql_error_code();
let message = self.graphql_error_message();
let graphql_error = GraphQLError::from_message_and_extensions(
message,
GraphQLErrorExtensions::new_from_code(code),
);
ServerMessage::error(id, &[graphql_error])
}
}
impl JwtError {
fn into_server_message(self, id: &str) -> ws::Message {
ServerMessage::error(
id,
&[GraphQLError::from_message_and_code(
self.to_string(),
self.error_code(),
)],
)
}
fn into_close_message(self) -> ws::Message {
CloseCode::Forbidden(self.error_code().to_string()).into()
}
}
#[cfg(test)]
mod tests {
use sonic_rs::json;
use super::*;
#[test]
fn should_parse_headers_from_object() {
let headers_json = json!({
"authorization": "Bearer token123",
"x-custom-header": "custom-value",
"x-number": 42,
"x-bool": true,
"x-float": 3.14
});
let headers_obj = headers_json.as_object().expect("Failed to get object");
let headers = parse_headers_from_object(&headers_obj);
assert_eq!(
headers.get("authorization").unwrap().to_str().unwrap(),
"Bearer token123"
);
assert_eq!(
headers.get("x-custom-header").unwrap().to_str().unwrap(),
"custom-value"
);
assert_eq!(headers.get("x-number").unwrap().to_str().unwrap(), "42");
assert_eq!(headers.get("x-bool").unwrap().to_str().unwrap(), "true");
assert_eq!(headers.get("x-float").unwrap().to_str().unwrap(), "3.14");
}
}