ploidy_codegen_rust/
naming.rs

1use std::{borrow::Cow, fmt::Display, ops::Deref};
2
3use heck::{AsPascalCase, AsSnekCase};
4use itertools::Itertools;
5use ploidy_core::{
6    codegen::{
7        UniqueNames,
8        unique::{UniqueNamesScope, WordSegments},
9    },
10    ir::{
11        InlineIrTypePathSegment, InlineIrTypeView, IrStructFieldName, IrStructFieldNameHint,
12        IrUntaggedVariantNameHint, PrimitiveIrType, SchemaIrTypeView, View,
13    },
14};
15use proc_macro2::{Ident, Span, TokenStream};
16use quote::{IdentFragment, ToTokens, TokenStreamExt, format_ident};
17use ref_cast::{RefCastCustom, ref_cast_custom};
18
19// Keywords that can't be used as identifiers, even with `r#`.
20const KEYWORDS: &[&str] = &["crate", "self", "super", "Self"];
21
22#[derive(Clone, Debug)]
23pub enum CodegenTypeName<'a> {
24    Schema(&'a SchemaIrTypeView<'a>),
25    Inline(&'a InlineIrTypeView<'a>),
26}
27
28impl ToTokens for CodegenTypeName<'_> {
29    fn to_tokens(&self, tokens: &mut TokenStream) {
30        match self {
31            Self::Schema(view) => {
32                let ident = view.extensions().get::<CodegenIdent>().unwrap();
33                tokens.append_all(CodegenIdentUsage::Type(&ident).to_token_stream())
34            }
35            Self::Inline(view) => {
36                let ident = view
37                    .path()
38                    .segments
39                    .iter()
40                    .map(CodegenTypePathSegment)
41                    .map(|segment| format_ident!("{}", segment))
42                    .reduce(|a, b| format_ident!("{}{}", a, b))
43                    .unwrap();
44                tokens.append(ident);
45            }
46        }
47    }
48}
49
50/// A string that's statically guaranteed to be valid for any
51/// [`CodegenIdentUsage`].
52#[derive(Debug, Eq, Ord, PartialEq, PartialOrd)]
53pub struct CodegenIdent(String);
54
55impl CodegenIdent {
56    /// Creates an identifier for any usage.
57    pub fn new(s: &str) -> Self {
58        let s = clean(s);
59        if KEYWORDS.contains(&s.as_str()) {
60            Self(format!("_{s}"))
61        } else {
62            Self(s)
63        }
64    }
65}
66
67impl Deref for CodegenIdent {
68    type Target = CodegenIdentRef;
69
70    fn deref(&self) -> &Self::Target {
71        CodegenIdentRef::new(&self.0)
72    }
73}
74
75/// A string slice that's guaranteed to be valid for any [`CodegenIdentUsage`].
76#[derive(Debug, Eq, Ord, PartialEq, PartialOrd, RefCastCustom)]
77#[repr(transparent)]
78pub struct CodegenIdentRef(str);
79
80impl CodegenIdentRef {
81    #[ref_cast_custom]
82    fn new(s: &str) -> &Self;
83}
84
85#[derive(Clone, Copy, Debug)]
86pub enum CodegenIdentUsage<'a> {
87    Module(&'a CodegenIdentRef),
88    Type(&'a CodegenIdentRef),
89    Field(&'a CodegenIdentRef),
90    Variant(&'a CodegenIdentRef),
91    Param(&'a CodegenIdentRef),
92    Method(&'a CodegenIdentRef),
93}
94
95impl Display for CodegenIdentUsage<'_> {
96    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97        match self {
98            Self::Module(name) | Self::Field(name) | Self::Param(name) | Self::Method(name) => {
99                if name.0.starts_with(unicode_ident::is_xid_start) {
100                    write!(f, "{}", AsSnekCase(&name.0))
101                } else {
102                    // `name` doesn't start with `XID_Start` (e.g., "1099KStatus"),
103                    // so prefix it with `_`; everything after is known to be
104                    // `XID_Continue`.
105                    write!(f, "_{}", AsSnekCase(&name.0))
106                }
107            }
108            Self::Type(name) | Self::Variant(name) => {
109                if name.0.starts_with(unicode_ident::is_xid_start) {
110                    write!(f, "{}", AsPascalCase(&name.0))
111                } else {
112                    write!(f, "_{}", AsPascalCase(&name.0))
113                }
114            }
115        }
116    }
117}
118
119impl ToTokens for CodegenIdentUsage<'_> {
120    fn to_tokens(&self, tokens: &mut TokenStream) {
121        let s = self.to_string();
122        let ident = syn::parse_str(&s).unwrap_or_else(|_| Ident::new_raw(&s, Span::call_site()));
123        tokens.append(ident);
124    }
125}
126
127/// A scope for generating unique, valid Rust identifiers.
128#[derive(Debug)]
129pub struct CodegenIdentScope<'a>(UniqueNamesScope<'a>);
130
131impl<'a> CodegenIdentScope<'a> {
132    /// Creates a new identifier scope that's backed by the given arena.
133    pub fn new(arena: &'a UniqueNames) -> Self {
134        Self::with_reserved(arena, &[])
135    }
136
137    /// Creates a new identifier scope that's backed by the given arena,
138    /// with additional pre-reserved names.
139    pub fn with_reserved(arena: &'a UniqueNames, reserved: &[&str]) -> Self {
140        Self(arena.scope_with_reserved(itertools::chain!(
141            reserved.iter().copied(),
142            KEYWORDS.iter().copied(),
143            std::iter::once("")
144        )))
145    }
146
147    /// Cleans the input string and returns a name that's unique
148    /// within this scope, and valid for any [`CodegenIdentUsage`].
149    pub fn uniquify(&mut self, name: &str) -> CodegenIdent {
150        CodegenIdent(self.0.uniquify(&clean(name)).into_owned())
151    }
152}
153
154#[derive(Clone, Copy, Debug)]
155pub struct CodegenUntaggedVariantName(pub IrUntaggedVariantNameHint);
156
157impl ToTokens for CodegenUntaggedVariantName {
158    fn to_tokens(&self, tokens: &mut TokenStream) {
159        use IrUntaggedVariantNameHint::*;
160        let s = match self.0 {
161            Primitive(PrimitiveIrType::String) => "String".into(),
162            Primitive(PrimitiveIrType::I32) => "I32".into(),
163            Primitive(PrimitiveIrType::I64) => "I64".into(),
164            Primitive(PrimitiveIrType::F32) => "F32".into(),
165            Primitive(PrimitiveIrType::F64) => "F64".into(),
166            Primitive(PrimitiveIrType::Bool) => "Bool".into(),
167            Primitive(PrimitiveIrType::DateTime) => "DateTime".into(),
168            Primitive(PrimitiveIrType::Date) => "Date".into(),
169            Primitive(PrimitiveIrType::Url) => "Url".into(),
170            Primitive(PrimitiveIrType::Uuid) => "Uuid".into(),
171            Primitive(PrimitiveIrType::Bytes) => "Bytes".into(),
172            Array => "Array".into(),
173            Map => "Map".into(),
174            Index(index) => Cow::Owned(format!("V{index}")),
175        };
176        tokens.append(Ident::new(&s, Span::call_site()));
177    }
178}
179
180#[derive(Clone, Copy, Debug)]
181pub struct CodegenStructFieldName(pub IrStructFieldNameHint);
182
183impl ToTokens for CodegenStructFieldName {
184    fn to_tokens(&self, tokens: &mut TokenStream) {
185        match self.0 {
186            IrStructFieldNameHint::Index(index) => {
187                CodegenIdentUsage::Field(&CodegenIdent(format!("variant_{index}")))
188                    .to_tokens(tokens)
189            }
190        }
191    }
192}
193
194#[derive(Clone, Copy, Debug)]
195pub struct CodegenTypePathSegment<'a>(&'a InlineIrTypePathSegment<'a>);
196
197impl IdentFragment for CodegenTypePathSegment<'_> {
198    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
199        use InlineIrTypePathSegment::*;
200        match self.0 {
201            // Segments are part of an inline type path that always has a root prefix,
202            // so we don't need to check for `XID_Start`.
203            Operation(name) => write!(f, "{}", AsPascalCase(clean(name))),
204            Parameter(name) => write!(f, "{}", AsPascalCase(clean(name))),
205            Request => f.write_str("Request"),
206            Response => f.write_str("Response"),
207            Field(IrStructFieldName::Name(name)) => {
208                write!(f, "{}", AsPascalCase(clean(name)))
209            }
210            Field(IrStructFieldName::Hint(IrStructFieldNameHint::Index(index))) => {
211                write!(f, "Variant{index}")
212            }
213            MapValue => f.write_str("Value"),
214            ArrayItem => f.write_str("Item"),
215            Variant(index) => write!(f, "V{index}"),
216        }
217    }
218}
219
220/// Makes a string suitable for inclusion within a Rust identifier.
221///
222/// Cleaning segments the string on word boundaries, collapses all
223/// non-`XID_Continue` characters into new boundaries, and
224/// reassembles the string. This makes the string resilient to
225/// case transformations, which also collapse boundaries, and so
226/// can produce duplicates in some cases.
227///
228/// Note that the result may not itself be a valid Rust identifier,
229/// because Rust identifiers must start with `XID_Start`.
230/// This is checked and handled in [`CodegenIdentUsage`].
231fn clean(s: &str) -> String {
232    WordSegments::new(s)
233        .flat_map(|s| s.split(|c| !unicode_ident::is_xid_continue(c)))
234        .join("_")
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    use pretty_assertions::assert_eq;
242    use syn::parse_quote;
243
244    // MARK: Usages
245
246    #[test]
247    fn test_codegen_ident_type() {
248        let ident = CodegenIdent::new("pet_store");
249        let usage = CodegenIdentUsage::Type(&ident);
250        let actual: syn::Ident = parse_quote!(#usage);
251        let expected: syn::Ident = parse_quote!(PetStore);
252        assert_eq!(actual, expected);
253    }
254
255    #[test]
256    fn test_codegen_ident_field() {
257        let ident = CodegenIdent::new("petStore");
258        let usage = CodegenIdentUsage::Field(&ident);
259        let actual: syn::Ident = parse_quote!(#usage);
260        let expected: syn::Ident = parse_quote!(pet_store);
261        assert_eq!(actual, expected);
262    }
263
264    #[test]
265    fn test_codegen_ident_module() {
266        let ident = CodegenIdent::new("MyModule");
267        let usage = CodegenIdentUsage::Module(&ident);
268        let actual: syn::Ident = parse_quote!(#usage);
269        let expected: syn::Ident = parse_quote!(my_module);
270        assert_eq!(actual, expected);
271    }
272
273    #[test]
274    fn test_codegen_ident_variant() {
275        let ident = CodegenIdent::new("http_error");
276        let usage = CodegenIdentUsage::Variant(&ident);
277        let actual: syn::Ident = parse_quote!(#usage);
278        let expected: syn::Ident = parse_quote!(HttpError);
279        assert_eq!(actual, expected);
280    }
281
282    #[test]
283    fn test_codegen_ident_param() {
284        let ident = CodegenIdent::new("userId");
285        let usage = CodegenIdentUsage::Param(&ident);
286        let actual: syn::Ident = parse_quote!(#usage);
287        let expected: syn::Ident = parse_quote!(user_id);
288        assert_eq!(actual, expected);
289    }
290
291    #[test]
292    fn test_codegen_ident_method() {
293        let ident = CodegenIdent::new("getUserById");
294        let usage = CodegenIdentUsage::Method(&ident);
295        let actual: syn::Ident = parse_quote!(#usage);
296        let expected: syn::Ident = parse_quote!(get_user_by_id);
297        assert_eq!(actual, expected);
298    }
299
300    // MARK: Special characters
301
302    #[test]
303    fn test_codegen_ident_handles_rust_keywords() {
304        let ident = CodegenIdent::new("type");
305        let usage = CodegenIdentUsage::Field(&ident);
306        let actual: syn::Ident = parse_quote!(#usage);
307        let expected: syn::Ident = parse_quote!(r#type);
308        assert_eq!(actual, expected);
309    }
310
311    #[test]
312    fn test_codegen_ident_handles_invalid_start_chars() {
313        let ident = CodegenIdent::new("123foo");
314        let usage = CodegenIdentUsage::Field(&ident);
315        let actual: syn::Ident = parse_quote!(#usage);
316        let expected: syn::Ident = parse_quote!(_123_foo);
317        assert_eq!(actual, expected);
318    }
319
320    #[test]
321    fn test_codegen_ident_handles_special_chars() {
322        let ident = CodegenIdent::new("foo-bar-baz");
323        let usage = CodegenIdentUsage::Field(&ident);
324        let actual: syn::Ident = parse_quote!(#usage);
325        let expected: syn::Ident = parse_quote!(foo_bar_baz);
326        assert_eq!(actual, expected);
327    }
328
329    #[test]
330    fn test_codegen_ident_handles_number_prefix() {
331        let ident = CodegenIdent::new("1099KStatus");
332
333        let usage = CodegenIdentUsage::Field(&ident);
334        let actual: syn::Ident = parse_quote!(#usage);
335        let expected: syn::Ident = parse_quote!(_1099_k_status);
336        assert_eq!(actual, expected);
337
338        let usage = CodegenIdentUsage::Type(&ident);
339        let actual: syn::Ident = parse_quote!(#usage);
340        let expected: syn::Ident = parse_quote!(_1099KStatus);
341        assert_eq!(actual, expected);
342    }
343
344    // MARK: Untagged variant names
345
346    #[test]
347    fn test_untagged_variant_name_string() {
348        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
349            PrimitiveIrType::String,
350        ));
351        let actual: syn::Ident = parse_quote!(#variant_name);
352        let expected: syn::Ident = parse_quote!(String);
353        assert_eq!(actual, expected);
354    }
355
356    #[test]
357    fn test_untagged_variant_name_i32() {
358        let variant_name =
359            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::I32));
360        let actual: syn::Ident = parse_quote!(#variant_name);
361        let expected: syn::Ident = parse_quote!(I32);
362        assert_eq!(actual, expected);
363    }
364
365    #[test]
366    fn test_untagged_variant_name_i64() {
367        let variant_name =
368            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::I64));
369        let actual: syn::Ident = parse_quote!(#variant_name);
370        let expected: syn::Ident = parse_quote!(I64);
371        assert_eq!(actual, expected);
372    }
373
374    #[test]
375    fn test_untagged_variant_name_f32() {
376        let variant_name =
377            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::F32));
378        let actual: syn::Ident = parse_quote!(#variant_name);
379        let expected: syn::Ident = parse_quote!(F32);
380        assert_eq!(actual, expected);
381    }
382
383    #[test]
384    fn test_untagged_variant_name_f64() {
385        let variant_name =
386            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::F64));
387        let actual: syn::Ident = parse_quote!(#variant_name);
388        let expected: syn::Ident = parse_quote!(F64);
389        assert_eq!(actual, expected);
390    }
391
392    #[test]
393    fn test_untagged_variant_name_bool() {
394        let variant_name =
395            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Bool));
396        let actual: syn::Ident = parse_quote!(#variant_name);
397        let expected: syn::Ident = parse_quote!(Bool);
398        assert_eq!(actual, expected);
399    }
400
401    #[test]
402    fn test_untagged_variant_name_datetime() {
403        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
404            PrimitiveIrType::DateTime,
405        ));
406        let actual: syn::Ident = parse_quote!(#variant_name);
407        let expected: syn::Ident = parse_quote!(DateTime);
408        assert_eq!(actual, expected);
409    }
410
411    #[test]
412    fn test_untagged_variant_name_date() {
413        let variant_name =
414            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Date));
415        let actual: syn::Ident = parse_quote!(#variant_name);
416        let expected: syn::Ident = parse_quote!(Date);
417        assert_eq!(actual, expected);
418    }
419
420    #[test]
421    fn test_untagged_variant_name_url() {
422        let variant_name =
423            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Url));
424        let actual: syn::Ident = parse_quote!(#variant_name);
425        let expected: syn::Ident = parse_quote!(Url);
426        assert_eq!(actual, expected);
427    }
428
429    #[test]
430    fn test_untagged_variant_name_uuid() {
431        let variant_name =
432            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Uuid));
433        let actual: syn::Ident = parse_quote!(#variant_name);
434        let expected: syn::Ident = parse_quote!(Uuid);
435        assert_eq!(actual, expected);
436    }
437
438    #[test]
439    fn test_untagged_variant_name_bytes() {
440        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
441            PrimitiveIrType::Bytes,
442        ));
443        let actual: syn::Ident = parse_quote!(#variant_name);
444        let expected: syn::Ident = parse_quote!(Bytes);
445        assert_eq!(actual, expected);
446    }
447
448    #[test]
449    fn test_untagged_variant_name_index() {
450        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Index(0));
451        let actual: syn::Ident = parse_quote!(#variant_name);
452        let expected: syn::Ident = parse_quote!(V0);
453        assert_eq!(actual, expected);
454
455        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Index(42));
456        let actual: syn::Ident = parse_quote!(#variant_name);
457        let expected: syn::Ident = parse_quote!(V42);
458        assert_eq!(actual, expected);
459    }
460
461    #[test]
462    fn test_untagged_variant_name_array() {
463        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Array);
464        let actual: syn::Ident = parse_quote!(#variant_name);
465        let expected: syn::Ident = parse_quote!(Array);
466        assert_eq!(actual, expected);
467    }
468
469    #[test]
470    fn test_untagged_variant_name_map() {
471        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Map);
472        let actual: syn::Ident = parse_quote!(#variant_name);
473        let expected: syn::Ident = parse_quote!(Map);
474        assert_eq!(actual, expected);
475    }
476
477    // MARK: Struct field names
478
479    #[test]
480    fn test_struct_field_name_index() {
481        let field_name = CodegenStructFieldName(IrStructFieldNameHint::Index(0));
482        let actual: syn::Ident = parse_quote!(#field_name);
483        let expected: syn::Ident = parse_quote!(variant_0);
484        assert_eq!(actual, expected);
485
486        let field_name = CodegenStructFieldName(IrStructFieldNameHint::Index(5));
487        let actual: syn::Ident = parse_quote!(#field_name);
488        let expected: syn::Ident = parse_quote!(variant_5);
489        assert_eq!(actual, expected);
490    }
491
492    // MARK: `clean()`
493
494    #[test]
495    fn test_clean() {
496        assert_eq!(clean("foo-bar"), "foo_bar");
497        assert_eq!(clean("foo.bar"), "foo_bar");
498        assert_eq!(clean("foo bar"), "foo_bar");
499        assert_eq!(clean("foo@bar"), "foo_bar");
500        assert_eq!(clean("foo#bar"), "foo_bar");
501        assert_eq!(clean("foo!bar"), "foo_bar");
502
503        assert_eq!(clean("foo_bar"), "foo_bar");
504        assert_eq!(clean("FooBar"), "Foo_Bar");
505        assert_eq!(clean("foo123"), "foo123");
506        assert_eq!(clean("_foo"), "foo");
507
508        assert_eq!(clean("_foo"), "foo");
509        assert_eq!(clean("__foo"), "foo");
510
511        // Digits are in `XID_Continue`, so they should be preserved.
512        assert_eq!(clean("123foo"), "123_foo");
513        assert_eq!(clean("9bar"), "9_bar");
514
515        // Non-ASCII characters that are valid in identifiers should be preserved;
516        // characters that aren't should be replaced.
517        assert_eq!(clean("café"), "café");
518        assert_eq!(clean("foo™bar"), "foo_bar");
519
520        // Invalid characters should be collapsed.
521        assert_eq!(clean("foo---bar"), "foo_bar");
522        assert_eq!(clean("foo...bar"), "foo_bar");
523    }
524
525    // MARK: Scopes
526
527    #[test]
528    fn test_codegen_ident_scope_handles_empty() {
529        let unique = UniqueNames::new();
530        let mut scope = CodegenIdentScope::new(&unique);
531        let ident = scope.uniquify("");
532
533        let usage = CodegenIdentUsage::Field(&ident);
534        let actual: syn::Ident = parse_quote!(#usage);
535        let expected: syn::Ident = parse_quote!(_2);
536        assert_eq!(actual, expected);
537
538        let usage = CodegenIdentUsage::Type(&ident);
539        let actual: syn::Ident = parse_quote!(#usage);
540        let expected: syn::Ident = parse_quote!(_2);
541        assert_eq!(actual, expected);
542    }
543}