use std::{convert::Infallible, net::SocketAddr, sync::Arc, time::Duration};
use futures::{future, TryFutureExt};
use hyper::server::{conn::AddrIncoming, Builder};
use once_cell::sync::OnceCell;
use tokio::sync::oneshot;
use tower::builder::ServiceBuilder;
use tracing::{info, warn};
use warp::Filter;
use casper_types::ProtocolVersion;
use super::{filters, ReactorEventT};
use crate::effect::EffectBuilder;
pub(super) async fn run<REv: ReactorEventT>(
builder: Builder<AddrIncoming>,
effect_builder: EffectBuilder<REv>,
api_version: ProtocolVersion,
shutdown_receiver: oneshot::Receiver<()>,
qps_limit: u64,
local_addr: Arc<OnceCell<SocketAddr>>,
) {
let rest_status = filters::create_status_filter(effect_builder, api_version);
let rest_metrics = filters::create_metrics_filter(effect_builder);
let rest_validator_changes =
filters::create_validator_changes_filter(effect_builder, api_version);
let rest_chainspec_filter = filters::create_chainspec_filter(effect_builder, api_version);
let service = warp::service(
rest_status
.or(rest_metrics)
.or(rest_validator_changes)
.or(rest_chainspec_filter),
);
let make_svc =
hyper::service::make_service_fn(move |_| future::ok::<_, Infallible>(service.clone()));
let rate_limited_service = ServiceBuilder::new()
.rate_limit(qps_limit, Duration::from_secs(1))
.service(make_svc);
let server = builder.serve(rate_limited_service);
if let Err(err) = local_addr.set(server.local_addr()) {
warn!(%err, "failed to set local addr for reflection");
}
info!(address = %server.local_addr(), "started REST server");
let _ = server
.with_graceful_shutdown(async move {
shutdown_receiver.await.ok();
})
.map_err(|error| {
warn!(%error, "error running REST server");
})
.await;
}
pub(super) async fn run_with_cors<REv: ReactorEventT>(
builder: Builder<AddrIncoming>,
effect_builder: EffectBuilder<REv>,
api_version: ProtocolVersion,
shutdown_receiver: oneshot::Receiver<()>,
qps_limit: u64,
local_addr: Arc<OnceCell<SocketAddr>>,
cors_origin: String,
) {
let rest_status = filters::create_status_filter(effect_builder, api_version);
let rest_metrics = filters::create_metrics_filter(effect_builder);
let rest_validator_changes =
filters::create_validator_changes_filter(effect_builder, api_version);
let rest_chainspec_filter = filters::create_chainspec_filter(effect_builder, api_version);
let service = warp::service(
rest_status
.or(rest_metrics)
.or(rest_validator_changes)
.or(rest_chainspec_filter)
.with(match cors_origin.as_str() {
"*" => warp::cors().allow_any_origin(),
origin => warp::cors().allow_origin(origin),
}),
);
let make_svc =
hyper::service::make_service_fn(move |_| future::ok::<_, Infallible>(service.clone()));
let rate_limited_service = ServiceBuilder::new()
.rate_limit(qps_limit, Duration::from_secs(1))
.service(make_svc);
let server = builder.serve(rate_limited_service);
if let Err(err) = local_addr.set(server.local_addr()) {
warn!(%err, "failed to set local addr for reflection");
}
info!(address = %server.local_addr(), "started REST server");
let _ = server
.with_graceful_shutdown(async move {
shutdown_receiver.await.ok();
})
.map_err(|error| {
warn!(%error, "error running REST server");
})
.await;
}