use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
Ident,
LitStr,
Result,
};
struct Input {
name: syn::Ident,
prefix: String,
}
impl Parse for Input {
fn parse(input: ParseStream) -> Result<Self> {
let name: Ident = input.parse()?;
input.parse::<syn::Token![=]>()?;
let prefix: LitStr = input.parse()?;
Ok(Input {
name,
prefix: prefix.value(),
})
}
}
#[proc_macro]
pub fn puid(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
impl_puid(parse_macro_input!(input as Input))
.unwrap()
.into()
}
fn impl_puid(Input { name, prefix }: Input) -> Result<TokenStream> {
let prefix_len = prefix.len();
let len = prefix_len + 22 + 1;
let mut buf = Vec::with_capacity(len);
for i in prefix.bytes() {
buf.push(i);
}
buf.push(b'_');
buf.resize(len, b'0');
let serde = if cfg!(feature = "serde") {
let visitor_ident = syn::Ident::new(&format!("{name}SerdeVisitor"), name.span());
quote! {
impl ::serde::Serialize for #name {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: ::serde::Serializer,
{
serializer.serialize_str(self.as_str())
}
}
struct #visitor_ident;
impl ::serde::de::Visitor<'_> for #visitor_ident {
type Value = #name;
fn expecting(&self, formatter: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
formatter.write_str("a string with the format '#prefix_<suffix>'")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: ::serde::de::Error,
{
v.parse().map_err(E::custom)
}
}
impl<'de> ::serde::Deserialize<'de> for #name {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: ::serde::Deserializer<'de>,
{
deserializer.deserialize_str(#visitor_ident)
}
}
}
} else {
quote! {}
};
let snake_case_name = {
let mut name_str = name.to_string();
for c in 'A'..='Z' {
name_str = name_str.replace(c, &format!("_{}", c.to_ascii_lowercase()));
}
name_str.trim_start_matches('_').to_string()
};
let postgres = if cfg!(feature = "postgres") {
quote! {
impl ::sqlx::Type<::sqlx::Postgres> for #name {
fn type_info() -> ::sqlx::postgres::PgTypeInfo {
::sqlx::postgres::PgTypeInfo::with_name(#snake_case_name)
}
fn compatible(ty: &::sqlx::postgres::PgTypeInfo) -> bool {
ty == &::sqlx::postgres::PgTypeInfo::with_name("user_id") || <&str as ::sqlx::Type<::sqlx::Postgres>>::compatible(ty)
}
}
impl ::sqlx::Encode<'_, ::sqlx::Postgres> for #name {
fn encode_by_ref(&self, buf: &mut ::sqlx::postgres::PgArgumentBuffer) -> ::std::result::Result<::sqlx::encode::IsNull, ::sqlx::error::BoxDynError> {
buf.extend(self.as_str().as_bytes());
Ok(::sqlx::encode::IsNull::No)
}
}
impl<'r> ::sqlx::Decode<'r, ::sqlx::Postgres> for #name {
fn decode(value: ::sqlx::postgres::PgValueRef<'r>) -> ::std::result::Result<Self, ::sqlx::error::BoxDynError> {
let s: &str = value.as_str()?;
s.parse().map_err(::std::convert::Into::into)
}
}
}
} else {
quote! {}
};
let sea_query = if cfg!(feature = "sea-query") {
quote! {
impl From<#name> for ::sea_query::Value {
fn from(value: #name) -> Self {
::sea_query::Value::String(Some(value.to_string().into()))
}
}
impl ::sea_query::value::Nullable for #name {
fn null() -> ::sea_query::Value {
::sea_query::Value::String(None)
}
}
}
} else {
quote! {}
};
let create_domain = LitStr::new(
&format!(
"CREATE DOMAIN {snake_case_name} AS CHAR({len}) CHECK (VALUE ~ \
'^{prefix}_[0-9A-Za-z]{{22}}$');",
),
proc_macro2::Span::call_site(),
);
Ok(quote! {
#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct #name([u8; #len]);
impl #name {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let mut buf = [#(#buf),*];
::puid::encode_suffix(&mut buf[#prefix_len + 1..]);
#name(buf)
}
pub fn as_str(&self) -> &str {
unsafe {
::std::str::from_utf8_unchecked(&self.0)
}
}
pub fn create_domain() -> &'static str {
#create_domain
}
pub fn nil() -> Self {
#name([#(#buf),*])
}
pub fn is_nil(&self) -> bool {
self.0 == [#(#buf),*]
}
}
impl ::std::str::FromStr for #name {
type Err = ::puid::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.len() != #len {
return Err(::puid::Error::InvalidLength);
}
let mut buf = [#(#buf),*];
if !s.starts_with(#prefix) {
return Err(::puid::Error::InvalidPrefix);
}
if s.as_bytes()[#prefix_len] != b'_' {
return Err(::puid::Error::InvalidFormat);
}
for c in &s.as_bytes()[#prefix_len + 1..] {
if !::puid::is_valid_suffix_byte(*c) {
return Err(::puid::Error::InvalidSuffixChar(*c));
}
}
buf[#prefix_len + 1..].copy_from_slice(&s.as_bytes()[#prefix_len + 1..]);
Ok(#name(buf))
}
}
impl ::std::fmt::Display for #name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
f.write_str(self.as_str())
}
}
impl ::std::fmt::Debug for #name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
f.debug_struct(stringify!(#name))
.field("value", &self.as_str())
.finish()
}
}
impl From<&str> for #name {
fn from(s: &str) -> Self {
s.parse().expect("Invalid PUID string")
}
}
impl From<String> for #name {
fn from(s: String) -> Self {
s.parse().expect("Invalid PUID string")
}
}
#serde
#postgres
#sea_query
})
}