use dropshot::{ConfigLogging, ConfigLoggingLevel, ServerBuilder};
mod api {
use dropshot::{
HttpError, HttpResponseUpdatedNoContent, RequestContext, TypedBody,
WebsocketConnection,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
#[dropshot::api_description]
pub(crate) trait CounterApi {
type Context;
#[endpoint { method = PUT, path = "/counter" }]
async fn put_counter(
rqctx: RequestContext<Self::Context>,
update: TypedBody<CounterValue>,
) -> Result<HttpResponseUpdatedNoContent, HttpError>;
#[channel { protocol = WEBSOCKETS, path = "/ws" }]
async fn get_counter_ws(
rqctx: RequestContext<Self::Context>,
upgraded: WebsocketConnection,
) -> dropshot::WebsocketChannelResult;
}
#[derive(Deserialize, Serialize, JsonSchema)]
pub(crate) struct CounterValue {
pub(crate) counter: u8,
}
pub(crate) fn generate_openapi_spec() -> String {
let my_server = counter_api_mod::stub_api_description().unwrap();
let spec =
my_server.openapi("Counter Server", semver::Version::new(1, 0, 0));
serde_json::to_string_pretty(&spec.json().unwrap()).unwrap()
}
}
mod imp {
use std::sync::atomic::{AtomicU8, Ordering};
use dropshot::{
HttpError, HttpResponseUpdatedNoContent, RequestContext, TypedBody,
WebsocketConnection,
};
use futures::SinkExt;
use tokio_tungstenite::tungstenite::{protocol::Role, Message};
use crate::api::{CounterApi, CounterValue};
pub(crate) struct AtomicCounter {
counter: AtomicU8,
}
impl AtomicCounter {
pub(crate) fn new() -> AtomicCounter {
AtomicCounter { counter: AtomicU8::new(0) }
}
}
pub(crate) enum CounterImpl {}
impl CounterApi for CounterImpl {
type Context = AtomicCounter;
async fn put_counter(
rqctx: RequestContext<Self::Context>,
update: TypedBody<CounterValue>,
) -> Result<HttpResponseUpdatedNoContent, HttpError> {
let cx = rqctx.context();
let updated_value = update.into_inner();
if updated_value.counter == 10 {
Err(HttpError::for_bad_request(
Some(String::from("BadInput")),
format!("do not like the number {}", updated_value.counter),
))
} else {
cx.counter.store(updated_value.counter, Ordering::SeqCst);
Ok(HttpResponseUpdatedNoContent())
}
}
async fn get_counter_ws(
rqctx: RequestContext<Self::Context>,
upgraded: WebsocketConnection,
) -> dropshot::WebsocketChannelResult {
let mut ws = tokio_tungstenite::WebSocketStream::from_raw_socket(
upgraded.into_inner(),
Role::Server,
None,
)
.await;
let mut count = rqctx.context().counter.load(Ordering::Relaxed);
while ws.send(Message::Binary(vec![count])).await.is_ok() {
count = count.wrapping_add(1);
}
Ok(())
}
}
}
#[tokio::main]
async fn main() -> Result<(), String> {
let config_logging =
ConfigLogging::StderrTerminal { level: ConfigLoggingLevel::Info };
let log = config_logging
.to_logger("example-server-trait-websocket")
.map_err(|error| format!("failed to create logger: {}", error))?;
println!("OpenAPI spec:");
println!("{}", api::generate_openapi_spec());
let my_server =
api::counter_api_mod::api_description::<imp::CounterImpl>().unwrap();
let server = ServerBuilder::new(my_server, imp::AtomicCounter::new(), log)
.start()
.map_err(|error| format!("failed to create server: {}", error))?;
server.await
}