beeline 0.1.0

Async Rust helpers for Foursquare/Swarm OAuth and latest checkin polling.
Documentation
use crate::{Error, Result};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::Utc;
use hmac::{Hmac, KeyInit, Mac};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use std::time::Duration;

type HmacSha256 = Hmac<Sha256>;
const SIGNED_STATE_VERSION: &str = "v1";
const MAX_FUTURE_STATE_SKEW_SECONDS: i64 = 300;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct LinkState {
    pub discord_user_id: u64,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub guild_id: Option<u64>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub channel_id: Option<u64>,
    pub nonce: String,
    pub issued_at: i64,
}

impl LinkState {
    pub fn new(discord_user_id: u64, nonce: impl Into<String>) -> Self {
        Self {
            discord_user_id,
            guild_id: None,
            channel_id: None,
            nonce: nonce.into(),
            issued_at: Utc::now().timestamp(),
        }
    }

    pub fn with_guild_id(mut self, guild_id: u64) -> Self {
        self.guild_id = Some(guild_id);
        self
    }

    pub fn with_channel_id(mut self, channel_id: u64) -> Self {
        self.channel_id = Some(channel_id);
        self
    }

    pub fn with_issued_at(mut self, issued_at: i64) -> Self {
        self.issued_at = issued_at;
        self
    }

    pub fn encode(&self, signing_key: impl AsRef<[u8]>) -> Result<String> {
        let bytes = serde_json::to_vec(self)?;
        let payload = URL_SAFE_NO_PAD.encode(bytes);
        let signature = sign_payload(&payload, signing_key.as_ref())?;
        Ok(format!(
            "{SIGNED_STATE_VERSION}.{payload}.{}",
            URL_SAFE_NO_PAD.encode(signature)
        ))
    }

    pub fn decode(encoded: impl AsRef<str>, signing_key: impl AsRef<[u8]>) -> Result<Self> {
        let encoded = encoded.as_ref();
        let mut parts = encoded.split('.');
        let version = parts
            .next()
            .ok_or_else(|| Error::InvalidState("missing state version".to_string()))?;
        let payload = parts
            .next()
            .ok_or_else(|| Error::InvalidState("missing state payload".to_string()))?;
        let signature = parts
            .next()
            .ok_or_else(|| Error::InvalidState("missing state signature".to_string()))?;

        if parts.next().is_some() {
            return Err(Error::InvalidState("too many state sections".to_string()));
        }

        if version != SIGNED_STATE_VERSION {
            return Err(Error::InvalidState("unsupported state version".to_string()));
        }

        let signature = URL_SAFE_NO_PAD
            .decode(signature)
            .map_err(|err| Error::InvalidState(err.to_string()))?;
        verify_payload(payload, signing_key.as_ref(), &signature)?;
        Self::decode_payload(payload)
    }

    pub fn decode_with_max_age(
        encoded: impl AsRef<str>,
        signing_key: impl AsRef<[u8]>,
        max_age: Duration,
    ) -> Result<Self> {
        let state = Self::decode(encoded, signing_key)?;
        state.validate_max_age(max_age)
    }

    fn decode_payload(payload: &str) -> Result<Self> {
        let bytes = URL_SAFE_NO_PAD
            .decode(payload)
            .map_err(|err| Error::InvalidState(err.to_string()))?;
        serde_json::from_slice(&bytes).map_err(Error::Decode)
    }

    fn validate_max_age(self, max_age: Duration) -> Result<Self> {
        let now = Utc::now().timestamp();
        let max_age_seconds = max_age.as_secs();
        let future_skew_seconds = self.issued_at.saturating_sub(now);

        if future_skew_seconds > MAX_FUTURE_STATE_SKEW_SECONDS {
            return Err(Error::InvalidState(
                "state issued_at is too far in the future".to_string(),
            ));
        }

        let age_seconds = now.checked_sub(self.issued_at).unwrap_or(i64::MAX);
        if age_seconds > max_age_seconds as i64 {
            return Err(Error::ExpiredState {
                issued_at: self.issued_at,
                max_age_seconds,
            });
        }

        Ok(self)
    }
}

fn sign_payload(payload: &str, signing_key: &[u8]) -> Result<Vec<u8>> {
    let mut mac = state_mac(signing_key)?;
    mac.update(SIGNED_STATE_VERSION.as_bytes());
    mac.update(b".");
    mac.update(payload.as_bytes());
    Ok(mac.finalize().into_bytes().to_vec())
}

fn verify_payload(payload: &str, signing_key: &[u8], signature: &[u8]) -> Result<()> {
    let mut mac = state_mac(signing_key)?;
    mac.update(SIGNED_STATE_VERSION.as_bytes());
    mac.update(b".");
    mac.update(payload.as_bytes());
    mac.verify_slice(signature)
        .map_err(|_| Error::InvalidState("invalid state signature".to_string()))
}

fn state_mac(signing_key: &[u8]) -> Result<HmacSha256> {
    if signing_key.is_empty() {
        return Err(Error::InvalidState(
            "state signing key must not be empty".to_string(),
        ));
    }

    HmacSha256::new_from_slice(signing_key)
        .map_err(|_| Error::InvalidState("invalid state signing key".to_string()))
}