use std::{
collections::VecDeque,
sync::{Arc, LazyLock},
time::{Duration, Instant},
};
use async_trait::async_trait;
use bytes::Bytes;
use metrics::{
QUEUE_LATENCY, QUEUE_SIZE, RATE_LIMIT_VIOLATIONS, RATE_LIMITED_REQUESTS, REQUEST_COUNT,
REQUEST_FAILURE_COUNT, REQUEST_LATENCY, REQUEST_RETRIES, WORKER_BUSY, WORKER_COUNT,
WORKER_REQUEST_COUNT,
};
use miden_remote_prover::{COMPONENT, api::ProofType, error::RemoteProverError};
use pingora::{
http::ResponseHeader,
prelude::*,
protocols::Digest,
upstreams::peer::{ALPN, Peer},
};
use pingora_core::{Result, upstreams::peer::HttpPeer};
use pingora_limits::rate::Rate;
use pingora_proxy::{FailToProxy, ProxyHttp, Session};
use tokio::sync::RwLock;
use tracing::{Span, debug, error, info, info_span, warn};
use uuid::Uuid;
use worker::Worker;
use crate::{
commands::{
ProxyConfig,
update_workers::{Action, UpdateWorkers},
},
utils::{
create_queue_full_response, create_response_with_error_message,
create_too_many_requests_response,
},
};
mod health_check;
pub mod metrics;
pub(crate) mod status;
pub(crate) mod update_workers;
pub(crate) mod worker;
#[derive(Debug)]
pub struct LoadBalancerState {
workers: Arc<RwLock<Vec<Worker>>>,
timeout: Duration,
connection_timeout: Duration,
max_queue_items: usize,
max_retries_per_request: usize,
max_req_per_sec: isize,
available_workers_polling_interval: Duration,
health_check_interval: Duration,
supported_proof_type: ProofType,
}
impl LoadBalancerState {
#[tracing::instrument(target = COMPONENT, name = "proxy.new_load_balancer", skip(initial_workers))]
pub(crate) async fn new(
initial_workers: Vec<String>,
config: &ProxyConfig,
) -> core::result::Result<Self, RemoteProverError> {
let mut workers: Vec<Worker> = Vec::with_capacity(initial_workers.len());
let connection_timeout = config.connection_timeout;
let total_timeout = config.timeout;
for worker_addr in initial_workers {
match Worker::new(worker_addr, connection_timeout, total_timeout).await {
Ok(w) => workers.push(w),
Err(e) => {
error!("Failed to create worker: {}", e);
},
}
}
info!("Workers created: {:?}", workers);
WORKER_COUNT.set(i64::try_from(workers.len()).expect("worker count greater than i64::MAX"));
RATE_LIMIT_VIOLATIONS.reset();
RATE_LIMITED_REQUESTS.reset();
REQUEST_RETRIES.reset();
Ok(Self {
workers: Arc::new(RwLock::new(workers)),
timeout: total_timeout,
connection_timeout,
max_queue_items: config.max_queue_items,
max_retries_per_request: config.max_retries_per_request,
max_req_per_sec: config.max_req_per_sec,
available_workers_polling_interval: config.available_workers_polling_interval,
health_check_interval: config.health_check_interval,
supported_proof_type: config.proof_type,
})
}
pub async fn pop_available_worker(&self) -> Option<Worker> {
let mut available_workers = self.workers.write().await;
available_workers.iter_mut().find(|w| w.is_available()).map(|w| {
w.set_availability(false);
WORKER_BUSY.inc();
w.clone()
})
}
pub async fn add_available_worker(&self, worker: Worker) {
let mut workers = self.workers.write().await;
if let Some(pos) = workers.iter().position(|w| *w == worker) {
let mut w = workers.remove(pos);
w.set_availability(true);
workers.push(w);
}
}
pub async fn update_workers(
&self,
update_workers: UpdateWorkers,
) -> std::result::Result<(), RemoteProverError> {
let mut workers = self.workers.write().await;
info!("Current workers: {:?}", workers);
let mut native_workers = Vec::new();
for worker_addr in update_workers.workers {
native_workers
.push(Worker::new(worker_addr, self.connection_timeout, self.timeout).await?);
}
match update_workers.action {
Action::Add => {
for worker in native_workers {
if !workers.iter().any(|w| w == &worker) {
workers.push(worker);
}
}
},
Action::Remove => {
for worker in native_workers {
workers.retain(|w| w != &worker);
}
},
}
info!("Workers updated: {:?}", workers);
WORKER_COUNT.set(i64::try_from(workers.len()).expect("worker count greater than i64::MAX"));
Ok(())
}
pub async fn num_workers(&self) -> usize {
self.workers.read().await.len()
}
pub async fn num_busy_workers(&self) -> usize {
self.workers.read().await.iter().filter(|w| !w.is_available()).count()
}
}
static RATE_LIMITER: LazyLock<Rate> = LazyLock::new(|| Rate::new(Duration::from_secs(1)));
pub struct RequestQueue {
queue: RwLock<VecDeque<(Uuid, Instant)>>,
}
impl RequestQueue {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
QUEUE_SIZE.set(0);
Self { queue: RwLock::new(VecDeque::new()) }
}
#[allow(clippy::len_without_is_empty)]
pub async fn len(&self) -> usize {
self.queue.read().await.len()
}
pub async fn enqueue(&self, request_id: Uuid) {
QUEUE_SIZE.inc();
let mut queue = self.queue.write().await;
queue.push_back((request_id, Instant::now()));
}
pub async fn dequeue(&self) -> Option<Uuid> {
let mut queue = self.queue.write().await;
if let Some((request_id, queued_time)) = queue.pop_front() {
QUEUE_SIZE.dec();
QUEUE_LATENCY.observe(queued_time.elapsed().as_secs_f64());
Some(request_id)
} else {
None
}
}
pub async fn peek(&self) -> Option<Uuid> {
let queue = self.queue.read().await;
queue.front().copied().map(|(request_id, _)| request_id)
}
}
static QUEUE: LazyLock<RequestQueue> = LazyLock::new(RequestQueue::new);
#[derive(Debug)]
pub struct RequestContext {
tries: usize,
request_id: Uuid,
worker: Option<Worker>,
parent_span: Span,
created_at: Instant,
}
impl RequestContext {
fn new() -> Self {
let request_id = Uuid::new_v4();
Self {
tries: 0,
request_id,
worker: None,
parent_span: info_span!(target: COMPONENT, "proxy.new_request", request_id = request_id.to_string()),
created_at: Instant::now(),
}
}
fn set_worker(&mut self, worker: Worker) {
WORKER_REQUEST_COUNT.with_label_values(&[&worker.address()]).inc();
self.worker = Some(worker);
}
}
#[derive(Debug)]
pub struct LoadBalancer(pub Arc<LoadBalancerState>);
#[async_trait]
impl ProxyHttp for LoadBalancer {
type CTX = RequestContext;
fn new_ctx(&self) -> Self::CTX {
RequestContext::new()
}
#[tracing::instrument(name = "proxy.request_filter", parent = &ctx.parent_span, skip(session))]
async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool>
where
Self::CTX: Send + Sync,
{
let client_addr = match session.client_addr() {
Some(addr) => addr.to_string(),
None => {
return create_response_with_error_message(
session.as_downstream_mut(),
"No socket address".to_string(),
)
.await;
},
};
info!("Client address: {:?}", client_addr);
REQUEST_COUNT.inc();
let user_id = Some(client_addr);
let curr_window_requests = RATE_LIMITER.observe(&user_id, 1);
if curr_window_requests > self.0.max_req_per_sec {
RATE_LIMITED_REQUESTS.inc();
if curr_window_requests == self.0.max_req_per_sec + 1 {
RATE_LIMIT_VIOLATIONS.inc();
}
return create_too_many_requests_response(session, self.0.max_req_per_sec).await;
}
let queue_len = QUEUE.len().await;
info!("New request with ID: {}", ctx.request_id);
info!("Queue length: {}", queue_len);
if queue_len >= self.0.max_queue_items {
return create_queue_full_response(session).await;
}
Ok(false)
}
#[tracing::instrument(name = "proxy.upstream_peer", parent = &ctx.parent_span, skip(_session))]
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let request_id = ctx.request_id;
QUEUE.enqueue(request_id).await;
loop {
if QUEUE.peek().await.expect("Queue should not be empty") != request_id {
continue;
}
if let Some(worker) = self.0.pop_available_worker().await {
debug!("Worker {} picked up the request with ID: {}", worker.address(), request_id);
ctx.set_worker(worker);
break;
}
debug!("All workers are busy");
tokio::time::sleep(self.0.available_workers_polling_interval).await;
}
QUEUE.dequeue().await;
let mut http_peer = HttpPeer::new(
ctx.worker.clone().expect("Failed to get worker").address(),
false,
String::new(),
);
let peer_opts =
http_peer.get_mut_peer_options().ok_or(Error::new(ErrorType::InternalError))?;
peer_opts.total_connection_timeout = Some(self.0.timeout);
peer_opts.connection_timeout = Some(self.0.connection_timeout);
peer_opts.read_timeout = Some(self.0.timeout);
peer_opts.write_timeout = Some(self.0.timeout);
peer_opts.idle_timeout = Some(self.0.timeout);
peer_opts.alpn = ALPN::H2;
let peer = Box::new(http_peer);
Ok(peer)
}
#[tracing::instrument(name = "proxy.upstream_request_filter", parent = &_ctx.parent_span, skip(_session))]
async fn upstream_request_filter(
&self,
_session: &mut Session,
upstream_request: &mut RequestHeader,
_ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
if let Some(content_type) = upstream_request.headers.get("content-type") {
if content_type == "application/grpc" {
upstream_request.insert_header("content-type", "application/grpc")?;
}
}
Ok(())
}
#[tracing::instrument(name = "proxy.fail_to_connect", parent = &ctx.parent_span, skip(_session))]
fn fail_to_connect(
&self,
_session: &mut Session,
peer: &HttpPeer,
ctx: &mut Self::CTX,
mut e: Box<Error>,
) -> Box<Error> {
if ctx.tries > self.0.max_retries_per_request {
return e;
}
REQUEST_RETRIES.inc();
ctx.tries += 1;
e.set_retry(true);
e
}
#[tracing::instrument(name = "proxy.logging", parent = &ctx.parent_span, skip(_session))]
async fn logging(&self, _session: &mut Session, e: Option<&Error>, ctx: &mut Self::CTX)
where
Self::CTX: Send + Sync,
{
if let Some(e) = e {
REQUEST_FAILURE_COUNT.inc();
error!("Error: {:?}", e);
}
if let Some(worker) = ctx.worker.take() {
self.0.add_available_worker(worker).await;
}
REQUEST_LATENCY.observe(ctx.created_at.elapsed().as_secs_f64());
WORKER_BUSY.set(
i64::try_from(self.0.num_busy_workers().await)
.expect("busy worker count greater than i64::MAX"),
);
}
#[tracing::instrument(name = "proxy.early_request_filter", parent = &ctx.parent_span, skip(_session))]
async fn early_request_filter(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl.early_request_filter(_session, &mut ()).await
}
#[tracing::instrument(name = "proxy.connected_to_upstream", parent = &ctx.parent_span, skip(_session, _sock, _reused, _peer, _fd, _digest))]
async fn connected_to_upstream(
&self,
_session: &mut Session,
_reused: bool,
_peer: &HttpPeer,
#[cfg(unix)] _fd: std::os::unix::io::RawFd,
#[cfg(windows)] _sock: std::os::windows::io::RawSocket,
_digest: Option<&Digest>,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl
.connected_to_upstream(_session, _reused, _peer, _fd, _digest, &mut ())
.await
}
#[tracing::instrument(name = "proxy.request_body_filter", parent = &ctx.parent_span, skip(session, body))]
async fn request_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl
.request_body_filter(session, body, end_of_stream, &mut ())
.await
}
#[tracing::instrument(name = "proxy.upstream_response_filter", parent = &ctx.parent_span, skip(session, upstream_response))]
fn upstream_response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl.upstream_response_filter(session, upstream_response, &mut ())
}
#[tracing::instrument(name = "proxy.response_filter", parent = &ctx.parent_span, skip(session, upstream_response))]
async fn response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
ProxyHttpDefaultImpl.response_filter(session, upstream_response, &mut ()).await
}
#[tracing::instrument(name = "proxy.upstream_response_body_filter", parent = &ctx.parent_span, skip(session, body))]
fn upstream_response_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl.upstream_response_body_filter(session, body, end_of_stream, &mut ())
}
#[tracing::instrument(name = "proxy.response_body_filter", parent = &ctx.parent_span, skip(session, body))]
fn response_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<Option<Duration>>
where
Self::CTX: Send + Sync,
{
ProxyHttpDefaultImpl.response_body_filter(session, body, end_of_stream, &mut ())
}
#[tracing::instrument(name = "proxy.fail_to_proxy", parent = &ctx.parent_span, skip(session))]
async fn fail_to_proxy(
&self,
session: &mut Session,
e: &Error,
ctx: &mut Self::CTX,
) -> FailToProxy
where
Self::CTX: Send + Sync,
{
ProxyHttpDefaultImpl.fail_to_proxy(session, e, &mut ()).await
}
#[tracing::instrument(name = "proxy.error_while_proxy", parent = &ctx.parent_span, skip(session))]
fn error_while_proxy(
&self,
peer: &HttpPeer,
session: &mut Session,
e: Box<Error>,
ctx: &mut Self::CTX,
client_reused: bool,
) -> Box<Error> {
ProxyHttpDefaultImpl.error_while_proxy(peer, session, e, &mut (), client_reused)
}
}
struct ProxyHttpDefaultImpl;
#[async_trait]
impl ProxyHttp for ProxyHttpDefaultImpl {
type CTX = ();
fn new_ctx(&self) {}
async fn upstream_peer(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
unimplemented!("This is a dummy implementation, should not be called")
}
}