use clap::Parser;
use miden_node_utils::{
cors::cors_for_grpc_web_layer,
tracing::grpc::{TracedComponent, traced_span_fn},
};
use miden_remote_prover::{
COMPONENT,
api::{ProofType, RpcListener},
generated::api_server::ApiServer,
};
use tokio::net::TcpListener;
use tokio_stream::wrappers::TcpListenerStream;
use tonic_health::server::health_reporter;
use tonic_web::GrpcWebLayer;
use tower_http::trace::TraceLayer;
use tracing::{info, instrument};
#[derive(Debug, Parser)]
pub struct StartWorker {
#[arg(long, env = "MRP_WORKER_LOCALHOST")]
localhost: bool,
#[arg(long, default_value = "50051", env = "MRP_WORKER_PORT")]
port: u16,
#[arg(long, env = "MRP_WORKER_PROOF_TYPE")]
proof_type: ProofType,
}
impl StartWorker {
#[instrument(target = COMPONENT, name = "worker.execute")]
pub async fn execute(&self) -> Result<(), String> {
let host = if self.localhost { "127.0.0.1" } else { "0.0.0.0" };
let worker_addr = format!("{}:{}", host, self.port);
let rpc = RpcListener::new(
TcpListener::bind(&worker_addr).await.map_err(|err| err.to_string())?,
self.proof_type,
);
let server_addr = rpc.listener.local_addr().map_err(|err| err.to_string())?;
info!(target: COMPONENT,
endpoint = %server_addr,
proof_type = ?self.proof_type,
host = %host,
port = %self.port,
"Worker server initialized and listening"
);
let (health_reporter, health_service) = health_reporter();
health_reporter.set_serving::<ApiServer<RpcListener>>().await;
tonic::transport::Server::builder()
.accept_http1(true)
.layer(
TraceLayer::new_for_grpc()
.make_span_with(traced_span_fn(TracedComponent::RemoteProver)),
)
.layer(cors_for_grpc_web_layer())
.layer(GrpcWebLayer::new())
.add_service(rpc.api_service)
.add_service(rpc.status_service)
.add_service(health_service)
.serve_with_incoming(TcpListenerStream::new(rpc.listener))
.await
.map_err(|err| err.to_string())?;
Ok(())
}
}