use std::collections::VecDeque;
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use bytes::Bytes;
use metrics::{
QUEUE_LATENCY,
QUEUE_SIZE,
RATE_LIMIT_VIOLATIONS,
RATE_LIMITED_REQUESTS,
REQUEST_COUNT,
REQUEST_FAILURE_COUNT,
REQUEST_LATENCY,
REQUEST_RETRIES,
WORKER_BUSY,
WORKER_COUNT,
WORKER_REQUEST_COUNT,
};
use miden_remote_prover::COMPONENT;
use miden_remote_prover::api::ProofType;
use miden_remote_prover::error::RemoteProverError;
use miden_remote_prover::generated::remote_prover::{ProxyStatus, ProxyWorkerStatus};
use pingora::http::ResponseHeader;
use pingora::prelude::*;
use pingora::protocols::Digest;
use pingora::upstreams::peer::{ALPN, Peer};
use pingora_core::Result;
use pingora_core::upstreams::peer::HttpPeer;
use pingora_limits::rate::Rate;
use pingora_proxy::{FailToProxy, ProxyHttp, Session};
use tokio::sync::RwLock;
use tracing::{Span, debug, error, info, info_span, warn};
use uuid::Uuid;
use worker::Worker;
use crate::commands::ProxyConfig;
use crate::commands::update_workers::{Action, UpdateWorkers};
use crate::utils::{
create_queue_full_response,
create_response_with_error_message,
create_too_many_requests_response,
write_grpc_response_to_session,
};
mod health_check;
pub mod metrics;
pub(crate) mod update_workers;
pub(crate) mod worker;
const PROXY_STATUS_PATH: &str = "/remote_prover.ProxyStatusApi/Status";
#[derive(Debug)]
pub struct LoadBalancerState {
workers: Arc<RwLock<Vec<Worker>>>,
timeout: Duration,
connection_timeout: Duration,
max_queue_items: usize,
max_retries_per_request: usize,
max_req_per_sec: isize,
available_workers_polling_interval: Duration,
health_check_interval: Duration,
supported_proof_type: ProofType,
status_cache_sender: tokio::sync::watch::Sender<ProxyStatus>,
status_cache_receiver: tokio::sync::watch::Receiver<ProxyStatus>,
}
impl LoadBalancerState {
#[tracing::instrument(target = COMPONENT, name = "proxy.new_load_balancer", skip(initial_workers))]
pub(crate) async fn new(
initial_workers: Vec<String>,
config: &ProxyConfig,
) -> core::result::Result<Self, RemoteProverError> {
let mut workers: Vec<Worker> = Vec::with_capacity(initial_workers.len());
let connection_timeout = config.connection_timeout;
let total_timeout = config.timeout;
for worker_addr in initial_workers {
match Worker::new(worker_addr, connection_timeout, total_timeout).await {
Ok(w) => workers.push(w),
Err(e) => {
error!("Failed to create worker: {}", e);
},
}
}
info!("Workers created: {:?}", workers);
WORKER_COUNT.set(i64::try_from(workers.len()).expect("worker count greater than i64::MAX"));
RATE_LIMIT_VIOLATIONS.reset();
RATE_LIMITED_REQUESTS.reset();
REQUEST_RETRIES.reset();
let workers = Arc::new(RwLock::new(workers));
let supported_proof_type = config.proof_type;
let initial_status = {
let workers_guard = workers.read().await;
build_proxy_status_response(&workers_guard, supported_proof_type)
};
let (status_cache_sender, status_cache_receiver) =
tokio::sync::watch::channel(initial_status);
Ok(Self {
workers,
timeout: total_timeout,
connection_timeout,
max_queue_items: config.max_queue_items,
max_retries_per_request: config.max_retries_per_request,
max_req_per_sec: config.max_req_per_sec,
available_workers_polling_interval: config.available_workers_polling_interval,
health_check_interval: config.health_check_interval,
supported_proof_type,
status_cache_sender,
status_cache_receiver,
})
}
pub async fn pop_available_worker(&self) -> Option<Worker> {
let mut available_workers = self.workers.write().await;
available_workers.iter_mut().find(|w| w.is_available()).map(|w| {
w.set_availability(false);
WORKER_BUSY.inc();
w.clone()
})
}
pub async fn add_available_worker(&self, worker: Worker) {
let mut workers = self.workers.write().await;
if let Some(pos) = workers.iter().position(|w| *w == worker) {
let mut w = workers.remove(pos);
w.set_availability(true);
workers.push(w);
}
}
pub async fn update_workers(
&self,
update_workers: UpdateWorkers,
) -> std::result::Result<(), RemoteProverError> {
let mut workers = self.workers.write().await;
info!("Current workers: {:?}", workers);
let mut native_workers = Vec::new();
for worker_addr in update_workers.workers {
native_workers
.push(Worker::new(worker_addr, self.connection_timeout, self.timeout).await?);
}
match update_workers.action {
Action::Add => {
for worker in native_workers {
if !workers.iter().any(|w| w == &worker) {
workers.push(worker);
}
}
},
Action::Remove => {
for worker in native_workers {
workers.retain(|w| w != &worker);
}
},
}
info!("Workers updated: {:?}", workers);
WORKER_COUNT.set(i64::try_from(workers.len()).expect("worker count greater than i64::MAX"));
Ok(())
}
pub async fn num_workers(&self) -> usize {
self.workers.read().await.len()
}
pub async fn num_busy_workers(&self) -> usize {
self.workers.read().await.iter().filter(|w| !w.is_available()).count()
}
pub fn get_cached_status(&self) -> ProxyStatus {
self.status_cache_receiver.borrow().clone()
}
pub async fn update_status_cache(&self) {
let workers = self.workers.read().await;
let new_status = build_proxy_status_response(&workers, self.supported_proof_type);
self.status_cache_sender.send(new_status).expect("Failed to send new status");
}
}
static RATE_LIMITER: LazyLock<Rate> = LazyLock::new(|| Rate::new(Duration::from_secs(1)));
pub struct RequestQueue {
queue: RwLock<VecDeque<(Uuid, Instant)>>,
}
impl RequestQueue {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
QUEUE_SIZE.set(0);
Self { queue: RwLock::new(VecDeque::new()) }
}
#[allow(clippy::len_without_is_empty)]
pub async fn len(&self) -> usize {
self.queue.read().await.len()
}
pub async fn enqueue(&self, request_id: Uuid) {
QUEUE_SIZE.inc();
let mut queue = self.queue.write().await;
queue.push_back((request_id, Instant::now()));
}
pub async fn dequeue(&self) -> Option<Uuid> {
let mut queue = self.queue.write().await;
if let Some((request_id, queued_time)) = queue.pop_front() {
QUEUE_SIZE.dec();
QUEUE_LATENCY.observe(queued_time.elapsed().as_secs_f64());
Some(request_id)
} else {
None
}
}
pub async fn peek(&self) -> Option<Uuid> {
let queue = self.queue.read().await;
queue.front().copied().map(|(request_id, _)| request_id)
}
}
static QUEUE: LazyLock<RequestQueue> = LazyLock::new(RequestQueue::new);
struct PingoraHeaderInjector<'a>(&'a mut pingora::http::RequestHeader);
impl opentelemetry::propagation::Injector for PingoraHeaderInjector<'_> {
fn set(&mut self, key: &str, value: String) {
if let Err(e) = self.0.insert_header(key.to_string(), value) {
tracing::warn!(target: COMPONENT, header = %key, err = %e, "Failed to inject OpenTelemetry header");
}
}
}
#[derive(Debug)]
pub struct RequestContext {
tries: usize,
request_id: Uuid,
worker: Option<Worker>,
parent_span: Span,
created_at: Instant,
}
impl RequestContext {
fn new() -> Self {
let request_id = Uuid::new_v4();
Self {
tries: 0,
request_id,
worker: None,
parent_span: info_span!(target: COMPONENT, "proxy.new_request", request_id = request_id.to_string()),
created_at: Instant::now(),
}
}
fn set_worker(&mut self, worker: Worker) {
WORKER_REQUEST_COUNT.with_label_values(&[&worker.name()]).inc();
self.worker = Some(worker);
}
}
#[derive(Debug)]
pub struct LoadBalancer(pub Arc<LoadBalancerState>);
#[async_trait]
impl ProxyHttp for LoadBalancer {
type CTX = RequestContext;
fn new_ctx(&self) -> Self::CTX {
RequestContext::new()
}
#[tracing::instrument(name = "proxy.request_filter", parent = &ctx.parent_span, skip(session))]
async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool>
where
Self::CTX: Send + Sync,
{
let client_addr = match session.client_addr() {
Some(addr) => addr.to_string(),
None => {
return create_response_with_error_message(
session.as_downstream_mut(),
"No socket address".to_string(),
)
.await
.map(|_| true);
},
};
Span::current().record("client_addr", client_addr.clone());
let path = session.downstream_session.req_header().uri.path();
Span::current().record("path", path);
if path == PROXY_STATUS_PATH {
let status = self.0.get_cached_status();
return write_grpc_response_to_session(session, status).await.map(|_| true);
}
REQUEST_COUNT.inc();
let user_id = Some(client_addr);
let curr_window_requests = RATE_LIMITER.observe(&user_id, 1);
if curr_window_requests > self.0.max_req_per_sec {
RATE_LIMITED_REQUESTS.inc();
if curr_window_requests == self.0.max_req_per_sec + 1 {
RATE_LIMIT_VIOLATIONS.inc();
}
return create_too_many_requests_response(session, self.0.max_req_per_sec)
.await
.map(|_| true);
}
let queue_len = QUEUE.len().await;
info!("New request with ID: {}", ctx.request_id);
info!("Queue length: {}", queue_len);
if queue_len >= self.0.max_queue_items {
return create_queue_full_response(session).await.map(|_| true);
}
Ok(false)
}
#[tracing::instrument(name = "proxy.upstream_peer", parent = &ctx.parent_span, skip(_session))]
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
let request_id = ctx.request_id;
QUEUE.enqueue(request_id).await;
loop {
if QUEUE.peek().await.expect("Queue should not be empty") != request_id {
continue;
}
if let Some(worker) = self.0.pop_available_worker().await {
debug!("Worker {} picked up the request with ID: {}", worker.name(), request_id);
ctx.set_worker(worker);
break;
}
debug!("All workers are busy");
tokio::time::sleep(self.0.available_workers_polling_interval).await;
}
QUEUE.dequeue().await;
let mut http_peer = HttpPeer::new(
ctx.worker.clone().expect("Failed to get worker").name(),
false,
String::new(),
);
let peer_opts =
http_peer.get_mut_peer_options().ok_or(Error::new(ErrorType::InternalError))?;
peer_opts.total_connection_timeout = Some(self.0.timeout);
peer_opts.connection_timeout = Some(self.0.connection_timeout);
peer_opts.alpn = ALPN::H2;
let peer = Box::new(http_peer);
Ok(peer)
}
#[tracing::instrument(name = "proxy.upstream_request_filter", parent = &_ctx.parent_span, skip(_session))]
async fn upstream_request_filter(
&self,
_session: &mut Session,
upstream_request: &mut RequestHeader,
_ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
if let Some(content_type) = upstream_request.headers.get("content-type")
&& content_type == "application/grpc"
{
upstream_request.insert_header("content-type", "application/grpc")?;
}
upstream_request.insert_header("x-request-id", _ctx.request_id.to_string())?;
{
use tracing_opentelemetry::OpenTelemetrySpanExt;
let ctx = tracing::Span::current().context();
opentelemetry::global::get_text_map_propagator(|propagator| {
propagator.inject_context(&ctx, &mut PingoraHeaderInjector(upstream_request));
});
}
Ok(())
}
#[tracing::instrument(name = "proxy.fail_to_connect", parent = &ctx.parent_span, skip(_session))]
fn fail_to_connect(
&self,
_session: &mut Session,
peer: &HttpPeer,
ctx: &mut Self::CTX,
mut e: Box<Error>,
) -> Box<Error> {
if ctx.tries > self.0.max_retries_per_request {
return e;
}
REQUEST_RETRIES.inc();
ctx.tries += 1;
e.set_retry(true);
e
}
#[tracing::instrument(name = "proxy.logging", parent = &ctx.parent_span, skip(_session))]
async fn logging(&self, _session: &mut Session, e: Option<&Error>, ctx: &mut Self::CTX)
where
Self::CTX: Send + Sync,
{
if let Some(e) = e {
REQUEST_FAILURE_COUNT.inc();
error!("Error: {:?}", e);
}
if let Some(worker) = ctx.worker.take() {
self.0.add_available_worker(worker).await;
}
REQUEST_LATENCY.observe(ctx.created_at.elapsed().as_secs_f64());
WORKER_BUSY.set(
i64::try_from(self.0.num_busy_workers().await)
.expect("busy worker count greater than i64::MAX"),
);
}
#[tracing::instrument(name = "proxy.early_request_filter", parent = &ctx.parent_span, skip(_session))]
async fn early_request_filter(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl.early_request_filter(_session, &mut ()).await
}
#[tracing::instrument(name = "proxy.connected_to_upstream", parent = &ctx.parent_span, skip(_session, _sock, _reused, _peer, _fd, _digest))]
async fn connected_to_upstream(
&self,
_session: &mut Session,
_reused: bool,
_peer: &HttpPeer,
#[cfg(unix)] _fd: std::os::unix::io::RawFd,
#[cfg(windows)] _sock: std::os::windows::io::RawSocket,
_digest: Option<&Digest>,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl
.connected_to_upstream(_session, _reused, _peer, _fd, _digest, &mut ())
.await
}
#[tracing::instrument(name = "proxy.request_body_filter", parent = &ctx.parent_span, skip(session, body))]
async fn request_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl
.request_body_filter(session, body, end_of_stream, &mut ())
.await
}
#[tracing::instrument(name = "proxy.upstream_response_filter", parent = &ctx.parent_span, skip(session, upstream_response))]
fn upstream_response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl.upstream_response_filter(session, upstream_response, &mut ())
}
#[tracing::instrument(name = "proxy.response_filter", parent = &ctx.parent_span, skip(session, upstream_response))]
async fn response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) -> Result<()>
where
Self::CTX: Send + Sync,
{
ProxyHttpDefaultImpl.response_filter(session, upstream_response, &mut ()).await
}
#[tracing::instrument(name = "proxy.upstream_response_body_filter", parent = &ctx.parent_span, skip(session, body))]
fn upstream_response_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<()> {
ProxyHttpDefaultImpl.upstream_response_body_filter(session, body, end_of_stream, &mut ())
}
#[tracing::instrument(name = "proxy.response_body_filter", parent = &ctx.parent_span, skip(session, body))]
fn response_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<Option<Duration>>
where
Self::CTX: Send + Sync,
{
ProxyHttpDefaultImpl.response_body_filter(session, body, end_of_stream, &mut ())
}
#[tracing::instrument(name = "proxy.fail_to_proxy", parent = &ctx.parent_span, skip(session))]
async fn fail_to_proxy(
&self,
session: &mut Session,
e: &Error,
ctx: &mut Self::CTX,
) -> FailToProxy
where
Self::CTX: Send + Sync,
{
ProxyHttpDefaultImpl.fail_to_proxy(session, e, &mut ()).await
}
#[tracing::instrument(name = "proxy.error_while_proxy", parent = &ctx.parent_span, skip(session))]
fn error_while_proxy(
&self,
peer: &HttpPeer,
session: &mut Session,
e: Box<Error>,
ctx: &mut Self::CTX,
client_reused: bool,
) -> Box<Error> {
ProxyHttpDefaultImpl.error_while_proxy(peer, session, e, &mut (), client_reused)
}
}
struct ProxyHttpDefaultImpl;
#[async_trait]
impl ProxyHttp for ProxyHttpDefaultImpl {
type CTX = ();
fn new_ctx(&self) {}
async fn upstream_peer(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>> {
unimplemented!("This is a dummy implementation, should not be called")
}
}
fn build_proxy_status_response(workers: &[Worker], supported_proof_type: ProofType) -> ProxyStatus {
let worker_statuses: Vec<ProxyWorkerStatus> =
workers.iter().map(ProxyWorkerStatus::from).collect();
ProxyStatus {
version: env!("CARGO_PKG_VERSION").to_string(),
supported_proof_type: supported_proof_type.into(),
workers: worker_statuses,
}
}