use dotenvy::dotenv;
use env_logger::Env;
use log::{info, warn};
use serde::{Deserialize, Serialize};
use warp::http::StatusCode;
use warp::Filter;
use std::collections::VecDeque;
use std::convert::Infallible;
use std::env;
use std::sync::{Arc, RwLock};
use ppoprf::ppoprf;
const DEFAULT_EPOCH_DURATION: u64 = 5;
const DEFAULT_MDS: &str = "116;117;118;119;120";
const EPOCH_DURATION_ENV_KEY: &str = "EPOCH_DURATION";
const MDS_ENV_KEY: &str = "METADATA_TAGS";
struct ServerState {
prf_server: ppoprf::Server,
active_md: u8,
future_mds: VecDeque<u8>,
}
type State = Arc<RwLock<ServerState>>;
fn with_state(
state: State,
) -> impl Filter<Extract = (State,), Error = Infallible> + Clone {
warp::any().map(move || Arc::clone(&state))
}
#[derive(Deserialize)]
struct EvalRequest {
name: String,
points: Vec<ppoprf::Point>,
}
#[derive(Serialize)]
struct EvalResponse {
name: String,
results: Vec<ppoprf::Evaluation>,
}
#[derive(Serialize)]
struct ServerErrorResponse {
error: String,
}
fn help() -> &'static str {
concat!(
"STAR protocol randomness server.\n",
"See https://arxiv.org/abs/2109.10074 for more information.\n"
)
}
async fn eval(
data: EvalRequest,
state: State,
) -> Result<impl warp::Reply, Infallible> {
let state = state.read().unwrap();
let result: Result<Vec<ppoprf::Evaluation>, ppoprf::PPRFError> = data
.points
.iter()
.map(|p| state.prf_server.eval(p, state.active_md, false))
.collect();
match result {
Ok(results) => Ok(warp::reply::with_status(
warp::reply::json(&EvalResponse {
name: data.name,
results,
}),
StatusCode::OK,
)),
Err(error) => Ok(warp::reply::with_status(
warp::reply::json(&ServerErrorResponse {
error: format!("{error}"),
}),
StatusCode::INTERNAL_SERVER_ERROR,
)),
}
}
#[tokio::main]
async fn main() {
dotenv().ok();
let host = "localhost";
let port = 8080;
env_logger::init_from_env(Env::default().default_filter_or("info"));
info!("Server configured on {host} port {port}");
let mds_str = match env::var(MDS_ENV_KEY) {
Ok(val) => val,
Err(_) => {
info!(
"{} env var not defined, using default: {}",
MDS_ENV_KEY, DEFAULT_MDS
);
DEFAULT_MDS.to_string()
}
};
let mds: Vec<u8> = mds_str
.split(';')
.map(|y| {
y.parse().expect(
"Could not parse metadata tags. Must contain 8-bit unsigned values!",
)
})
.collect();
let epoch =
std::time::Duration::from_secs(match env::var(EPOCH_DURATION_ENV_KEY) {
Ok(val) => val.parse().expect(
"Could not parse epoch duration. It must be a positive number!",
),
Err(_) => {
info!(
"{} env var not defined, using default: {} seconds",
EPOCH_DURATION_ENV_KEY, DEFAULT_EPOCH_DURATION
);
DEFAULT_EPOCH_DURATION
}
});
let state = Arc::new(RwLock::new(ServerState {
prf_server: ppoprf::Server::new(mds.clone()).unwrap(),
active_md: mds[0],
future_mds: VecDeque::from(mds[1..].to_vec()),
}));
info!("PPOPRF initialized with epoch metadata tags {:?}", &mds);
let background_state = state.clone();
tokio::spawn(async move {
info!(
"Background task will rotate epoch every {} seconds",
epoch.as_secs()
);
for &md in &mds {
info!(
"Epoch tag now '{:?}'; next rotation in {} seconds",
md,
epoch.as_secs()
);
tokio::time::sleep(epoch).await;
if let Ok(mut state) = background_state.write() {
info!("Epoch rotation: puncturing '{:?}'", md);
state.prf_server.puncture(md).unwrap();
let new_md = state.future_mds.pop_front().unwrap();
state.active_md = new_md;
}
}
warn!("All epoch tags punctured! No further evaluations possible.");
});
let info = warp::get().map(help);
let rand = warp::post()
.and(warp::body::content_length_limit(8 * 1024))
.and(warp::body::json())
.and(with_state(state))
.and_then(eval);
let routes = rand.or(info);
warp::serve(routes).run(([127, 0, 0, 1], 8080)).await;
}