rustrails-storage 0.1.2

File storage (ActiveStorage equivalent)
Documentation
//! Signed URL generation for blobs and variants.

use std::time::Duration;

use chrono::{DateTime, Utc};
use hmac::{Hmac, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use thiserror::Error;
use url::Url;

use crate::{Blob, Variant, urlsafe_decode, urlsafe_encode};

type HmacSha256 = Hmac<Sha256>;

/// Errors returned by signed URL generation and verification.
#[derive(Debug, Error)]
pub enum SignedUrlError {
    /// The supplied base URL was invalid.
    #[error("invalid url: {0}")]
    InvalidUrl(String),
    /// The signature could not be verified.
    #[error("signature verification failed")]
    InvalidSignature,
    /// The token payload could not be decoded.
    #[error("invalid token payload")]
    InvalidPayload,
    /// The signed URL has expired.
    #[error("signed url has expired")]
    Expired,
}

/// Resource extracted from a verified signed URL.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum SignedResource {
    /// A blob download URL.
    Blob { key: String },
    /// A processed variant download URL.
    Variant { key: String },
    /// A redirect URL.
    Redirect { location: String },
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct SignedClaims {
    resource: SignedResource,
    expires_at: i64,
}

/// Signs and verifies blob and variant URLs.
#[derive(Debug, Clone)]
pub struct SignedUrlGenerator {
    base_url: Url,
    secret: Vec<u8>,
}

impl SignedUrlGenerator {
    /// Creates a new URL signer.
    ///
    /// # Errors
    ///
    /// Returns an error when the base URL cannot be parsed.
    pub fn new(
        base_url: impl AsRef<str>,
        secret: impl Into<Vec<u8>>,
    ) -> Result<Self, SignedUrlError> {
        Ok(Self {
            base_url: Url::parse(base_url.as_ref())
                .map_err(|error| SignedUrlError::InvalidUrl(error.to_string()))?,
            secret: secret.into(),
        })
    }

    /// Generates a signed URL for a blob.
    ///
    /// # Errors
    ///
    /// Returns an error when the URL cannot be built.
    pub fn blob_url(&self, blob: &Blob, expires_in: Duration) -> Result<Url, SignedUrlError> {
        self.signed_url(
            SignedResource::Blob {
                key: blob.key().to_owned(),
            },
            expires_in,
        )
    }

    /// Generates a signed URL for a variant.
    ///
    /// # Errors
    ///
    /// Returns an error when the URL cannot be built.
    pub fn variant_url(
        &self,
        variant: &Variant,
        expires_in: Duration,
    ) -> Result<Url, SignedUrlError> {
        self.signed_url(
            SignedResource::Variant {
                key: variant.key().to_owned(),
            },
            expires_in,
        )
    }

    /// Generates a signed redirect URL.
    ///
    /// # Errors
    ///
    /// Returns an error when the URL cannot be built.
    pub fn redirect_url(
        &self,
        location: &Url,
        expires_in: Duration,
    ) -> Result<Url, SignedUrlError> {
        self.signed_url(
            SignedResource::Redirect {
                location: location.to_string(),
            },
            expires_in,
        )
    }

    /// Verifies a signed URL at the current time.
    ///
    /// # Errors
    ///
    /// Returns an error when the signature is invalid, the payload cannot be decoded, or the URL has expired.
    pub fn verify(&self, url: &Url) -> Result<SignedResource, SignedUrlError> {
        self.verify_at(url, Utc::now())
    }

    /// Verifies a signed URL at the supplied instant.
    ///
    /// # Errors
    ///
    /// Returns an error when the signature is invalid, the payload cannot be decoded, or the URL has expired.
    pub fn verify_at(
        &self,
        url: &Url,
        now: DateTime<Utc>,
    ) -> Result<SignedResource, SignedUrlError> {
        let token = url
            .query_pairs()
            .find(|(key, _)| key == "token")
            .map(|(_, value)| value.into_owned())
            .ok_or(SignedUrlError::InvalidPayload)?;
        let (payload, signature) = token
            .split_once('.')
            .ok_or(SignedUrlError::InvalidPayload)?;
        let payload_bytes = urlsafe_decode(payload).map_err(|_| SignedUrlError::InvalidPayload)?;
        let signature_bytes =
            urlsafe_decode(signature).map_err(|_| SignedUrlError::InvalidPayload)?;
        let expected = sign_bytes(&self.secret, &payload_bytes)?;
        if expected != signature_bytes {
            return Err(SignedUrlError::InvalidSignature);
        }
        let claims: SignedClaims =
            serde_json::from_slice(&payload_bytes).map_err(|_| SignedUrlError::InvalidPayload)?;
        if now.timestamp() > claims.expires_at {
            return Err(SignedUrlError::Expired);
        }
        Ok(claims.resource)
    }

    fn signed_url(
        &self,
        resource: SignedResource,
        expires_in: Duration,
    ) -> Result<Url, SignedUrlError> {
        let expires_at = Utc::now()
            + chrono::Duration::from_std(expires_in).map_err(|_| SignedUrlError::InvalidPayload)?;
        let claims = SignedClaims {
            resource,
            expires_at: expires_at.timestamp(),
        };
        let payload = serde_json::to_vec(&claims).map_err(|_| SignedUrlError::InvalidPayload)?;
        let signature = sign_bytes(&self.secret, &payload)?;
        let token = format!("{}.{}", urlsafe_encode(&payload), urlsafe_encode(signature));
        let mut url = self.base_url.clone();
        url.query_pairs_mut().append_pair("token", &token);
        Ok(url)
    }
}

pub(crate) fn sign_payload(secret: &[u8], payload: &[u8]) -> Result<String, SignedUrlError> {
    let signature = sign_bytes(secret, payload)?;
    Ok(format!(
        "{}.{}",
        urlsafe_encode(payload),
        urlsafe_encode(signature)
    ))
}

pub(crate) fn verify_payload(token: &str, secret: &[u8]) -> Result<Vec<u8>, SignedUrlError> {
    let (payload, signature) = token
        .split_once('.')
        .ok_or(SignedUrlError::InvalidPayload)?;
    let payload_bytes = urlsafe_decode(payload).map_err(|_| SignedUrlError::InvalidPayload)?;
    let signature_bytes = urlsafe_decode(signature).map_err(|_| SignedUrlError::InvalidPayload)?;
    let expected = sign_bytes(secret, &payload_bytes)?;
    if expected != signature_bytes {
        return Err(SignedUrlError::InvalidSignature);
    }
    Ok(payload_bytes)
}

fn sign_bytes(secret: &[u8], payload: &[u8]) -> Result<Vec<u8>, SignedUrlError> {
    let mut mac = HmacSha256::new_from_slice(secret).map_err(|_| SignedUrlError::InvalidPayload)?;
    mac.update(payload);
    Ok(mac.finalize().into_bytes().to_vec())
}

#[cfg(test)]
mod tests {
    use bytes::Bytes;

    use super::*;

    fn generator() -> SignedUrlGenerator {
        SignedUrlGenerator::new("https://example.test/storage", b"secret".to_vec())
            .expect("generator should build")
    }

    fn blob() -> Blob {
        Blob::create(
            Bytes::from_static(b"hello"),
            "hello.txt",
            None,
            Default::default(),
            "memory",
        )
        .expect("blob should build")
    }

    #[test]
    fn test_blob_url_round_trip_verification() {
        let generator = generator();
        let url = generator
            .blob_url(&blob(), Duration::from_secs(60))
            .expect("url should build");
        let resource = generator.verify(&url).expect("url should verify");
        assert!(matches!(resource, SignedResource::Blob { .. }));
    }

    #[test]
    fn test_variant_url_round_trip_verification() {
        let generator = generator();
        let variant = Variant::new(blob(), Default::default());
        let url = generator
            .variant_url(&variant, Duration::from_secs(60))
            .expect("url should build");
        let resource = generator.verify(&url).expect("url should verify");
        assert!(matches!(resource, SignedResource::Variant { .. }));
    }

    #[test]
    fn test_redirect_url_round_trip_verification() {
        let generator = generator();
        let location = Url::parse("https://cdn.example/files/1").expect("url should parse");
        let url = generator
            .redirect_url(&location, Duration::from_secs(60))
            .expect("url should build");
        let resource = generator.verify(&url).expect("url should verify");
        assert_eq!(
            resource,
            SignedResource::Redirect {
                location: location.to_string()
            }
        );
    }

    #[test]
    fn test_verify_rejects_expired_url() {
        let generator = generator();
        let url = generator
            .blob_url(&blob(), Duration::from_secs(1))
            .expect("url should build");
        let future = Utc::now() + chrono::Duration::seconds(2);
        let error = generator
            .verify_at(&url, future)
            .expect_err("url should be expired");
        assert!(matches!(error, SignedUrlError::Expired));
    }

    #[test]
    fn test_verify_rejects_tampered_token() {
        let generator = generator();
        let mut url = generator
            .blob_url(&blob(), Duration::from_secs(60))
            .expect("url should build");
        url.query_pairs_mut()
            .clear()
            .append_pair("token", "tampered");
        let error = generator.verify(&url).expect_err("url should fail");
        assert!(matches!(error, SignedUrlError::InvalidPayload));
    }

    #[test]
    fn test_sign_payload_and_verify_payload_round_trip() {
        let payload = b"hello";
        let token = sign_payload(b"secret", payload).expect("token should build");
        let decoded = verify_payload(&token, b"secret").expect("token should verify");
        assert_eq!(decoded, payload);
    }
}