ploidy_codegen_rust/
naming.rs

1use std::{borrow::Cow, cmp::Ordering, fmt::Display, ops::Deref};
2
3use heck::{AsPascalCase, AsSnekCase};
4use itertools::Itertools;
5use ploidy_core::{
6    codegen::{
7        UniqueNames,
8        unique::{UniqueNamesScope, WordSegments},
9    },
10    ir::{
11        InlineIrTypePathSegment, InlineIrTypeView, IrStructFieldName, IrStructFieldNameHint,
12        IrUntaggedVariantNameHint, PrimitiveIrType, SchemaIrTypeView, View,
13    },
14};
15use proc_macro2::{Ident, Span, TokenStream};
16use quote::{IdentFragment, ToTokens, TokenStreamExt, format_ident};
17use ref_cast::{RefCastCustom, ref_cast_custom};
18
19// Keywords that can't be used as identifiers, even with `r#`.
20const KEYWORDS: &[&str] = &["crate", "self", "super", "Self"];
21
22#[derive(Clone, Copy, Debug)]
23pub enum CodegenTypeName<'a> {
24    Schema(&'a SchemaIrTypeView<'a>),
25    Inline(&'a InlineIrTypeView<'a>),
26}
27
28impl<'a> CodegenTypeName<'a> {
29    #[inline]
30    pub fn into_sort_key(self) -> CodegenTypeNameSortKey<'a> {
31        CodegenTypeNameSortKey(self)
32    }
33}
34
35impl ToTokens for CodegenTypeName<'_> {
36    fn to_tokens(&self, tokens: &mut TokenStream) {
37        match self {
38            Self::Schema(view) => {
39                let ident = view.extensions().get::<CodegenIdent>().unwrap();
40                tokens.append_all(CodegenIdentUsage::Type(&ident).to_token_stream())
41            }
42            Self::Inline(view) => {
43                let ident = view
44                    .path()
45                    .segments
46                    .iter()
47                    .map(CodegenTypePathSegment)
48                    .map(|segment| format_ident!("{}", segment))
49                    .reduce(|a, b| format_ident!("{}{}", a, b))
50                    .unwrap();
51                tokens.append(ident);
52            }
53        }
54    }
55}
56
57/// A comparator that sorts type names lexicographically.
58#[derive(Clone, Copy, Debug)]
59pub struct CodegenTypeNameSortKey<'a>(CodegenTypeName<'a>);
60
61impl<'a> CodegenTypeNameSortKey<'a> {
62    #[inline]
63    pub fn into_name(self) -> CodegenTypeName<'a> {
64        self.0
65    }
66}
67
68impl Eq for CodegenTypeNameSortKey<'_> {}
69
70impl Ord for CodegenTypeNameSortKey<'_> {
71    fn cmp(&self, other: &Self) -> Ordering {
72        match (&self.0, &other.0) {
73            (CodegenTypeName::Schema(a), CodegenTypeName::Schema(b)) => a.name().cmp(b.name()),
74            (CodegenTypeName::Inline(a), CodegenTypeName::Inline(b)) => a.path().cmp(b.path()),
75            (CodegenTypeName::Schema(_), CodegenTypeName::Inline(_)) => Ordering::Less,
76            (CodegenTypeName::Inline(_), CodegenTypeName::Schema(_)) => Ordering::Greater,
77        }
78    }
79}
80
81impl PartialEq for CodegenTypeNameSortKey<'_> {
82    fn eq(&self, other: &Self) -> bool {
83        self.cmp(other).is_eq()
84    }
85}
86
87impl PartialOrd for CodegenTypeNameSortKey<'_> {
88    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
89        Some(self.cmp(other))
90    }
91}
92
93/// A string that's statically guaranteed to be valid for any
94/// [`CodegenIdentUsage`].
95#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
96pub struct CodegenIdent(String);
97
98impl CodegenIdent {
99    /// Creates an identifier for any usage.
100    pub fn new(s: &str) -> Self {
101        let s = clean(s);
102        if KEYWORDS.contains(&s.as_str()) {
103            Self(format!("_{s}"))
104        } else {
105            Self(s)
106        }
107    }
108}
109
110impl Deref for CodegenIdent {
111    type Target = CodegenIdentRef;
112
113    fn deref(&self) -> &Self::Target {
114        CodegenIdentRef::new(&self.0)
115    }
116}
117
118/// A string slice that's guaranteed to be valid for any [`CodegenIdentUsage`].
119#[derive(Debug, Eq, Ord, PartialEq, PartialOrd, RefCastCustom)]
120#[repr(transparent)]
121pub struct CodegenIdentRef(str);
122
123impl CodegenIdentRef {
124    #[ref_cast_custom]
125    fn new(s: &str) -> &Self;
126}
127
128#[derive(Clone, Copy, Debug)]
129pub enum CodegenIdentUsage<'a> {
130    Module(&'a CodegenIdentRef),
131    Type(&'a CodegenIdentRef),
132    Field(&'a CodegenIdentRef),
133    Variant(&'a CodegenIdentRef),
134    Param(&'a CodegenIdentRef),
135    Method(&'a CodegenIdentRef),
136}
137
138impl Display for CodegenIdentUsage<'_> {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        match self {
141            Self::Module(name) | Self::Field(name) | Self::Param(name) | Self::Method(name) => {
142                if name.0.starts_with(unicode_ident::is_xid_start) {
143                    write!(f, "{}", AsSnekCase(&name.0))
144                } else {
145                    // `name` doesn't start with `XID_Start` (e.g., "1099KStatus"),
146                    // so prefix it with `_`; everything after is known to be
147                    // `XID_Continue`.
148                    write!(f, "_{}", AsSnekCase(&name.0))
149                }
150            }
151            Self::Type(name) | Self::Variant(name) => {
152                if name.0.starts_with(unicode_ident::is_xid_start) {
153                    write!(f, "{}", AsPascalCase(&name.0))
154                } else {
155                    write!(f, "_{}", AsPascalCase(&name.0))
156                }
157            }
158        }
159    }
160}
161
162impl ToTokens for CodegenIdentUsage<'_> {
163    fn to_tokens(&self, tokens: &mut TokenStream) {
164        let s = self.to_string();
165        let ident = syn::parse_str(&s).unwrap_or_else(|_| Ident::new_raw(&s, Span::call_site()));
166        tokens.append(ident);
167    }
168}
169
170/// A scope for generating unique, valid Rust identifiers.
171#[derive(Debug)]
172pub struct CodegenIdentScope<'a>(UniqueNamesScope<'a>);
173
174impl<'a> CodegenIdentScope<'a> {
175    /// Creates a new identifier scope that's backed by the given arena.
176    pub fn new(arena: &'a UniqueNames) -> Self {
177        Self::with_reserved(arena, &[])
178    }
179
180    /// Creates a new identifier scope that's backed by the given arena,
181    /// with additional pre-reserved names.
182    pub fn with_reserved(arena: &'a UniqueNames, reserved: &[&str]) -> Self {
183        Self(arena.scope_with_reserved(itertools::chain!(
184            reserved.iter().copied(),
185            KEYWORDS.iter().copied(),
186            std::iter::once("")
187        )))
188    }
189
190    /// Cleans the input string and returns a name that's unique
191    /// within this scope, and valid for any [`CodegenIdentUsage`].
192    pub fn uniquify(&mut self, name: &str) -> CodegenIdent {
193        CodegenIdent(self.0.uniquify(&clean(name)).into_owned())
194    }
195}
196
197#[derive(Clone, Copy, Debug)]
198pub struct CodegenUntaggedVariantName(pub IrUntaggedVariantNameHint);
199
200impl ToTokens for CodegenUntaggedVariantName {
201    fn to_tokens(&self, tokens: &mut TokenStream) {
202        use IrUntaggedVariantNameHint::*;
203        let s = match self.0 {
204            Primitive(PrimitiveIrType::String) => "String".into(),
205            Primitive(PrimitiveIrType::I8) => "I8".into(),
206            Primitive(PrimitiveIrType::U8) => "U8".into(),
207            Primitive(PrimitiveIrType::I16) => "I16".into(),
208            Primitive(PrimitiveIrType::U16) => "U16".into(),
209            Primitive(PrimitiveIrType::I32) => "I32".into(),
210            Primitive(PrimitiveIrType::U32) => "U32".into(),
211            Primitive(PrimitiveIrType::I64) => "I64".into(),
212            Primitive(PrimitiveIrType::U64) => "U64".into(),
213            Primitive(PrimitiveIrType::F32) => "F32".into(),
214            Primitive(PrimitiveIrType::F64) => "F64".into(),
215            Primitive(PrimitiveIrType::Bool) => "Bool".into(),
216            Primitive(PrimitiveIrType::DateTime) => "DateTime".into(),
217            Primitive(PrimitiveIrType::UnixTime) => "UnixTime".into(),
218            Primitive(PrimitiveIrType::Date) => "Date".into(),
219            Primitive(PrimitiveIrType::Url) => "Url".into(),
220            Primitive(PrimitiveIrType::Uuid) => "Uuid".into(),
221            Primitive(PrimitiveIrType::Bytes) => "Bytes".into(),
222            Primitive(PrimitiveIrType::Binary) => "Binary".into(),
223            Array => "Array".into(),
224            Map => "Map".into(),
225            Index(index) => Cow::Owned(format!("V{index}")),
226        };
227        tokens.append(Ident::new(&s, Span::call_site()));
228    }
229}
230
231#[derive(Clone, Copy, Debug)]
232pub struct CodegenStructFieldName(pub IrStructFieldNameHint);
233
234impl ToTokens for CodegenStructFieldName {
235    fn to_tokens(&self, tokens: &mut TokenStream) {
236        match self.0 {
237            IrStructFieldNameHint::Index(index) => {
238                CodegenIdentUsage::Field(&CodegenIdent(format!("variant_{index}")))
239                    .to_tokens(tokens)
240            }
241        }
242    }
243}
244
245#[derive(Clone, Copy, Debug)]
246pub struct CodegenTypePathSegment<'a>(&'a InlineIrTypePathSegment<'a>);
247
248impl IdentFragment for CodegenTypePathSegment<'_> {
249    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
250        use InlineIrTypePathSegment::*;
251        match self.0 {
252            // Segments are part of an inline type path that always has a root prefix,
253            // so we don't need to check for `XID_Start`.
254            Operation(name) => write!(f, "{}", AsPascalCase(clean(name))),
255            Parameter(name) => write!(f, "{}", AsPascalCase(clean(name))),
256            Request => f.write_str("Request"),
257            Response => f.write_str("Response"),
258            Field(IrStructFieldName::Name(name)) => {
259                write!(f, "{}", AsPascalCase(clean(name)))
260            }
261            Field(IrStructFieldName::Hint(IrStructFieldNameHint::Index(index))) => {
262                write!(f, "Variant{index}")
263            }
264            MapValue => f.write_str("Value"),
265            ArrayItem => f.write_str("Item"),
266            Variant(index) => write!(f, "V{index}"),
267        }
268    }
269}
270
271/// Makes a string suitable for inclusion within a Rust identifier.
272///
273/// Cleaning segments the string on word boundaries, collapses all
274/// non-`XID_Continue` characters into new boundaries, and
275/// reassembles the string. This makes the string resilient to
276/// case transformations, which also collapse boundaries, and so
277/// can produce duplicates in some cases.
278///
279/// Note that the result may not itself be a valid Rust identifier,
280/// because Rust identifiers must start with `XID_Start`.
281/// This is checked and handled in [`CodegenIdentUsage`].
282fn clean(s: &str) -> String {
283    WordSegments::new(s)
284        .flat_map(|s| s.split(|c| !unicode_ident::is_xid_continue(c)))
285        .join("_")
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    use pretty_assertions::assert_eq;
293    use syn::parse_quote;
294
295    // MARK: Usages
296
297    #[test]
298    fn test_codegen_ident_type() {
299        let ident = CodegenIdent::new("pet_store");
300        let usage = CodegenIdentUsage::Type(&ident);
301        let actual: syn::Ident = parse_quote!(#usage);
302        let expected: syn::Ident = parse_quote!(PetStore);
303        assert_eq!(actual, expected);
304    }
305
306    #[test]
307    fn test_codegen_ident_field() {
308        let ident = CodegenIdent::new("petStore");
309        let usage = CodegenIdentUsage::Field(&ident);
310        let actual: syn::Ident = parse_quote!(#usage);
311        let expected: syn::Ident = parse_quote!(pet_store);
312        assert_eq!(actual, expected);
313    }
314
315    #[test]
316    fn test_codegen_ident_module() {
317        let ident = CodegenIdent::new("MyModule");
318        let usage = CodegenIdentUsage::Module(&ident);
319        let actual: syn::Ident = parse_quote!(#usage);
320        let expected: syn::Ident = parse_quote!(my_module);
321        assert_eq!(actual, expected);
322    }
323
324    #[test]
325    fn test_codegen_ident_variant() {
326        let ident = CodegenIdent::new("http_error");
327        let usage = CodegenIdentUsage::Variant(&ident);
328        let actual: syn::Ident = parse_quote!(#usage);
329        let expected: syn::Ident = parse_quote!(HttpError);
330        assert_eq!(actual, expected);
331    }
332
333    #[test]
334    fn test_codegen_ident_param() {
335        let ident = CodegenIdent::new("userId");
336        let usage = CodegenIdentUsage::Param(&ident);
337        let actual: syn::Ident = parse_quote!(#usage);
338        let expected: syn::Ident = parse_quote!(user_id);
339        assert_eq!(actual, expected);
340    }
341
342    #[test]
343    fn test_codegen_ident_method() {
344        let ident = CodegenIdent::new("getUserById");
345        let usage = CodegenIdentUsage::Method(&ident);
346        let actual: syn::Ident = parse_quote!(#usage);
347        let expected: syn::Ident = parse_quote!(get_user_by_id);
348        assert_eq!(actual, expected);
349    }
350
351    // MARK: Special characters
352
353    #[test]
354    fn test_codegen_ident_handles_rust_keywords() {
355        let ident = CodegenIdent::new("type");
356        let usage = CodegenIdentUsage::Field(&ident);
357        let actual: syn::Ident = parse_quote!(#usage);
358        let expected: syn::Ident = parse_quote!(r#type);
359        assert_eq!(actual, expected);
360    }
361
362    #[test]
363    fn test_codegen_ident_handles_invalid_start_chars() {
364        let ident = CodegenIdent::new("123foo");
365        let usage = CodegenIdentUsage::Field(&ident);
366        let actual: syn::Ident = parse_quote!(#usage);
367        let expected: syn::Ident = parse_quote!(_123_foo);
368        assert_eq!(actual, expected);
369    }
370
371    #[test]
372    fn test_codegen_ident_handles_special_chars() {
373        let ident = CodegenIdent::new("foo-bar-baz");
374        let usage = CodegenIdentUsage::Field(&ident);
375        let actual: syn::Ident = parse_quote!(#usage);
376        let expected: syn::Ident = parse_quote!(foo_bar_baz);
377        assert_eq!(actual, expected);
378    }
379
380    #[test]
381    fn test_codegen_ident_handles_number_prefix() {
382        let ident = CodegenIdent::new("1099KStatus");
383
384        let usage = CodegenIdentUsage::Field(&ident);
385        let actual: syn::Ident = parse_quote!(#usage);
386        let expected: syn::Ident = parse_quote!(_1099_k_status);
387        assert_eq!(actual, expected);
388
389        let usage = CodegenIdentUsage::Type(&ident);
390        let actual: syn::Ident = parse_quote!(#usage);
391        let expected: syn::Ident = parse_quote!(_1099KStatus);
392        assert_eq!(actual, expected);
393    }
394
395    // MARK: Untagged variant names
396
397    #[test]
398    fn test_untagged_variant_name_string() {
399        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
400            PrimitiveIrType::String,
401        ));
402        let actual: syn::Ident = parse_quote!(#variant_name);
403        let expected: syn::Ident = parse_quote!(String);
404        assert_eq!(actual, expected);
405    }
406
407    #[test]
408    fn test_untagged_variant_name_i32() {
409        let variant_name =
410            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::I32));
411        let actual: syn::Ident = parse_quote!(#variant_name);
412        let expected: syn::Ident = parse_quote!(I32);
413        assert_eq!(actual, expected);
414    }
415
416    #[test]
417    fn test_untagged_variant_name_i64() {
418        let variant_name =
419            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::I64));
420        let actual: syn::Ident = parse_quote!(#variant_name);
421        let expected: syn::Ident = parse_quote!(I64);
422        assert_eq!(actual, expected);
423    }
424
425    #[test]
426    fn test_untagged_variant_name_f32() {
427        let variant_name =
428            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::F32));
429        let actual: syn::Ident = parse_quote!(#variant_name);
430        let expected: syn::Ident = parse_quote!(F32);
431        assert_eq!(actual, expected);
432    }
433
434    #[test]
435    fn test_untagged_variant_name_f64() {
436        let variant_name =
437            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::F64));
438        let actual: syn::Ident = parse_quote!(#variant_name);
439        let expected: syn::Ident = parse_quote!(F64);
440        assert_eq!(actual, expected);
441    }
442
443    #[test]
444    fn test_untagged_variant_name_bool() {
445        let variant_name =
446            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Bool));
447        let actual: syn::Ident = parse_quote!(#variant_name);
448        let expected: syn::Ident = parse_quote!(Bool);
449        assert_eq!(actual, expected);
450    }
451
452    #[test]
453    fn test_untagged_variant_name_datetime() {
454        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
455            PrimitiveIrType::DateTime,
456        ));
457        let actual: syn::Ident = parse_quote!(#variant_name);
458        let expected: syn::Ident = parse_quote!(DateTime);
459        assert_eq!(actual, expected);
460    }
461
462    #[test]
463    fn test_untagged_variant_name_date() {
464        let variant_name =
465            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Date));
466        let actual: syn::Ident = parse_quote!(#variant_name);
467        let expected: syn::Ident = parse_quote!(Date);
468        assert_eq!(actual, expected);
469    }
470
471    #[test]
472    fn test_untagged_variant_name_url() {
473        let variant_name =
474            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Url));
475        let actual: syn::Ident = parse_quote!(#variant_name);
476        let expected: syn::Ident = parse_quote!(Url);
477        assert_eq!(actual, expected);
478    }
479
480    #[test]
481    fn test_untagged_variant_name_uuid() {
482        let variant_name =
483            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Uuid));
484        let actual: syn::Ident = parse_quote!(#variant_name);
485        let expected: syn::Ident = parse_quote!(Uuid);
486        assert_eq!(actual, expected);
487    }
488
489    #[test]
490    fn test_untagged_variant_name_bytes() {
491        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
492            PrimitiveIrType::Bytes,
493        ));
494        let actual: syn::Ident = parse_quote!(#variant_name);
495        let expected: syn::Ident = parse_quote!(Bytes);
496        assert_eq!(actual, expected);
497    }
498
499    #[test]
500    fn test_untagged_variant_name_index() {
501        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Index(0));
502        let actual: syn::Ident = parse_quote!(#variant_name);
503        let expected: syn::Ident = parse_quote!(V0);
504        assert_eq!(actual, expected);
505
506        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Index(42));
507        let actual: syn::Ident = parse_quote!(#variant_name);
508        let expected: syn::Ident = parse_quote!(V42);
509        assert_eq!(actual, expected);
510    }
511
512    #[test]
513    fn test_untagged_variant_name_array() {
514        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Array);
515        let actual: syn::Ident = parse_quote!(#variant_name);
516        let expected: syn::Ident = parse_quote!(Array);
517        assert_eq!(actual, expected);
518    }
519
520    #[test]
521    fn test_untagged_variant_name_map() {
522        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Map);
523        let actual: syn::Ident = parse_quote!(#variant_name);
524        let expected: syn::Ident = parse_quote!(Map);
525        assert_eq!(actual, expected);
526    }
527
528    // MARK: Struct field names
529
530    #[test]
531    fn test_struct_field_name_index() {
532        let field_name = CodegenStructFieldName(IrStructFieldNameHint::Index(0));
533        let actual: syn::Ident = parse_quote!(#field_name);
534        let expected: syn::Ident = parse_quote!(variant_0);
535        assert_eq!(actual, expected);
536
537        let field_name = CodegenStructFieldName(IrStructFieldNameHint::Index(5));
538        let actual: syn::Ident = parse_quote!(#field_name);
539        let expected: syn::Ident = parse_quote!(variant_5);
540        assert_eq!(actual, expected);
541    }
542
543    // MARK: `clean()`
544
545    #[test]
546    fn test_clean() {
547        assert_eq!(clean("foo-bar"), "foo_bar");
548        assert_eq!(clean("foo.bar"), "foo_bar");
549        assert_eq!(clean("foo bar"), "foo_bar");
550        assert_eq!(clean("foo@bar"), "foo_bar");
551        assert_eq!(clean("foo#bar"), "foo_bar");
552        assert_eq!(clean("foo!bar"), "foo_bar");
553
554        assert_eq!(clean("foo_bar"), "foo_bar");
555        assert_eq!(clean("FooBar"), "Foo_Bar");
556        assert_eq!(clean("foo123"), "foo123");
557        assert_eq!(clean("_foo"), "foo");
558
559        assert_eq!(clean("_foo"), "foo");
560        assert_eq!(clean("__foo"), "foo");
561
562        // Digits are in `XID_Continue`, so they should be preserved.
563        assert_eq!(clean("123foo"), "123_foo");
564        assert_eq!(clean("9bar"), "9_bar");
565
566        // Non-ASCII characters that are valid in identifiers should be preserved;
567        // characters that aren't should be replaced.
568        assert_eq!(clean("café"), "café");
569        assert_eq!(clean("foo™bar"), "foo_bar");
570
571        // Invalid characters should be collapsed.
572        assert_eq!(clean("foo---bar"), "foo_bar");
573        assert_eq!(clean("foo...bar"), "foo_bar");
574    }
575
576    // MARK: Scopes
577
578    #[test]
579    fn test_codegen_ident_scope_handles_empty() {
580        let unique = UniqueNames::new();
581        let mut scope = CodegenIdentScope::new(&unique);
582        let ident = scope.uniquify("");
583
584        let usage = CodegenIdentUsage::Field(&ident);
585        let actual: syn::Ident = parse_quote!(#usage);
586        let expected: syn::Ident = parse_quote!(_2);
587        assert_eq!(actual, expected);
588
589        let usage = CodegenIdentUsage::Type(&ident);
590        let actual: syn::Ident = parse_quote!(#usage);
591        let expected: syn::Ident = parse_quote!(_2);
592        assert_eq!(actual, expected);
593    }
594}