Skip to main content

aranya_capi_codegen/
util.rs

1use quote::{IdentFragment, format_ident};
2use syn::{
3    Ident, Path, PathSegment, Result, Token,
4    parse::{Parse, ParseStream},
5};
6
7macro_rules! parse_doc {
8    ($($arg:tt)*) => {{
9        let doc = ::std::format!($($arg)*);
10        let mut tokens = ::proc_macro2::TokenStream::new();
11        for frag in doc.trim().split("\n") {
12            let frag = frag
13                .strip_prefix("///")
14                .unwrap_or(&frag);
15            tokens.extend(::quote::quote! {
16                #[doc = #frag]
17            })
18        }
19        tokens
20    }};
21}
22pub(crate) use parse_doc;
23
24/// Skips the next token if it's a comma.
25pub fn skip_comma(input: ParseStream<'_>) -> Result<()> {
26    let lookahead = input.lookahead1();
27    if lookahead.peek(Token![,]) {
28        let _: Token![,] = input.parse()?;
29    }
30    Ok(())
31}
32
33/// Extension trait for [`Path`].
34pub trait PathExt {
35    /// Joins the two paths.
36    #[must_use]
37    fn join<S>(&self, seg: S) -> Self
38    where
39        S: Into<PathSegment>;
40
41    /// Returns the last [`Ident`] in the path.
42    fn ty_name(&self) -> &Ident;
43}
44
45impl PathExt for Path {
46    fn join<S>(&self, seg: S) -> Self
47    where
48        S: Into<PathSegment>,
49    {
50        let mut path = self.clone();
51        path.segments.push(seg.into());
52        path
53    }
54
55    /// Returns the last [`Ident`] in the path.
56    #[allow(clippy::arithmetic_side_effects)]
57    fn ty_name(&self) -> &Ident {
58        &self.segments[self.segments.len() - 1].ident
59    }
60}
61
62/// Extension trait for [`struct@Ident`].
63pub trait IdentExt {
64    /// Adds `prefix` to the identifier.
65    #[must_use]
66    fn with_prefix<I>(&self, prefix: I) -> Self
67    where
68        I: IdentFragment;
69
70    /// Adds `suffix` to the identifier.
71    #[must_use]
72    fn with_suffix<I>(&self, suffix: I) -> Self
73    where
74        I: IdentFragment;
75
76    /// Converts the identifier to snake_case.
77    #[must_use]
78    fn to_snake_case(&self) -> Self;
79
80    /// Converts the identifier to SCREAMING_SNAKE_CASE.
81    #[must_use]
82    fn to_screaming_snake_case(&self) -> Self;
83}
84
85impl IdentExt for Ident {
86    fn with_prefix<I>(&self, prefix: I) -> Self
87    where
88        I: IdentFragment,
89    {
90        format_ident!("{}{}", prefix, self, span = self.span())
91    }
92
93    fn with_suffix<I>(&self, suffix: I) -> Self
94    where
95        I: IdentFragment,
96    {
97        format_ident!("{}{}", self, suffix, span = self.span())
98    }
99
100    fn to_snake_case(&self) -> Self {
101        let mut new = String::new();
102        let mut in_word = false;
103        for c in self.to_string().chars() {
104            if c.is_uppercase() {
105                if in_word {
106                    new.push('_');
107                }
108                in_word = true;
109                new.extend(c.to_lowercase());
110            } else {
111                new.push(c);
112            }
113        }
114        Self::new(&new, self.span())
115    }
116
117    fn to_screaming_snake_case(&self) -> Self {
118        let mut new = String::new();
119        let mut in_word = false;
120        for c in self.to_string().chars() {
121            if c.is_uppercase() {
122                if in_word {
123                    new.push('_');
124                }
125                in_word = true;
126            }
127            new.extend(c.to_uppercase());
128        }
129        Self::new(&new, self.span())
130    }
131}
132
133/// A key-value pair.
134pub struct KeyValPair<K, V> {
135    pub key: K,
136    pub val: V,
137}
138
139impl<K, V> Parse for KeyValPair<K, V>
140where
141    K: Parse,
142    V: Parse,
143{
144    fn parse(input: ParseStream<'_>) -> Result<Self> {
145        let key = input.parse()?;
146        let _: Token![=] = input.parse()?;
147        let val = input.parse()?;
148        skip_comma(input)?;
149        Ok(Self { key, val })
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use proc_macro2::Span;
156    use syn::Ident;
157
158    use super::*;
159
160    #[test]
161    fn test_ident_with_prefix() {
162        let tests = [
163            ("SomeType", "OsSomeType", format_ident!("Os")),
164            ("some_func", "os_some_func", format_ident!("os_")),
165        ];
166        for (i, (input, want, prefix)) in tests.into_iter().enumerate() {
167            let orig = Ident::new(input, Span::call_site());
168            let got = orig.with_prefix(&prefix);
169            let want = Ident::new(want, Span::call_site());
170            assert_eq!(got, want, "#{i}");
171
172            let got = got.to_string();
173            let got = got.strip_prefix(&prefix.to_string()).unwrap();
174            assert_eq!(got, orig.to_string(), "#{i}");
175        }
176    }
177
178    #[test]
179    fn test_ident_with_suffix() {
180        let tests = [
181            ("os_some_func", "os_some_func_ext", format_ident!("_ext")),
182            ("os_some_func", "os_some_func_v2_1", format_ident!("_v2_1")),
183        ];
184        for (i, (input, want, suffix)) in tests.into_iter().enumerate() {
185            let got = Ident::new(input, Span::call_site()).with_suffix(suffix);
186            let want = Ident::new(want, Span::call_site());
187            assert_eq!(got, want, "#{i}");
188        }
189    }
190
191    #[test]
192    fn test_ident_to_snake_case() {
193        let tests = [
194            ("ABCD", "a_b_c_d"),
195            ("OsFoo", "os_foo"),
196            ("OsFooInit", "os_foo_init"),
197        ];
198        for (i, (input, want)) in tests.into_iter().enumerate() {
199            let got = Ident::new(input, Span::call_site()).to_snake_case();
200            let want = Ident::new(want, Span::call_site());
201            assert_eq!(got, want, "#{i}");
202        }
203    }
204}