use std::time::Duration;
use chrono::{DateTime, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use thiserror::Error;
use url::Url;
use crate::{Blob, Variant, urlsafe_decode, urlsafe_encode};
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Error)]
pub enum SignedUrlError {
#[error("invalid url: {0}")]
InvalidUrl(String),
#[error("signature verification failed")]
InvalidSignature,
#[error("invalid token payload")]
InvalidPayload,
#[error("signed url has expired")]
Expired,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SignedResource {
Blob { key: String },
Variant { key: String },
Redirect { location: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SignedClaims {
resource: SignedResource,
expires_at: i64,
}
#[derive(Debug, Clone)]
pub struct SignedUrlGenerator {
base_url: Url,
secret: Vec<u8>,
}
impl SignedUrlGenerator {
pub fn new(
base_url: impl AsRef<str>,
secret: impl Into<Vec<u8>>,
) -> Result<Self, SignedUrlError> {
Ok(Self {
base_url: Url::parse(base_url.as_ref())
.map_err(|error| SignedUrlError::InvalidUrl(error.to_string()))?,
secret: secret.into(),
})
}
pub fn blob_url(&self, blob: &Blob, expires_in: Duration) -> Result<Url, SignedUrlError> {
self.signed_url(
SignedResource::Blob {
key: blob.key().to_owned(),
},
expires_in,
)
}
pub fn variant_url(
&self,
variant: &Variant,
expires_in: Duration,
) -> Result<Url, SignedUrlError> {
self.signed_url(
SignedResource::Variant {
key: variant.key().to_owned(),
},
expires_in,
)
}
pub fn redirect_url(
&self,
location: &Url,
expires_in: Duration,
) -> Result<Url, SignedUrlError> {
self.signed_url(
SignedResource::Redirect {
location: location.to_string(),
},
expires_in,
)
}
pub fn verify(&self, url: &Url) -> Result<SignedResource, SignedUrlError> {
self.verify_at(url, Utc::now())
}
pub fn verify_at(
&self,
url: &Url,
now: DateTime<Utc>,
) -> Result<SignedResource, SignedUrlError> {
let token = url
.query_pairs()
.find(|(key, _)| key == "token")
.map(|(_, value)| value.into_owned())
.ok_or(SignedUrlError::InvalidPayload)?;
let (payload, signature) = token
.split_once('.')
.ok_or(SignedUrlError::InvalidPayload)?;
let payload_bytes = urlsafe_decode(payload).map_err(|_| SignedUrlError::InvalidPayload)?;
let signature_bytes =
urlsafe_decode(signature).map_err(|_| SignedUrlError::InvalidPayload)?;
let expected = sign_bytes(&self.secret, &payload_bytes)?;
if expected != signature_bytes {
return Err(SignedUrlError::InvalidSignature);
}
let claims: SignedClaims =
serde_json::from_slice(&payload_bytes).map_err(|_| SignedUrlError::InvalidPayload)?;
if now.timestamp() > claims.expires_at {
return Err(SignedUrlError::Expired);
}
Ok(claims.resource)
}
fn signed_url(
&self,
resource: SignedResource,
expires_in: Duration,
) -> Result<Url, SignedUrlError> {
let expires_at = Utc::now()
+ chrono::Duration::from_std(expires_in).map_err(|_| SignedUrlError::InvalidPayload)?;
let claims = SignedClaims {
resource,
expires_at: expires_at.timestamp(),
};
let payload = serde_json::to_vec(&claims).map_err(|_| SignedUrlError::InvalidPayload)?;
let signature = sign_bytes(&self.secret, &payload)?;
let token = format!("{}.{}", urlsafe_encode(&payload), urlsafe_encode(signature));
let mut url = self.base_url.clone();
url.query_pairs_mut().append_pair("token", &token);
Ok(url)
}
}
pub(crate) fn sign_payload(secret: &[u8], payload: &[u8]) -> Result<String, SignedUrlError> {
let signature = sign_bytes(secret, payload)?;
Ok(format!(
"{}.{}",
urlsafe_encode(payload),
urlsafe_encode(signature)
))
}
pub(crate) fn verify_payload(token: &str, secret: &[u8]) -> Result<Vec<u8>, SignedUrlError> {
let (payload, signature) = token
.split_once('.')
.ok_or(SignedUrlError::InvalidPayload)?;
let payload_bytes = urlsafe_decode(payload).map_err(|_| SignedUrlError::InvalidPayload)?;
let signature_bytes = urlsafe_decode(signature).map_err(|_| SignedUrlError::InvalidPayload)?;
let expected = sign_bytes(secret, &payload_bytes)?;
if expected != signature_bytes {
return Err(SignedUrlError::InvalidSignature);
}
Ok(payload_bytes)
}
fn sign_bytes(secret: &[u8], payload: &[u8]) -> Result<Vec<u8>, SignedUrlError> {
let mut mac = HmacSha256::new_from_slice(secret).map_err(|_| SignedUrlError::InvalidPayload)?;
mac.update(payload);
Ok(mac.finalize().into_bytes().to_vec())
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use super::*;
fn generator() -> SignedUrlGenerator {
SignedUrlGenerator::new("https://example.test/storage", b"secret".to_vec())
.expect("generator should build")
}
fn blob() -> Blob {
Blob::create(
Bytes::from_static(b"hello"),
"hello.txt",
None,
Default::default(),
"memory",
)
.expect("blob should build")
}
#[test]
fn test_blob_url_round_trip_verification() {
let generator = generator();
let url = generator
.blob_url(&blob(), Duration::from_secs(60))
.expect("url should build");
let resource = generator.verify(&url).expect("url should verify");
assert!(matches!(resource, SignedResource::Blob { .. }));
}
#[test]
fn test_variant_url_round_trip_verification() {
let generator = generator();
let variant = Variant::new(blob(), Default::default());
let url = generator
.variant_url(&variant, Duration::from_secs(60))
.expect("url should build");
let resource = generator.verify(&url).expect("url should verify");
assert!(matches!(resource, SignedResource::Variant { .. }));
}
#[test]
fn test_redirect_url_round_trip_verification() {
let generator = generator();
let location = Url::parse("https://cdn.example/files/1").expect("url should parse");
let url = generator
.redirect_url(&location, Duration::from_secs(60))
.expect("url should build");
let resource = generator.verify(&url).expect("url should verify");
assert_eq!(
resource,
SignedResource::Redirect {
location: location.to_string()
}
);
}
#[test]
fn test_verify_rejects_expired_url() {
let generator = generator();
let url = generator
.blob_url(&blob(), Duration::from_secs(1))
.expect("url should build");
let future = Utc::now() + chrono::Duration::seconds(2);
let error = generator
.verify_at(&url, future)
.expect_err("url should be expired");
assert!(matches!(error, SignedUrlError::Expired));
}
#[test]
fn test_verify_rejects_tampered_token() {
let generator = generator();
let mut url = generator
.blob_url(&blob(), Duration::from_secs(60))
.expect("url should build");
url.query_pairs_mut()
.clear()
.append_pair("token", "tampered");
let error = generator.verify(&url).expect_err("url should fail");
assert!(matches!(error, SignedUrlError::InvalidPayload));
}
#[test]
fn test_sign_payload_and_verify_payload_round_trip() {
let payload = b"hello";
let token = sign_payload(b"secret", payload).expect("token should build");
let decoded = verify_payload(&token, b"secret").expect("token should verify");
assert_eq!(decoded, payload);
}
}