aranya_capi_codegen/
util.rs

1use quote::{format_ident, IdentFragment};
2use syn::{
3    parse::{Parse, ParseStream},
4    Ident, Path, PathSegment, Result, Token,
5};
6
7macro_rules! parse_doc {
8    ($($arg:tt)*) => {{
9        let doc = ::std::format!($($arg)*);
10        let mut tokens = ::proc_macro2::TokenStream::new();
11        for frag in doc.trim().split("\n") {
12            let frag = frag
13                .strip_prefix("///")
14                .unwrap_or(&frag);
15            tokens.extend(::quote::quote! {
16                #[doc = #frag]
17            })
18        }
19        tokens
20    }};
21}
22pub(crate) use parse_doc;
23
24/// Skips the next token if it's a comma.
25pub fn skip_comma(input: ParseStream<'_>) -> Result<()> {
26    let lookahead = input.lookahead1();
27    if lookahead.peek(Token![,]) {
28        let _: Token![,] = input.parse()?;
29    }
30    Ok(())
31}
32
33/// Extension trait for [`Path`].
34pub trait PathExt {
35    /// Joins the two paths.
36    #[must_use]
37    fn join<S>(&self, seg: S) -> Self
38    where
39        S: Into<PathSegment>;
40
41    /// Returns the last [`Ident`] in the path.
42    fn ty_name(&self) -> &Ident;
43}
44
45impl PathExt for Path {
46    fn join<S>(&self, seg: S) -> Self
47    where
48        S: Into<PathSegment>,
49    {
50        let mut path = self.clone();
51        path.segments.push(seg.into());
52        path
53    }
54
55    /// Returns the last [`Ident`] in the path.
56    #[allow(clippy::arithmetic_side_effects)]
57    fn ty_name(&self) -> &Ident {
58        &self.segments[self.segments.len() - 1].ident
59    }
60}
61
62/// Extension trait for [`struct@Ident`].
63pub trait IdentExt {
64    /// Adds `prefix` to the identifier.
65    #[must_use]
66    fn with_prefix<I>(&self, prefix: I) -> Self
67    where
68        I: IdentFragment;
69
70    /// Adds `suffix` to the identifier.
71    #[must_use]
72    fn with_suffix<I>(&self, suffix: I) -> Self
73    where
74        I: IdentFragment;
75
76    /// Converts the identifier to snake_case.
77    fn to_snake_case(&self) -> Self;
78
79    /// Converts the identifier to SCREAMING_SNAKE_CASE.
80    fn to_screaming_snake_case(&self) -> Self;
81}
82
83impl IdentExt for Ident {
84    fn with_prefix<I>(&self, prefix: I) -> Self
85    where
86        I: IdentFragment,
87    {
88        format_ident!("{}{}", prefix, self, span = self.span())
89    }
90
91    fn with_suffix<I>(&self, suffix: I) -> Self
92    where
93        I: IdentFragment,
94    {
95        format_ident!("{}{}", self, suffix, span = self.span())
96    }
97
98    fn to_snake_case(&self) -> Ident {
99        let mut new = String::new();
100        let mut in_word = false;
101        for c in self.to_string().chars() {
102            if c.is_uppercase() {
103                if in_word {
104                    new.push('_');
105                }
106                in_word = true;
107                new.extend(c.to_lowercase());
108            } else {
109                new.push(c);
110            }
111        }
112        Ident::new(&new, self.span())
113    }
114
115    fn to_screaming_snake_case(&self) -> Ident {
116        let mut new = String::new();
117        let mut in_word = false;
118        for c in self.to_string().chars() {
119            if c.is_uppercase() {
120                if in_word {
121                    new.push('_');
122                }
123                in_word = true;
124            }
125            new.extend(c.to_uppercase());
126        }
127        Ident::new(&new, self.span())
128    }
129}
130
131/// A key-value pair.
132pub struct KeyValPair<K, V> {
133    pub key: K,
134    pub val: V,
135}
136
137impl<K, V> Parse for KeyValPair<K, V>
138where
139    K: Parse,
140    V: Parse,
141{
142    fn parse(input: ParseStream<'_>) -> Result<Self> {
143        let key = input.parse()?;
144        let _: Token![=] = input.parse()?;
145        let val = input.parse()?;
146        skip_comma(input)?;
147        Ok(Self { key, val })
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use proc_macro2::Span;
154    use syn::Ident;
155
156    use super::*;
157
158    #[test]
159    fn test_ident_with_prefix() {
160        let tests = [
161            ("SomeType", "OsSomeType", format_ident!("Os")),
162            ("some_func", "os_some_func", format_ident!("os_")),
163        ];
164        for (i, (input, want, prefix)) in tests.into_iter().enumerate() {
165            let orig = Ident::new(input, Span::call_site());
166            let got = orig.with_prefix(&prefix);
167            let want = Ident::new(want, Span::call_site());
168            assert_eq!(got, want, "#{i}");
169
170            let got = got.to_string();
171            let got = got.strip_prefix(&prefix.to_string()).unwrap();
172            assert_eq!(got, orig.to_string(), "#{i}");
173        }
174    }
175
176    #[test]
177    fn test_ident_with_suffix() {
178        let tests = [
179            ("os_some_func", "os_some_func_ext", format_ident!("_ext")),
180            ("os_some_func", "os_some_func_v2_1", format_ident!("_v2_1")),
181        ];
182        for (i, (input, want, suffix)) in tests.into_iter().enumerate() {
183            let got = Ident::new(input, Span::call_site()).with_suffix(suffix);
184            let want = Ident::new(want, Span::call_site());
185            assert_eq!(got, want, "#{i}");
186        }
187    }
188
189    #[test]
190    fn test_ident_to_snake_case() {
191        let tests = [
192            ("ABCD", "a_b_c_d"),
193            ("OsFoo", "os_foo"),
194            ("OsFooInit", "os_foo_init"),
195        ];
196        for (i, (input, want)) in tests.into_iter().enumerate() {
197            let got = Ident::new(input, Span::call_site()).to_snake_case();
198            let want = Ident::new(want, Span::call_site());
199            assert_eq!(got, want, "#{i}");
200        }
201    }
202}