use std::{sync::Arc, time::Duration};
use async_trait::async_trait;
use miden_node_utils::tracing::grpc::{TracedComponent, traced_span_fn};
use miden_remote_prover::{
api::ProofType,
generated::remote_prover::{
self as proto, ProxyStatusRequest, ProxyStatusResponse, WorkerStatus,
proxy_status_api_server::{ProxyStatusApi, ProxyStatusApiServer},
},
};
use pingora::{server::ListenFds, services::Service};
use tokio::{net::TcpListener, sync::watch, time::interval};
use tokio_stream::wrappers::TcpListenerStream;
use tonic::{Request, Response, Status, transport::Server};
use tower_http::trace::TraceLayer;
use tracing::{error, info, instrument};
use super::worker::WorkerHealthStatus;
use crate::{
COMPONENT,
commands::PROXY_HOST,
proxy::{LoadBalancerState, worker::Worker},
};
#[derive(Clone, Debug)]
pub struct ProxyStatusPingoraService {
load_balancer: Arc<LoadBalancerState>,
port: u16,
status_rx: watch::Receiver<ProxyStatusResponse>,
status_tx: watch::Sender<ProxyStatusResponse>,
status_update_interval: Duration,
}
impl ProxyStatusPingoraService {
pub async fn new(
load_balancer: Arc<LoadBalancerState>,
port: u16,
status_update_interval: Duration,
) -> Self {
let version = env!("CARGO_PKG_VERSION").to_string();
let supported_proof_type: ProofType = load_balancer.supported_proof_type;
let supported_proof_type: i32 = supported_proof_type.into();
let initial_status = {
let workers = load_balancer.workers.read().await;
let worker_statuses: Vec<WorkerStatus> =
workers.iter().map(WorkerStatus::from).collect();
ProxyStatusResponse {
version: version.clone(),
supported_proof_type,
workers: worker_statuses,
}
};
let (status_tx, status_rx) = watch::channel(initial_status);
Self {
load_balancer,
port,
status_rx,
status_tx,
status_update_interval,
}
}
}
#[async_trait]
impl ProxyStatusApi for ProxyStatusPingoraService {
#[instrument(target = COMPONENT, name = "proxy.status", skip(_request))]
async fn status(
&self,
_request: Request<ProxyStatusRequest>,
) -> Result<Response<ProxyStatusResponse>, Status> {
let status = self.status_rx.borrow().clone();
Ok(Response::new(status))
}
}
#[async_trait]
impl Service for ProxyStatusPingoraService {
async fn start_service(
&mut self,
#[cfg(unix)] _fds: Option<ListenFds>,
shutdown: watch::Receiver<bool>,
_listeners_per_fd: usize,
) {
info!("Starting gRPC status service on port {}", self.port);
let addr = format!("{}:{}", PROXY_HOST, self.port);
let listener = match TcpListener::bind(&addr).await {
Ok(listener) => {
info!("gRPC status service bound to {}", addr);
listener
},
Err(e) => {
error!("Failed to bind gRPC status service to {}: {}", addr, e);
return;
},
};
let updater = ProxyStatusUpdater::new(
self.load_balancer.clone(),
self.status_tx.clone(),
self.status_update_interval,
);
let cache_updater_shutdown = shutdown.clone();
let updater_task = async move {
updater.start(cache_updater_shutdown).await;
};
let status_server = ProxyStatusApiServer::new(self.clone());
let mut server_shutdown = shutdown.clone();
let server = Server::builder()
.layer(
TraceLayer::new_for_grpc()
.make_span_with(traced_span_fn(TracedComponent::RemoteProverProxy)),
)
.add_service(status_server)
.serve_with_incoming_shutdown(TcpListenerStream::new(listener), async move {
let _ = server_shutdown.changed().await;
info!("gRPC status service received shutdown signal");
});
tokio::select! {
result = server => {
if let Err(e) = result {
error!(err=?e, "gRPC status service failed");
} else {
info!("gRPC status service stopped gracefully");
}
}
_ = updater_task => {
error!("Status updater task ended unexpectedly");
}
}
}
fn name(&self) -> &'static str {
"grpc-status"
}
fn threads(&self) -> Option<usize> {
Some(1) }
}
pub struct ProxyStatusUpdater {
load_balancer: Arc<LoadBalancerState>,
status_tx: watch::Sender<ProxyStatusResponse>,
update_interval: Duration,
version: String,
supported_proof_type: i32,
}
impl ProxyStatusUpdater {
pub fn new(
load_balancer: Arc<LoadBalancerState>,
status_tx: watch::Sender<ProxyStatusResponse>,
update_interval: Duration,
) -> Self {
let version = env!("CARGO_PKG_VERSION").to_string();
let supported_proof_type: ProofType = load_balancer.supported_proof_type;
let supported_proof_type: i32 = supported_proof_type.into();
Self {
load_balancer,
status_tx,
update_interval,
version,
supported_proof_type,
}
}
pub async fn start(&self, mut shutdown: watch::Receiver<bool>) {
let mut update_timer = interval(self.update_interval);
loop {
tokio::select! {
_ = update_timer.tick() => {
let new_status = self.build_status().await;
let _ = self.status_tx.send(new_status);
}
_ = shutdown.changed() => {
info!("Status updater received shutdown signal");
break;
}
}
}
}
async fn build_status(&self) -> ProxyStatusResponse {
let workers = self.load_balancer.workers.read().await;
let worker_statuses: Vec<WorkerStatus> = workers.iter().map(WorkerStatus::from).collect();
ProxyStatusResponse {
version: self.version.clone(),
supported_proof_type: self.supported_proof_type,
workers: worker_statuses,
}
}
}
impl From<&WorkerHealthStatus> for proto::WorkerHealthStatus {
fn from(status: &WorkerHealthStatus) -> Self {
match status {
WorkerHealthStatus::Healthy => proto::WorkerHealthStatus::Healthy,
WorkerHealthStatus::Unhealthy { .. } => proto::WorkerHealthStatus::Unhealthy,
WorkerHealthStatus::Unknown => proto::WorkerHealthStatus::Unknown,
}
}
}
impl From<&Worker> for WorkerStatus {
fn from(worker: &Worker) -> Self {
Self {
address: worker.address(),
version: worker.version().to_string(),
status: proto::WorkerHealthStatus::from(worker.health_status()).into(),
}
}
}