use crate::error::AnnilError;
use crate::extractor::auth::AuthExtractor;
use crate::extractor::track::TrackIdentifier;
use crate::state::AnnilKeys;
use async_trait::async_trait;
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use axum::Extension;
use jwt_simple::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::num::NonZeroU8;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Serialize, Deserialize, Clone)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum AnnilClaim {
User(UserClaim),
Share(ShareClaim),
}
#[derive(Serialize, Deserialize, Clone)]
pub struct UserClaim {
pub(crate) user_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) share: Option<ShareToken>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct ShareToken {
pub(crate) key_id: String,
pub(crate) secret: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(crate) allowed: Option<Vec<Uuid>>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct ShareClaim {
pub(crate) audios: HashMap<String, HashMap<String, Vec<NonZeroU8>>>,
}
#[async_trait]
impl<S> FromRequestParts<S> for AnnilClaim
where
S: Send + Sync,
{
type Rejection = AnnilError;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let AuthExtractor(auth) = AuthExtractor::from_request_parts(parts, state).await?;
let keys = Extension::<Arc<AnnilKeys>>::from_request_parts(parts, state)
.await
.expect("Failed to extract keys from extension. Please re-check your code first.");
let metadata = Token::decode_metadata(&auth).map_err(|_| AnnilError::Unauthorized)?;
match metadata.key_id() {
None => {
if let Ok(token) = keys.sign_key.verify_token::<AnnilClaim>(&auth, None) {
return Ok(token.custom);
}
}
Some(_) => {
if let Ok(token) = keys.share_key.verify_token::<AnnilClaim>(
&auth,
Some(VerificationOptions {
required_key_id: Some(
keys.share_key.key_id().as_deref().unwrap().to_string(),
),
..Default::default()
}),
) {
if token.custom.is_guest() {
return Ok(token.custom);
}
}
}
}
Err(AnnilError::Unauthorized)
}
}
impl AnnilClaim {
pub(crate) fn can_fetch(&self, track: &TrackIdentifier) -> bool {
match &self {
AnnilClaim::User(_) => true,
AnnilClaim::Share(s) => {
match s.audios.get(&track.album_id.to_string()) {
Some(album) => match album.get(&format!("{}", track.disc_id)) {
Some(disc) => {
disc.contains(&track.track_id)
}
None => false,
},
None => false,
}
}
}
}
#[inline]
pub(crate) fn is_guest(&self) -> bool {
matches!(self, AnnilClaim::Share(_))
}
}
#[test]
fn test_sign() {
let key = HS256Key::from_bytes(b"a token here");
let jwt = key
.authenticate(JWTClaims {
issued_at: Some(0.into()),
expires_at: None,
invalid_before: None,
issuer: None,
subject: None,
audiences: None,
jwt_id: None,
nonce: None,
custom: AnnilClaim::User(UserClaim {
user_id: "test".to_string(),
share: None,
}),
})
.expect("failed to sign jwt");
assert_eq!(jwt, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpYXQiOjAsInR5cGUiOiJ1c2VyIiwidXNlcl9pZCI6InRlc3QifQ.qBXwC9ILW5GEdTUIt6igJTwwLsuCFCi5sAAvruXQuVM");
}