ploidy_codegen_rust/
naming.rs

1use std::{borrow::Cow, cmp::Ordering, fmt::Display, ops::Deref};
2
3use heck::{AsPascalCase, AsSnekCase};
4use itertools::Itertools;
5use ploidy_core::{
6    codegen::{
7        UniqueNames,
8        unique::{UniqueNamesScope, WordSegments},
9    },
10    ir::{
11        InlineIrTypePathSegment, InlineIrTypeView, IrStructFieldName, IrStructFieldNameHint,
12        IrUntaggedVariantNameHint, PrimitiveIrType, SchemaIrTypeView, View,
13    },
14};
15use proc_macro2::{Ident, Span, TokenStream};
16use quote::{IdentFragment, ToTokens, TokenStreamExt, format_ident};
17use ref_cast::{RefCastCustom, ref_cast_custom};
18
19// Keywords that can't be used as identifiers, even with `r#`.
20const KEYWORDS: &[&str] = &["crate", "self", "super", "Self"];
21
22#[derive(Clone, Copy, Debug)]
23pub enum CodegenTypeName<'a> {
24    Schema(&'a SchemaIrTypeView<'a>),
25    Inline(&'a InlineIrTypeView<'a>),
26}
27
28impl<'a> CodegenTypeName<'a> {
29    #[inline]
30    pub fn into_sort_key(self) -> CodegenTypeNameSortKey<'a> {
31        CodegenTypeNameSortKey(self)
32    }
33}
34
35impl ToTokens for CodegenTypeName<'_> {
36    fn to_tokens(&self, tokens: &mut TokenStream) {
37        match self {
38            Self::Schema(view) => {
39                let ident = view.extensions().get::<CodegenIdent>().unwrap();
40                tokens.append_all(CodegenIdentUsage::Type(&ident).to_token_stream())
41            }
42            Self::Inline(view) => {
43                let ident = view
44                    .path()
45                    .segments
46                    .iter()
47                    .map(CodegenTypePathSegment)
48                    .map(|segment| format_ident!("{}", segment))
49                    .reduce(|a, b| format_ident!("{}{}", a, b))
50                    .unwrap();
51                tokens.append(ident);
52            }
53        }
54    }
55}
56
57/// A comparator that sorts type names lexicographically.
58#[derive(Clone, Copy, Debug)]
59pub struct CodegenTypeNameSortKey<'a>(CodegenTypeName<'a>);
60
61impl<'a> CodegenTypeNameSortKey<'a> {
62    #[inline]
63    pub fn into_name(self) -> CodegenTypeName<'a> {
64        self.0
65    }
66}
67
68impl Eq for CodegenTypeNameSortKey<'_> {}
69
70impl Ord for CodegenTypeNameSortKey<'_> {
71    fn cmp(&self, other: &Self) -> Ordering {
72        match (&self.0, &other.0) {
73            (CodegenTypeName::Schema(a), CodegenTypeName::Schema(b)) => a.name().cmp(b.name()),
74            (CodegenTypeName::Inline(a), CodegenTypeName::Inline(b)) => a.path().cmp(b.path()),
75            (CodegenTypeName::Schema(_), CodegenTypeName::Inline(_)) => Ordering::Less,
76            (CodegenTypeName::Inline(_), CodegenTypeName::Schema(_)) => Ordering::Greater,
77        }
78    }
79}
80
81impl PartialEq for CodegenTypeNameSortKey<'_> {
82    fn eq(&self, other: &Self) -> bool {
83        self.cmp(other).is_eq()
84    }
85}
86
87impl PartialOrd for CodegenTypeNameSortKey<'_> {
88    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
89        Some(self.cmp(other))
90    }
91}
92
93/// A string that's statically guaranteed to be valid for any
94/// [`CodegenIdentUsage`].
95#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
96pub struct CodegenIdent(String);
97
98impl CodegenIdent {
99    /// Creates an identifier for any usage.
100    pub fn new(s: &str) -> Self {
101        let s = clean(s);
102        if KEYWORDS.contains(&s.as_str()) {
103            Self(format!("_{s}"))
104        } else {
105            Self(s)
106        }
107    }
108}
109
110impl Deref for CodegenIdent {
111    type Target = CodegenIdentRef;
112
113    fn deref(&self) -> &Self::Target {
114        CodegenIdentRef::new(&self.0)
115    }
116}
117
118/// A string slice that's guaranteed to be valid for any [`CodegenIdentUsage`].
119#[derive(Debug, Eq, Ord, PartialEq, PartialOrd, RefCastCustom)]
120#[repr(transparent)]
121pub struct CodegenIdentRef(str);
122
123impl CodegenIdentRef {
124    #[ref_cast_custom]
125    fn new(s: &str) -> &Self;
126}
127
128#[derive(Clone, Copy, Debug)]
129pub enum CodegenIdentUsage<'a> {
130    Module(&'a CodegenIdentRef),
131    Type(&'a CodegenIdentRef),
132    Field(&'a CodegenIdentRef),
133    Variant(&'a CodegenIdentRef),
134    Param(&'a CodegenIdentRef),
135    Method(&'a CodegenIdentRef),
136}
137
138impl Display for CodegenIdentUsage<'_> {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        match self {
141            Self::Module(name) | Self::Field(name) | Self::Param(name) | Self::Method(name) => {
142                if name.0.starts_with(unicode_ident::is_xid_start) {
143                    write!(f, "{}", AsSnekCase(&name.0))
144                } else {
145                    // `name` doesn't start with `XID_Start` (e.g., "1099KStatus"),
146                    // so prefix it with `_`; everything after is known to be
147                    // `XID_Continue`.
148                    write!(f, "_{}", AsSnekCase(&name.0))
149                }
150            }
151            Self::Type(name) | Self::Variant(name) => {
152                if name.0.starts_with(unicode_ident::is_xid_start) {
153                    write!(f, "{}", AsPascalCase(&name.0))
154                } else {
155                    write!(f, "_{}", AsPascalCase(&name.0))
156                }
157            }
158        }
159    }
160}
161
162impl ToTokens for CodegenIdentUsage<'_> {
163    fn to_tokens(&self, tokens: &mut TokenStream) {
164        let s = self.to_string();
165        let ident = syn::parse_str(&s).unwrap_or_else(|_| Ident::new_raw(&s, Span::call_site()));
166        tokens.append(ident);
167    }
168}
169
170/// A scope for generating unique, valid Rust identifiers.
171#[derive(Debug)]
172pub struct CodegenIdentScope<'a>(UniqueNamesScope<'a>);
173
174impl<'a> CodegenIdentScope<'a> {
175    /// Creates a new identifier scope that's backed by the given arena.
176    pub fn new(arena: &'a UniqueNames) -> Self {
177        Self::with_reserved(arena, &[])
178    }
179
180    /// Creates a new identifier scope that's backed by the given arena,
181    /// with additional pre-reserved names.
182    pub fn with_reserved(arena: &'a UniqueNames, reserved: &[&str]) -> Self {
183        Self(arena.scope_with_reserved(itertools::chain!(
184            reserved.iter().copied(),
185            KEYWORDS.iter().copied(),
186            std::iter::once("")
187        )))
188    }
189
190    /// Cleans the input string and returns a name that's unique
191    /// within this scope, and valid for any [`CodegenIdentUsage`].
192    pub fn uniquify(&mut self, name: &str) -> CodegenIdent {
193        CodegenIdent(self.0.uniquify(&clean(name)).into_owned())
194    }
195}
196
197#[derive(Clone, Copy, Debug)]
198pub struct CodegenUntaggedVariantName(pub IrUntaggedVariantNameHint);
199
200impl ToTokens for CodegenUntaggedVariantName {
201    fn to_tokens(&self, tokens: &mut TokenStream) {
202        use IrUntaggedVariantNameHint::*;
203        let s = match self.0 {
204            Primitive(PrimitiveIrType::String) => "String".into(),
205            Primitive(PrimitiveIrType::I32) => "I32".into(),
206            Primitive(PrimitiveIrType::I64) => "I64".into(),
207            Primitive(PrimitiveIrType::F32) => "F32".into(),
208            Primitive(PrimitiveIrType::F64) => "F64".into(),
209            Primitive(PrimitiveIrType::Bool) => "Bool".into(),
210            Primitive(PrimitiveIrType::DateTime) => "DateTime".into(),
211            Primitive(PrimitiveIrType::UnixTime) => "UnixTime".into(),
212            Primitive(PrimitiveIrType::Date) => "Date".into(),
213            Primitive(PrimitiveIrType::Url) => "Url".into(),
214            Primitive(PrimitiveIrType::Uuid) => "Uuid".into(),
215            Primitive(PrimitiveIrType::Bytes) => "Bytes".into(),
216            Array => "Array".into(),
217            Map => "Map".into(),
218            Index(index) => Cow::Owned(format!("V{index}")),
219        };
220        tokens.append(Ident::new(&s, Span::call_site()));
221    }
222}
223
224#[derive(Clone, Copy, Debug)]
225pub struct CodegenStructFieldName(pub IrStructFieldNameHint);
226
227impl ToTokens for CodegenStructFieldName {
228    fn to_tokens(&self, tokens: &mut TokenStream) {
229        match self.0 {
230            IrStructFieldNameHint::Index(index) => {
231                CodegenIdentUsage::Field(&CodegenIdent(format!("variant_{index}")))
232                    .to_tokens(tokens)
233            }
234        }
235    }
236}
237
238#[derive(Clone, Copy, Debug)]
239pub struct CodegenTypePathSegment<'a>(&'a InlineIrTypePathSegment<'a>);
240
241impl IdentFragment for CodegenTypePathSegment<'_> {
242    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
243        use InlineIrTypePathSegment::*;
244        match self.0 {
245            // Segments are part of an inline type path that always has a root prefix,
246            // so we don't need to check for `XID_Start`.
247            Operation(name) => write!(f, "{}", AsPascalCase(clean(name))),
248            Parameter(name) => write!(f, "{}", AsPascalCase(clean(name))),
249            Request => f.write_str("Request"),
250            Response => f.write_str("Response"),
251            Field(IrStructFieldName::Name(name)) => {
252                write!(f, "{}", AsPascalCase(clean(name)))
253            }
254            Field(IrStructFieldName::Hint(IrStructFieldNameHint::Index(index))) => {
255                write!(f, "Variant{index}")
256            }
257            MapValue => f.write_str("Value"),
258            ArrayItem => f.write_str("Item"),
259            Variant(index) => write!(f, "V{index}"),
260        }
261    }
262}
263
264/// Makes a string suitable for inclusion within a Rust identifier.
265///
266/// Cleaning segments the string on word boundaries, collapses all
267/// non-`XID_Continue` characters into new boundaries, and
268/// reassembles the string. This makes the string resilient to
269/// case transformations, which also collapse boundaries, and so
270/// can produce duplicates in some cases.
271///
272/// Note that the result may not itself be a valid Rust identifier,
273/// because Rust identifiers must start with `XID_Start`.
274/// This is checked and handled in [`CodegenIdentUsage`].
275fn clean(s: &str) -> String {
276    WordSegments::new(s)
277        .flat_map(|s| s.split(|c| !unicode_ident::is_xid_continue(c)))
278        .join("_")
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    use pretty_assertions::assert_eq;
286    use syn::parse_quote;
287
288    // MARK: Usages
289
290    #[test]
291    fn test_codegen_ident_type() {
292        let ident = CodegenIdent::new("pet_store");
293        let usage = CodegenIdentUsage::Type(&ident);
294        let actual: syn::Ident = parse_quote!(#usage);
295        let expected: syn::Ident = parse_quote!(PetStore);
296        assert_eq!(actual, expected);
297    }
298
299    #[test]
300    fn test_codegen_ident_field() {
301        let ident = CodegenIdent::new("petStore");
302        let usage = CodegenIdentUsage::Field(&ident);
303        let actual: syn::Ident = parse_quote!(#usage);
304        let expected: syn::Ident = parse_quote!(pet_store);
305        assert_eq!(actual, expected);
306    }
307
308    #[test]
309    fn test_codegen_ident_module() {
310        let ident = CodegenIdent::new("MyModule");
311        let usage = CodegenIdentUsage::Module(&ident);
312        let actual: syn::Ident = parse_quote!(#usage);
313        let expected: syn::Ident = parse_quote!(my_module);
314        assert_eq!(actual, expected);
315    }
316
317    #[test]
318    fn test_codegen_ident_variant() {
319        let ident = CodegenIdent::new("http_error");
320        let usage = CodegenIdentUsage::Variant(&ident);
321        let actual: syn::Ident = parse_quote!(#usage);
322        let expected: syn::Ident = parse_quote!(HttpError);
323        assert_eq!(actual, expected);
324    }
325
326    #[test]
327    fn test_codegen_ident_param() {
328        let ident = CodegenIdent::new("userId");
329        let usage = CodegenIdentUsage::Param(&ident);
330        let actual: syn::Ident = parse_quote!(#usage);
331        let expected: syn::Ident = parse_quote!(user_id);
332        assert_eq!(actual, expected);
333    }
334
335    #[test]
336    fn test_codegen_ident_method() {
337        let ident = CodegenIdent::new("getUserById");
338        let usage = CodegenIdentUsage::Method(&ident);
339        let actual: syn::Ident = parse_quote!(#usage);
340        let expected: syn::Ident = parse_quote!(get_user_by_id);
341        assert_eq!(actual, expected);
342    }
343
344    // MARK: Special characters
345
346    #[test]
347    fn test_codegen_ident_handles_rust_keywords() {
348        let ident = CodegenIdent::new("type");
349        let usage = CodegenIdentUsage::Field(&ident);
350        let actual: syn::Ident = parse_quote!(#usage);
351        let expected: syn::Ident = parse_quote!(r#type);
352        assert_eq!(actual, expected);
353    }
354
355    #[test]
356    fn test_codegen_ident_handles_invalid_start_chars() {
357        let ident = CodegenIdent::new("123foo");
358        let usage = CodegenIdentUsage::Field(&ident);
359        let actual: syn::Ident = parse_quote!(#usage);
360        let expected: syn::Ident = parse_quote!(_123_foo);
361        assert_eq!(actual, expected);
362    }
363
364    #[test]
365    fn test_codegen_ident_handles_special_chars() {
366        let ident = CodegenIdent::new("foo-bar-baz");
367        let usage = CodegenIdentUsage::Field(&ident);
368        let actual: syn::Ident = parse_quote!(#usage);
369        let expected: syn::Ident = parse_quote!(foo_bar_baz);
370        assert_eq!(actual, expected);
371    }
372
373    #[test]
374    fn test_codegen_ident_handles_number_prefix() {
375        let ident = CodegenIdent::new("1099KStatus");
376
377        let usage = CodegenIdentUsage::Field(&ident);
378        let actual: syn::Ident = parse_quote!(#usage);
379        let expected: syn::Ident = parse_quote!(_1099_k_status);
380        assert_eq!(actual, expected);
381
382        let usage = CodegenIdentUsage::Type(&ident);
383        let actual: syn::Ident = parse_quote!(#usage);
384        let expected: syn::Ident = parse_quote!(_1099KStatus);
385        assert_eq!(actual, expected);
386    }
387
388    // MARK: Untagged variant names
389
390    #[test]
391    fn test_untagged_variant_name_string() {
392        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
393            PrimitiveIrType::String,
394        ));
395        let actual: syn::Ident = parse_quote!(#variant_name);
396        let expected: syn::Ident = parse_quote!(String);
397        assert_eq!(actual, expected);
398    }
399
400    #[test]
401    fn test_untagged_variant_name_i32() {
402        let variant_name =
403            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::I32));
404        let actual: syn::Ident = parse_quote!(#variant_name);
405        let expected: syn::Ident = parse_quote!(I32);
406        assert_eq!(actual, expected);
407    }
408
409    #[test]
410    fn test_untagged_variant_name_i64() {
411        let variant_name =
412            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::I64));
413        let actual: syn::Ident = parse_quote!(#variant_name);
414        let expected: syn::Ident = parse_quote!(I64);
415        assert_eq!(actual, expected);
416    }
417
418    #[test]
419    fn test_untagged_variant_name_f32() {
420        let variant_name =
421            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::F32));
422        let actual: syn::Ident = parse_quote!(#variant_name);
423        let expected: syn::Ident = parse_quote!(F32);
424        assert_eq!(actual, expected);
425    }
426
427    #[test]
428    fn test_untagged_variant_name_f64() {
429        let variant_name =
430            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::F64));
431        let actual: syn::Ident = parse_quote!(#variant_name);
432        let expected: syn::Ident = parse_quote!(F64);
433        assert_eq!(actual, expected);
434    }
435
436    #[test]
437    fn test_untagged_variant_name_bool() {
438        let variant_name =
439            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Bool));
440        let actual: syn::Ident = parse_quote!(#variant_name);
441        let expected: syn::Ident = parse_quote!(Bool);
442        assert_eq!(actual, expected);
443    }
444
445    #[test]
446    fn test_untagged_variant_name_datetime() {
447        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
448            PrimitiveIrType::DateTime,
449        ));
450        let actual: syn::Ident = parse_quote!(#variant_name);
451        let expected: syn::Ident = parse_quote!(DateTime);
452        assert_eq!(actual, expected);
453    }
454
455    #[test]
456    fn test_untagged_variant_name_date() {
457        let variant_name =
458            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Date));
459        let actual: syn::Ident = parse_quote!(#variant_name);
460        let expected: syn::Ident = parse_quote!(Date);
461        assert_eq!(actual, expected);
462    }
463
464    #[test]
465    fn test_untagged_variant_name_url() {
466        let variant_name =
467            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Url));
468        let actual: syn::Ident = parse_quote!(#variant_name);
469        let expected: syn::Ident = parse_quote!(Url);
470        assert_eq!(actual, expected);
471    }
472
473    #[test]
474    fn test_untagged_variant_name_uuid() {
475        let variant_name =
476            CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(PrimitiveIrType::Uuid));
477        let actual: syn::Ident = parse_quote!(#variant_name);
478        let expected: syn::Ident = parse_quote!(Uuid);
479        assert_eq!(actual, expected);
480    }
481
482    #[test]
483    fn test_untagged_variant_name_bytes() {
484        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Primitive(
485            PrimitiveIrType::Bytes,
486        ));
487        let actual: syn::Ident = parse_quote!(#variant_name);
488        let expected: syn::Ident = parse_quote!(Bytes);
489        assert_eq!(actual, expected);
490    }
491
492    #[test]
493    fn test_untagged_variant_name_index() {
494        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Index(0));
495        let actual: syn::Ident = parse_quote!(#variant_name);
496        let expected: syn::Ident = parse_quote!(V0);
497        assert_eq!(actual, expected);
498
499        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Index(42));
500        let actual: syn::Ident = parse_quote!(#variant_name);
501        let expected: syn::Ident = parse_quote!(V42);
502        assert_eq!(actual, expected);
503    }
504
505    #[test]
506    fn test_untagged_variant_name_array() {
507        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Array);
508        let actual: syn::Ident = parse_quote!(#variant_name);
509        let expected: syn::Ident = parse_quote!(Array);
510        assert_eq!(actual, expected);
511    }
512
513    #[test]
514    fn test_untagged_variant_name_map() {
515        let variant_name = CodegenUntaggedVariantName(IrUntaggedVariantNameHint::Map);
516        let actual: syn::Ident = parse_quote!(#variant_name);
517        let expected: syn::Ident = parse_quote!(Map);
518        assert_eq!(actual, expected);
519    }
520
521    // MARK: Struct field names
522
523    #[test]
524    fn test_struct_field_name_index() {
525        let field_name = CodegenStructFieldName(IrStructFieldNameHint::Index(0));
526        let actual: syn::Ident = parse_quote!(#field_name);
527        let expected: syn::Ident = parse_quote!(variant_0);
528        assert_eq!(actual, expected);
529
530        let field_name = CodegenStructFieldName(IrStructFieldNameHint::Index(5));
531        let actual: syn::Ident = parse_quote!(#field_name);
532        let expected: syn::Ident = parse_quote!(variant_5);
533        assert_eq!(actual, expected);
534    }
535
536    // MARK: `clean()`
537
538    #[test]
539    fn test_clean() {
540        assert_eq!(clean("foo-bar"), "foo_bar");
541        assert_eq!(clean("foo.bar"), "foo_bar");
542        assert_eq!(clean("foo bar"), "foo_bar");
543        assert_eq!(clean("foo@bar"), "foo_bar");
544        assert_eq!(clean("foo#bar"), "foo_bar");
545        assert_eq!(clean("foo!bar"), "foo_bar");
546
547        assert_eq!(clean("foo_bar"), "foo_bar");
548        assert_eq!(clean("FooBar"), "Foo_Bar");
549        assert_eq!(clean("foo123"), "foo123");
550        assert_eq!(clean("_foo"), "foo");
551
552        assert_eq!(clean("_foo"), "foo");
553        assert_eq!(clean("__foo"), "foo");
554
555        // Digits are in `XID_Continue`, so they should be preserved.
556        assert_eq!(clean("123foo"), "123_foo");
557        assert_eq!(clean("9bar"), "9_bar");
558
559        // Non-ASCII characters that are valid in identifiers should be preserved;
560        // characters that aren't should be replaced.
561        assert_eq!(clean("café"), "café");
562        assert_eq!(clean("foo™bar"), "foo_bar");
563
564        // Invalid characters should be collapsed.
565        assert_eq!(clean("foo---bar"), "foo_bar");
566        assert_eq!(clean("foo...bar"), "foo_bar");
567    }
568
569    // MARK: Scopes
570
571    #[test]
572    fn test_codegen_ident_scope_handles_empty() {
573        let unique = UniqueNames::new();
574        let mut scope = CodegenIdentScope::new(&unique);
575        let ident = scope.uniquify("");
576
577        let usage = CodegenIdentUsage::Field(&ident);
578        let actual: syn::Ident = parse_quote!(#usage);
579        let expected: syn::Ident = parse_quote!(_2);
580        assert_eq!(actual, expected);
581
582        let usage = CodegenIdentUsage::Type(&ident);
583        let actual: syn::Ident = parse_quote!(#usage);
584        let expected: syn::Ident = parse_quote!(_2);
585        assert_eq!(actual, expected);
586    }
587}