use bytes::Bytes;
use http::request::Parts;
use http::StatusCode;
use typeway_protobuf::ProtoMessage;
use crate::body::{body_from_bytes, body_from_string, BoxBody};
use crate::extract::FromRequest;
use crate::response::IntoResponse;
pub struct Proto<T>(pub T);
impl<T: std::fmt::Debug> std::fmt::Debug for Proto<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Proto").field(&self.0).finish()
}
}
impl<T: Clone> Clone for Proto<T> {
fn clone(&self) -> Self {
Proto(self.0.clone())
}
}
impl<T> std::ops::Deref for Proto<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> std::ops::DerefMut for Proto<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
fn is_binary_protobuf(content_type: &str) -> bool {
content_type == "application/grpc"
|| content_type == "application/grpc+proto"
|| content_type == "application/protobuf"
|| content_type == "application/x-protobuf"
}
impl<T: ProtoMessage + 'static> FromRequest for Proto<T> {
type Error = (StatusCode, String);
async fn from_request(parts: &Parts, body: Bytes) -> Result<Self, Self::Error> {
let content_type = parts
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/json");
if is_binary_protobuf(content_type) {
T::typeway_decode_bytes(body).map(Proto).map_err(|e| {
(
StatusCode::BAD_REQUEST,
format!("protobuf decode error: {e}"),
)
})
} else {
serde_json::from_slice(&body)
.map(Proto)
.map_err(|e| (StatusCode::BAD_REQUEST, format!("invalid JSON: {e}")))
}
}
}
impl<T: ProtoMessage> IntoResponse for Proto<T> {
fn into_response(self) -> http::Response<BoxBody> {
match serde_json::to_vec(&self.0) {
Ok(bytes) => {
let body = body_from_bytes(Bytes::from(bytes));
let mut res = http::Response::new(body);
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
res
}
Err(e) => {
let body = body_from_string(format!("serialization error: {e}"));
let mut res = http::Response::new(body);
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use typeway_protobuf::{TypewayDecode, TypewayDecodeError, TypewayEncode};
#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
struct TestUser {
id: u32,
name: String,
}
impl TypewayEncode for TestUser {
fn encoded_len(&self) -> usize {
let mut len = 0;
if self.id != 0 {
len += 1 + typeway_protobuf::tw_varint_len(self.id as u64);
}
if !self.name.is_empty() {
len +=
1 + typeway_protobuf::tw_varint_len(self.name.len() as u64) + self.name.len();
}
len
}
fn encode_to(&self, buf: &mut Vec<u8>) {
if self.id != 0 {
typeway_protobuf::tw_encode_tag(buf, 1, 0);
typeway_protobuf::tw_encode_varint(buf, self.id as u64);
}
if !self.name.is_empty() {
typeway_protobuf::tw_encode_tag(buf, 2, 2);
typeway_protobuf::tw_encode_varint(buf, self.name.len() as u64);
buf.extend_from_slice(self.name.as_bytes());
}
}
}
impl TypewayDecode for TestUser {
fn typeway_decode(bytes: &[u8]) -> Result<Self, TypewayDecodeError> {
let mut user = TestUser::default();
let mut offset = 0;
while offset < bytes.len() {
let (tag_wire, consumed) = typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
offset += consumed;
let field_number = (tag_wire >> 3) as u32;
let wire_type = (tag_wire & 0x07) as u8;
match field_number {
1 => {
let (val, consumed) = typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
offset += consumed;
user.id = val as u32;
}
2 => {
let (len, consumed) = typeway_protobuf::tw_decode_varint(&bytes[offset..])?;
offset += consumed;
let len = len as usize;
user.name = String::from_utf8(bytes[offset..offset + len].to_vec())
.map_err(|_| TypewayDecodeError::InvalidUtf8("name"))?;
offset += len;
}
_ => {
offset +=
typeway_protobuf::tw_skip_wire_value(&bytes[offset..], wire_type)?;
}
}
}
Ok(user)
}
}
fn make_parts(content_type: &str) -> Parts {
let (parts, _) = http::Request::builder()
.header("content-type", content_type)
.body(())
.unwrap()
.into_parts();
parts
}
#[tokio::test]
async fn proto_from_json() {
let parts = make_parts("application/json");
let body = Bytes::from(r#"{"id":42,"name":"Alice"}"#);
let Proto(user) = Proto::<TestUser>::from_request(&parts, body).await.unwrap();
assert_eq!(user.id, 42);
assert_eq!(user.name, "Alice");
}
#[tokio::test]
async fn proto_from_grpc_json() {
let parts = make_parts("application/grpc+json");
let body = Bytes::from(r#"{"id":1,"name":"Bob"}"#);
let Proto(user) = Proto::<TestUser>::from_request(&parts, body).await.unwrap();
assert_eq!(user.id, 1);
}
#[tokio::test]
async fn proto_from_binary() {
let user = TestUser {
id: 42,
name: "Alice".into(),
};
let binary = user.encode_to_vec();
let parts = make_parts("application/grpc+proto");
let Proto(decoded) = Proto::<TestUser>::from_request(&parts, Bytes::from(binary))
.await
.unwrap();
assert_eq!(decoded.id, 42);
assert_eq!(decoded.name, "Alice");
}
#[tokio::test]
async fn proto_from_application_grpc() {
let user = TestUser {
id: 7,
name: "Charlie".into(),
};
let binary = user.encode_to_vec();
let parts = make_parts("application/grpc");
let Proto(decoded) = Proto::<TestUser>::from_request(&parts, Bytes::from(binary))
.await
.unwrap();
assert_eq!(decoded.id, 7);
}
#[tokio::test]
async fn proto_from_application_protobuf() {
let user = TestUser {
id: 99,
name: "Dave".into(),
};
let binary = user.encode_to_vec();
let parts = make_parts("application/protobuf");
let Proto(decoded) = Proto::<TestUser>::from_request(&parts, Bytes::from(binary))
.await
.unwrap();
assert_eq!(decoded.id, 99);
}
#[tokio::test]
async fn proto_invalid_json() {
let parts = make_parts("application/json");
let result = Proto::<TestUser>::from_request(&parts, Bytes::from("not json")).await;
assert!(result.is_err());
let (status, _) = result.unwrap_err();
assert_eq!(status, StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn proto_invalid_binary() {
let parts = make_parts("application/grpc+proto");
let result =
Proto::<TestUser>::from_request(&parts, Bytes::from_static(&[0xFF, 0xFF])).await;
assert!(result.is_err());
}
#[test]
fn proto_into_response_json() {
let response = Proto(TestUser {
id: 42,
name: "Alice".into(),
})
.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
}
#[test]
fn proto_tuple_destructure() {
let Proto(user) = Proto(TestUser {
id: 1,
name: "test".into(),
});
assert_eq!(user.id, 1);
}
#[test]
fn proto_deref() {
let p = Proto(TestUser {
id: 1,
name: "test".into(),
});
assert_eq!(p.id, 1);
assert_eq!(p.name, "test");
}
#[test]
fn proto_debug_and_clone() {
let p = Proto(TestUser {
id: 1,
name: "test".into(),
});
let p2 = p.clone();
assert_eq!(format!("{p:?}"), format!("{p2:?}"));
}
}