use std::sync::LazyLock;
use std::time::{Duration, Instant};
use anyhow::Context;
use miden_node_utils::ErrorReport;
use miden_remote_prover::COMPONENT;
use miden_remote_prover::api::ProofType;
use miden_remote_prover::error::RemoteProverError;
use miden_remote_prover::generated::ProxyWorkerStatus;
use miden_remote_prover::generated::remote_prover::worker_status_api_client::WorkerStatusApiClient;
use pingora::lb::Backend;
use semver::{Version, VersionReq};
use serde::Serialize;
use tonic::transport::Channel;
use tracing::{error, info};
use super::metrics::WORKER_UNHEALTHY;
const MAX_BACKOFF_EXPONENT: usize = 9;
const MRP_PROXY_VERSION: &str = env!("CARGO_PKG_VERSION");
static WORKER_VERSION_REQUIREMENT: LazyLock<VersionReq> = LazyLock::new(|| {
let current =
Version::parse(MRP_PROXY_VERSION).expect("Proxy version should be valid at this point");
VersionReq::parse(&format!("~{}.{}", current.major, current.minor))
.expect("Version should be valid at this point")
});
#[derive(Debug, Clone)]
pub struct Worker {
backend: Backend,
status_client: Option<WorkerStatusApiClient<Channel>>,
is_available: bool,
health_status: WorkerHealthStatus,
version: String,
connection_timeout: Duration,
total_timeout: Duration,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub enum WorkerHealthStatus {
Healthy,
Unhealthy {
num_failed_attempts: usize,
#[serde(skip_serializing)]
first_fail_timestamp: Instant,
reason: String,
},
Unknown,
}
impl Worker {
pub async fn new(
worker_addr: String,
connection_timeout: Duration,
total_timeout: Duration,
) -> Result<Self, RemoteProverError> {
let backend =
Backend::new(&worker_addr).map_err(RemoteProverError::BackendCreationFailed)?;
let (status_client, health_status) =
match create_status_client(&worker_addr, connection_timeout, total_timeout).await {
Ok(client) => (Some(client), WorkerHealthStatus::Unknown),
Err(err) => {
error!("Failed to create status client for worker {}: {}", worker_addr, err);
(
None,
WorkerHealthStatus::Unhealthy {
num_failed_attempts: 1,
first_fail_timestamp: Instant::now(),
reason: err.as_report_context("failed to create status client"),
},
)
},
};
Ok(Self {
backend,
is_available: health_status == WorkerHealthStatus::Unknown,
status_client,
health_status,
version: String::new(),
connection_timeout,
total_timeout,
})
}
async fn recreate_status_client(&mut self) -> Result<(), RemoteProverError> {
let address = self.address();
match create_status_client(&address, self.connection_timeout, self.total_timeout).await {
Ok(client) => {
self.status_client = Some(client);
Ok(())
},
Err(err) => {
error!("Failed to recreate status client for worker {}: {}", address, err);
Err(err)
},
}
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(target = COMPONENT, name = "worker.check_status")]
pub async fn check_status(&mut self, supported_proof_type: ProofType) -> Result<(), String> {
if !self.should_do_health_check() {
return Ok(());
}
if self.status_client.is_none() {
match self.recreate_status_client().await {
Ok(()) => {
info!("Successfully recreated status client for worker {}", self.address());
},
Err(err) => {
return Err(err.as_report_context("failed to recreate status client"));
},
}
}
let worker_status = match self.status_client.as_mut().unwrap().status(()).await {
Ok(response) => response.into_inner(),
Err(e) => {
error!("Failed to check worker status ({}): {}", self.address(), e);
return Err(e.message().to_string());
},
};
if worker_status.version.is_empty() {
return Err("Worker version is empty".to_string());
}
if !is_valid_version(&WORKER_VERSION_REQUIREMENT, &worker_status.version).unwrap_or(false) {
return Err(format!("Worker version is invalid ({})", worker_status.version));
}
self.version = worker_status.version;
let worker_supported_proof_type = ProofType::try_from(worker_status.supported_proof_type)
.inspect_err(|err| {
error!(%err, address=%self.address(), "Failed to convert worker supported proof type");
})?;
if supported_proof_type != worker_supported_proof_type {
return Err(format!("Unsupported proof type: {supported_proof_type}"));
}
Ok(())
}
#[tracing::instrument(target = COMPONENT, name = "worker.update_status")]
pub fn update_status(&mut self, check_result: Result<(), String>) {
match check_result {
Ok(()) => {
self.set_health_status(WorkerHealthStatus::Healthy);
},
Err(reason) => {
let failed_attempts = self.num_failures();
self.set_health_status(WorkerHealthStatus::Unhealthy {
num_failed_attempts: failed_attempts + 1,
first_fail_timestamp: match &self.health_status {
WorkerHealthStatus::Unhealthy { first_fail_timestamp, .. } => {
*first_fail_timestamp
},
_ => Instant::now(),
},
reason,
});
},
}
}
pub fn set_availability(&mut self, is_available: bool) {
self.is_available = is_available;
}
pub fn num_failures(&self) -> usize {
match &self.health_status {
WorkerHealthStatus::Healthy | WorkerHealthStatus::Unknown => 0,
WorkerHealthStatus::Unhealthy {
num_failed_attempts: failed_attempts,
first_fail_timestamp: _,
reason: _,
} => *failed_attempts,
}
}
pub fn health_status(&self) -> &WorkerHealthStatus {
&self.health_status
}
pub fn version(&self) -> &str {
&self.version
}
pub fn is_available(&self) -> bool {
self.is_available
}
pub fn address(&self) -> String {
self.backend.addr.to_string()
}
pub fn is_healthy(&self) -> bool {
!matches!(self.health_status, WorkerHealthStatus::Unhealthy { .. })
}
fn should_do_health_check(&self) -> bool {
match self.health_status {
WorkerHealthStatus::Healthy | WorkerHealthStatus::Unknown => true,
WorkerHealthStatus::Unhealthy {
num_failed_attempts: failed_attempts,
first_fail_timestamp,
reason: _,
} => {
let time_since_first_failure = first_fail_timestamp.elapsed();
time_since_first_failure
> Duration::from_secs(
2u64.pow(failed_attempts.min(MAX_BACKOFF_EXPONENT) as u32),
)
},
}
}
fn set_health_status(&mut self, health_status: WorkerHealthStatus) {
let was_healthy = self.is_healthy();
self.health_status = health_status;
match &self.health_status {
WorkerHealthStatus::Healthy | WorkerHealthStatus::Unknown => {
if !was_healthy {
self.is_available = true;
}
},
WorkerHealthStatus::Unhealthy { .. } => {
WORKER_UNHEALTHY.with_label_values(&[&self.address()]).inc();
self.is_available = false;
},
}
}
}
impl PartialEq for Worker {
fn eq(&self, other: &Self) -> bool {
self.backend == other.backend
}
}
impl From<&Worker> for ProxyWorkerStatus {
fn from(worker: &Worker) -> Self {
use miden_remote_prover::generated::remote_prover::WorkerHealthStatus as ProtoWorkerHealthStatus;
Self {
address: worker.address(),
version: worker.version().to_string(),
status: match worker.health_status() {
WorkerHealthStatus::Healthy => ProtoWorkerHealthStatus::Healthy,
WorkerHealthStatus::Unhealthy { .. } => ProtoWorkerHealthStatus::Unhealthy,
WorkerHealthStatus::Unknown => ProtoWorkerHealthStatus::Unknown,
} as i32,
}
}
}
async fn create_status_client(
address: &str,
connection_timeout: Duration,
total_timeout: Duration,
) -> Result<WorkerStatusApiClient<Channel>, RemoteProverError> {
let channel = Channel::from_shared(format!("http://{address}"))
.map_err(|err| RemoteProverError::InvalidURI(err, address.to_string()))?
.connect_timeout(connection_timeout)
.timeout(total_timeout)
.connect()
.await
.map_err(|err| RemoteProverError::ConnectionFailed(err, address.to_string()))?;
Ok(WorkerStatusApiClient::new(channel))
}
fn is_valid_version(version_req: &VersionReq, version: &str) -> anyhow::Result<bool> {
let received = Version::parse(version).context("Invalid worker version: {err}")?;
Ok(version_req.matches(&received))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_valid_version() {
let version_req = VersionReq::parse("~1.0").unwrap();
assert!(is_valid_version(&version_req, "1.0.0").unwrap());
assert!(is_valid_version(&version_req, "1.0.1").unwrap());
assert!(is_valid_version(&version_req, "1.0.12").unwrap());
assert!(is_valid_version(&version_req, "1.0").is_err());
assert!(!is_valid_version(&version_req, "2.0.0").unwrap());
assert!(!is_valid_version(&version_req, "1.1.0").unwrap());
assert!(!is_valid_version(&version_req, "0.9.0").unwrap());
assert!(!is_valid_version(&version_req, "0.9.1").unwrap());
assert!(!is_valid_version(&version_req, "0.10.0").unwrap());
assert!(is_valid_version(&version_req, "miden").is_err());
assert!(is_valid_version(&version_req, "1.miden.12").is_err());
}
}