use dropshot::{
endpoint, ApiDescription, Body, ConfigDropshot, ConfigLogging, ConfigLoggingLevel,
EmptyScanParams, HttpError, HttpResponseOk, HttpResponseUpdatedNoContent, HttpServerStarter,
PaginationParams, Path, Query, RequestContext, ResultsPage, TypedBody,
};
use futures::StreamExt;
use http::Response;
use openapiv3::OpenAPI;
use progenitor_middleware_impl::{space_out_items, GenerationSettings, Generator, InterfaceStyle};
use schemars::JsonSchema;
use serde::Deserialize;
use std::{
net::{Ipv4Addr, SocketAddr},
str::from_utf8,
sync::{Arc, Mutex},
};
fn generate_formatted(generator: &mut Generator, spec: &OpenAPI) -> String {
let content = generator.generate_tokens(spec).unwrap();
let rustfmt_config = rustfmt_wrapper::config::Config {
normalize_doc_attributes: Some(true),
wrap_comments: Some(true),
..Default::default()
};
space_out_items(rustfmt_wrapper::rustfmt_config(rustfmt_config, content).unwrap()).unwrap()
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct CursedPath {
#[serde(rename = "ref")]
reef: String,
#[serde(rename = "type")]
tripe: String,
#[serde(rename = "trait")]
trade: String,
}
#[allow(dead_code)]
#[derive(Deserialize, JsonSchema)]
struct CursedQuery {
#[serde(rename = "if")]
iffy: String,
#[serde(rename = "in")]
inn: String,
#[serde(rename = "use")]
youse: String,
}
#[endpoint {
method = GET,
path = "/{ref}/{type}/{trait}",
}]
async fn renamed_parameters(
_rqctx: RequestContext<()>,
_path: Path<CursedPath>,
_query: Query<CursedQuery>,
) -> Result<HttpResponseUpdatedNoContent, HttpError> {
unreachable!();
}
#[test]
fn test_renamed_parameters() {
let mut api = ApiDescription::new();
api.register(renamed_parameters).unwrap();
let mut out = Vec::new();
api.openapi("pagination-demo", semver::Version::new(9000, 0, 0))
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator = Generator::default();
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_renamed_parameters"),
&output,
)
}
#[endpoint {
method = GET,
path = "/",
}]
async fn freeform_response(_rqctx: RequestContext<()>) -> Result<Response<Body>, HttpError> {
unreachable!();
}
#[test]
fn test_freeform_response() {
let mut api = ApiDescription::new();
api.register(freeform_response).unwrap();
let mut out = Vec::new();
api.openapi("pagination-demo", semver::Version::new(9000, 0, 0))
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator = Generator::default();
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_freeform_response"),
&output,
)
}
#[derive(Deserialize, JsonSchema)]
#[allow(dead_code)]
struct BodyWithDefaults {
s: String,
#[serde(default)]
yes: bool,
#[serde(default = "forty_two", rename = "forty-two")]
forty_two: u32,
#[serde(default = "yes_yes")]
something: Option<bool>,
}
fn forty_two() -> u32 {
42
}
fn yes_yes() -> Option<bool> {
Some(true)
}
#[endpoint {
method = POST,
path = "/",
}]
async fn default_params(
_rqctx: RequestContext<()>,
_body: TypedBody<BodyWithDefaults>,
) -> Result<Response<Body>, HttpError> {
unreachable!();
}
#[test]
fn test_default_params() {
let mut api = ApiDescription::new();
api.register(default_params).unwrap();
let mut out = Vec::new();
api.openapi("pagination-demo", semver::Version::new(9000, 0, 0))
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator = Generator::default();
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_default_params_positional"),
&output,
);
let mut generator =
Generator::new(GenerationSettings::default().with_interface(InterfaceStyle::Builder));
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{}.rs", "test_default_params_builder"),
&output,
);
}
#[derive(Debug)]
struct PaginatedU32sContext {
all_values: std::ops::Range<u32>,
page_pairs: Mutex<Vec<(usize, usize)>>,
}
#[endpoint {
method = GET,
path = "/",
}]
async fn paginated_u32s(
rqctx: RequestContext<Arc<PaginatedU32sContext>>,
query_params: Query<PaginationParams<EmptyScanParams, u32>>,
) -> Result<HttpResponseOk<ResultsPage<u32>>, HttpError> {
let ctx = rqctx.context();
let page_params = query_params.into_inner();
let limit = usize::try_from(
rqctx
.page_limit(&page_params)
.expect("invalid page limit")
.get(),
)
.expect("non-usize limit");
let offset = match page_params.page {
dropshot::WhichPage::First(EmptyScanParams {}) => 0,
dropshot::WhichPage::Next(offset) => usize::try_from(offset + 1).expect("non-usize offset"),
};
ctx.page_pairs.lock().unwrap().push((offset, limit));
let values = ctx.all_values.clone().skip(offset).take(limit).collect();
let result = ResultsPage::new(values, &(), |&x, &()| x).expect("bad results page");
Ok(HttpResponseOk(result))
}
#[tokio::test]
async fn test_stream_pagination() {
const TEST_NAME: &str = "test_stream_pagination";
let mut api = ApiDescription::new();
api.register(paginated_u32s).unwrap();
let mut out = Vec::new();
api.openapi(TEST_NAME, semver::Version::new(1, 0, 0))
.write(&mut out)
.unwrap();
let out = from_utf8(&out).unwrap();
let spec = serde_json::from_str::<OpenAPI>(out).unwrap();
let mut generator =
Generator::new(GenerationSettings::new().with_interface(InterfaceStyle::Positional));
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(
format!("tests/output/src/{TEST_NAME}_positional.rs"),
&output,
);
let mut generator =
Generator::new(GenerationSettings::new().with_interface(InterfaceStyle::Builder));
let output = generate_formatted(&mut generator, &spec);
expectorate::assert_contents(format!("tests/output/src/{TEST_NAME}_builder.rs"), &output);
let config_dropshot = ConfigDropshot {
bind_address: SocketAddr::from((Ipv4Addr::LOCALHOST, 0)),
..Default::default()
};
let config_logging = ConfigLogging::StderrTerminal {
level: ConfigLoggingLevel::Debug,
};
let log = config_logging
.to_logger(TEST_NAME)
.expect("failed to create logger");
let server_ctx = Arc::new(PaginatedU32sContext {
all_values: 0..35,
page_pairs: Mutex::default(),
});
let server = HttpServerStarter::new(&config_dropshot, api, Arc::clone(&server_ctx), &log)
.expect("failed to create server")
.start();
let server_addr = format!("http://{}", server.local_addr());
#[allow(dead_code)]
mod gen_client_positional {
include!("output/src/test_stream_pagination_positional.rs");
}
let client = gen_client_positional::Client::new(&server_addr);
let page_limit = 10.try_into().unwrap();
let mut stream = client.paginated_u32s_stream(Some(page_limit));
let mut all_values = Vec::new();
while let Some(result) = stream.next().await {
match result {
Ok(value) => {
all_values.push(value);
}
Err(err) => {
panic!("unexpected error: {err}");
}
}
}
let expected_values = (0..35).collect::<Vec<_>>();
assert_eq!(expected_values, all_values);
let expected_pages = vec![(0, 10), (10, 10), (20, 10), (30, 10), (35, 10)];
assert_eq!(expected_pages, *server_ctx.page_pairs.lock().unwrap());
server_ctx.page_pairs.lock().unwrap().clear();
#[allow(dead_code, unused_imports)]
mod gen_client_builder {
include!("output/src/test_stream_pagination_builder.rs");
}
let client = gen_client_builder::Client::new(&server_addr);
let mut stream = client.paginated_u32s().limit(page_limit).stream();
let mut all_values = Vec::new();
while let Some(result) = stream.next().await {
match result {
Ok(value) => {
all_values.push(value);
}
Err(err) => {
panic!("unexpected error: {err}");
}
}
}
let expected_values = (0..35).collect::<Vec<_>>();
assert_eq!(expected_values, all_values);
let expected_pages = vec![(0, 10), (10, 10), (20, 10), (30, 10), (35, 10)];
assert_eq!(expected_pages, *server_ctx.page_pairs.lock().unwrap());
server.close().await.expect("failed to close server");
}