use frame_support::{pallet_prelude::*, parameter_types, traits::OriginTrait, Deserialize, Serialize};
use pallet_timestamp::Now;
use parity_scale_codec::{Decode, Encode};
use scale_info::{prelude::string::String, TypeInfo};
use serde::{de::Error, ser::SerializeStruct, Serializer};
use sp_runtime::{traits::BadOrigin, DeserializeOwned, RuntimeDebug};
pub use jwt_compact_frame::{
alg::{Ed25519, VerifyingKey},
Claims as StandardClaims, *,
};
use serde::Deserializer;
#[derive(Clone, Encode, Decode, Eq, PartialEq, Ord, PartialOrd, RuntimeDebug, TypeInfo, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum InvestorType {
Retail,
Professional,
Institutional,
}
impl InvestorType {
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
InvestorType::Retail => "retail",
InvestorType::Professional => "professional",
InvestorType::Institutional => "institutional",
}
}
}
parameter_types! {
pub const Retail: InvestorType = InvestorType::Retail;
pub const Professional: InvestorType = InvestorType::Professional;
pub const Institutional: InvestorType = InvestorType::Institutional;
}
#[derive(Clone, Encode, Decode, Eq, PartialEq, Ord, PartialOrd, RuntimeDebug, TypeInfo, Deserialize)]
pub struct SampleClaims<AccountId> {
#[serde(rename = "sub")]
pub subject: AccountId,
#[serde(rename = "iss")]
pub issuer: String,
#[serde(rename = "aud", deserialize_with = "from_bounded_cid")]
pub ipfs_cid: Cid,
pub investor_type: InvestorType,
#[serde(deserialize_with = "from_bounded_did")]
pub did: Did,
}
pub type Did = BoundedVec<u8, ConstU32<57>>;
pub type Cid = BoundedVec<u8, ConstU32<96>>;
pub struct EnsureInvestor<T>(sp_std::marker::PhantomData<T>);
impl<T> EnsureOriginWithCredentials<T::RuntimeOrigin> for EnsureInvestor<T>
where
T: frame_system::Config + pallet_timestamp::Config,
{
type Claims = SampleClaims<T::AccountId>;
type Success = (T::AccountId, Did, InvestorType, Cid);
fn try_origin(
origin: T::RuntimeOrigin,
token: &jwt_compact_frame::UntrustedToken,
verifying_key: [u8; 32],
) -> Result<Self::Success, T::RuntimeOrigin> {
let Some(who) = origin.clone().into_signer() else { return Err(origin) };
let Ok(token) = Self::verify_token(token, verifying_key) else { return Err(origin) };
let claims = token.claims();
let Ok(now) = Now::<T>::get().try_into() else { return Err(origin) };
let Some(date_time) = claims.expiration else { return Err(origin) };
let timestamp: u64 = date_time.timestamp_millis().try_into().map_err(|_| origin.clone())?;
if claims.custom.subject == who && timestamp >= now {
return Ok((
who,
claims.custom.did.clone(),
claims.custom.investor_type.clone(),
claims.custom.ipfs_cid.clone(),
));
}
Err(origin)
}
}
#[allow(clippy::module_name_repetitions)]
pub trait EnsureOriginWithCredentials<OuterOrigin>
where
OuterOrigin: OriginTrait,
{
type Success;
type Claims: Clone + Encode + Decode + Eq + PartialEq + Ord + PartialOrd + TypeInfo + DeserializeOwned;
fn try_origin(
origin: OuterOrigin,
token: &jwt_compact_frame::UntrustedToken,
verifying_key: [u8; 32],
) -> Result<Self::Success, OuterOrigin>;
fn ensure_origin(
origin: OuterOrigin,
token: &jwt_compact_frame::UntrustedToken,
verifying_key: [u8; 32],
) -> Result<Self::Success, BadOrigin> {
Self::try_origin(origin, token, verifying_key).map_err(|_| BadOrigin)
}
fn verify_token(
token: &jwt_compact_frame::UntrustedToken,
verifying_key: [u8; 32],
) -> Result<jwt_compact_frame::Token<Self::Claims>, ValidationError> {
let signing_key =
<<Ed25519 as Algorithm>::VerifyingKey>::from_slice(&verifying_key).expect("The Key is always valid");
Ed25519.validator::<Self::Claims>(&signing_key).validate(token)
}
}
pub fn from_bounded_did<'de, D>(deserializer: D) -> Result<Did, D::Error>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)
.map(|string| string.as_bytes().to_vec())
.and_then(|vec| vec.try_into().map_err(|_| Error::custom("failed to deserialize")))
}
pub fn from_bounded_cid<'de, D>(deserializer: D) -> Result<Cid, D::Error>
where
D: Deserializer<'de>,
{
String::deserialize(deserializer)
.map(|string| string.as_bytes().to_vec())
.and_then(|vec| vec.try_into().map_err(|_| Error::custom("failed to deserialize")))
}
impl<AccountId> Serialize for SampleClaims<AccountId>
where
AccountId: Serialize, {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut state = serializer.serialize_struct("SampleClaims", 5)?;
state.serialize_field("sub", &self.subject)?;
state.serialize_field("iss", &self.issuer)?;
let ipfs_cid_bytes: scale_info::prelude::vec::Vec<u8> = self.ipfs_cid.clone().into(); let ipfs_cid_string = String::from_utf8_lossy(&ipfs_cid_bytes); state.serialize_field("aud", &ipfs_cid_string)?;
state.serialize_field("investor_type", &self.investor_type)?;
let did_bytes: scale_info::prelude::vec::Vec<u8> = self.did.clone().into(); let did_string = String::from_utf8_lossy(&did_bytes); state.serialize_field("did", &did_string)?;
state.end()
}
}