use dropshot::endpoint;
use dropshot::ApiDescription;
use dropshot::ConfigDropshot;
use dropshot::ConfigLogging;
use dropshot::ConfigLoggingLevel;
use dropshot::HttpError;
use dropshot::HttpResponseCreated;
use dropshot::HttpResponseDeleted;
use dropshot::HttpResponseOk;
use dropshot::HttpServer;
use dropshot::Path;
use dropshot::RequestContext;
use dropshot::ServerBuilder;
use dropshot::TypedBody;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::StreamExt;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use slog::info;
use slog::Logger;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::mem;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
#[tokio::main]
async fn main() -> Result<(), String> {
let initial_servers = [("A", "127.0.0.1:12345"), ("B", "127.0.0.1:12346")];
let mut running_servers = FuturesUnordered::new();
let (running_servers_tx, mut running_servers_rx) = mpsc::channel(8);
let shared_context =
Arc::new(SharedMultiServerContext::new(running_servers_tx));
for (name, bind_address) in initial_servers {
let bind_address = bind_address.parse().unwrap();
shared_context.start_server(name, bind_address).await?;
}
mem::drop(shared_context);
loop {
tokio::select! {
maybe_new_server = running_servers_rx.recv() => {
match maybe_new_server {
Some(server) => running_servers.push(server),
None => return Ok(()),
}
}
maybe_result = running_servers.next() => {
if let Some(result) = maybe_result {
result?;
}
}
}
}
}
type ServerShutdownFuture = BoxFuture<'static, Result<(), String>>;
struct MultiServerContext {
shared: Arc<SharedMultiServerContext>,
name: String,
log: Logger,
}
impl Drop for MultiServerContext {
fn drop(&mut self) {
info!(self.log, "shut down server {:?}", self.name);
}
}
struct SharedMultiServerContext {
servers: Mutex<HashMap<String, HttpServer<MultiServerContext>>>,
started_server_shutdown_handles: mpsc::Sender<ServerShutdownFuture>,
}
impl SharedMultiServerContext {
fn new(
started_server_shutdown_handles: mpsc::Sender<ServerShutdownFuture>,
) -> Self {
Self { servers: Mutex::default(), started_server_shutdown_handles }
}
async fn start_server(
self: &Arc<Self>,
name: &str,
bind_address: SocketAddr,
) -> Result<(), String> {
let mut servers = self.servers.lock().await;
let slot = match servers.entry(name.to_string()) {
Entry::Occupied(_) => {
return Err(format!("already running a server named {name:?}",))
}
Entry::Vacant(slot) => slot,
};
let config_logging =
ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Info };
let log = config_logging
.to_logger(format!("example-multiserver-{name}"))
.map_err(|error| format!("failed to create logger: {}", error))?;
let mut api = ApiDescription::new();
api.register(api_get_servers).unwrap();
api.register(api_start_server).unwrap();
api.register(api_stop_server).unwrap();
let config_dropshot =
ConfigDropshot { bind_address, ..Default::default() };
let context = MultiServerContext {
shared: Arc::clone(self),
name: name.to_string(),
log: log.clone(),
};
let server = ServerBuilder::new(api, context, log)
.config(config_dropshot)
.start()
.map_err(|error| format!("failed to create server: {}", error))?;
let shutdown_handle = server.wait_for_shutdown();
slot.insert(server);
mem::drop(servers);
_ = self
.started_server_shutdown_handles
.send(shutdown_handle.boxed())
.await;
Ok(())
}
}
#[derive(Debug, Serialize, JsonSchema)]
struct ServerDescription {
name: String,
bind_addr: SocketAddr,
}
#[endpoint {
method = GET,
path = "/servers",
}]
async fn api_get_servers(
rqctx: RequestContext<MultiServerContext>,
) -> Result<HttpResponseOk<Vec<ServerDescription>>, HttpError> {
let api_context = rqctx.context();
let servers = api_context.shared.servers.lock().await;
let servers = servers
.iter()
.map(|(name, server)| ServerDescription {
name: name.clone(),
bind_addr: server.local_addr(),
})
.collect();
Ok(HttpResponseOk(servers))
}
#[derive(Deserialize, JsonSchema)]
struct PathName {
name: String,
}
#[endpoint {
method = POST,
path = "/servers/{name}",
}]
async fn api_start_server(
rqctx: RequestContext<MultiServerContext>,
path: Path<PathName>,
body: TypedBody<SocketAddr>,
) -> Result<HttpResponseCreated<ServerDescription>, HttpError> {
let api_context = rqctx.context();
let name = path.into_inner().name;
let bind_addr = body.into_inner();
api_context.shared.start_server(&name, bind_addr).await.map_err(|err| {
HttpError::for_bad_request(
Some("StartServerFailed".to_string()),
format!("failed to start server {name:?}: {err}"),
)
})?;
Ok(HttpResponseCreated(ServerDescription { name, bind_addr }))
}
#[endpoint {
method = DELETE,
path = "/servers/{name}",
}]
async fn api_stop_server(
rqctx: RequestContext<MultiServerContext>,
path: Path<PathName>,
) -> Result<HttpResponseDeleted, HttpError> {
let api_context = rqctx.context();
let name = path.into_inner().name;
let mut servers = api_context.shared.servers.lock().await;
let server = servers.remove(&name).ok_or_else(|| {
HttpError::for_bad_request(
Some("InvalidServerName".to_string()),
format!("no server named {name:?}"),
)
})?;
tokio::spawn(server.close());
Ok(HttpResponseDeleted())
}