use std::future::Future;
use axum::extract::{Request, State};
use axum::Router;
use http_body_util::BodyExt;
use crate::{malformed, serialize_proto_message, server, Result, TwirpErrorResponse};
pub struct TwirpRouterBuilder<S> {
service: S,
fqn: &'static str,
router: Router<S>,
}
impl<S> TwirpRouterBuilder<S>
where
S: Clone + Send + Sync + 'static,
{
pub fn new(fqn: &'static str, service: S) -> Self {
TwirpRouterBuilder {
service,
fqn,
router: Router::new(),
}
}
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
where
F: Fn(S, http::Request<Req>) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<http::Response<Res>, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Res: prost::Message + Default + serde::Serialize,
{
TwirpRouterBuilder {
service: self.service,
fqn: self.fqn,
router: self.router.route(
url,
axum::routing::post(move |State(api): State<S>, req: Request| async move {
server::handle_request(api, req, f).await
}),
),
}
}
pub fn build(self) -> axum::Router {
Router::new().nest(
self.fqn,
self.router
.fallback(crate::server::not_found_handler)
.with_state(self.service),
)
}
}
pub async fn decode_request<I>(mut req: reqwest::Request) -> Result<http::Request<I>>
where
I: prost::Message + Default,
{
let url = req.url().clone();
let headers = req.headers().clone();
let body = std::mem::take(req.body_mut())
.ok_or_else(|| malformed("failed to read the request body"))?
.collect()
.await?
.to_bytes();
let data = I::decode(body).map_err(|e| malformed(format!("failed to decode request: {e}")))?;
let mut req = Request::builder().method("POST").uri(url.to_string());
req.headers_mut()
.expect("failed to get headers")
.extend(headers);
let req = req
.body(data)
.map_err(|e| malformed(format!("failed to build the request: {e}")))?;
Ok(req)
}
pub fn encode_response<O>(resp: http::Response<O>) -> Result<reqwest::Response>
where
O: prost::Message + Default,
{
let mut resp = resp.map(serialize_proto_message);
resp.headers_mut()
.insert("Content-Type", "application/protobuf".try_into()?);
Ok(reqwest::Response::from(resp))
}