annotation_rs_helpers/
helper.rs

1use crate::Symbol;
2use proc_macro2::TokenStream;
3use std::fmt::Display;
4use std::str::FromStr;
5use syn::punctuated::Punctuated;
6use syn::{
7    Attribute, Error, GenericArgument, Lit, Meta, PathArguments, PathSegment, Type, TypePath,
8};
9
10#[inline]
11pub fn unwrap_punctuated_first<T, P>(
12    punctuated: &Punctuated<T, P>,
13    error: Error,
14) -> Result<&T, Error> {
15    match punctuated.first() {
16        Some(s) => Ok(s),
17        None => Err(error),
18    }
19}
20
21#[inline]
22pub fn get_nested_type<'a>(
23    segment: &'a PathSegment,
24    message: &'static str,
25) -> Result<&'a Type, Error> {
26    let error = Error::new_spanned(segment, message);
27    match &segment.arguments {
28        PathArguments::AngleBracketed(argument) => {
29            match unwrap_punctuated_first(&argument.args, error.clone())? {
30                GenericArgument::Type(nested_type) => Ok(nested_type),
31                _ => Err(error),
32            }
33        }
34        _ => Err(error),
35    }
36}
37
38pub fn get_nested_types<'a>(
39    segment: &'a PathSegment,
40    message: &'static str,
41) -> Result<Vec<&'a Type>, Error> {
42    let error = Error::new_spanned(segment, message);
43    match &segment.arguments {
44        PathArguments::AngleBracketed(arguments) => arguments
45            .args
46            .iter()
47            .map(|argument| match argument {
48                GenericArgument::Type(nested_type) => Ok(nested_type),
49                _ => Err(error.clone()),
50            })
51            .collect(),
52        _ => Err(error),
53    }
54}
55
56#[inline]
57pub fn unwrap_type_path<'a>(ty: &'a Type, message: &'static str) -> Result<&'a TypePath, Error> {
58    match ty {
59        Type::Path(type_path) => Ok(type_path),
60        _ => Err(Error::new_spanned(ty, message)),
61    }
62}
63
64#[inline]
65pub fn get_lit_str<U: Display>(lit: &Lit, ident: &U) -> Result<String, Error> {
66    match lit {
67        Lit::Str(lit_str) => Ok(lit_str.value()),
68        _ => Err(Error::new_spanned(
69            lit,
70            format!("expected {} lit to be a string", ident),
71        )),
72    }
73}
74
75#[inline]
76pub fn get_lit_as_string<U: Display>(lit: &Lit, ident: &U) -> Result<String, Error> {
77    match lit {
78        Lit::Str(lit_str) => Ok(lit_str.value()),
79        Lit::Int(lit_int) => Ok(lit_int.to_string()),
80        Lit::Float(lit_float) => Ok(lit_float.to_string()),
81        Lit::Bool(lit_bool) => Ok(lit_bool.value.to_string()),
82        _ => Err(Error::new_spanned(
83            lit,
84            format!("expected {} lit to be a string/integer/float/boll", ident),
85        )),
86    }
87}
88
89#[inline]
90pub fn get_lit_int<T: FromStr, U: Display>(lit: &Lit, ident: &U) -> Result<T, Error>
91where
92    <T as std::str::FromStr>::Err: std::fmt::Display,
93{
94    match lit {
95        Lit::Int(lit_int) => Ok(lit_int.base10_parse().unwrap()),
96        _ => Err(Error::new_spanned(
97            lit,
98            format!("expected {} lit to be a integer", ident),
99        )),
100    }
101}
102
103#[inline]
104pub fn get_lit_float<T: FromStr, U: Display>(lit: &Lit, ident: &U) -> Result<T, Error>
105where
106    <T as std::str::FromStr>::Err: std::fmt::Display,
107{
108    match lit {
109        Lit::Float(lit_float) => Ok(lit_float.base10_parse().unwrap()),
110        _ => Err(Error::new_spanned(
111            lit,
112            format!("expected {} lit to be a float", ident),
113        )),
114    }
115}
116
117#[inline]
118pub fn get_lit_bool<U: Display>(lit: &Lit, ident: &U) -> Result<bool, Error> {
119    match lit {
120        Lit::Bool(lit_bool) => Ok(lit_bool.value),
121        _ => Err(Error::new_spanned(
122            lit,
123            format!("expected {} lit to be a bool", ident),
124        )),
125    }
126}
127
128pub fn get_mod_path(attrs: &[Attribute]) -> Result<Option<TokenStream>, Error> {
129    let mut mod_path = None;
130    for attr in attrs.iter() {
131        if attr.path == Symbol::new("mod_path") {
132            let meta = attr.parse_meta()?;
133            mod_path = match &meta {
134                Meta::NameValue(mod_path_value) => {
135                    let mod_path_str = get_lit_str(
136                        &mod_path_value.lit,
137                        mod_path_value.path.get_ident().as_ref().unwrap(),
138                    )?;
139
140                    Some(
141                        TokenStream::from_str(mod_path_str.as_str())
142                            .map_err(|_| Error::new_spanned(&meta, "Invalid mod_path"))?,
143                    )
144                }
145                _ => None,
146            }
147        }
148    }
149
150    Ok(mod_path)
151}