1use crate::header::PROXY_AUTHENTICATE;
6use crate::headers::authorization::Authority;
7use crate::headers::{HeaderMapExt, ProxyAuthorization, authorization::Credentials};
8use crate::{Request, Response, StatusCode};
9use rama_core::{Context, Layer, Service};
10use rama_net::user::UserId;
11use rama_utils::macros::define_inner_service_accessors;
12use std::fmt;
13use std::marker::PhantomData;
14
15pub struct ProxyAuthLayer<A, C, L = ()> {
19 proxy_auth: A,
20 allow_anonymous: bool,
21 _phantom: PhantomData<fn(C, L) -> ()>,
22}
23
24impl<A: fmt::Debug, C, L> fmt::Debug for ProxyAuthLayer<A, C, L> {
25 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26 f.debug_struct("ProxyAuthLayer")
27 .field("proxy_auth", &self.proxy_auth)
28 .field(
29 "_phantom",
30 &format_args!("{}", std::any::type_name::<fn(C, L) -> ()>()),
31 )
32 .finish()
33 }
34}
35
36impl<A: Clone, C, L> Clone for ProxyAuthLayer<A, C, L> {
37 fn clone(&self) -> Self {
38 Self {
39 proxy_auth: self.proxy_auth.clone(),
40 allow_anonymous: self.allow_anonymous,
41 _phantom: PhantomData,
42 }
43 }
44}
45
46impl<A, C> ProxyAuthLayer<A, C, ()> {
47 pub const fn new(proxy_auth: A) -> Self {
49 ProxyAuthLayer {
50 proxy_auth,
51 allow_anonymous: false,
52 _phantom: PhantomData,
53 }
54 }
55
56 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
58 self.allow_anonymous = allow_anonymous;
59 self
60 }
61
62 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
64 self.allow_anonymous = allow_anonymous;
65 self
66 }
67}
68
69impl<A, C, L> ProxyAuthLayer<A, C, L> {
70 pub fn with_labels<L2>(self) -> ProxyAuthLayer<A, C, L2> {
80 ProxyAuthLayer {
81 proxy_auth: self.proxy_auth,
82 allow_anonymous: self.allow_anonymous,
83 _phantom: PhantomData,
84 }
85 }
86}
87
88impl<A, C, L, S> Layer<S> for ProxyAuthLayer<A, C, L>
89where
90 A: Authority<C, L> + Clone,
91 C: Credentials + Clone + Send + Sync + 'static,
92{
93 type Service = ProxyAuthService<A, C, S, L>;
94
95 fn layer(&self, inner: S) -> Self::Service {
96 ProxyAuthService::new(self.proxy_auth.clone(), inner)
97 }
98
99 fn into_layer(self, inner: S) -> Self::Service {
100 ProxyAuthService::new(self.proxy_auth, inner)
101 }
102}
103
104pub struct ProxyAuthService<A, C, S, L = ()> {
112 proxy_auth: A,
113 allow_anonymous: bool,
114 inner: S,
115 _phantom: PhantomData<fn(C, L) -> ()>,
116}
117
118impl<A, C, S, L> ProxyAuthService<A, C, S, L> {
119 pub const fn new(proxy_auth: A, inner: S) -> Self {
121 Self {
122 proxy_auth,
123 allow_anonymous: false,
124 inner,
125 _phantom: PhantomData,
126 }
127 }
128
129 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
131 self.allow_anonymous = allow_anonymous;
132 self
133 }
134
135 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
137 self.allow_anonymous = allow_anonymous;
138 self
139 }
140
141 define_inner_service_accessors!();
142}
143
144impl<A: fmt::Debug, C, S: fmt::Debug, L> fmt::Debug for ProxyAuthService<A, C, S, L> {
145 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146 f.debug_struct("ProxyAuthService")
147 .field("proxy_auth", &self.proxy_auth)
148 .field("allow_anonymous", &self.allow_anonymous)
149 .field("inner", &self.inner)
150 .field(
151 "_phantom",
152 &format_args!("{}", std::any::type_name::<fn(C, L) -> ()>()),
153 )
154 .finish()
155 }
156}
157
158impl<A: Clone, C, S: Clone, L> Clone for ProxyAuthService<A, C, S, L> {
159 fn clone(&self) -> Self {
160 ProxyAuthService {
161 proxy_auth: self.proxy_auth.clone(),
162 allow_anonymous: self.allow_anonymous,
163 inner: self.inner.clone(),
164 _phantom: PhantomData,
165 }
166 }
167}
168
169impl<A, C, L, S, State, ReqBody, ResBody> Service<State, Request<ReqBody>>
170 for ProxyAuthService<A, C, S, L>
171where
172 A: Authority<C, L>,
173 C: Credentials + Clone + Send + Sync + 'static,
174 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
175 L: 'static,
176 ReqBody: Send + 'static,
177 ResBody: Default + Send + 'static,
178 State: Clone + Send + Sync + 'static,
179{
180 type Response = S::Response;
181 type Error = S::Error;
182
183 async fn serve(
184 &self,
185 mut ctx: Context<State>,
186 req: Request<ReqBody>,
187 ) -> Result<Self::Response, Self::Error> {
188 if let Some(credentials) = req
189 .headers()
190 .typed_get::<ProxyAuthorization<C>>()
191 .map(|h| h.0)
192 .or_else(|| ctx.get::<C>().cloned())
193 {
194 if let Some(ext) = self.proxy_auth.authorized(credentials).await {
195 ctx.extend(ext);
196 self.inner.serve(ctx, req).await
197 } else {
198 Ok(Response::builder()
199 .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
200 .header(PROXY_AUTHENTICATE, C::SCHEME)
201 .body(Default::default())
202 .unwrap())
203 }
204 } else if self.allow_anonymous {
205 ctx.insert(UserId::Anonymous);
206 self.inner.serve(ctx, req).await
207 } else {
208 Ok(Response::builder()
209 .status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
210 .header(PROXY_AUTHENTICATE, C::SCHEME)
211 .body(Default::default())
212 .unwrap())
213 }
214 }
215}