composable_tower_http/extension/modify/
service.rs1use std::{
2 future::Future,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use http::Request;
9use tower::Service;
10
11use crate::{extract::SealedExtracted, modify::Modifier};
12
13#[derive(Debug, Clone)]
14pub struct ModificationService<S, M, T> {
15 service: S,
16 modifier: M,
17 _phantom: PhantomData<T>,
18}
19
20impl<S, M, T> ModificationService<S, M, T> {
21 pub const fn new(service: S, modifier: M) -> Self {
22 Self {
23 service,
24 modifier,
25 _phantom: PhantomData,
26 }
27 }
28}
29
30impl<S, M, B, T> Service<Request<B>> for ModificationService<S, M, T>
31where
32 M: Modifier<T> + Clone + Send + 'static,
33 T: Send + Sync + 'static,
34 S: Service<Request<B>> + Clone + Send + 'static,
35 S::Future: Send,
36 S::Response: From<ModificationError<M::Error>>,
37 B: Send + 'static,
38{
39 type Response = S::Response;
40 type Error = S::Error;
41 type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;
42
43 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
44 self.service.poll_ready(cx)
45 }
46
47 fn call(&mut self, mut request: Request<B>) -> Self::Future {
48 let mut service = self.service.clone();
49 let modifier = self.modifier.clone();
50
51 Box::pin(async move {
52 match request.extensions_mut().remove::<SealedExtracted<T>>() {
53 Some(SealedExtracted(extracted)) => {
54 match modifier.modify(extracted).await {
55 Ok(modified) => request.extensions_mut().insert(SealedExtracted(modified)),
56 Err(err) => return Ok(From::from(ModificationError::Modification(err))),
57 };
58 }
59 None => return Ok(From::from(ModificationError::Extract)),
60 }
61
62 service.call(request).await
63 })
64 }
65}
66
67#[derive(Debug, thiserror::Error)]
68pub enum ModificationError<E> {
69 #[error("Extraction error")]
70 Extract,
71 #[error("Modification error: {0}")]
72 Modification(#[source] E),
73}
74
75#[cfg(feature = "axum")]
76mod axum {
77 use axum::response::{IntoResponse, Response};
78 use http::StatusCode;
79
80 use super::ModificationError;
81
82 impl<E> IntoResponse for ModificationError<E>
83 where
84 E: IntoResponse,
85 {
86 fn into_response(self) -> Response {
87 match self {
88 ModificationError::Extract => (StatusCode::INTERNAL_SERVER_ERROR).into_response(),
89 ModificationError::Modification(err) => err.into_response(),
90 }
91 }
92 }
93
94 impl<E> From<ModificationError<E>> for Response
95 where
96 E: IntoResponse,
97 {
98 fn from(value: ModificationError<E>) -> Self {
99 value.into_response()
100 }
101 }
102}