use std::{
rc::Rc,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{Context, Poll},
};
use http::{header, StatusCode};
use ntex::{
http::body::{Body, BodySize, MessageBody},
service::{Service, ServiceCtx},
util::Bytes,
web::{self, DefaultError},
Middleware, SharedCfg,
};
use crate::RouterSharedState;
#[derive(Clone)]
pub struct LongLivedClientLimitService {
enabled: bool,
}
impl LongLivedClientLimitService {
pub fn new(router_config: &hive_router_config::HiveRouterConfig) -> Self {
let limit = router_config.traffic_shaping.router.max_long_lived_clients;
let has_long_lived = router_config.subscriptions.enabled || router_config.websocket.enabled;
Self {
enabled: limit > 0 && has_long_lived,
}
}
}
impl<S> Middleware<S, SharedCfg> for LongLivedClientLimitService {
type Service = LongLivedClientLimitMiddleware<S>;
fn create(&self, service: S, _cfg: SharedCfg) -> Self::Service {
LongLivedClientLimitMiddleware {
service,
enabled: self.enabled,
}
}
}
pub struct LongLivedClientLimitMiddleware<S> {
service: S,
enabled: bool,
}
impl<S> Service<web::WebRequest<DefaultError>> for LongLivedClientLimitMiddleware<S>
where
S: Service<web::WebRequest<DefaultError>, Response = web::WebResponse, Error = web::Error>,
{
type Response = web::WebResponse;
type Error = S::Error;
ntex::forward_ready!(service);
async fn call(
&self,
req: web::WebRequest<DefaultError>,
ctx: ServiceCtx<'_, Self>,
) -> Result<Self::Response, Self::Error> {
if !self.enabled {
return ctx.call(&self.service, req).await;
}
if !is_long_lived_request(req.headers()) {
return ctx.call(&self.service, req).await;
}
let shared_state = match req.app_state::<Arc<RouterSharedState>>() {
Some(s) => s,
None => return ctx.call(&self.service, req).await,
};
let limit = shared_state
.router_config
.traffic_shaping
.router
.max_long_lived_clients;
let counter = shared_state.long_lived_client_count.clone();
let prev = counter.fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
if current < limit {
Some(current + 1)
} else {
None
}
});
if prev.is_err() {
let error_response = web::HttpResponse::build(StatusCode::SERVICE_UNAVAILABLE)
.header(header::RETRY_AFTER, "5")
.body("Too many long-lived clients");
return Ok(req.into_response(error_response));
}
let guard = LongLivedClientGuard(counter);
let response = ctx.call(&self.service, req).await?;
let response = response.map_body(|_head, body| {
let wrapped = GuardedBody {
inner: body.into_body().into(),
_guard: guard,
};
Body::from_message(wrapped).into()
});
Ok(response)
}
}
struct LongLivedClientGuard(Arc<AtomicUsize>);
impl Drop for LongLivedClientGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::AcqRel);
}
}
struct GuardedBody {
inner: Body,
_guard: LongLivedClientGuard,
}
impl MessageBody for GuardedBody {
fn size(&self) -> BodySize {
self.inner.size()
}
fn poll_next_chunk(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Bytes, Rc<dyn std::error::Error>>>> {
self.inner.poll_next_chunk(cx)
}
}
#[inline]
fn is_long_lived_request(headers: &ntex::http::HeaderMap) -> bool {
if headers
.get(header::UPGRADE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.eq_ignore_ascii_case("websocket"))
&& headers
.get(header::CONNECTION)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.to_ascii_lowercase().contains("upgrade"))
{
return true;
}
let accept = match headers.get(header::ACCEPT).and_then(|v| v.to_str().ok()) {
Some(v) if !v.is_empty() => v,
_ => return false,
};
if !looks_like_streaming_accept(accept) {
return false;
}
use crate::pipeline::header::StreamContentType;
use headers_accept::Accept;
use std::str::FromStr;
Accept::from_str(accept)
.ok()
.and_then(|a| a.negotiate(StreamContentType::media_types().iter()))
.is_some()
}
#[inline]
fn looks_like_streaming_accept(accept: &str) -> bool {
accept.contains("multipart") || accept.contains("event-stream")
}