1#![forbid(unsafe_code)]
2
3use std::error::Error;
4use std::io;
5
6#[cfg(feature = "claims")]
7pub use jiff;
8
9use paseto_core::encodings::{Footer, Payload, WriteBytes};
10pub use paseto_core::validation::Validate;
11use serde_core::Serialize;
12use serde_core::de::DeserializeOwned;
13
14#[derive(Default)]
26pub struct Json<T>(pub T);
27
28struct Writer<W: WriteBytes>(W);
29impl<W: WriteBytes> io::Write for Writer<W> {
30 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
31 self.0.write(buf);
32 Ok(buf.len())
33 }
34
35 fn flush(&mut self) -> io::Result<()> {
36 Ok(())
37 }
38}
39
40impl<T: Serialize + DeserializeOwned> Footer for Json<T> {
41 fn encode(&self, writer: impl WriteBytes) -> Result<(), Box<dyn Error + Send + Sync>> {
42 serde_json::to_writer(Writer(writer), &self.0).map_err(|err| Box::new(err) as _)
43 }
44
45 fn decode(footer: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
46 match footer {
47 [] => Err("missing footer".into()),
48 x => serde_json::from_slice(x).map(Self).map_err(|e| e.into()),
49 }
50 }
51}
52
53impl<M: Serialize + DeserializeOwned> Payload for Json<M> {
54 const SUFFIX: &'static str = "";
56
57 fn encode(self, writer: impl WriteBytes) -> Result<(), Box<dyn Error + Send + Sync>> {
58 serde_json::to_writer(Writer(writer), &self.0).map_err(|err| Box::new(err) as _)
59 }
60
61 fn decode(payload: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
62 serde_json::from_slice(payload)
63 .map_err(From::from)
64 .map(Self)
65 }
66}
67
68#[cfg(feature = "claims")]
69#[derive(Default, Clone, Debug)]
70pub struct RegisteredClaims {
71 pub iss: Option<String>,
72 pub sub: Option<String>,
73 pub aud: Option<String>,
74 pub exp: Option<jiff::Timestamp>,
75 pub nbf: Option<jiff::Timestamp>,
76 pub iat: Option<jiff::Timestamp>,
77 pub jti: Option<String>,
78}
79
80#[cfg(feature = "claims")]
81pub use claims_impls::{ForAudience, ForSubject, FromIssuer, HasExpiry, Time, TimeWithLeeway};
82
83#[cfg(feature = "claims")]
84mod claims_impls {
85 use core::fmt;
86 use std::error::Error;
87 use std::time::Duration;
88
89 use paseto_core::{PasetoError, validation::Validate};
90 use paseto_core::{encodings::Payload, pae::WriteBytes};
91 use serde_core::{
92 Deserialize, Deserializer, Serializer,
93 de::{MapAccess, Visitor},
94 ser::SerializeStruct,
95 };
96
97 use crate::RegisteredClaims;
98 use crate::Writer;
99
100 pub struct Time {
101 now: jiff::Timestamp,
102 }
103
104 impl Time {
105 pub fn valid_now() -> Self {
106 Self {
107 now: jiff::Timestamp::now(),
108 }
109 }
110
111 pub fn valid_at(now: jiff::Timestamp) -> Self {
112 Self { now }
113 }
114
115 pub fn with_leeway(self, leeway: Duration) -> TimeWithLeeway {
116 TimeWithLeeway {
117 now: self.now,
118 leeway,
119 }
120 }
121 }
122
123 impl Validate for Time {
124 type Claims = RegisteredClaims;
125
126 fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
127 if let Some(exp) = claims.exp
128 && exp < self.now
129 {
130 return Err(PasetoError::ClaimsError);
131 }
132
133 if let Some(nbf) = claims.nbf
134 && self.now < nbf
135 {
136 return Err(PasetoError::ClaimsError);
137 }
138
139 Ok(())
140 }
141 }
142
143 pub struct TimeWithLeeway {
144 now: jiff::Timestamp,
145 leeway: std::time::Duration,
146 }
147
148 impl Validate for TimeWithLeeway {
149 type Claims = RegisteredClaims;
150
151 fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
152 if let Some(exp) = claims.exp
153 && exp < self.now - self.leeway
154 {
155 return Err(PasetoError::ClaimsError);
156 }
157
158 if let Some(nbf) = claims.nbf
159 && self.now + self.leeway < nbf
160 {
161 return Err(PasetoError::ClaimsError);
162 }
163
164 Ok(())
165 }
166 }
167
168 pub struct ForSubject<T: AsRef<str>>(pub T);
169
170 impl<T: AsRef<str>> Validate for ForSubject<T> {
171 type Claims = RegisteredClaims;
172
173 fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
174 if claims.sub.as_deref() != Some(self.0.as_ref()) {
175 return Err(PasetoError::ClaimsError);
176 }
177
178 Ok(())
179 }
180 }
181
182 pub struct FromIssuer<T: AsRef<str>>(pub T);
183
184 impl<T: AsRef<str>> Validate for FromIssuer<T> {
185 type Claims = RegisteredClaims;
186
187 fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
188 if claims.iss.as_deref() != Some(self.0.as_ref()) {
189 return Err(PasetoError::ClaimsError);
190 }
191
192 Ok(())
193 }
194 }
195
196 pub struct ForAudience<T: AsRef<str>>(pub T);
197
198 impl<T: AsRef<str>> Validate for ForAudience<T> {
199 type Claims = RegisteredClaims;
200
201 fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
202 if claims.aud.as_deref() != Some(self.0.as_ref()) {
203 return Err(PasetoError::ClaimsError);
204 }
205
206 Ok(())
207 }
208 }
209
210 pub struct HasExpiry;
211
212 impl Validate for HasExpiry {
213 type Claims = RegisteredClaims;
214 fn validate(&self, claims: &Self::Claims) -> Result<(), PasetoError> {
215 if claims.exp.is_none() {
216 return Err(PasetoError::ClaimsError);
217 }
218 Ok(())
219 }
220 }
221
222 impl RegisteredClaims {
223 pub fn new(now: jiff::Timestamp, exp: Duration) -> Self {
224 Self {
225 iss: None,
226 sub: None,
227 aud: None,
228 exp: Some(now + exp),
229 nbf: Some(now),
230 iat: Some(now),
231 jti: None,
232 }
233 }
234
235 pub fn now(exp: Duration) -> Self {
236 Self::new(jiff::Timestamp::now(), exp)
237 }
238
239 pub fn from_issuer(mut self, iss: String) -> Self {
240 self.iss = Some(iss);
241 self
242 }
243
244 pub fn for_audience(mut self, aud: String) -> Self {
245 self.aud = Some(aud);
246 self
247 }
248
249 pub fn for_subject(mut self, sub: String) -> Self {
250 self.sub = Some(sub);
251 self
252 }
253
254 pub fn with_token_id(mut self, jti: String) -> Self {
255 self.jti = Some(jti);
256 self
257 }
258 }
259
260 impl Payload for RegisteredClaims {
261 const SUFFIX: &'static str = "";
263
264 fn encode(self, writer: impl WriteBytes) -> Result<(), Box<dyn Error + Send + Sync>> {
265 serde_json::to_writer(Writer(writer), &self).map_err(|err| Box::new(err) as _)
266 }
267
268 fn decode(payload: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
269 serde_json::from_slice(payload).map_err(From::from)
270 }
271 }
272
273 impl serde_core::Serialize for RegisteredClaims {
274 fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
275 where
276 S: Serializer,
277 {
278 let mut state = s.serialize_struct("RegisteredClaims", 7)?;
279 if let Some(x) = &self.iss {
280 state.serialize_field("iss", &x)?;
281 }
282 if let Some(x) = &self.sub {
283 state.serialize_field("sub", &x)?;
284 }
285 if let Some(x) = &self.aud {
286 state.serialize_field("aud", &x)?;
287 }
288 if let Some(x) = &self.exp {
289 state.serialize_field("exp", &x)?;
290 }
291 if let Some(x) = &self.nbf {
292 state.serialize_field("nbf", &x)?;
293 }
294 if let Some(x) = &self.iat {
295 state.serialize_field("iat", &x)?;
296 }
297 if let Some(x) = &self.jti {
298 state.serialize_field("jti", &x)?;
299 }
300 state.end()
301 }
302 }
303
304 enum RegisteredClaimField {
305 Issuer,
306 Subject,
307 Audience,
308 Expiration,
309 NotBefore,
310 IssuedAt,
311 TokenIdentifier,
312 Ignored,
313 }
314
315 struct RegisteredClaimFieldVisitor;
316
317 impl<'de> Visitor<'de> for RegisteredClaimFieldVisitor {
318 type Value = RegisteredClaimField;
319 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
320 f.write_str("field identifier")
321 }
322
323 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
324 where
325 E: serde_core::de::Error,
326 {
327 self.visit_bytes(v.as_bytes())
328 }
329
330 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
331 where
332 E: serde_core::de::Error,
333 {
334 match v {
335 b"iss" => Ok(RegisteredClaimField::Issuer),
336 b"sub" => Ok(RegisteredClaimField::Subject),
337 b"aud" => Ok(RegisteredClaimField::Audience),
338 b"exp" => Ok(RegisteredClaimField::Expiration),
339 b"nbf" => Ok(RegisteredClaimField::NotBefore),
340 b"iat" => Ok(RegisteredClaimField::IssuedAt),
341 b"jti" => Ok(RegisteredClaimField::TokenIdentifier),
342 _ => Ok(RegisteredClaimField::Ignored),
343 }
344 }
345 }
346
347 impl<'de> Deserialize<'de> for RegisteredClaimField {
348 #[inline]
349 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
350 d.deserialize_identifier(RegisteredClaimFieldVisitor)
351 }
352 }
353
354 struct RegisteredClaimsVisitor;
355
356 impl<'de> Visitor<'de> for RegisteredClaimsVisitor {
357 type Value = RegisteredClaims;
358 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
359 f.write_str("struct RegisteredClaims")
360 }
361
362 #[inline]
363 fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
364 let mut issuer: Option<String> = None;
365 let mut subject: Option<String> = None;
366 let mut audience: Option<String> = None;
367 let mut expiration: Option<jiff::Timestamp> = None;
368 let mut not_before: Option<jiff::Timestamp> = None;
369 let mut issued_at: Option<jiff::Timestamp> = None;
370 let mut token_identifier: Option<String> = None;
371 while let Some(key) = map.next_key()? {
372 match key {
373 RegisteredClaimField::Issuer => {
374 if issuer.is_some() {
375 return Err(serde_core::de::Error::duplicate_field("iss"));
376 }
377 issuer = map.next_value()?;
378 }
379 RegisteredClaimField::Subject => {
380 if subject.is_some() {
381 return Err(serde_core::de::Error::duplicate_field("sub"));
382 }
383 subject = map.next_value()?;
384 }
385 RegisteredClaimField::Audience => {
386 if audience.is_some() {
387 return Err(serde_core::de::Error::duplicate_field("aud"));
388 }
389 audience = map.next_value()?;
390 }
391 RegisteredClaimField::Expiration => {
392 if expiration.is_some() {
393 return Err(serde_core::de::Error::duplicate_field("exp"));
394 }
395 expiration = map.next_value()?;
396 }
397 RegisteredClaimField::NotBefore => {
398 if not_before.is_some() {
399 return Err(serde_core::de::Error::duplicate_field("nbf"));
400 }
401 not_before = map.next_value()?;
402 }
403 RegisteredClaimField::IssuedAt => {
404 if issued_at.is_some() {
405 return Err(serde_core::de::Error::duplicate_field("iat"));
406 }
407 issued_at = map.next_value()?;
408 }
409 RegisteredClaimField::TokenIdentifier => {
410 if token_identifier.is_some() {
411 return Err(serde_core::de::Error::duplicate_field("jti"));
412 }
413 token_identifier = map.next_value()?;
414 }
415 _ => {
416 map.next_value::<serde_core::de::IgnoredAny>()?;
417 }
418 }
419 }
420 Ok(RegisteredClaims {
421 iss: issuer,
422 sub: subject,
423 aud: audience,
424 exp: expiration,
425 nbf: not_before,
426 iat: issued_at,
427 jti: token_identifier,
428 })
429 }
430 }
431
432 impl<'de> Deserialize<'de> for RegisteredClaims {
433 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
434 const FIELDS: &[&str] = &["iss", "sub", "aud", "exp", "nbf", "iat", "jti"];
435 d.deserialize_struct("RegisteredClaims", FIELDS, RegisteredClaimsVisitor)
436 }
437 }
438}