use std::borrow::Cow;
use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
use rustc_ast::{
token::{Delimiter, LitKind, TokenKind},
tokenstream::{TokenStream, TokenTree},
};
use rustc_errors::DiagCtxtHandle;
use rustc_hash::FxHashSet;
use rustc_hir::{AttrArgs, AttrKind, Attribute, def_id::DefId};
use rustc_middle::ty::TyCtxt;
use rustc_span::Symbol;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SerdeAttr {
Flatten,
Skip,
Default,
Rename(Symbol),
RenameAll(Symbol),
Tag(Symbol),
Content(Symbol),
}
#[derive(Debug)]
pub struct SerdeFieldAttrs(FxHashSet<SerdeAttr>);
impl SerdeFieldAttrs {
pub fn skip(&self) -> bool {
self.0.contains(&SerdeAttr::Skip)
}
pub fn default(&self) -> bool {
self.0.contains(&SerdeAttr::Default)
}
pub fn flatten(&self) -> bool {
self.0.contains(&SerdeAttr::Flatten)
}
pub fn rename_field<'s>(&'s self, field_name: impl Into<Cow<'s, str>>) -> Cow<'s, str> {
for attr in &self.0 {
if let SerdeAttr::Rename(sym) = attr {
return Cow::Borrowed(sym.as_str());
}
}
for attr in &self.0 {
if let SerdeAttr::RenameAll(rule) = attr {
return apply_rename_all(field_name.into(), rule.as_str());
}
}
field_name.into()
}
pub fn tag_name(&self) -> Option<&str> {
for attr in &self.0 {
if let SerdeAttr::Tag(sym) = attr {
return Some(sym.as_str());
}
}
None
}
pub fn content_name(&self) -> Option<&str> {
for attr in &self.0 {
if let SerdeAttr::Content(sym) = attr {
return Some(sym.as_str());
}
}
None
}
}
pub fn serde_attrs_for_field_def(tcx: TyCtxt<'_>, field_def_did: DefId) -> SerdeFieldAttrs {
let attrs = tcx
.get_attrs(field_def_did, Symbol::intern("serde"))
.filter_map(|a| extract_serde_attrs(tcx.dcx(), a))
.flatten();
let container_attrs = tcx
.get_attrs(tcx.parent(field_def_did), Symbol::intern("serde"))
.filter_map(|a| extract_serde_attrs(tcx.dcx(), a))
.flatten();
SerdeFieldAttrs(attrs.chain(container_attrs).collect())
}
fn extract_serde_attrs(
dcx: DiagCtxtHandle<'_>,
attr: &Attribute,
) -> Option<impl Iterator<Item = SerdeAttr>> {
let AttrKind::Normal(normal_attr) = &attr.kind else {
return None;
};
if normal_attr.path.segments.len() != 1 {
return None;
}
if normal_attr.path.segments[0].name != Symbol::intern("serde") {
return None;
}
match &normal_attr.args {
AttrArgs::Delimited(delimited) => {
if !matches!(delimited.delim, Delimiter::Parenthesis) {
return None;
}
Some(parse_serde_meta_list(dcx, &delimited.tokens))
}
_ => {
dcx.span_err(
attr.span,
"serde attributes must be in the form of `#[serde(...)]`",
);
None
}
}
}
fn apply_rename_all<'s>(original: Cow<'s, str>, rename_all: &str) -> Cow<'s, str> {
match rename_all {
"snake_case" => Cow::Owned(original.to_snake_case()),
"SCREAMING_SNAKE_CASE" => Cow::Owned(original.to_shouty_snake_case()),
"kebab-case" => Cow::Owned(original.to_kebab_case()),
"camelCase" => Cow::Owned(original.to_lower_camel_case()),
"PascalCase" => Cow::Owned(original.to_upper_camel_case()),
_ => original,
}
}
fn parse_serde_meta_list(
dcx: DiagCtxtHandle<'_>,
tokens: &TokenStream,
) -> impl Iterator<Item = SerdeAttr> {
let mut result = Vec::new();
let chunks = split_by_comma(tokens);
for chunk in chunks {
if let Some(attr) = parse_meta_item(dcx, &chunk) {
result.push(attr);
}
}
result.into_iter()
}
fn split_by_comma(stream: &TokenStream) -> Vec<TokenStream> {
let mut result = Vec::new();
let mut current = Vec::new();
for tt in stream.iter() {
match tt {
TokenTree::Token(token, spacing) => {
if let TokenKind::Comma = token.kind {
if !current.is_empty() {
result.push(TokenStream::new(current));
current = Vec::new();
}
} else {
current.push(TokenTree::Token(token.clone(), *spacing));
}
}
TokenTree::Delimited(delim_span, delim_kind, inside, span_kind) => {
current.push(TokenTree::Delimited(
*delim_span,
*delim_kind,
*inside,
span_kind.clone(),
));
}
}
}
if !current.is_empty() {
result.push(TokenStream::new(current));
}
result
}
fn parse_meta_item(dcx: DiagCtxtHandle<'_>, chunk: &TokenStream) -> Option<SerdeAttr> {
let mut iter = chunk.iter().peekable();
let (ident_sym, ident_span) = match iter.next() {
Some(TokenTree::Token(token, _)) => match token.ident() {
Some((ident, _)) => (ident.name, ident.span),
None => return None, },
_ => return None, };
if let Some(TokenTree::Token(eq_token, _)) = iter.peek() {
if let TokenKind::Eq = eq_token.kind {
iter.next();
if let Some(TokenTree::Token(tok, _)) = iter.next() {
match &tok.kind {
TokenKind::Literal(rustc_ast::token::Lit {
kind: LitKind::Str,
symbol,
..
}) => {
let str_val = *symbol;
return match ident_sym.as_str() {
"rename" => Some(SerdeAttr::Rename(str_val)),
"rename_all" => Some(SerdeAttr::RenameAll(str_val)),
"tag" => Some(SerdeAttr::Tag(str_val)),
"content" => Some(SerdeAttr::Content(str_val)),
"default" => Some(SerdeAttr::Default),
"deserialize_with" => None,
other => {
dcx.span_err(
tok.span,
format!("serde attribute `{other}` unrecognized by `riptc`, so we cannot confidently determine final type"),
);
return None;
}
};
}
_ => {
dcx.span_err(
tok.span,
"invalid serde attribute, not a string literal".to_string(),
);
return None;
}
}
} else {
return None;
}
}
}
match ident_sym.as_str() {
"flatten" => Some(SerdeAttr::Flatten),
"skip" => Some(SerdeAttr::Skip),
"default" => Some(SerdeAttr::Default),
"deserialize_with" => None,
other => {
dcx.span_err(
ident_span,
format!("serde attribute `{other}` unrecognized by `riptc`, so we cannot confidently determine final type"),
);
None
}
}
}