use std::collections::HashMap;
use heck::ToSnakeCase;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use serde::Deserialize;
use serde_tokenstream::{
from_tokenstream, from_tokenstream_spanned, ParseWrapper,
};
use syn::{parse_quote, parse_quote_spanned, spanned::Spanned, Error};
use crate::{
channel::ChannelParams,
doc::{string_to_doc_attrs, ExtractedDoc},
endpoint::EndpointParams,
error_store::{ErrorSink, ErrorStore},
metadata::{
ApiEndpointKind, ChannelMetadata, EndpointMetadata,
ValidatedChannelMetadata, ValidatedEndpointMetadata,
},
params::RqctxKind,
syn_parsing::{
ItemTraitPartParsed, TraitItemFnForSignature, TraitItemPartParsed,
UnparsedBlock,
},
util::{get_crate, MacroKind},
};
pub(crate) fn do_trait(
attr: proc_macro2::TokenStream,
item: proc_macro2::TokenStream,
) -> (proc_macro2::TokenStream, Vec<Error>) {
let mut error_store = ErrorStore::new();
let errors = error_store.sink();
let api_metadata = match from_tokenstream::<ApiMetadata>(&attr) {
Ok(m) => Some(m),
Err(e) => {
errors.push(e);
None
}
};
let item_trait = match syn::parse2::<ItemTraitPartParsed>(item) {
Ok(item_trait) => Some(item_trait),
Err(e) => {
errors.push(e);
None
}
};
let output = match (api_metadata, item_trait.as_ref()) {
(Some(api_metadata), Some(item_trait)) => {
let parser = ApiParser::new(api_metadata, &item_trait, errors);
parser.to_output()
}
(None, Some(item_trait)) => {
let parser = ApiParser::invalid_no_metadata(&item_trait, &errors);
parser.to_output()
}
(_, None) => {
ApiOutput {
output: quote! {},
context: "Self::Context".to_string(),
has_endpoint_param_errors: false,
has_channel_param_errors: false,
}
}
};
let mut errors = error_store.into_inner();
if output.has_endpoint_param_errors {
let item_trait = item_trait
.as_ref()
.expect("has_endpoint_param_errors is true => item_fn is Some");
errors.insert(
0,
Error::new_spanned(
&item_trait.ident,
crate::endpoint::usage_str(&output.context),
),
);
}
if output.has_channel_param_errors {
let item_trait = item_trait
.as_ref()
.expect("has_channel_param_errors is true => item_fn is Some");
errors.insert(
0,
Error::new_spanned(
&item_trait.ident,
crate::channel::usage_str(&output.context),
),
);
}
(output.output, errors)
}
#[derive(Deserialize, Debug)]
struct ApiMetadata {
#[serde(default)]
context: Option<String>,
module: Option<String>,
tag_config: Option<ApiTagConfig>,
_dropshot_crate: Option<String>,
}
impl ApiMetadata {
const DEFAULT_CONTEXT_NAME: &'static str = "Context";
fn dropshot_crate(&self) -> TokenStream {
get_crate(self._dropshot_crate.as_deref())
}
fn context_name(&self) -> &str {
self.context.as_deref().unwrap_or(Self::DEFAULT_CONTEXT_NAME)
}
fn module_name(&self, trait_ident: &syn::Ident) -> String {
self.module
.clone()
.unwrap_or_else(|| trait_ident.to_string().to_snake_case() + "_mod")
}
}
#[derive(Deserialize, Debug)]
struct ApiTagConfig {
#[serde(default)]
allow_other_tags: bool,
#[serde(default)]
policy: Option<ParseWrapper<syn::Expr>>,
tags: HashMap<String, ApiTagDetails>,
}
#[derive(Deserialize, Debug)]
struct ApiTagDetails {
#[serde(default)]
description: Option<String>,
#[serde(default)]
external_docs: Option<ApiTagExternalDocs>,
}
#[derive(Deserialize, Debug)]
struct ApiTagExternalDocs {
#[serde(default)]
description: Option<String>,
url: String,
}
struct ApiParser<'ast> {
dropshot: TokenStream,
item_trait: ApiItemTrait<'ast>,
items: Vec<ApiItem<'ast>>,
tag_config: Option<ApiTagConfig>,
context_item: ContextItem<'ast>,
module_ident: Option<syn::Ident>,
}
const ENDPOINT_IDENT: &str = "endpoint";
const CHANNEL_IDENT: &str = "channel";
impl<'ast> ApiParser<'ast> {
fn new(
metadata: ApiMetadata,
item_trait: &'ast ItemTraitPartParsed,
errors: ErrorSink<'_, Error>,
) -> Self {
let dropshot = metadata.dropshot_crate();
let mut items = Vec::with_capacity(item_trait.items.len());
let trait_ident = &item_trait.ident;
let item_trait = ApiItemTrait::new(item_trait, &errors);
let context_name = metadata.context_name();
let mut context_item = None;
for item in &item_trait.item.items {
if let TraitItemPartParsed::Other(syn::TraitItem::Type(ty)) = item {
if ty.ident == context_name {
context_item = Some(ContextItem::new(ty, &errors));
break;
}
}
}
let context_item = if let Some(context_item) = context_item {
context_item
} else {
ContextItem::new_missing(context_name, trait_ident, &errors)
};
let module_ident =
format_ident!("{}", metadata.module_name(trait_ident));
for item in &item_trait.item.items {
match item {
TraitItemPartParsed::Fn(f) => {
items.push(ApiItem::Fn(ApiFnItem::new(
&dropshot,
f,
trait_ident,
context_item.ident(),
&errors,
)));
}
TraitItemPartParsed::Other(other) => {
let should_push = match other {
syn::TraitItem::Const(c) => {
check_endpoint_or_channel_on_non_fn(
"const",
&c.ident.to_string(),
&c.attrs,
&errors,
);
true
}
syn::TraitItem::Fn(_) => {
unreachable!(
"function items should have been handled above"
)
}
syn::TraitItem::Type(t) => {
check_endpoint_or_channel_on_non_fn(
"type",
&t.ident.to_string(),
&t.attrs,
&errors,
);
t.ident != context_name
}
syn::TraitItem::Macro(m) => {
check_endpoint_or_channel_on_non_fn(
"macro",
&m.mac.path.to_token_stream().to_string(),
&m.attrs,
&errors,
);
true
}
_ => true,
};
if should_push {
items.push(ApiItem::Other(item));
}
}
}
}
Self {
dropshot,
item_trait,
items,
tag_config: metadata.tag_config,
context_item,
module_ident: Some(module_ident),
}
}
fn invalid_no_metadata(
item_trait: &'ast ItemTraitPartParsed,
errors: &ErrorSink<'_, Error>,
) -> Self {
let item_trait = ApiItemTrait::new(item_trait, errors);
let items = item_trait.item.items.iter().map(ApiItem::Other);
Self {
dropshot: get_crate(None),
item_trait,
items: items.collect(),
tag_config: None,
context_item: ContextItem::new_invalid_metadata(),
module_ident: None,
}
}
fn to_output(&self) -> ApiOutput {
let context = format!("Self::{}", self.context_item.ident());
let context_item =
self.make_context_trait_item().map(TraitItemPartParsed::Other);
let other_items =
self.items.iter().map(|item| item.to_out_trait_item());
let out_items = context_item.into_iter().chain(other_items);
let item_trait = self.item_trait.item;
let mut supertraits = item_trait.supertraits.clone();
supertraits.push(parse_quote!('static));
let mut attrs = item_trait.attrs.clone();
if !matches!(item_trait.vis, syn::Visibility::Public(_)) {
attrs.push(parse_quote!(#[allow(dead_code)]));
}
let out_trait = ItemTraitPartParsed {
attrs,
supertraits,
items: out_items.collect(),
..item_trait.clone()
};
let module = self.make_module();
let output = quote! {
#out_trait
#module
};
let has_endpoint_param_errors =
self.items.iter().any(|item| match item {
ApiItem::Fn(ApiFnItem::Invalid {
kind: InvalidApiItemKind::Endpoint(summary),
..
}) => summary.has_param_errors,
_ => false,
});
let has_channel_param_errors =
self.items.iter().any(|item| match item {
ApiItem::Fn(ApiFnItem::Invalid {
kind: InvalidApiItemKind::Channel(summary),
..
}) => summary.has_param_errors,
_ => false,
});
ApiOutput {
output,
context,
has_endpoint_param_errors,
has_channel_param_errors,
}
}
fn make_context_trait_item(&self) -> Option<syn::TraitItem> {
let dropshot = &self.dropshot;
let item = self.context_item.original_item()?;
let mut bounds = item.bounds.clone();
bounds.push(parse_quote!(#dropshot::ServerContext));
let out_item = syn::TraitItemType { bounds, ..item.clone() };
Some(syn::TraitItem::Type(out_item))
}
fn make_module(&self) -> TokenStream {
let item_trait = self.item_trait.valid_item();
let context_item = self.context_item.valid_item();
let module_ident = self.module_ident.as_ref();
match (item_trait, context_item, module_ident) {
(Some(item_trait), Some(context_item), Some(module_ident)) => {
let module_gen = SupportModuleGenerator {
dropshot: &self.dropshot,
module_ident,
item_trait,
context_item,
items: &self.items,
tag_config: self.tag_config.as_ref(),
};
module_gen.to_token_stream()
}
(_, _, Some(module_ident)) => {
let doc = ModuleDocComments::generate_invalid(
&self.item_trait.item.ident,
);
let outer = doc.outer();
let vis = &self.item_trait.item.vis;
quote! {
#outer
#vis mod #module_ident {}
}
}
_ => {
quote! {}
}
}
}
}
struct ApiOutput {
output: TokenStream,
context: String,
has_endpoint_param_errors: bool,
has_channel_param_errors: bool,
}
#[derive(Clone, Copy)]
struct ApiItemTrait<'ast> {
item: &'ast ItemTraitPartParsed,
is_valid: bool,
}
impl<'ast> ApiItemTrait<'ast> {
fn new(
item: &'ast ItemTraitPartParsed,
errors: &ErrorSink<'_, Error>,
) -> Self {
let trait_ident = &item.ident;
let errors = errors.new();
if item.unsafety.is_some() {
errors.push(Error::new_spanned(
&item.unsafety,
format!(
"API trait `{trait_ident}` must not be marked as `unsafe`"
),
));
}
if item.auto_token.is_some() {
errors.push(Error::new_spanned(
&item.auto_token,
format!("API trait `{trait_ident}` must not be an auto trait"),
));
}
if !item.generics.params.is_empty() {
errors.push(Error::new_spanned(
&item.generics,
format!("API trait `{trait_ident}` must not have generics"),
));
}
if let Some(where_clause) = &item.generics.where_clause {
if !where_clause.predicates.is_empty() {
errors.push(Error::new_spanned(
where_clause,
format!(
"API trait `{trait_ident}` must not have a where clause"
),
));
}
}
Self { item, is_valid: !errors.has_errors() }
}
fn valid_item(&self) -> Option<&'ast ItemTraitPartParsed> {
self.is_valid.then_some(self.item)
}
}
#[derive(Clone, Debug)]
enum ContextItem<'ast> {
Valid(&'ast syn::TraitItemType),
Invalid(&'ast syn::TraitItemType),
Missing {
ident: syn::Ident,
},
}
impl<'ast> ContextItem<'ast> {
fn new(
ty: &'ast syn::TraitItemType,
errors: &ErrorSink<'_, Error>,
) -> Self {
let errors = errors.new();
if !ty.generics.params.is_empty() {
errors.push(Error::new_spanned(
&ty.generics,
format!("context type `{}` must not have generics", ty.ident),
));
}
if errors.has_errors() {
Self::Invalid(ty)
} else {
Self::Valid(ty)
}
}
fn new_missing(
context_name: &str,
trait_ident: &syn::Ident,
errors: &ErrorSink<'_, Error>,
) -> Self {
errors.push(Error::new_spanned(
&trait_ident,
format!(
"API trait `{trait_ident}` does not have associated type \
`{context_name}`\n\
(this type specifies the shared context for endpoints)",
),
));
Self::Missing { ident: format_ident!("{context_name}") }
}
fn new_invalid_metadata() -> Self {
Self::Missing {
ident: format_ident!("{}", ApiMetadata::DEFAULT_CONTEXT_NAME),
}
}
fn ident(&self) -> &syn::Ident {
match self {
Self::Valid(ty) => &ty.ident,
Self::Invalid(ty) => &ty.ident,
Self::Missing { ident } => ident,
}
}
fn original_item(&self) -> Option<&'ast syn::TraitItemType> {
match self {
Self::Valid(ty) | Self::Invalid(ty) => Some(ty),
Self::Missing { .. } => None,
}
}
fn valid_item(&self) -> Option<&'ast syn::TraitItemType> {
match self {
Self::Valid(ty) => Some(ty),
Self::Invalid(_) | Self::Missing { .. } => None,
}
}
}
struct SupportModuleGenerator<'ast> {
dropshot: &'ast TokenStream,
module_ident: &'ast syn::Ident,
item_trait: &'ast ItemTraitPartParsed,
context_item: &'ast syn::TraitItemType,
items: &'ast [ApiItem<'ast>],
tag_config: Option<&'ast ApiTagConfig>,
}
impl SupportModuleGenerator<'_> {
fn make_api_description(&self, doc: TokenStream) -> TokenStream {
let dropshot = &self.dropshot;
let trait_ident = &self.item_trait.ident;
let context_ident = &self.context_item.ident;
let body = self.make_api_factory_body(FactoryKind::Regular);
quote! {
#doc
#[automatically_derived]
pub fn api_description<ServerImpl: #trait_ident>() -> ::std::result::Result<
#dropshot::ApiDescription<<ServerImpl as #trait_ident>::#context_ident>,
#dropshot::ApiDescriptionBuildErrors,
> {
#body
}
}
}
fn make_stub_api_description(&self, doc: TokenStream) -> TokenStream {
let dropshot = &self.dropshot;
let body = self.make_api_factory_body(FactoryKind::Stub);
quote! {
#doc
#[automatically_derived]
pub fn stub_api_description() -> ::std::result::Result<
#dropshot::ApiDescription<#dropshot::StubContext>,
#dropshot::ApiDescriptionBuildErrors,
> {
#body
}
}
}
fn make_api_factory_body(&self, kind: FactoryKind) -> TokenStream {
let dropshot = &self.dropshot;
let trait_ident = &self.item_trait.ident;
let trait_ident_str = trait_ident.to_string();
if self.has_invalid_fn_items() {
let err_msg = format!(
"invalid endpoints encountered while parsing API trait `{}`",
trait_ident_str
);
quote! {
panic!(#err_msg);
}
} else {
let tag_config = self.make_tag_config();
let endpoints = self.items.iter().filter_map(|item| match item {
ApiItem::Fn(ApiFnItem::Endpoint(e)) => {
Some(e.to_api_endpoint(&self.dropshot, kind))
}
ApiItem::Fn(ApiFnItem::Channel(c)) => {
Some(c.to_api_endpoint(&self.dropshot, kind))
}
ApiItem::Fn(ApiFnItem::Invalid { .. })
| ApiItem::Fn(ApiFnItem::Unmanaged(_))
| ApiItem::Other(_) => None,
});
quote! {
let mut dropshot_api = #dropshot::ApiDescription::new()#tag_config;
let mut dropshot_errors: Vec<#dropshot::ApiDescriptionRegisterError> = Vec::new();
#(#endpoints)*
if !dropshot_errors.is_empty() {
Err(#dropshot::ApiDescriptionBuildErrors::new(dropshot_errors))
} else {
Ok(dropshot_api)
}
}
}
}
fn make_tag_config(&self) -> Option<TokenStream> {
let dropshot = self.dropshot;
let tag_config = self.tag_config.as_ref()?;
let allow_other_tags = tag_config.allow_other_tags;
let policy = tag_config.policy.as_ref().map_or_else(
|| {
quote! { #dropshot::EndpointTagPolicy::Any }
},
|wrapper| wrapper.to_token_stream(),
);
let tags = tag_config.tags.iter().map(|(tag, details)| {
let description =
quote_project_option(details.description.as_deref());
let external_docs = details.external_docs.as_ref().map(|ed| {
let description =
quote_project_option(ed.description.as_deref());
let url = &ed.url;
quote! {
#dropshot::TagExternalDocs {
description: #description,
url: #url.to_string(),
}
}
});
let external_docs = quote_project_option(external_docs);
quote! {
tags.insert(
#tag.to_string(),
#dropshot::TagDetails {
description: #description,
external_docs: #external_docs,
}
);
}
});
Some(quote! {
.tag_config({
let mut tags = ::std::collections::HashMap::new();
#(#tags)*
#dropshot::TagConfig {
allow_other_tags: #allow_other_tags,
policy: #policy,
tags,
}
})
})
}
fn make_doc_comments(&self) -> ModuleDocComments {
if self.has_invalid_fn_items() {
ModuleDocComments::generate_invalid(&self.item_trait.ident)
} else {
ModuleDocComments::generate(
self.dropshot,
&self.item_trait.ident,
&self.context_item.ident,
self.module_ident,
)
}
}
fn make_type_checks(&self) -> impl Iterator<Item = TokenStream> + '_ {
self.items.iter().filter_map(|item| match item {
ApiItem::Fn(ApiFnItem::Endpoint(_)) => {
None
}
ApiItem::Fn(ApiFnItem::Channel(c)) => {
Some(c.params.to_trait_type_checks())
}
_ => None,
})
}
fn has_invalid_fn_items(&self) -> bool {
self.items.iter().any(|item| item.is_invalid())
}
}
impl ToTokens for SupportModuleGenerator<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let vis = &self.item_trait.vis;
let module_ident = self.module_ident;
let doc_comments = self.make_doc_comments();
let api = self.make_api_description(doc_comments.api_description());
let stub_api =
self.make_stub_api_description(doc_comments.stub_api_description());
let type_checks = self.make_type_checks();
let outer = doc_comments.outer();
tokens.extend(quote! {
#outer
#[automatically_derived]
#vis mod #module_ident {
use super::*;
#(#type_checks)*
#stub_api
#api
}
});
}
}
fn quote_project_option<T: ToTokens>(t: Option<T>) -> TokenStream {
t.map_or_else(|| quote! { None }, |t| quote! { Some(#t.into()) })
}
struct ModuleDocComments {
outer: String,
api_description: String,
stub_api_description: String,
}
impl ModuleDocComments {
fn generate(
dropshot: &TokenStream,
trait_ident: &syn::Ident,
context_ident: &syn::Ident,
module_ident: &syn::Ident,
) -> ModuleDocComments {
let outer = format!(
"Support module for the Dropshot API trait \
[`{trait_ident}`]({trait_ident}).",
);
let api_description = format!(
"Given an implementation of [`{trait_ident}`], generate an API description.
This function accepts a single type argument `ServerImpl`, turning it into a
Dropshot [`ApiDescription`]`<ServerImpl::`[`{context_ident}`]`>`.
The returned `ApiDescription` can then be turned into a Dropshot server that
accepts a concrete `{context_ident}`.
## Example
```rust,ignore
/// A type used to define the concrete implementation for `{trait_ident}`.
///
/// This type is never constructed -- it is just a place to define your
/// implementation of `{trait_ident}`.
enum {trait_ident}Impl {{}}
impl {trait_ident} for {trait_ident}Impl {{
type {context_ident} = /* context type */;
// ... trait methods
}}
#[tokio::main]
async fn main() {{
// Generate the description for `{trait_ident}Impl`.
let description = {module_ident}::api_description::<{trait_ident}Impl>().unwrap();
// Create a value of the concrete context type.
let context = /* some value of type `{trait_ident}Impl::{context_ident}` */;
// Create a Dropshot server from the description.
let log = /* ... */;
let server = dropshot::ServerBuilder::new(description, context, log)
.start()
.unwrap();
// Run the server.
server.await
}}
```
[`ApiDescription`]: {dropshot}::ApiDescription
[`{trait_ident}`]: {trait_ident}
[`{context_ident}`]: {trait_ident}::{context_ident}
",
);
let stub_api_description = format!(
"Generate a _stub_ API description for [`{trait_ident}`], meant for OpenAPI
generation.
Unlike [`api_description`], this function does not require an implementation
of [`{trait_ident}`] to be available, instead generating handlers that panic.
The return value is of type [`ApiDescription`]`<`[`StubContext`]`>`.
The main use of this function is in cases where [`{trait_ident}`] is defined
in a separate crate from its implementation. The OpenAPI spec can then be
generated directly from the stub API description.
## Example
A function that prints the OpenAPI spec to standard output:
```rust,ignore
fn print_openapi_spec() {{
let stub = {module_ident}::stub_api_description().unwrap();
// Generate OpenAPI spec from `stub`.
let spec = stub.openapi(\"{trait_ident}\", \"0.1.0\");
spec.write(&mut std::io::stdout()).unwrap();
}}
```
[`{trait_ident}`]: {trait_ident}
[`api_description`]: {module_ident}::api_description
[`ApiDescription`]: {dropshot}::ApiDescription
[`StubContext`]: {dropshot}::StubContext
");
ModuleDocComments { outer, api_description, stub_api_description }
}
fn generate_invalid(trait_ident: &syn::Ident) -> Self {
let outer = format!(
"**Invalid**: Support module for the Dropshot API trait `{trait_ident}`.
Errors were encountered while parsing the API.
");
let api_description = format!(
"**Invalid, panics:** Given an implementation of `{trait_ident}`, generate an API description.
Errors were encountered while parsing the API, so this function panics.
");
let stub_api_description = format!(
"**Invalid, panics:** Generate a _stub_ API description for `{trait_ident}`, meant for OpenAPI
generation.
Errors were encountered while parsing the API, so this function panics.
");
ModuleDocComments { outer, api_description, stub_api_description }
}
fn outer(&self) -> TokenStream {
string_to_doc_attrs(&self.outer)
}
fn api_description(&self) -> TokenStream {
string_to_doc_attrs(&self.api_description)
}
fn stub_api_description(&self) -> TokenStream {
string_to_doc_attrs(&self.stub_api_description)
}
}
#[derive(Clone, Copy, Debug)]
enum FactoryKind {
Regular,
Stub,
}
fn check_endpoint_or_channel_on_non_fn(
kind: &str,
name: &str,
attrs: &[syn::Attribute],
errors: &ErrorSink<'_, Error>,
) {
if let Some(attr) = attrs.iter().find(|a| a.path().is_ident(ENDPOINT_IDENT))
{
errors.push(Error::new_spanned(
attr,
format!("{kind} `{name}` marked as endpoint is not a method"),
));
}
if let Some(attr) = attrs.iter().find(|a| a.path().is_ident(CHANNEL_IDENT))
{
errors.push(Error::new_spanned(
attr,
format!("{kind} `{name}` marked as channel is not a method"),
));
}
}
#[allow(clippy::large_enum_variant)]
enum ApiItem<'ast> {
Fn(ApiFnItem<'ast>),
Other(&'ast TraitItemPartParsed),
}
impl ApiItem<'_> {
fn is_invalid(&self) -> bool {
matches!(self, Self::Fn(ApiFnItem::Invalid { .. }))
}
fn to_out_trait_item(&self) -> TraitItemPartParsed {
match self {
Self::Fn(f) => TraitItemPartParsed::Fn(f.to_out_trait_item()),
Self::Other(o) => {
o.clone_and_strip_recognized_attrs()
}
}
}
}
#[allow(clippy::large_enum_variant)]
enum ApiFnItem<'ast> {
Endpoint(ApiEndpoint<'ast>),
Channel(ApiChannel<'ast>),
Invalid { f: &'ast TraitItemFnForSignature, kind: InvalidApiItemKind },
Unmanaged(&'ast TraitItemFnForSignature),
}
impl<'ast> ApiFnItem<'ast> {
fn new(
dropshot: &TokenStream,
f: &'ast TraitItemFnForSignature,
trait_ident: &'ast syn::Ident,
context_ident: &syn::Ident,
errors: &ErrorSink<'_, Error>,
) -> Self {
let attrs = f
.attrs
.iter()
.filter_map(|attr| {
if attr.path().is_ident(ENDPOINT_IDENT) {
Some(ApiAttr::Endpoint(attr))
} else if attr.path().is_ident(CHANNEL_IDENT) {
Some(ApiAttr::Channel(attr))
} else {
None
}
})
.collect::<Vec<_>>();
match attrs.as_slice() {
[] => {
Self::Unmanaged(f)
}
[ApiAttr::Endpoint(eattr)] => {
match ApiEndpoint::new(
dropshot,
f,
eattr,
trait_ident,
context_ident,
errors,
) {
Ok(endpoint) => Self::Endpoint(endpoint),
Err(summary) => Self::Invalid {
f,
kind: InvalidApiItemKind::Endpoint(summary),
},
}
}
[ApiAttr::Channel(cattr)] => {
match ApiChannel::new(
dropshot,
f,
cattr,
trait_ident,
context_ident,
errors,
) {
Ok(channel) => Self::Channel(channel),
Err(summary) => Self::Invalid {
f,
kind: InvalidApiItemKind::Channel(summary),
},
}
}
[first, rest @ ..] => {
let name = &f.sig.ident;
for attr in rest {
let msg = match (first, attr) {
(ApiAttr::Endpoint(_), ApiAttr::Endpoint(_)) => {
format!("method `{name}` marked as endpoint multiple times")
}
(ApiAttr::Channel(_), ApiAttr::Channel(_)) => {
format!("method `{name}` marked as channel multiple times")
}
_ => {
format!("method `{name}` marked as both endpoint and channel")
}
};
errors.push(Error::new_spanned(attr, msg));
}
Self::Invalid { f, kind: InvalidApiItemKind::Unknown }
}
}
}
fn to_out_trait_item(&self) -> TraitItemFnForSignature {
match self {
Self::Endpoint(e) => e.to_out_trait_item(),
Self::Channel(c) => c.to_out_trait_item(),
Self::Invalid { f, .. } | Self::Unmanaged(f) => {
f.clone_and_strip_recognized_attrs()
}
}
}
}
fn parse_endpoint_metadata(
name_str: &str,
attr: &syn::Attribute,
errors: &ErrorSink<'_, Error>,
) -> Option<ValidatedEndpointMetadata> {
let l = match &attr.meta {
syn::Meta::List(l) => l,
_ => {
errors.push(Error::new_spanned(
&attr,
format!(
"endpoint `{name_str}` must be of the form \
#[endpoint {{ method = GET, path = \"/path\", ... }}]"
),
));
return None;
}
};
match from_tokenstream_spanned::<EndpointMetadata>(
l.delimiter.span(),
&l.tokens,
) {
Ok(m) => m.validate(name_str, attr, MacroKind::Trait, errors),
Err(error) => {
errors.push(Error::new(
error.span(),
format!(
"endpoint `{name_str}` has invalid attributes: {error}"
),
));
return None;
}
}
}
struct ApiEndpoint<'ast> {
f: &'ast TraitItemFnForSignature,
attr: &'ast syn::Attribute,
trait_ident: &'ast syn::Ident,
metadata: ValidatedEndpointMetadata,
params: EndpointParams<'ast>,
}
impl<'ast> ApiEndpoint<'ast> {
fn new(
dropshot: &TokenStream,
f: &'ast TraitItemFnForSignature,
attr: &'ast syn::Attribute,
trait_ident: &'ast syn::Ident,
context_ident: &syn::Ident,
errors: &ErrorSink<'_, Error>,
) -> Result<Self, ApiItemErrorSummary> {
let name_str = f.sig.ident.to_string();
let metadata = parse_endpoint_metadata(&name_str, attr, errors);
let params = EndpointParams::new(
dropshot,
&f.sig,
RqctxKind::Trait { trait_ident, context_ident },
errors,
);
match (metadata, params) {
(Some(metadata), Some(params)) => {
Ok(Self { f, attr, trait_ident, metadata, params })
}
(_, params) => {
Err(ApiItemErrorSummary { has_param_errors: params.is_none() })
}
}
}
fn to_out_trait_item(&self) -> TraitItemFnForSignature {
let mut f = self.f.clone();
transform_signature(&mut f, &self.params.ret_ty);
f
}
fn to_api_endpoint(
&self,
dropshot: &TokenStream,
kind: FactoryKind,
) -> TokenStream {
match kind {
FactoryKind::Regular => {
let name = &self.f.sig.ident;
let trait_ident = self.trait_ident;
let path_to_name = quote_spanned! {self.attr.span()=>
<ServerImpl as #trait_ident>::#name
};
self.to_api_endpoint_impl(
dropshot,
&ApiEndpointKind::Regular(&path_to_name),
)
}
FactoryKind::Stub => {
let extractor_types = self.params.extractor_types().collect();
let ret_ty = self.params.ret_ty;
self.to_api_endpoint_impl(
dropshot,
&ApiEndpointKind::Stub {
attr: &self.attr,
extractor_types,
ret_ty,
},
)
}
}
}
fn to_api_endpoint_impl(
&self,
dropshot: &TokenStream,
kind: &ApiEndpointKind<'_>,
) -> TokenStream {
let name = &self.f.sig.ident;
let name_str = name.to_string();
let doc = ExtractedDoc::from_attrs(&self.f.attrs);
let endpoint_fn =
self.metadata.to_api_endpoint_fn(dropshot, &name_str, kind, &doc);
let endpoint_name = format_ident!("endpoint_{}", name_str);
quote_spanned! {self.attr.span()=>
{
let #endpoint_name = #endpoint_fn;
if let Err(error) = dropshot_api.register(#endpoint_name) {
dropshot_errors.push(error);
}
}
}
}
}
fn parse_channel_metadata(
name_str: &str,
attr: &syn::Attribute,
errors: &ErrorSink<'_, Error>,
) -> Option<ValidatedChannelMetadata> {
let l = match &attr.meta {
syn::Meta::List(l) => l,
_ => {
errors.push(Error::new_spanned(
&attr,
format!(
"endpoint `{name_str}` must be of the form \
#[channel {{ protocol = WEBSOCKETS, path = \"/path\", ... }}]"
),
));
return None;
}
};
match from_tokenstream_spanned::<ChannelMetadata>(
l.delimiter.span(),
&l.tokens,
) {
Ok(m) => m.validate(name_str, attr, MacroKind::Trait, errors),
Err(error) => {
errors.push(Error::new(
error.span(),
format!(
"endpoint `{name_str}` has invalid attributes: {error}"
),
));
return None;
}
}
}
struct ApiChannel<'ast> {
f: &'ast TraitItemFnForSignature,
attr: &'ast syn::Attribute,
trait_ident: &'ast syn::Ident,
metadata: ValidatedChannelMetadata,
params: ChannelParams<'ast>,
}
impl<'ast> ApiChannel<'ast> {
fn new(
dropshot: &TokenStream,
f: &'ast TraitItemFnForSignature,
attr: &'ast syn::Attribute,
trait_ident: &'ast syn::Ident,
context_ident: &syn::Ident,
errors: &ErrorSink<'_, Error>,
) -> Result<Self, ApiItemErrorSummary> {
let name_str = f.sig.ident.to_string();
let metadata = parse_channel_metadata(&name_str, attr, errors);
let params = ChannelParams::new(
dropshot,
&f.sig,
RqctxKind::Trait { trait_ident, context_ident },
errors,
);
match (metadata, params) {
(Some(metadata), Some(params)) => {
Ok(Self { f, attr, trait_ident, metadata, params })
}
(_, params) => {
Err(ApiItemErrorSummary { has_param_errors: params.is_none() })
}
}
}
fn to_out_trait_item(&self) -> TraitItemFnForSignature {
let mut f = self.f.clone();
transform_signature(&mut f, &self.params.ret_ty);
f
}
fn to_api_endpoint(
&self,
dropshot: &TokenStream,
kind: FactoryKind,
) -> TokenStream {
match kind {
FactoryKind::Regular => {
let adapter_fn =
self.params.to_trait_adapter_fn(self.trait_ident);
let adapter_name = &self.params.adapter_name;
let path_to_name = quote_spanned! {self.attr.span()=>
#adapter_name::<ServerImpl>
};
let endpoint = self.to_api_endpoint_impl(
&dropshot,
&ApiEndpointKind::Regular(&path_to_name),
);
quote_spanned! {self.attr.span()=>
{
#adapter_fn
#endpoint
}
}
}
FactoryKind::Stub => {
let extractor_types = self.params.extractor_types().collect();
let ret_ty = &self.params.endpoint_result_ty;
self.to_api_endpoint_impl(
dropshot,
&ApiEndpointKind::Stub {
attr: &self.attr,
extractor_types,
ret_ty,
},
)
}
}
}
fn to_api_endpoint_impl(
&self,
dropshot: &TokenStream,
kind: &ApiEndpointKind<'_>,
) -> TokenStream {
let name = &self.f.sig.ident;
let name_str = name.to_string();
let doc = ExtractedDoc::from_attrs(&self.f.attrs);
let endpoint_fn =
self.metadata.to_api_endpoint_fn(dropshot, &name_str, kind, &doc);
let endpoint_name = format_ident!("endpoint_{}", name_str);
quote_spanned! {self.attr.span()=>
{
let #endpoint_name = #endpoint_fn;
if let Err(error) = dropshot_api.register(#endpoint_name) {
dropshot_errors.push(error);
}
}
}
}
}
fn transform_signature(f: &mut TraitItemFnForSignature, ret_ty: &syn::Type) {
f.strip_recognized_attrs();
let output_ty = {
let bounds = parse_quote_spanned! {ret_ty.span()=>
::core::future::Future<Output = #ret_ty> + Send + 'static
};
syn::Type::ImplTrait(syn::TypeImplTrait {
impl_token: Default::default(),
bounds,
})
};
let block = f.block.as_ref().map(|block| {
let block = block.clone();
let tokens = quote_spanned! {block.span()=> async move #block };
UnparsedBlock { brace_token: block.brace_token, tokens }
});
f.sig.asyncness = None;
f.sig.output =
syn::ReturnType::Type(Default::default(), Box::new(output_ty));
f.block = block;
}
enum ApiAttr<'ast> {
Endpoint(&'ast syn::Attribute),
Channel(&'ast syn::Attribute),
}
impl ToTokens for ApiAttr<'_> {
fn to_tokens(&self, tokens: &mut TokenStream) {
match self {
ApiAttr::Endpoint(attr) => attr.to_tokens(tokens),
ApiAttr::Channel(attr) => attr.to_tokens(tokens),
}
}
}
#[derive(Clone, Copy, Debug)]
enum InvalidApiItemKind {
Endpoint(ApiItemErrorSummary),
Channel(ApiItemErrorSummary),
Unknown,
}
#[derive(Clone, Copy, Debug)]
struct ApiItemErrorSummary {
has_param_errors: bool,
}
trait StripRecognizedAttrs {
fn strip_recognized_attrs(&mut self);
fn clone_and_strip_recognized_attrs(&self) -> Self
where
Self: Clone,
{
let mut cloned = self.clone();
cloned.strip_recognized_attrs();
cloned
}
}
impl StripRecognizedAttrs for TraitItemPartParsed {
fn strip_recognized_attrs(&mut self) {
match self {
TraitItemPartParsed::Fn(f) => f.strip_recognized_attrs(),
TraitItemPartParsed::Other(o) => o.strip_recognized_attrs(),
}
}
}
impl StripRecognizedAttrs for TraitItemFnForSignature {
fn strip_recognized_attrs(&mut self) {
self.attrs.strip_recognized_attrs();
}
}
impl StripRecognizedAttrs for syn::TraitItem {
fn strip_recognized_attrs(&mut self) {
match self {
syn::TraitItem::Const(c) => {
c.attrs.strip_recognized_attrs();
}
syn::TraitItem::Fn(f) => {
f.attrs.strip_recognized_attrs();
}
syn::TraitItem::Type(t) => {
t.attrs.strip_recognized_attrs();
}
syn::TraitItem::Macro(m) => {
m.attrs.strip_recognized_attrs();
}
_ => {}
}
}
}
impl StripRecognizedAttrs for Vec<syn::Attribute> {
fn strip_recognized_attrs(&mut self) {
self.retain(|a| {
!(a.path().is_ident(ENDPOINT_IDENT)
|| a.path().is_ident(CHANNEL_IDENT))
});
}
}
#[cfg(test)]
mod tests {
use expectorate::assert_contents;
use crate::{test_util::assert_banned_idents, util::DROPSHOT};
use super::*;
#[test]
fn test_api_trait_basic() {
let (item, errors) = do_trait(
quote! {},
quote! {
trait MyTrait {
type Context;
#[endpoint {
method = GET,
path = "/xyz",
versions = "1.2.3"..
}]
async fn handler_xyz(
rqctx: RequestContext<Self::Context>,
) -> Result<HttpResponseOk<()>, HttpError>;
#[channel {
protocol = WEBSOCKETS,
path = "/ws",
versions = ..
}]
async fn handler_ws(
rqctx: RequestContext<Self::Context>,
upgraded: WebsocketConnection,
) -> WebsocketChannelResult;
}
},
);
assert!(errors.is_empty());
assert_contents(
"tests/output/api_trait_basic.rs",
&prettyplease::unparse(&parse_quote! { #item }),
);
}
#[test]
fn test_api_trait_with_custom_params() {
let (item, errors) = do_trait(
quote! {
context = Situation,
module = my_support_module,
tag_config = {
allow_other_tags = true,
policy = EndpointTagPolicy::Any,
tags = {
topspin = {
description =
"Topspin is a tennis shot that \
causes the ball to spin forward",
external_docs = {
description = "Wikipedia entry",
url = "https://en.wikipedia.org/wiki/Topspin",
},
},
},
},
_dropshot_crate = "topspin",
},
quote! {
pub trait MyTrait {
type Situation;
#[endpoint { method = GET, path = "/xyz" }]
async fn handler_xyz(
rqctx: RequestContext<Self::Situation>,
) -> Result<HttpResponseOk<()>, HttpError>;
#[channel { protocol = WEBSOCKETS, path = "/ws" }]
async fn handler_ws(
rqctx: RequestContext<Self::Situation>,
upgraded: WebsocketConnection,
) -> WebsocketChannelResult;
}
},
);
eprintln!("errors: {:#?}", errors);
assert!(errors.is_empty());
let file = parse_quote! { #item };
assert_contents(
"tests/output/api_trait_with_custom_params.rs",
&prettyplease::unparse(&file),
);
let banned = [ApiMetadata::DEFAULT_CONTEXT_NAME, DROPSHOT, "my_trait"];
assert_banned_idents(&file, banned);
}
#[test]
fn test_api_trait_no_endpoints() {
let (item, errors) = do_trait(
quote! {},
quote! {
pub(crate) trait MyTrait {
type Context;
}
},
);
assert!(errors.is_empty());
assert_contents(
"tests/output/api_trait_no_endpoints.rs",
&prettyplease::unparse(&parse_quote! { #item }),
);
}
#[test]
fn test_api_trait_operation_id() {
let (item, errors) = do_trait(
quote! {},
quote! {
pub trait MyTrait {
type Context;
#[endpoint {
operation_id = "vzerolower",
method = GET,
path = "/xyz"
}]
async fn handler_xyz(
rqctx: RequestContext<Self::Context>,
) -> Result<HttpResponseOk<()>, HttpError>;
#[channel {
protocol = WEBSOCKETS,
path = "/ws",
operation_id = "vzeroupper",
}]
async fn handler_ws(
rqctx: RequestContext<Self::Context>,
upgraded: WebsocketConnection,
) -> WebsocketChannelResult;
}
},
);
assert!(errors.is_empty());
assert_contents(
"tests/output/api_trait_operation_id.rs",
&prettyplease::unparse(&parse_quote! { #item }),
);
}
}