use buffa_codegen::idents::{escape_mod_ident, make_field_ident};
use connectrpc_codegen::plugin::CodeGeneratorResponseFile;
use heck::ToSnakeCase;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Ident, Path, Type};
use uni_error::UniError;
use crate::error::{CodegenErrKind, CodegenResult};
use crate::internal::guardrails::{ensure_unique_generated_identifiers, ensure_unique_routes};
use crate::internal::ir::{
CommentSet, DescriptorIr, Field, FieldKind, FieldLabel, HttpVerb, Method, ProtoFile, Service,
};
use crate::internal::options::CodegenOptions;
use crate::internal::resolver::TypeResolver;
use crate::internal::shape::{
FieldSource, FileShapes, GeneratedDto, RequestPartShape, RequestReconstruction, RequestShape,
ShapeField, plan_file_shapes,
};
const REST_MODULE_SUFFIX: &str = "_rest";
const ROUTER_FUNCTION_NAME: &str = "make_router";
pub fn generate_file(
ir: &DescriptorIr,
file_name: &str,
options: &CodegenOptions,
) -> CodegenResult<Option<CodeGeneratorResponseFile>> {
let proto_file = ir.file(file_name).ok_or_else(|| {
UniError::from_kind_context(
CodegenErrKind::FileToGenerateNotFound,
format!("file_to_generate {file_name:?} was not present in proto_file"),
)
})?;
if !proto_file.has_http_bindings() {
return Ok(None);
}
let output_name = output_file_name(proto_file.name.as_ref());
let content = RustGenerator::new(ir, proto_file, options)?.generate()?;
Ok(Some(CodeGeneratorResponseFile {
name: Some(output_name),
content: Some(content),
..Default::default()
}))
}
fn output_file_name(proto_file_name: &str) -> String {
let stem = proto_file_name
.strip_suffix(".proto")
.unwrap_or(proto_file_name);
format!("{stem}.connect2rest.rs")
}
struct RustGenerator<'a> {
proto_file: &'a ProtoFile,
options: &'a CodegenOptions,
resolver: TypeResolver<'a>,
shapes: FileShapes,
}
impl<'a> RustGenerator<'a> {
fn new(
ir: &'a DescriptorIr,
proto_file: &'a ProtoFile,
options: &'a CodegenOptions,
) -> CodegenResult<Self> {
Ok(Self {
proto_file,
options,
resolver: TypeResolver::new(ir, options),
shapes: plan_file_shapes(ir, proto_file, options)?,
})
}
fn generate(&self) -> CodegenResult<String> {
self.validate_file()?;
let modules = self
.proto_file
.services
.iter()
.filter(|service| service.has_http_bindings())
.map(|service| self.service_module(service))
.collect::<CodegenResult<Vec<_>>>()?;
let source = format_tokens(quote! {
#(#modules)*
})?;
Ok(format!(
"// @generated by connect2rest.\n\
// This file is generated from {}. Do not edit by hand.\n\n\
{source}",
self.proto_file.name.as_ref()
))
}
fn service_module(&self, service: &Service) -> CodegenResult<TokenStream> {
self.validate_service(service)?;
let module_ident = parse_ident(
&service_module_name(service.name.as_ref()),
"generated REST module",
)?;
let dto_items = self
.shapes
.generated_dtos
.iter()
.map(dto_tokens)
.collect::<CodegenResult<Vec<_>>>()?;
let handler_items = service
.methods
.iter()
.filter(|method| method.http.is_some())
.map(|method| self.handler_tokens(service, method))
.collect::<CodegenResult<Vec<_>>>()?;
let router_item = self.router_tokens(service)?;
Ok(quote! {
pub mod #module_ident {
#![allow(unused_imports)]
use std::sync::Arc;
use axum::extract::{Path, Query, State};
use axum::Json;
use http::{Extensions, HeaderMap};
#(#dto_items)*
#(#handler_items)*
#router_item
}
})
}
fn handler_tokens(&self, service: &Service, method: &Method) -> CodegenResult<TokenStream> {
let shape = self.shape_for(method)?;
let service_trait = parse_path(
self.resolver
.connect_service_trait(service.full_name.as_ref())
.as_str(),
"Connect service trait",
)?;
let method_ident = parse_ident(
self.resolver.method_fn_name(method.name.as_ref()).as_ref(),
"handler function",
)?;
let service_value = parse_ident(
self.resolver.value_ident("service", self.options).as_ref(),
"service binding",
)?;
let request_value = parse_ident(
self.resolver.value_ident("request", self.options).as_ref(),
"request binding",
)?;
let ctx_value = parse_ident(
self.resolver.value_ident("ctx", self.options).as_ref(),
"context binding",
)?;
let query_value = parse_ident(
self.resolver.value_ident("query", self.options).as_ref(),
"query binding",
)?;
let headers_value = parse_ident(
self.resolver.value_ident("headers", self.options).as_ref(),
"headers binding",
)?;
let extensions_value = parse_ident(
self.resolver
.value_ident("extensions", self.options)
.as_ref(),
"extensions binding",
)?;
let body_value = parse_ident(
self.resolver.value_ident("body", self.options).as_ref(),
"body binding",
)?;
let runtime = parse_path(self.options.runtime_module.as_ref(), "runtime module")?;
let output_type = parse_type(
self.resolver
.owned_message_type(method.output_type.as_ref())?
.as_str(),
"method output type",
)?;
let request_type = parse_type(shape.request_type.as_str(), "method request type")?;
let request_view_type = parse_type(
&format!("{}<'static>", shape.request_view_type.as_str()),
"method request view type",
)?;
let streaming_content_type = self.options.streaming_content_type.as_ref();
let mut params = vec![quote! {
State(#service_value): State<Arc<S>>
}];
if let Some(path_param) = path_extractor_param(shape, &self.resolver, self.options)? {
params.push(path_param);
}
if let Some(query_type) = shape.query_shape.as_ref().map(part_type).transpose()? {
params.push(quote! {
Query(#query_value): Query<#query_type>
});
}
params.push(quote! {
#headers_value: HeaderMap
});
params.push(quote! {
#extensions_value: Extensions
});
if method.client_streaming {
params.push(quote! {
#body_value: #runtime::JsonLines<#request_type>
});
} else if let Some(body_type) = shape.body_shape.as_ref().map(part_type).transpose()? {
params.push(quote! {
Json(#body_value): Json<#body_type>
});
}
let request_preparation = if method.client_streaming {
quote! {
let #request_value = #runtime::ndjson_request_stream::<#request_view_type>(#body_value);
}
} else {
let request_reconstruction = request_reconstruction_tokens(
shape,
&request_value,
&query_value,
&body_value,
&self.resolver,
self.options,
)?;
quote! {
#request_reconstruction
let #request_value = match #runtime::owned_view::<#request_view_type>(&#request_value) {
Ok(#request_value) => #request_value,
Err(err) => return #runtime::error_response(err),
};
}
};
let response_conversion = if method.server_streaming {
quote! {
#runtime::stream_response::<#output_type, _>(
#service_value.#method_ident(#ctx_value, #request_value).await,
#streaming_content_type
)
}
} else {
quote! {
#runtime::service_response::<#output_type, _>(
#service_value.#method_ident(#ctx_value, #request_value).await
)
}
};
Ok(quote! {
pub async fn #method_ident<S>(
#(#params),*
) -> http::Response<axum::body::Body>
where
S: #service_trait + Send + Sync + 'static,
{
let #ctx_value = #runtime::request_context(#headers_value, #extensions_value);
#request_preparation
#response_conversion
}
})
}
fn router_tokens(&self, service: &Service) -> CodegenResult<TokenStream> {
let service_trait = parse_path(
self.resolver
.connect_service_trait(service.full_name.as_ref())
.as_str(),
"Connect service trait",
)?;
let router_ident = parse_ident(ROUTER_FUNCTION_NAME, "router function")?;
let route_calls = service
.methods
.iter()
.filter(|method| method.http.is_some())
.map(|method| self.route_tokens(method))
.collect::<CodegenResult<Vec<_>>>()?;
Ok(quote! {
pub fn #router_ident<S>(service: Arc<S>) -> axum::Router
where
S: #service_trait + Send + Sync + 'static,
{
axum::Router::new()
#(#route_calls)*
.with_state(service)
}
})
}
fn route_tokens(&self, method: &Method) -> CodegenResult<TokenStream> {
let binding = method.http.as_ref().ok_or_else(|| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("method {} has no HTTP binding", method.full_name.as_ref()),
)
})?;
let path = binding.path.as_ref();
let method_ident = parse_ident(
self.resolver.method_fn_name(method.name.as_ref()).as_ref(),
"route handler",
)?;
let route = match binding.verb {
HttpVerb::Get => quote! {
.route(#path, axum::routing::get(#method_ident::<S>))
},
HttpVerb::Post => quote! {
.route(#path, axum::routing::post(#method_ident::<S>))
},
HttpVerb::Put => quote! {
.route(#path, axum::routing::put(#method_ident::<S>))
},
HttpVerb::Delete => quote! {
.route(#path, axum::routing::delete(#method_ident::<S>))
},
HttpVerb::Patch => quote! {
.route(#path, axum::routing::patch(#method_ident::<S>))
},
};
Ok(route)
}
fn shape_for(&self, method: &Method) -> CodegenResult<&RequestShape> {
self.shapes
.request_shapes
.iter()
.find(|shape| shape.method == method.full_name)
.ok_or_else(|| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!(
"request shape for {} was not planned",
method.full_name.as_ref()
),
)
})
}
fn validate_file(&self) -> CodegenResult<()> {
let scope = format!("REST file {}", self.proto_file.name.as_ref());
ensure_unique_generated_identifiers(
&scope,
self.proto_file
.services
.iter()
.filter(|service| service.has_http_bindings())
.map(|service| {
(
service_module_name(service.name.as_ref()),
format!("service {}", service.full_name.as_ref()),
)
}),
)?;
ensure_unique_generated_identifiers(
&scope,
self.shapes.generated_dtos.iter().map(|dto| {
(
dto.name.as_ref().to_owned(),
format!(
"generated {:?} DTO for {}",
dto.kind,
dto.source_message.as_ref()
),
)
}),
)
}
fn validate_service(&self, service: &Service) -> CodegenResult<()> {
let scope = format!("REST service {}", service.full_name.as_ref());
let methods = service
.methods
.iter()
.filter(|method| method.http.is_some())
.collect::<Vec<_>>();
ensure_unique_generated_identifiers(
&scope,
methods.iter().map(|method| {
(
self.resolver
.method_fn_name(method.name.as_ref())
.as_ref()
.to_owned(),
format!("method {}", method.full_name.as_ref()),
)
}),
)?;
ensure_unique_routes(
&scope,
methods
.iter()
.map(|method| {
rest_route_key(method)
.map(|key| (key, format!("method {}", method.full_name.as_ref())))
})
.collect::<CodegenResult<Vec<_>>>()?,
)
}
}
fn rest_route_key(method: &Method) -> CodegenResult<String> {
let binding = method.http.as_ref().ok_or_else(|| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("method {} has no HTTP binding", method.full_name.as_ref()),
)
})?;
Ok(format!(
"{} {}",
binding.verb.as_str(),
binding.path.as_ref()
))
}
fn dto_tokens(dto: &GeneratedDto) -> CodegenResult<TokenStream> {
let comments = comment_attrs(&dto.comments);
let dto_ident = parse_ident(dto.name.as_ref(), "generated DTO")?;
let fields = dto
.fields
.iter()
.map(dto_field_tokens)
.collect::<CodegenResult<Vec<_>>>()?;
Ok(quote! {
#(#comments)*
#[derive(Clone, Debug, Default, serde::Deserialize)]
#[serde(default)]
pub struct #dto_ident {
#(#fields),*
}
})
}
fn dto_field_tokens(field: &ShapeField) -> CodegenResult<TokenStream> {
let comments = comment_attrs(&field.field.comments);
let field_ident = parse_ident(
&make_field_ident(field.field.name.as_ref()).to_string(),
"DTO field",
)?;
let field_type = parse_type(field.rust_type.as_str(), "DTO field type")?;
let serde_attrs = serde_field_attrs(&field.field)?;
Ok(quote! {
#(#comments)*
#(#serde_attrs)*
pub #field_ident: #field_type
})
}
fn serde_field_attrs(field: &Field) -> CodegenResult<Vec<TokenStream>> {
let json_name = field.json_name.as_ref();
let proto_name = field.name.as_ref();
let alias = if json_name == proto_name {
quote! {}
} else {
quote! { , alias = #proto_name }
};
let with = serde_with_module(field)
.map(|module| quote! { , with = #module })
.unwrap_or_default();
let deserialize_with = serde_deserialize_with(field)
.map(|module| quote! { , deserialize_with = #module })
.unwrap_or_default();
Ok(vec![quote! {
#[serde(rename = #json_name #alias #with #deserialize_with)]
}])
}
fn serde_with_module(field: &Field) -> Option<&'static str> {
if field.label == Some(FieldLabel::Repeated) {
return Some(match field.kind {
FieldKind::Enum(_) => "::buffa::json_helpers::repeated_enum",
_ => "::buffa::json_helpers::proto_seq",
});
}
match field.kind {
FieldKind::Double => Some("::buffa::json_helpers::double"),
FieldKind::Float => Some("::buffa::json_helpers::float"),
FieldKind::Int64 | FieldKind::Sint64 | FieldKind::Sfixed64 => {
Some("::buffa::json_helpers::int64")
}
FieldKind::Uint64 | FieldKind::Fixed64 => Some("::buffa::json_helpers::uint64"),
FieldKind::Int32 | FieldKind::Sint32 | FieldKind::Sfixed32 => {
Some("::buffa::json_helpers::int32")
}
FieldKind::Uint32 | FieldKind::Fixed32 => Some("::buffa::json_helpers::uint32"),
FieldKind::Bool => Some("::buffa::json_helpers::proto_bool"),
FieldKind::String => Some("::buffa::json_helpers::proto_string"),
FieldKind::Bytes => Some("::buffa::json_helpers::bytes"),
FieldKind::Enum(_) => Some("::buffa::json_helpers::proto_enum"),
FieldKind::Group(_) | FieldKind::Message(_) | FieldKind::Unknown => None,
}
}
fn serde_deserialize_with(field: &Field) -> Option<&'static str> {
if field.label == Some(FieldLabel::Repeated) && serde_with_module(field).is_none() {
Some("::buffa::json_helpers::null_as_default")
} else {
None
}
}
fn comment_attrs(comments: &CommentSet) -> Vec<TokenStream> {
comments
.leading_detached
.iter()
.chain(comments.leading.iter())
.flat_map(|comment| comment.lines())
.map(str::trim_end)
.filter(|line| !line.is_empty())
.map(|line| {
quote! {
#[doc = #line]
}
})
.collect()
}
fn service_module_name(service_name: &str) -> String {
format!(
"{}{REST_MODULE_SUFFIX}",
escape_mod_ident(&service_name.to_snake_case())
)
}
fn path_extractor_param(
shape: &RequestShape,
resolver: &TypeResolver<'_>,
options: &CodegenOptions,
) -> CodegenResult<Option<TokenStream>> {
match shape.path_fields.as_slice() {
[] => Ok(None),
[field] => {
let binding = path_field_binding(field, resolver, options)?;
let field_type = parse_type(field.rust_type.as_str(), "path extractor type")?;
Ok(Some(quote! {
Path(#binding): Path<#field_type>
}))
}
fields => {
let bindings = fields
.iter()
.map(|field| path_field_binding(field, resolver, options))
.collect::<CodegenResult<Vec<_>>>()?;
let field_types = fields
.iter()
.map(|field| parse_type(field.rust_type.as_str(), "path extractor type"))
.collect::<CodegenResult<Vec<_>>>()?;
let tuple_type = parse_quoted_type(quote! {
(#(#field_types),*)
})?;
Ok(Some(quote! {
Path((#(#bindings),*)): Path<#tuple_type>
}))
}
}
}
fn request_reconstruction_tokens(
shape: &RequestShape,
request_value: &Ident,
query_value: &Ident,
body_value: &Ident,
resolver: &TypeResolver<'_>,
options: &CodegenOptions,
) -> CodegenResult<TokenStream> {
let request_type = parse_type(shape.request_type.as_str(), "method request type")?;
match &shape.reconstruction {
RequestReconstruction::Empty => Ok(quote! {
let #request_value = #request_type::default();
}),
RequestReconstruction::VerbatimBody => Ok(quote! {
let #request_value = #body_value;
}),
RequestReconstruction::VerbatimQuery => Ok(quote! {
let #request_value = #query_value;
}),
RequestReconstruction::FromParts { fields } => {
let assignments = fields
.iter()
.map(|assignment| {
let field_ident = parse_ident(
&make_field_ident(assignment.field.as_ref()).to_string(),
"request field",
)?;
let value = field_source_expr(
assignment.field.as_ref(),
assignment.source,
shape,
query_value,
body_value,
resolver,
options,
)?;
Ok(quote! {
#field_ident: #value
})
})
.collect::<CodegenResult<Vec<_>>>()?;
Ok(quote! {
let #request_value = #request_type {
#(#assignments,)*
..::core::default::Default::default()
};
})
}
}
}
fn field_source_expr(
field_name: &str,
source: FieldSource,
shape: &RequestShape,
query_value: &Ident,
body_value: &Ident,
resolver: &TypeResolver<'_>,
options: &CodegenOptions,
) -> CodegenResult<TokenStream> {
match source {
FieldSource::Path => {
let binding = parse_ident(
resolver.value_ident(field_name, options).as_ref(),
"path field binding",
)?;
Ok(quote! { #binding })
}
FieldSource::Body => part_field_expr(field_name, shape.body_shape.as_ref(), body_value),
FieldSource::Query => part_field_expr(field_name, shape.query_shape.as_ref(), query_value),
}
}
fn part_field_expr(
field_name: &str,
part: Option<&RequestPartShape>,
binding: &Ident,
) -> CodegenResult<TokenStream> {
match part {
Some(RequestPartShape::GeneratedDto { .. }) => {
let field_ident = parse_ident(
&make_field_ident(field_name).to_string(),
"generated DTO field",
)?;
Ok(quote! { #binding.#field_ident })
}
Some(RequestPartShape::ExistingMessage { field, .. })
if field.field.name.as_ref() == field_name =>
{
Ok(quote! { #binding })
}
Some(RequestPartShape::VerbatimRequest { .. }) => {
let field_ident = parse_ident(
&make_field_ident(field_name).to_string(),
"request part field",
)?;
Ok(quote! { #binding.#field_ident })
}
_ => Err(UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("request field {field_name} was not found in planned request part"),
)),
}
}
fn path_field_binding(
field: &ShapeField,
resolver: &TypeResolver<'_>,
options: &CodegenOptions,
) -> CodegenResult<Ident> {
parse_ident(
resolver
.value_ident(field.field.name.as_ref(), options)
.as_ref(),
"path field binding",
)
}
fn part_type(part: &RequestPartShape) -> CodegenResult<Type> {
let rust_type = match part {
RequestPartShape::VerbatimRequest { rust_type }
| RequestPartShape::ExistingMessage { rust_type, .. }
| RequestPartShape::GeneratedDto { rust_type, .. } => rust_type,
};
parse_type(rust_type.as_str(), "request part type")
}
fn parse_ident(value: &str, context: &'static str) -> CodegenResult<Ident> {
syn::parse_str(value).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("{context} identifier {value:?} did not parse: {err}"),
)
})
}
fn parse_path(value: &str, context: &'static str) -> CodegenResult<Path> {
syn::parse_str(value).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("{context} path {value:?} did not parse: {err}"),
)
})
}
fn parse_type(value: &str, context: &'static str) -> CodegenResult<Type> {
syn::parse_str(value).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("{context} {value:?} did not parse: {err}"),
)
})
}
fn parse_quoted_type(tokens: TokenStream) -> CodegenResult<Type> {
syn::parse2(tokens.clone()).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("generated Rust type {tokens} did not parse: {err}"),
)
})
}
fn format_tokens(tokens: TokenStream) -> CodegenResult<String> {
let file = syn::parse2(tokens.clone()).map_err(|err| {
UniError::from_kind_context(
CodegenErrKind::InvalidDescriptor,
format!("generated Rust file did not parse: {err}\n{tokens}"),
)
})?;
Ok(prettyplease::unparse(&file))
}
#[cfg(test)]
mod tests {
use std::fs;
use std::process::Command;
use std::time::{SystemTime, UNIX_EPOCH};
use buffa::encoding::{Tag, WireType};
use buffa::{MessageField, UnknownField, UnknownFieldData};
use connectrpc_codegen::codegen::descriptor::{
DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
MethodDescriptorProto, MethodOptions, ServiceDescriptorProto,
field_descriptor_proto::{Label, Type},
};
use crate::{CodeGeneratorRequest, try_generate_rest};
#[test]
fn generates_expected_rust_skeleton() {
let content = generated_content();
assert_eq!(
content,
r#"// @generated by connect2rest.
// This file is generated from test/v1/test.proto. Do not edit by hand.
/// Generated axum handlers and router.
pub mod test_service_rest {
#![allow(unused_imports)]
use std::sync::Arc;
use axum::extract::{Path, Query, State};
use axum::Json;
use http::{Extensions, HeaderMap};
#[derive(Clone, Debug, Default, serde::Deserialize)]
#[serde(default)]
pub struct TestRequestQuery__ {
#[serde(
rename = "testType",
alias = "test_type",
with = "::buffa::json_helpers::proto_enum"
)]
pub test_type: ::buffa::EnumValue<crate::proto::test::v1::Tester>,
#[serde(rename = "tester")]
pub tester: crate::proto::test::v1::Nested,
}
pub async fn get_one<S>(
State(service__): State<Arc<S>>,
Path(data__): Path<::std::string::String>,
Query(query__): Query<TestRequestQuery__>,
headers__: HeaderMap,
extensions__: Extensions,
) -> http::Response<axum::body::Body>
where
S: crate::connect::test::v1::TestService + Send + Sync + 'static,
{
let ctx__ = ::connect2axum::request_context(headers__, extensions__);
let request__ = crate::proto::test::v1::TestRequest {
data: data__,
test_type: query__.test_type,
tester: query__.tester,
..::core::default::Default::default()
};
let request__ = match ::connect2axum::owned_view::<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>(&request__) {
Ok(request__) => request__,
Err(err) => return ::connect2axum::error_response(err),
};
::connect2axum::service_response::<
crate::proto::test::v1::TestResponse,
_,
>(service__.get_one(ctx__, request__).await)
}
pub async fn do_test<S>(
State(service__): State<Arc<S>>,
Path(
(data__, test_type__),
): Path<
(::std::string::String, ::buffa::EnumValue<crate::proto::test::v1::Tester>),
>,
headers__: HeaderMap,
extensions__: Extensions,
Json(body__): Json<crate::proto::test::v1::Nested>,
) -> http::Response<axum::body::Body>
where
S: crate::connect::test::v1::TestService + Send + Sync + 'static,
{
let ctx__ = ::connect2axum::request_context(headers__, extensions__);
let request__ = crate::proto::test::v1::TestRequest {
data: data__,
test_type: test_type__,
tester: body__,
..::core::default::Default::default()
};
let request__ = match ::connect2axum::owned_view::<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>(&request__) {
Ok(request__) => request__,
Err(err) => return ::connect2axum::error_response(err),
};
::connect2axum::service_response::<
crate::proto::test::v1::TestResponse,
_,
>(service__.do_test(ctx__, request__).await)
}
pub async fn patch_all<S>(
State(service__): State<Arc<S>>,
headers__: HeaderMap,
extensions__: Extensions,
Json(body__): Json<crate::proto::test::v1::TestRequest>,
) -> http::Response<axum::body::Body>
where
S: crate::connect::test::v1::TestService + Send + Sync + 'static,
{
let ctx__ = ::connect2axum::request_context(headers__, extensions__);
let request__ = body__;
let request__ = match ::connect2axum::owned_view::<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>(&request__) {
Ok(request__) => request__,
Err(err) => return ::connect2axum::error_response(err),
};
::connect2axum::service_response::<
crate::proto::test::v1::TestResponse,
_,
>(service__.patch_all(ctx__, request__).await)
}
pub async fn ping<S>(
State(service__): State<Arc<S>>,
headers__: HeaderMap,
extensions__: Extensions,
) -> http::Response<axum::body::Body>
where
S: crate::connect::test::v1::TestService + Send + Sync + 'static,
{
let ctx__ = ::connect2axum::request_context(headers__, extensions__);
let request__ = crate::proto::test::v1::EmptyRequest::default();
let request__ = match ::connect2axum::owned_view::<
crate::proto::test::v1::__buffa::view::EmptyRequestView<'static>,
>(&request__) {
Ok(request__) => request__,
Err(err) => return ::connect2axum::error_response(err),
};
::connect2axum::service_response::<
crate::proto::test::v1::TestResponse,
_,
>(service__.ping(ctx__, request__).await)
}
pub fn make_router<S>(service: Arc<S>) -> axum::Router
where
S: crate::connect::test::v1::TestService + Send + Sync + 'static,
{
axum::Router::new()
.route("/test/{data}", axum::routing::get(get_one::<S>))
.route("/test/{data}/testing/{test_type}", axum::routing::post(do_test::<S>))
.route("/test", axum::routing::patch(patch_all::<S>))
.route("/ping", axum::routing::get(ping::<S>))
.with_state(service)
}
}
"#
);
}
#[test]
fn generated_handlers_compile_and_call_service() {
let content = generated_content();
let temp_dir = temp_crate_dir();
fs::create_dir_all(temp_dir.join("src")).expect("temp src dir");
fs::write(temp_dir.join("Cargo.toml"), temp_cargo_toml()).expect("temp Cargo.toml");
fs::write(temp_dir.join("src/lib.rs"), temp_lib_rs(&content)).expect("temp lib.rs");
let output = Command::new("cargo")
.args(["test", "--quiet"])
.current_dir(&temp_dir)
.output()
.expect("cargo test runs");
if !output.status.success() {
panic!(
"generated crate tests failed\nstdout:\n{}\nstderr:\n{}\nsource:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
content
);
}
fs::remove_dir_all(temp_dir).expect("remove temp crate");
}
#[test]
fn generates_streaming_handlers_for_all_stream_shapes() {
let content = streaming_content();
assert!(content.contains("pub async fn server_stream"));
assert!(content.contains("pub async fn client_stream"));
assert!(content.contains("pub async fn bidi_stream"));
assert!(content.contains("::connect2axum::JsonLines<crate::proto::test::v1::TestRequest>"));
assert!(content.contains("::connect2axum::ndjson_request_stream::<"));
assert!(content.contains("::connect2axum::stream_response::<"));
assert!(content.contains("\"application/x-ndjson\""));
assert!(content.contains(".route(\"/test:server-stream\""));
assert!(content.contains(".route(\"/test:client-stream\""));
assert!(content.contains(".route(\"/test:bidi-stream\""));
}
#[test]
fn rejects_duplicate_rest_routes() {
let mut file = test_file();
file.service[0].method.push(method(
"PatchAgain",
".test.v1.TestRequest",
http_rule(6, "/test", Some("*")),
));
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/test.proto".into()],
proto_file: vec![file],
..Default::default()
};
let err = try_generate_rest(&request).unwrap_err();
assert!(err.to_string().contains("duplicate route"));
assert!(err.to_string().contains("PATCH /test"));
}
#[test]
fn rejects_duplicate_rest_handler_idents() {
let mut file = test_file();
file.service[0].method.push(method(
"Get_One",
".test.v1.TestRequest",
http_rule(2, "/test-again/{data}", None),
));
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/test.proto".into()],
proto_file: vec![file],
..Default::default()
};
let err = try_generate_rest(&request).unwrap_err();
assert!(
err.to_string()
.contains("duplicate generated Rust identifier")
);
assert!(err.to_string().contains("get_one"));
}
#[test]
fn generates_modules_for_multiple_services_in_one_file() {
let mut file = test_file();
file.service.push(ServiceDescriptorProto {
name: Some("OtherService".into()),
method: vec![method(
"Ping",
".test.v1.EmptyRequest",
http_rule(2, "/other/ping", None),
)],
..Default::default()
});
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/test.proto".into()],
proto_file: vec![file],
..Default::default()
};
let response = try_generate_rest(&request).unwrap();
let content = response.file[0].content.as_deref().unwrap();
assert!(content.contains("pub mod test_service_rest"));
assert!(content.contains("pub mod other_service_rest"));
}
#[test]
fn resolves_cross_package_inputs_and_outputs() {
let request = CodeGeneratorRequest {
file_to_generate: vec!["api/v1/api.proto".into()],
proto_file: vec![model_file(), cross_package_service_file()],
..Default::default()
};
let response = try_generate_rest(&request).unwrap();
let content = response.file[0].content.as_deref().unwrap();
assert!(content.contains("crate::proto::model::v1::SharedRequest"));
assert!(content.contains("crate::proto::model::v1::__buffa::view::SharedRequestView"));
assert!(content.contains("crate::proto::model::v1::SharedResponse"));
}
#[test]
fn resolves_same_package_types_split_across_files() {
let request = CodeGeneratorRequest {
file_to_generate: vec!["split/v1/service.proto".into()],
proto_file: vec![split_messages_file(), split_service_file()],
..Default::default()
};
let response = try_generate_rest(&request).unwrap();
let content = response.file[0].content.as_deref().unwrap();
assert!(content.contains("crate::proto::split::v1::SplitRequest"));
assert!(content.contains("crate::proto::split::v1::__buffa::view::SplitRequestView"));
assert!(content.contains("crate::proto::split::v1::SplitResponse"));
}
#[test]
fn resolves_google_protobuf_empty() {
let request = CodeGeneratorRequest {
file_to_generate: vec!["test/v1/empty.proto".into()],
proto_file: vec![google_empty_file(), empty_service_file()],
..Default::default()
};
let response = try_generate_rest(&request).unwrap();
let content = response.file[0].content.as_deref().unwrap();
assert!(content.contains("::buffa_types::google::protobuf::Empty"));
assert!(content.contains("::buffa_types::google::protobuf::__buffa::view::EmptyView"));
}
fn generated_content() -> String {
let request = request();
let response = try_generate_rest(&request).unwrap();
assert_eq!(response.file.len(), 1);
response.file[0]
.content
.as_ref()
.expect("generated file has content")
.clone()
}
fn streaming_content() -> String {
let request = streaming_request();
let response = try_generate_rest(&request).unwrap();
assert_eq!(response.file.len(), 1);
response.file[0]
.content
.as_ref()
.expect("generated file has content")
.clone()
}
fn request() -> CodeGeneratorRequest {
CodeGeneratorRequest {
file_to_generate: vec!["test/v1/test.proto".into()],
proto_file: vec![test_file()],
..Default::default()
}
}
fn streaming_request() -> CodeGeneratorRequest {
CodeGeneratorRequest {
file_to_generate: vec!["test/v1/streaming.proto".into()],
proto_file: vec![streaming_test_file()],
..Default::default()
}
}
fn test_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("test/v1/test.proto".into()),
package: Some("test.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("TestRequest".into()),
field: vec![
field("data", 1, Type::TYPE_STRING, None),
field("test_type", 2, Type::TYPE_ENUM, Some(".test.v1.Tester")),
field("tester", 8, Type::TYPE_MESSAGE, Some(".test.v1.Nested")),
],
..Default::default()
},
DescriptorProto {
name: Some("Nested".into()),
..Default::default()
},
DescriptorProto {
name: Some("EmptyRequest".into()),
..Default::default()
},
DescriptorProto {
name: Some("TestResponse".into()),
..Default::default()
},
],
enum_type: vec![EnumDescriptorProto {
name: Some("Tester".into()),
..Default::default()
}],
service: vec![ServiceDescriptorProto {
name: Some("TestService".into()),
method: vec![
method(
"GetOne",
".test.v1.TestRequest",
http_rule(2, "/test/{data}", None),
),
method(
"DoTest",
".test.v1.TestRequest",
http_rule(4, "/test/{data}/testing/{test_type}", Some("tester")),
),
method(
"PatchAll",
".test.v1.TestRequest",
http_rule(6, "/test", Some("*")),
),
method("Ping", ".test.v1.EmptyRequest", http_rule(2, "/ping", None)),
],
..Default::default()
}],
..Default::default()
}
}
fn streaming_test_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("test/v1/streaming.proto".into()),
package: Some("test.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("TestRequest".into()),
field: vec![field("data", 1, Type::TYPE_STRING, None)],
..Default::default()
},
DescriptorProto {
name: Some("TestResponse".into()),
field: vec![field("message", 1, Type::TYPE_STRING, None)],
..Default::default()
},
],
service: vec![ServiceDescriptorProto {
name: Some("TestService".into()),
method: vec![
streaming_method(
"ServerStream",
".test.v1.TestRequest",
http_rule(4, "/test:server-stream", Some("*")),
false,
true,
),
streaming_method(
"ClientStream",
".test.v1.TestRequest",
http_rule(4, "/test:client-stream", Some("*")),
true,
false,
),
streaming_method(
"BidiStream",
".test.v1.TestRequest",
http_rule(4, "/test:bidi-stream", Some("*")),
true,
true,
),
],
..Default::default()
}],
..Default::default()
}
}
fn method(
name: &str,
input_type: &str,
options: MessageField<MethodOptions>,
) -> MethodDescriptorProto {
streaming_method(name, input_type, options, false, false)
}
fn streaming_method(
name: &str,
input_type: &str,
options: MessageField<MethodOptions>,
client_streaming: bool,
server_streaming: bool,
) -> MethodDescriptorProto {
MethodDescriptorProto {
name: Some(name.into()),
input_type: Some(input_type.into()),
output_type: Some(".test.v1.TestResponse".into()),
client_streaming: Some(client_streaming),
server_streaming: Some(server_streaming),
options,
..Default::default()
}
}
fn method_with_output(
name: &str,
input_type: &str,
output_type: &str,
options: MessageField<MethodOptions>,
) -> MethodDescriptorProto {
MethodDescriptorProto {
name: Some(name.into()),
input_type: Some(input_type.into()),
output_type: Some(output_type.into()),
options,
..Default::default()
}
}
fn model_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("model/v1/model.proto".into()),
package: Some("model.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("SharedRequest".into()),
field: vec![field("name", 1, Type::TYPE_STRING, None)],
..Default::default()
},
DescriptorProto {
name: Some("SharedResponse".into()),
field: vec![field("message", 1, Type::TYPE_STRING, None)],
..Default::default()
},
],
..Default::default()
}
}
fn cross_package_service_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("api/v1/api.proto".into()),
package: Some("api.v1".into()),
service: vec![ServiceDescriptorProto {
name: Some("ApiService".into()),
method: vec![method_with_output(
"Share",
".model.v1.SharedRequest",
".model.v1.SharedResponse",
http_rule(4, "/share", Some("*")),
)],
..Default::default()
}],
..Default::default()
}
}
fn split_messages_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("split/v1/messages.proto".into()),
package: Some("split.v1".into()),
message_type: vec![
DescriptorProto {
name: Some("SplitRequest".into()),
field: vec![field("name", 1, Type::TYPE_STRING, None)],
..Default::default()
},
DescriptorProto {
name: Some("SplitResponse".into()),
field: vec![field("message", 1, Type::TYPE_STRING, None)],
..Default::default()
},
],
..Default::default()
}
}
fn split_service_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("split/v1/service.proto".into()),
package: Some("split.v1".into()),
service: vec![ServiceDescriptorProto {
name: Some("SplitService".into()),
method: vec![method_with_output(
"Send",
".split.v1.SplitRequest",
".split.v1.SplitResponse",
http_rule(4, "/split/send", Some("*")),
)],
..Default::default()
}],
..Default::default()
}
}
fn google_empty_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("google/protobuf/empty.proto".into()),
package: Some("google.protobuf".into()),
message_type: vec![DescriptorProto {
name: Some("Empty".into()),
..Default::default()
}],
..Default::default()
}
}
fn empty_service_file() -> FileDescriptorProto {
FileDescriptorProto {
name: Some("test/v1/empty.proto".into()),
package: Some("test.v1".into()),
service: vec![ServiceDescriptorProto {
name: Some("EmptyService".into()),
method: vec![method_with_output(
"Ping",
".google.protobuf.Empty",
".google.protobuf.Empty",
http_rule(2, "/empty/ping", None),
)],
..Default::default()
}],
..Default::default()
}
}
fn field(name: &str, number: i32, kind: Type, type_name: Option<&str>) -> FieldDescriptorProto {
let json_name = if name == "test_type" {
"testType"
} else {
name
};
FieldDescriptorProto {
name: Some(name.into()),
number: Some(number),
label: Some(Label::LABEL_OPTIONAL),
r#type: Some(kind),
type_name: type_name.map(str::to_owned),
json_name: Some(json_name.into()),
..Default::default()
}
}
fn http_rule(verb_field: u32, path: &str, body: Option<&str>) -> MessageField<MethodOptions> {
let mut rule = Vec::new();
Tag::new(verb_field, WireType::LengthDelimited).encode(&mut rule);
buffa::types::encode_string(path, &mut rule);
if let Some(body) = body {
Tag::new(7, WireType::LengthDelimited).encode(&mut rule);
buffa::types::encode_string(body, &mut rule);
}
let mut options = MethodOptions::default();
options.__buffa_unknown_fields.push(UnknownField {
number: 72_295_728,
data: UnknownFieldData::LengthDelimited(rule),
});
MessageField::some(options)
}
fn temp_crate_dir() -> std::path::PathBuf {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time")
.as_nanos();
std::env::temp_dir().join(format!(
"connect2rest-codegen-{}-{nanos}",
std::process::id()
))
}
fn temp_cargo_toml() -> &'static str {
r#"[package]
name = "connect2rest-generated-check"
version = "0.0.0"
edition = "2024"
[dependencies]
serde = { version = "1", features = ["derive"] }
"#
}
fn temp_lib_rs(generated: &str) -> String {
let mut source = r#"extern crate self as axum;
extern crate self as buffa;
extern crate self as connect2axum;
extern crate self as connectrpc;
extern crate self as http;
pub struct Json<T>(pub T);
pub mod body {
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Body(pub &'static str);
impl Body {
pub fn from<T>(_value: T) -> Self {
Self("body")
}
}
}
pub mod extract {
pub struct Path<T>(pub T);
pub struct Query<T>(pub T);
pub struct State<T>(pub T);
}
pub mod routing {
pub struct MethodRouter;
pub fn delete<H>(_handler: H) -> MethodRouter {
MethodRouter
}
pub fn get<H>(_handler: H) -> MethodRouter {
MethodRouter
}
pub fn patch<H>(_handler: H) -> MethodRouter {
MethodRouter
}
pub fn post<H>(_handler: H) -> MethodRouter {
MethodRouter
}
pub fn put<H>(_handler: H) -> MethodRouter {
MethodRouter
}
}
pub struct Router;
impl Router {
pub fn new() -> Self {
Self
}
pub fn route(self, _path: &str, _method: routing::MethodRouter) -> Self {
self
}
pub fn with_state<S>(self, _state: S) -> Self {
self
}
}
pub struct HeaderMap;
pub struct Extensions;
pub struct RequestContext;
pub fn request_context(_headers: HeaderMap, _extensions: Extensions) -> RequestContext {
RequestContext
}
pub fn owned_view<V>(
_message: &<V as view::MessageView<'static>>::Owned,
) -> Result<view::OwnedView<V>, ConnectError>
where
V: view::MessageView<'static>,
V::Owned: Clone,
{
Ok(view::OwnedView(_message.clone(), ::std::marker::PhantomData))
}
pub fn error_response(_err: ConnectError) -> Response<body::Body> {
Response {
status: 400,
body: body::Body("error"),
}
}
pub fn service_response<M, B>(_response: ServiceResult<B>) -> Response<body::Body>
where
B: Encodable<M>,
{
match _response {
Ok(_) => Response {
status: 200,
body: body::Body("ok"),
},
Err(err) => error_response(err),
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Response<T> {
pub status: u16,
pub body: T,
}
impl<T> Response<T> {
pub fn new(body: T) -> Self {
Self { status: 200, body }
}
}
#[derive(Clone, Debug)]
pub struct ConnectError;
pub type ServiceResult<B> = Result<Response<B>, ConnectError>;
pub trait Encodable<M> {}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum EnumValue<E> {
Known(E),
Unknown(i32),
}
impl<E: Default> Default for EnumValue<E> {
fn default() -> Self {
Self::Known(E::default())
}
}
impl<E> From<E> for EnumValue<E> {
fn from(value: E) -> Self {
Self::Known(value)
}
}
impl<'de, E> serde::Deserialize<'de> for EnumValue<E>
where
E: serde::Deserialize<'de>,
{
fn deserialize<D>(d: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
E::deserialize(d).map(Self::Known)
}
}
pub mod json_helpers {
pub mod proto_string {
pub fn deserialize<'de, D>(d: D) -> Result<::std::string::String, D::Error>
where
D: serde::Deserializer<'de>,
{
serde::Deserialize::deserialize(d)
}
}
pub mod proto_enum {
pub fn deserialize<'de, E, D>(d: D) -> Result<crate::EnumValue<E>, D::Error>
where
E: serde::Deserialize<'de>,
D: serde::Deserializer<'de>,
{
E::deserialize(d).map(crate::EnumValue::Known)
}
}
}
pub mod view {
pub trait MessageView<'a> {
type Owned;
}
pub struct OwnedView<V: MessageView<'static>>(
pub V::Owned,
pub ::std::marker::PhantomData<V>,
);
}
pub mod connect {
pub mod test {
pub mod v1 {
pub trait TestService: Send + Sync + 'static {
fn get_one<'a>(
&'a self,
_ctx: crate::RequestContext,
_request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>,
) -> impl ::std::future::Future<
Output = crate::ServiceResult<crate::proto::test::v1::TestResponse>,
> + Send + 'a;
fn do_test<'a>(
&'a self,
_ctx: crate::RequestContext,
_request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>,
) -> impl ::std::future::Future<
Output = crate::ServiceResult<crate::proto::test::v1::TestResponse>,
> + Send + 'a;
fn patch_all<'a>(
&'a self,
_ctx: crate::RequestContext,
_request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>,
) -> impl ::std::future::Future<
Output = crate::ServiceResult<crate::proto::test::v1::TestResponse>,
> + Send + 'a;
fn ping<'a>(
&'a self,
_ctx: crate::RequestContext,
_request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::EmptyRequestView<'static>,
>,
) -> impl ::std::future::Future<
Output = crate::ServiceResult<crate::proto::test::v1::TestResponse>,
> + Send + 'a;
}
}
}
}
pub mod proto {
pub mod test {
pub mod v1 {
#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Deserialize)]
pub struct TestRequest {
pub data: ::std::string::String,
pub test_type: ::buffa::EnumValue<Tester>,
pub tester: Nested,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Deserialize)]
pub struct Nested {
pub data: ::std::string::String,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Deserialize)]
pub enum Tester {
#[default]
Unknown,
Known,
}
#[derive(Clone, Debug, Default, Eq, PartialEq, serde::Deserialize)]
pub struct EmptyRequest;
#[derive(Clone, Debug, Default, Eq, PartialEq)]
pub struct TestResponse;
impl crate::Encodable<TestResponse> for TestResponse {}
pub mod __buffa {
pub mod view {
pub struct TestRequestView<'a>(::std::marker::PhantomData<&'a ()>);
pub struct EmptyRequestView<'a>(::std::marker::PhantomData<&'a ()>);
}
}
impl crate::view::MessageView<'static> for __buffa::view::TestRequestView<'static> {
type Owned = TestRequest;
}
impl crate::view::MessageView<'static> for __buffa::view::EmptyRequestView<'static> {
type Owned = EmptyRequest;
}
}
}
}
"#
.to_owned();
source.push_str(generated);
source.push_str(
r#"
#[cfg(test)]
mod generated_handler_tests {
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use crate::extract::{Path, Query, State};
use crate::proto::test::v1::{EmptyRequest, Nested, TestRequest, TestResponse, Tester};
use crate::{Extensions, HeaderMap, Json, Response, ServiceResult};
#[derive(Clone, Debug, Eq, PartialEq)]
enum Call {
Test(TestRequest),
Empty(EmptyRequest),
}
#[derive(Default)]
struct RecordingService {
calls: Mutex<Vec<Call>>,
fail: bool,
}
impl RecordingService {
fn failing() -> Self {
Self {
calls: Mutex::new(Vec::new()),
fail: true,
}
}
fn calls(&self) -> Vec<Call> {
self.calls.lock().expect("calls lock").clone()
}
fn respond(&self, call: Call) -> ServiceResult<TestResponse> {
if self.fail {
Err(crate::ConnectError)
} else {
self.calls.lock().expect("calls lock").push(call);
Ok(Response::new(TestResponse))
}
}
}
impl crate::connect::test::v1::TestService for RecordingService {
fn get_one<'a>(
&'a self,
_ctx: crate::RequestContext,
request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>,
) -> impl Future<Output = ServiceResult<TestResponse>> + Send + 'a {
async move { self.respond(Call::Test(request.0)) }
}
fn do_test<'a>(
&'a self,
_ctx: crate::RequestContext,
request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>,
) -> impl Future<Output = ServiceResult<TestResponse>> + Send + 'a {
async move { self.respond(Call::Test(request.0)) }
}
fn patch_all<'a>(
&'a self,
_ctx: crate::RequestContext,
request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::TestRequestView<'static>,
>,
) -> impl Future<Output = ServiceResult<TestResponse>> + Send + 'a {
async move { self.respond(Call::Test(request.0)) }
}
fn ping<'a>(
&'a self,
_ctx: crate::RequestContext,
request: crate::view::OwnedView<
crate::proto::test::v1::__buffa::view::EmptyRequestView<'static>,
>,
) -> impl Future<Output = ServiceResult<TestResponse>> + Send + 'a {
async move { self.respond(Call::Empty(request.0)) }
}
}
#[test]
fn successful_unary_request_reconstructs_path_and_query() {
let service = Arc::new(RecordingService::default());
let response = block_on(crate::test_service_rest::get_one(
State(service.clone()),
Path("alpha".to_owned()),
Query(crate::test_service_rest::TestRequestQuery__ {
test_type: Tester::Known.into(),
tester: Nested {
data: "query".to_owned(),
},
}),
HeaderMap,
Extensions,
));
assert_eq!(response.status, 200);
assert_eq!(
service.calls(),
vec![Call::Test(TestRequest {
data: "alpha".to_owned(),
test_type: Tester::Known.into(),
tester: Nested {
data: "query".to_owned(),
},
})]
);
}
#[test]
fn service_error_response_is_mapped() {
let service = Arc::new(RecordingService::failing());
let response = block_on(crate::test_service_rest::get_one(
State(service),
Path("alpha".to_owned()),
Query(crate::test_service_rest::TestRequestQuery__ {
test_type: Tester::Known.into(),
tester: Nested {
data: "query".to_owned(),
},
}),
HeaderMap,
Extensions,
));
assert_eq!(response.status, 400);
assert_eq!(response.body, crate::body::Body("error"));
}
#[test]
fn path_and_body_request_is_reconstructed() {
let service = Arc::new(RecordingService::default());
let response = block_on(crate::test_service_rest::do_test(
State(service.clone()),
Path(("alpha".to_owned(), Tester::Known.into())),
HeaderMap,
Extensions,
Json(Nested {
data: "body".to_owned(),
}),
));
assert_eq!(response.status, 200);
assert_eq!(
service.calls(),
vec![Call::Test(TestRequest {
data: "alpha".to_owned(),
test_type: Tester::Known.into(),
tester: Nested {
data: "body".to_owned(),
},
})]
);
}
#[test]
fn body_star_request_is_forwarded_verbatim() {
let service = Arc::new(RecordingService::default());
let request = TestRequest {
data: "whole".to_owned(),
test_type: Tester::Known.into(),
tester: Nested {
data: "body".to_owned(),
},
};
let response = block_on(crate::test_service_rest::patch_all(
State(service.clone()),
HeaderMap,
Extensions,
Json(request.clone()),
));
assert_eq!(response.status, 200);
assert_eq!(service.calls(), vec![Call::Test(request)]);
}
#[test]
fn empty_request_uses_default_owned_message() {
let service = Arc::new(RecordingService::default());
let response = block_on(crate::test_service_rest::ping(
State(service.clone()),
HeaderMap,
Extensions,
));
assert_eq!(response.status, 200);
assert_eq!(service.calls(), vec![Call::Empty(EmptyRequest)]);
}
fn block_on<F: Future>(future: F) -> F::Output {
let mut future = Pin::from(Box::new(future));
let waker = std::task::Waker::noop();
let mut context = Context::from_waker(waker);
loop {
match future.as_mut().poll(&mut context) {
Poll::Ready(output) => return output,
Poll::Pending => std::thread::yield_now(),
}
}
}
}
"#,
);
source
}
}