actix_jwt_session/
extractors.rs1use std::sync::Arc;
4
5use crate::*;
6
7#[derive(Clone, Debug)]
8pub struct Extractors<ClaimsType: Claims + std::fmt::Debug> {
9 pub(crate) jwt_extractors: Vec<Arc<dyn SessionExtractor<ClaimsType>>>,
10 pub(crate) refresh_extractors: Vec<Arc<dyn SessionExtractor<RefreshToken>>>,
11}
12
13impl<ClaimsType: Claims> Default for Extractors<ClaimsType> {
14 fn default() -> Self {
15 Self {
16 jwt_extractors: vec![],
17 refresh_extractors: vec![],
18 }
19 }
20}
21
22impl<ClaimsType: Claims> Extractors<ClaimsType> {
23 #[must_use]
25 pub fn with_refresh_cookie(mut self, name: &'static str) -> Self {
26 self.refresh_extractors
27 .push(Arc::new(CookieExtractor::<RefreshToken>::new(name)));
28 self
29 }
30
31 #[must_use]
33 pub fn with_refresh_header(mut self, name: &'static str) -> Self {
34 self.refresh_extractors
35 .push(Arc::new(HeaderExtractor::<RefreshToken>::new(name)));
36 self
37 }
38
39 #[must_use]
41 pub fn with_jwt_cookie(mut self, name: &'static str) -> Self {
42 self.jwt_extractors
43 .push(Arc::new(CookieExtractor::<ClaimsType>::new(name)));
44 self
45 }
46
47 #[must_use]
49 pub fn with_jwt_header(mut self, name: &'static str) -> Self {
50 self.jwt_extractors
51 .push(Arc::new(HeaderExtractor::<ClaimsType>::new(name)));
52 self
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Hash)]
57pub enum ExtractorKind {
58 Header,
59 Cookie,
60 UrlParam,
61 ReqBody,
62}
63
64#[async_trait(?Send)]
78pub trait SessionExtractor<ClaimsType: Claims>: Send + Sync + 'static + std::fmt::Debug {
79 async fn extract_claims(
126 &self,
127 req: &mut ServiceRequest,
128 jwt_encoding_key: Arc<EncodingKey>,
129 jwt_decoding_key: Arc<DecodingKey>,
130 algorithm: Algorithm,
131 storage: SessionStorage,
132 ) -> Result<(), Error> {
133 let Some(as_str) = self.extract_token_text(req).await else {
134 return Ok(());
135 };
136 let decoded_claims = self.decode(&as_str, jwt_decoding_key, algorithm)?;
137 self.validate(&decoded_claims, storage).await?;
138 req.extensions_mut().insert(Authenticated {
139 claims: Arc::new(decoded_claims),
140 jwt_encoding_key,
141 algorithm,
142 });
143 Ok(())
144 }
145
146 fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)>;
147
148 fn decode(
150 &self,
151 value: &str,
152 jwt_decoding_key: Arc<DecodingKey>,
153 algorithm: Algorithm,
154 ) -> Result<ClaimsType, Error> {
155 let mut validation = Validation::new(algorithm);
156 validation.validate_exp = false;
157 validation.validate_nbf = false;
158 validation.leeway = 0;
159 validation.required_spec_claims.clear();
160
161 decode::<ClaimsType>(value, &jwt_decoding_key, &validation)
162 .map_err(|e| {
163 #[cfg(feature = "use-tracing")]
164 tracing::debug!("Failed to decode claims: {e:?}. {e}");
165 Error::CantDecode
166 })
167 .map(|t| t.claims)
168 }
169
170 async fn validate(&self, claims: &ClaimsType, storage: SessionStorage) -> Result<(), Error> {
175 let stored = storage
176 .clone()
177 .find_jwt::<ClaimsType>(claims.jti())
178 .await
179 .map_err(|e| {
180 #[cfg(feature = "use-tracing")]
181 tracing::debug!(
182 "Failed to load {} from storage: {e:?}",
183 std::any::type_name::<ClaimsType>()
184 );
185 Error::LoadError
186 })?;
187
188 if &stored != claims {
189 #[cfg(feature = "use-tracing")]
190 tracing::debug!("{claims:?} != {stored:?}");
191 Err(Error::DontMatch)
192 } else {
193 Ok(())
194 }
195 }
196
197 async fn extract_token_text<'req>(
203 &self,
204 req: &'req mut ServiceRequest,
205 ) -> Option<Cow<'req, str>>;
206}
207
208#[derive(Debug)]
215pub struct CookieExtractor<ClaimsType> {
216 __ty: PhantomData<ClaimsType>,
217 cookie_name: &'static str,
218}
219
220impl<ClaimsType: Claims> CookieExtractor<ClaimsType> {
221 pub fn new(cookie_name: &'static str) -> Self {
224 Self {
225 __ty: Default::default(),
226 cookie_name,
227 }
228 }
229}
230
231#[async_trait(?Send)]
232impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for CookieExtractor<ClaimsType> {
233 async fn extract_token_text<'req>(
234 &self,
235 req: &'req mut ServiceRequest,
236 ) -> Option<Cow<'req, str>> {
237 req.cookie(self.cookie_name)
238 .map(|c| c.value().to_string().into())
239 }
240 fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)> {
241 Some((ExtractorKind::Cookie, self.cookie_name.into()))
242 }
243}
244
245#[derive(Debug)]
253pub struct HeaderExtractor<ClaimsType> {
254 __ty: PhantomData<ClaimsType>,
255 header_name: &'static str,
256}
257
258impl<ClaimsType: Claims> HeaderExtractor<ClaimsType> {
259 pub fn new(header_name: &'static str) -> Self {
262 Self {
263 __ty: Default::default(),
264 header_name,
265 }
266 }
267}
268
269#[async_trait(?Send)]
270impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for HeaderExtractor<ClaimsType> {
271 async fn extract_token_text<'req>(
272 &self,
273 req: &'req mut ServiceRequest,
274 ) -> Option<Cow<'req, str>> {
275 req.headers()
276 .get(self.header_name)
277 .and_then(|h| h.to_str().ok())
278 .map(|h| h.to_owned().into())
279 }
280 fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)> {
281 Some((ExtractorKind::Header, self.header_name.into()))
282 }
283}
284
285#[derive(Debug)]
286pub struct JsonExtractor<ClaimsType> {
287 __ty: PhantomData<ClaimsType>,
288 path: &'static [&'static str],
290}
291
292impl<ClaimsType: Claims> JsonExtractor<ClaimsType> {
293 pub fn new(path: &'static [&'static str]) -> Self {
308 Self {
309 __ty: Default::default(),
310 path,
311 }
312 }
313}
314
315#[async_trait(?Send)]
316impl<ClaimsType: Claims> SessionExtractor<ClaimsType> for JsonExtractor<ClaimsType> {
317 async fn extract_token_text<'req>(
318 &self,
319 req: &'req mut ServiceRequest,
320 ) -> Option<Cow<'req, str>> {
321 let Ok(v) = req
322 .extract::<actix_web::web::Json<serde_json::Value>>()
323 .await
324 else {
325 return None;
326 };
327 let json = v.into_inner();
328 let mut v = &json;
329
330 let len = self.path.len();
331 self.path.iter().enumerate().fold(None, |_, (idx, piece)| {
332 if idx + 1 == len {
333 v.as_object()?
334 .get(*piece)?
335 .as_str()
336 .map(ToOwned::to_owned)
337 .map(Into::into)
338 } else {
339 v = v.as_object()?.get(*piece)?;
340 None
341 }
342 })
343 }
344 fn extractor_key(&self) -> Option<(ExtractorKind, Cow<'static, str>)> {
345 Some((ExtractorKind::ReqBody, self.path.join(".").into()))
346 }
347}