paseto_json/
lib.rs

1#![forbid(unsafe_code)]
2
3use std::error::Error;
4use std::io;
5
6#[cfg(feature = "claims")]
7pub use jiff;
8
9use paseto_core::encodings::{Footer, Payload, WriteBytes};
10pub use paseto_core::validation::Validate;
11use serde_core::Serialize;
12use serde_core::de::DeserializeOwned;
13
14/// `Json` is a type wrapper to implement `Footer` for all types that implement
15/// [`serde_core::Serialize`] and [`serde_core::Deserialize`]
16///
17/// When using a JSON footer, you should be aware of the risks of parsing user provided JSON.
18/// <https://github.com/paseto-standard/paseto-spec/blob/master/docs/02-Implementation-Guide/01-Payload-Processing.md#storing-json-in-the-footer>.
19///
20/// Currently, this uses [`serde_json`] internally, which by default offers a stack-overflow protection limit on parsing JSON.
21/// You should also parse into a known struct layout, and avoid arbitrary key-value mappings.
22///
23/// If you need stricter checks, you can make your own [`Footer`] encodings that give access to the bytes before
24/// the footer is decoded.
25#[derive(Default)]
26pub struct Json<T>(pub T);
27
28struct Writer<W: WriteBytes>(W);
29impl<W: WriteBytes> io::Write for Writer<W> {
30    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
31        self.0.write(buf);
32        Ok(buf.len())
33    }
34
35    fn flush(&mut self) -> io::Result<()> {
36        Ok(())
37    }
38}
39
40impl<T: Serialize + DeserializeOwned> Footer for Json<T> {
41    fn encode(&self, writer: impl WriteBytes) -> Result<(), Box<dyn Error + Send + Sync>> {
42        serde_json::to_writer(Writer(writer), &self.0).map_err(|err| Box::new(err) as _)
43    }
44
45    fn decode(footer: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
46        match footer {
47            [] => Err("missing footer".into()),
48            x => serde_json::from_slice(x).map(Self).map_err(|e| e.into()),
49        }
50    }
51}
52
53impl<M: Serialize + DeserializeOwned> Payload for Json<M> {
54    /// JSON is the standard payload and requires no version suffix
55    const SUFFIX: &'static str = "";
56
57    fn encode(self, writer: impl WriteBytes) -> Result<(), Box<dyn Error + Send + Sync>> {
58        serde_json::to_writer(Writer(writer), &self.0).map_err(|err| Box::new(err) as _)
59    }
60
61    fn decode(payload: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
62        serde_json::from_slice(payload)
63            .map_err(From::from)
64            .map(Self)
65    }
66}
67
68#[cfg(feature = "claims")]
69#[derive(Default, Clone, Debug)]
70pub struct RegisteredClaims {
71    pub iss: Option<String>,
72    pub sub: Option<String>,
73    pub aud: Option<String>,
74    pub exp: Option<jiff::Timestamp>,
75    pub nbf: Option<jiff::Timestamp>,
76    pub iat: Option<jiff::Timestamp>,
77    pub jti: Option<String>,
78}
79
80#[cfg(feature = "claims")]
81pub use claims_impls::{ForAudience, ForSubject, FromIssuer, HasExpiry, Time, TimeWithLeeway};
82
83#[cfg(feature = "claims")]
84mod claims_impls {
85    use core::fmt;
86    use std::error::Error;
87    use std::time::Duration;
88
89    use paseto_core::{PasetoError, validation::Validate};
90    use paseto_core::{encodings::Payload, pae::WriteBytes};
91    use serde_core::{
92        Deserialize, Deserializer, Serializer,
93        de::{MapAccess, Visitor},
94        ser::SerializeStruct,
95    };
96
97    use crate::RegisteredClaims;
98    use crate::Writer;
99
100    pub struct Time {
101        now: jiff::Timestamp,
102    }
103
104    impl Time {
105        pub fn valid_now() -> Self {
106            Self {
107                now: jiff::Timestamp::now(),
108            }
109        }
110
111        pub fn valid_at(now: jiff::Timestamp) -> Self {
112            Self { now }
113        }
114
115        pub fn with_leeway(self, leeway: Duration) -> TimeWithLeeway {
116            TimeWithLeeway {
117                now: self.now,
118                leeway,
119            }
120        }
121    }
122
123    impl Validate for Time {
124        type Claims = RegisteredClaims;
125
126        fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
127            if let Some(exp) = claims.exp
128                && exp < self.now
129            {
130                return Err(PasetoError::ClaimsError);
131            }
132
133            if let Some(nbf) = claims.nbf
134                && self.now < nbf
135            {
136                return Err(PasetoError::ClaimsError);
137            }
138
139            Ok(())
140        }
141    }
142
143    pub struct TimeWithLeeway {
144        now: jiff::Timestamp,
145        leeway: std::time::Duration,
146    }
147
148    impl Validate for TimeWithLeeway {
149        type Claims = RegisteredClaims;
150
151        fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
152            if let Some(exp) = claims.exp
153                && exp < self.now - self.leeway
154            {
155                return Err(PasetoError::ClaimsError);
156            }
157
158            if let Some(nbf) = claims.nbf
159                && self.now + self.leeway < nbf
160            {
161                return Err(PasetoError::ClaimsError);
162            }
163
164            Ok(())
165        }
166    }
167
168    pub struct ForSubject<T: AsRef<str>>(pub T);
169
170    impl<T: AsRef<str>> Validate for ForSubject<T> {
171        type Claims = RegisteredClaims;
172
173        fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
174            if claims.sub.as_deref() != Some(self.0.as_ref()) {
175                return Err(PasetoError::ClaimsError);
176            }
177
178            Ok(())
179        }
180    }
181
182    pub struct FromIssuer<T: AsRef<str>>(pub T);
183
184    impl<T: AsRef<str>> Validate for FromIssuer<T> {
185        type Claims = RegisteredClaims;
186
187        fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
188            if claims.iss.as_deref() != Some(self.0.as_ref()) {
189                return Err(PasetoError::ClaimsError);
190            }
191
192            Ok(())
193        }
194    }
195
196    pub struct ForAudience<T: AsRef<str>>(pub T);
197
198    impl<T: AsRef<str>> Validate for ForAudience<T> {
199        type Claims = RegisteredClaims;
200
201        fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
202            if claims.aud.as_deref() != Some(self.0.as_ref()) {
203                return Err(PasetoError::ClaimsError);
204            }
205
206            Ok(())
207        }
208    }
209
210    pub struct HasExpiry;
211
212    impl Validate for HasExpiry {
213        type Claims = RegisteredClaims;
214        fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
215            if claims.exp.is_none() {
216                return Err(PasetoError::ClaimsError);
217            }
218            Ok(())
219        }
220    }
221
222    impl RegisteredClaims {
223        pub fn new(now: jiff::Timestamp, exp: Duration) -> Self {
224            Self {
225                iss: None,
226                sub: None,
227                aud: None,
228                exp: Some(now + exp),
229                nbf: Some(now),
230                iat: Some(now),
231                jti: None,
232            }
233        }
234
235        pub fn now(exp: Duration) -> Self {
236            Self::new(jiff::Timestamp::now(), exp)
237        }
238
239        pub fn from_issuer(mut self, iss: String) -> Self {
240            self.iss = Some(iss);
241            self
242        }
243
244        pub fn for_audience(mut self, aud: String) -> Self {
245            self.aud = Some(aud);
246            self
247        }
248
249        pub fn for_subject(mut self, sub: String) -> Self {
250            self.sub = Some(sub);
251            self
252        }
253
254        pub fn with_token_id(mut self, jti: String) -> Self {
255            self.jti = Some(jti);
256            self
257        }
258    }
259
260    impl Payload for RegisteredClaims {
261        /// JSON is the standard payload and requires no version suffix
262        const SUFFIX: &'static str = "";
263
264        fn encode(self, writer: impl WriteBytes) -> Result<(), Box<dyn Error + Send + Sync>> {
265            serde_json::to_writer(Writer(writer), &self).map_err(|err| Box::new(err) as _)
266        }
267
268        fn decode(payload: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
269            serde_json::from_slice(payload).map_err(From::from)
270        }
271    }
272
273    impl serde_core::Serialize for RegisteredClaims {
274        fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
275        where
276            S: Serializer,
277        {
278            let mut state = s.serialize_struct("RegisteredClaims", 7)?;
279            if let Some(x) = &self.iss {
280                state.serialize_field("iss", &x)?;
281            }
282            if let Some(x) = &self.sub {
283                state.serialize_field("sub", &x)?;
284            }
285            if let Some(x) = &self.aud {
286                state.serialize_field("aud", &x)?;
287            }
288            if let Some(x) = &self.exp {
289                state.serialize_field("exp", &x)?;
290            }
291            if let Some(x) = &self.nbf {
292                state.serialize_field("nbf", &x)?;
293            }
294            if let Some(x) = &self.iat {
295                state.serialize_field("iat", &x)?;
296            }
297            if let Some(x) = &self.jti {
298                state.serialize_field("jti", &x)?;
299            }
300            state.end()
301        }
302    }
303
304    enum RegisteredClaimField {
305        Issuer,
306        Subject,
307        Audience,
308        Expiration,
309        NotBefore,
310        IssuedAt,
311        TokenIdentifier,
312        Ignored,
313    }
314
315    struct RegisteredClaimFieldVisitor;
316
317    impl<'de> Visitor<'de> for RegisteredClaimFieldVisitor {
318        type Value = RegisteredClaimField;
319        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
320            f.write_str("field identifier")
321        }
322
323        fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
324        where
325            E: serde_core::de::Error,
326        {
327            self.visit_bytes(v.as_bytes())
328        }
329
330        fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
331        where
332            E: serde_core::de::Error,
333        {
334            match v {
335                b"iss" => Ok(RegisteredClaimField::Issuer),
336                b"sub" => Ok(RegisteredClaimField::Subject),
337                b"aud" => Ok(RegisteredClaimField::Audience),
338                b"exp" => Ok(RegisteredClaimField::Expiration),
339                b"nbf" => Ok(RegisteredClaimField::NotBefore),
340                b"iat" => Ok(RegisteredClaimField::IssuedAt),
341                b"jti" => Ok(RegisteredClaimField::TokenIdentifier),
342                _ => Ok(RegisteredClaimField::Ignored),
343            }
344        }
345    }
346
347    impl<'de> Deserialize<'de> for RegisteredClaimField {
348        #[inline]
349        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
350            d.deserialize_identifier(RegisteredClaimFieldVisitor)
351        }
352    }
353
354    struct RegisteredClaimsVisitor;
355
356    impl<'de> Visitor<'de> for RegisteredClaimsVisitor {
357        type Value = RegisteredClaims;
358        fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
359            f.write_str("struct RegisteredClaims")
360        }
361
362        #[inline]
363        fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
364            let mut issuer: Option<String> = None;
365            let mut subject: Option<String> = None;
366            let mut audience: Option<String> = None;
367            let mut expiration: Option<jiff::Timestamp> = None;
368            let mut not_before: Option<jiff::Timestamp> = None;
369            let mut issued_at: Option<jiff::Timestamp> = None;
370            let mut token_identifier: Option<String> = None;
371            while let Some(key) = map.next_key()? {
372                match key {
373                    RegisteredClaimField::Issuer => {
374                        if issuer.is_some() {
375                            return Err(serde_core::de::Error::duplicate_field("iss"));
376                        }
377                        issuer = map.next_value()?;
378                    }
379                    RegisteredClaimField::Subject => {
380                        if subject.is_some() {
381                            return Err(serde_core::de::Error::duplicate_field("sub"));
382                        }
383                        subject = map.next_value()?;
384                    }
385                    RegisteredClaimField::Audience => {
386                        if audience.is_some() {
387                            return Err(serde_core::de::Error::duplicate_field("aud"));
388                        }
389                        audience = map.next_value()?;
390                    }
391                    RegisteredClaimField::Expiration => {
392                        if expiration.is_some() {
393                            return Err(serde_core::de::Error::duplicate_field("exp"));
394                        }
395                        expiration = map.next_value()?;
396                    }
397                    RegisteredClaimField::NotBefore => {
398                        if not_before.is_some() {
399                            return Err(serde_core::de::Error::duplicate_field("nbf"));
400                        }
401                        not_before = map.next_value()?;
402                    }
403                    RegisteredClaimField::IssuedAt => {
404                        if issued_at.is_some() {
405                            return Err(serde_core::de::Error::duplicate_field("iat"));
406                        }
407                        issued_at = map.next_value()?;
408                    }
409                    RegisteredClaimField::TokenIdentifier => {
410                        if token_identifier.is_some() {
411                            return Err(serde_core::de::Error::duplicate_field("jti"));
412                        }
413                        token_identifier = map.next_value()?;
414                    }
415                    _ => {
416                        map.next_value::<serde_core::de::IgnoredAny>()?;
417                    }
418                }
419            }
420            Ok(RegisteredClaims {
421                iss: issuer,
422                sub: subject,
423                aud: audience,
424                exp: expiration,
425                nbf: not_before,
426                iat: issued_at,
427                jti: token_identifier,
428            })
429        }
430    }
431
432    impl<'de> Deserialize<'de> for RegisteredClaims {
433        fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
434            const FIELDS: &[&str] = &["iss", "sub", "aud", "exp", "nbf", "iat", "jti"];
435            d.deserialize_struct("RegisteredClaims", FIELDS, RegisteredClaimsVisitor)
436        }
437    }
438}