use std::{
cell::RefCell,
collections::HashSet,
future::Future,
num::NonZeroU32,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
use governor::{
clock::DefaultClock,
middleware::NoOpMiddleware,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use once_cell::sync::OnceCell;
use pin_project_lite::pin_project;
use rand::{rngs::SmallRng, RngCore, SeedableRng};
use tokio::sync::watch;
use tracing::instrument::{Instrument, Instrumented};
use vise::{
Buckets, Counter, EncodeLabelSet, EncodeLabelValue, Family, GaugeGuard, Histogram, Metrics,
};
use zksync_web3_decl::jsonrpsee::{
server::middleware::rpc::{layer::ResponseFuture, RpcServiceT},
types::{error::ErrorCode, ErrorObject, Request},
MethodResponse,
};
use super::metadata::{MethodCall, MethodTracer};
use crate::web3::metrics::{ObservedRpcParams, API_METRICS};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, EncodeLabelValue, EncodeLabelSet)]
#[metrics(label = "transport", rename_all = "snake_case")]
pub(crate) enum Transport {
Ws,
}
#[derive(Debug, Metrics)]
#[metrics(prefix = "api_jsonrpc_backend_batch")]
struct LimitMiddlewareMetrics {
rate_limited: Family<Transport, Counter>,
#[metrics(buckets = Buckets::exponential(1.0..=512.0, 2.0))]
size: Family<Transport, Histogram<usize>>,
rejected: Family<Transport, Counter>,
}
#[vise::register]
static METRICS: vise::Global<LimitMiddlewareMetrics> = vise::Global::new();
pub(crate) struct LimitMiddleware<S> {
inner: S,
rate_limiter: Option<RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>>,
transport: Transport,
_guard: GaugeGuard,
}
impl<S> LimitMiddleware<S> {
pub(crate) fn new(inner: S, requests_per_minute_limit: Option<NonZeroU32>) -> Self {
Self {
inner,
rate_limiter: requests_per_minute_limit
.map(|limit| RateLimiter::direct(Quota::per_minute(limit))),
transport: Transport::Ws,
_guard: API_METRICS.ws_open_sessions.inc_guard(1),
}
}
}
impl<'a, S> RpcServiceT<'a> for LimitMiddleware<S>
where
S: Send + Sync + RpcServiceT<'a>,
{
type Future = ResponseFuture<S::Future>;
fn call(&self, request: Request<'a>) -> Self::Future {
if let Some(rate_limiter) = &self.rate_limiter {
let num_requests = NonZeroU32::MIN;
if rate_limiter.check_n(num_requests).is_err() {
METRICS.rate_limited[&self.transport].inc();
let rp = MethodResponse::error(
request.id,
ErrorObject::borrowed(
ErrorCode::ServerError(http::StatusCode::TOO_MANY_REQUESTS.as_u16().into())
.code(),
"Too many requests",
None,
),
);
return ResponseFuture::ready(rp);
}
}
ResponseFuture::future(self.inner.call(request))
}
}
#[derive(Debug)]
pub(crate) struct MetadataMiddleware<S, const TRACE_PARAMS: bool> {
inner: S,
registered_method_names: Arc<HashSet<&'static str>>,
method_tracer: Arc<MethodTracer>,
}
impl<'a, S, const TRACE_PARAMS: bool> RpcServiceT<'a> for MetadataMiddleware<S, TRACE_PARAMS>
where
S: Send + Sync + RpcServiceT<'a>,
{
type Future = WithMethodCall<'a, S::Future>;
fn call(&self, request: Request<'a>) -> Self::Future {
let method_name = self
.registered_method_names
.get(request.method_name())
.copied()
.unwrap_or("");
let observed_params = if TRACE_PARAMS {
ObservedRpcParams::new(request.params.as_ref())
} else {
ObservedRpcParams::Unknown
};
let call = self.method_tracer.new_call(method_name, observed_params);
WithMethodCall::new(self.inner.call(request), call)
}
}
pin_project! {
#[derive(Debug)]
pub(crate) struct WithMethodCall<'a, F> {
#[pin]
inner: F,
call: MethodCall<'a>,
}
}
impl<'a, F> WithMethodCall<'a, F> {
fn new(inner: F, call: MethodCall<'a>) -> Self {
Self { inner, call }
}
}
impl<F: Future<Output = MethodResponse>> Future for WithMethodCall<'_, F> {
type Output = MethodResponse;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let projection = self.project();
let guard = projection.call.set_as_current();
match projection.inner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(response) => {
drop(guard);
projection.call.observe_response(&response);
Poll::Ready(response)
}
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct MetadataLayer<const TRACE_PARAMS: bool> {
registered_method_names: Arc<HashSet<&'static str>>,
method_tracer: Arc<MethodTracer>,
}
impl MetadataLayer<false> {
pub fn new(
registered_method_names: Arc<HashSet<&'static str>>,
method_tracer: Arc<MethodTracer>,
) -> Self {
Self {
registered_method_names,
method_tracer,
}
}
pub fn with_param_tracing(self) -> MetadataLayer<true> {
MetadataLayer {
registered_method_names: self.registered_method_names,
method_tracer: self.method_tracer,
}
}
}
impl<Svc, const TRACE_PARAMS: bool> tower::Layer<Svc> for MetadataLayer<TRACE_PARAMS> {
type Service = MetadataMiddleware<Svc, TRACE_PARAMS>;
fn layer(&self, inner: Svc) -> Self::Service {
MetadataMiddleware {
inner,
registered_method_names: self.registered_method_names.clone(),
method_tracer: self.method_tracer.clone(),
}
}
}
#[derive(Debug)]
pub(crate) struct CorrelationMiddleware<S> {
inner: S,
}
impl<S> CorrelationMiddleware<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
impl<'a, S> RpcServiceT<'a> for CorrelationMiddleware<S>
where
S: RpcServiceT<'a>,
{
type Future = Instrumented<S::Future>;
fn call(&self, request: Request<'a>) -> Self::Future {
thread_local! {
static CORRELATION_ID_RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_entropy());
}
let method = request.method_name();
let correlation_id = CORRELATION_ID_RNG.with(|rng| rng.borrow_mut().next_u64());
let call_span = tracing::debug_span!("rpc_call", method, correlation_id);
self.inner.call(request).instrument(call_span)
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct TrafficTracker {
last_call_sender: Arc<OnceCell<watch::Sender<Instant>>>,
}
impl TrafficTracker {
fn reset(&self) {
if let Some(last_call) = self.last_call_sender.get() {
last_call.send_replace(Instant::now());
}
}
pub async fn wait_for_no_requests(self, interval_without_requests: Duration) {
let mut last_call_subscriber = self
.last_call_sender
.get_or_init(|| watch::channel(Instant::now()).0)
.subscribe();
drop(self);
let deadline = *last_call_subscriber.borrow() + interval_without_requests;
let sleep = tokio::time::sleep_until(deadline.into());
tokio::pin!(sleep);
loop {
tokio::select! {
() = sleep.as_mut() => {
return; }
change_result = last_call_subscriber.changed() => {
if change_result.is_err() {
return; }
let new_deadline = *last_call_subscriber.borrow() + interval_without_requests;
sleep.as_mut().reset(new_deadline.into());
}
}
}
}
}
#[derive(Debug)]
pub(crate) struct ShutdownMiddleware<S> {
inner: S,
traffic_tracker: TrafficTracker,
}
impl<S> ShutdownMiddleware<S> {
pub fn new(inner: S, traffic_tracker: TrafficTracker) -> Self {
Self {
inner,
traffic_tracker,
}
}
}
impl<'a, S> RpcServiceT<'a> for ShutdownMiddleware<S>
where
S: Send + Sync + RpcServiceT<'a>,
{
type Future = S::Future;
fn call(&self, request: Request<'a>) -> Self::Future {
self.traffic_tracker.reset();
self.inner.call(request)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use rand::{thread_rng, Rng};
use test_casing::{test_casing, Product};
use zksync_types::api;
use zksync_web3_decl::jsonrpsee::{types::Id, ResponsePayload};
use super::*;
#[test_casing(4, Product(([false, true], [false, true])))]
#[tokio::test(flavor = "multi_thread")]
async fn metadata_middleware_basics(spawn_tasks: bool, sleep: bool) {
let method_tracer = Arc::new(MethodTracer::default());
let tasks = (0_u64..100).map(|i| {
let current_method = method_tracer.clone();
let inner = async move {
assert_eq!(current_method.meta().unwrap().block_id, None);
current_method.set_block_id(api::BlockId::Number(i.into()));
for diff in 0_u32..10 {
let meta = current_method.meta().unwrap();
assert_eq!(meta.block_id, Some(api::BlockId::Number(i.into())));
assert_eq!(meta.block_diff, diff.checked_sub(1));
current_method.set_block_diff(diff);
if sleep {
let delay = thread_rng().gen_range(1..=5);
tokio::time::sleep(Duration::from_millis(delay)).await;
} else {
tokio::task::yield_now().await;
}
}
MethodResponse::response(
Id::Number(1),
ResponsePayload::success("{}".to_string()),
usize::MAX,
)
};
WithMethodCall::new(
inner,
method_tracer.new_call("test", ObservedRpcParams::None),
)
});
if spawn_tasks {
let tasks: Vec<_> = tasks.map(tokio::spawn).collect();
for task in tasks {
task.await.unwrap();
}
} else {
futures::future::join_all(tasks).await;
}
let calls = method_tracer.recorded_calls().take();
assert_eq!(calls.len(), 100);
for call in &calls {
assert_eq!(call.metadata.name, "test");
assert!(call.metadata.block_id.is_some());
assert_eq!(call.metadata.block_diff, Some(9));
assert!(call.error_code.is_none());
}
}
#[tokio::test]
async fn traffic_tracker_basics() {
let traffic_tracker = TrafficTracker::default();
let now = Instant::now();
let wait = traffic_tracker
.clone()
.wait_for_no_requests(Duration::from_millis(10));
tokio::time::sleep(Duration::from_millis(5)).await;
traffic_tracker.reset();
wait.await;
let elapsed = now.elapsed();
assert!(elapsed >= Duration::from_millis(15), "{elapsed:?}");
}
}