Skip to main content

roam_types/
shape_classify.rs

1use facet_core::{Def, ScalarType, Shape, StructKind, Type, UserType};
2
3/// Classification of a `Shape` for code generation.
4#[derive(Debug, Clone, Copy)]
5pub enum ShapeKind<'a> {
6    Scalar(ScalarType),
7    List {
8        element: &'static Shape,
9    },
10    Array {
11        element: &'static Shape,
12        len: usize,
13    },
14    Slice {
15        element: &'static Shape,
16    },
17    Option {
18        inner: &'static Shape,
19    },
20    Map {
21        key: &'static Shape,
22        value: &'static Shape,
23    },
24    Set {
25        element: &'static Shape,
26    },
27    Struct(StructInfo<'a>),
28    Enum(EnumInfo<'a>),
29    Tuple {
30        elements: &'a [facet_core::TypeParam],
31    },
32    TupleStruct {
33        fields: &'a [facet_core::Field],
34    },
35    Tx {
36        inner: &'static Shape,
37    },
38    Rx {
39        inner: &'static Shape,
40    },
41    Pointer {
42        pointee: &'static Shape,
43    },
44    Result {
45        ok: &'static Shape,
46        err: &'static Shape,
47    },
48    Opaque,
49}
50
51/// Information about a struct type.
52#[derive(Debug, Clone, Copy)]
53pub struct StructInfo<'a> {
54    pub name: Option<&'static str>,
55    pub kind: StructKind,
56    pub fields: &'a [facet_core::Field],
57}
58
59/// Information about an enum type.
60#[derive(Debug, Clone, Copy)]
61pub struct EnumInfo<'a> {
62    pub name: Option<&'static str>,
63    pub variants: &'a [facet_core::Variant],
64}
65
66/// Information about an enum variant for code generation.
67#[derive(Debug, Clone, Copy)]
68pub enum VariantKind<'a> {
69    Unit,
70    Newtype { inner: &'static Shape },
71    Tuple { fields: &'a [facet_core::Field] },
72    Struct { fields: &'a [facet_core::Field] },
73}
74
75/// Classify an enum variant.
76pub fn classify_variant(variant: &facet_core::Variant) -> VariantKind<'_> {
77    match variant.data.kind {
78        StructKind::Unit => VariantKind::Unit,
79        StructKind::TupleStruct | StructKind::Tuple => {
80            if variant.data.fields.len() == 1 {
81                VariantKind::Newtype {
82                    inner: variant.data.fields[0].shape(),
83                }
84            } else {
85                VariantKind::Tuple {
86                    fields: variant.data.fields,
87                }
88            }
89        }
90        StructKind::Struct => VariantKind::Struct {
91            fields: variant.data.fields,
92        },
93    }
94}
95
96/// Classify a `Shape` into a higher-level semantic kind.
97pub fn classify_shape(shape: &'static Shape) -> ShapeKind<'static> {
98    if crate::is_tx(shape)
99        && let Some(inner) = shape.type_params.first()
100    {
101        return ShapeKind::Tx { inner: inner.shape };
102    }
103    if crate::is_rx(shape)
104        && let Some(inner) = shape.type_params.first()
105    {
106        return ShapeKind::Rx { inner: inner.shape };
107    }
108
109    if shape.is_transparent()
110        && let Some(inner) = shape.inner
111    {
112        return classify_shape(inner);
113    }
114
115    if let Some(scalar) = shape.scalar_type() {
116        return ShapeKind::Scalar(scalar);
117    }
118
119    match shape.def {
120        Def::List(list_def) => {
121            return ShapeKind::List {
122                element: list_def.t(),
123            };
124        }
125        Def::Array(array_def) => {
126            return ShapeKind::Array {
127                element: array_def.t(),
128                len: array_def.n,
129            };
130        }
131        Def::Slice(slice_def) => {
132            return ShapeKind::Slice {
133                element: slice_def.t(),
134            };
135        }
136        Def::Option(opt_def) => return ShapeKind::Option { inner: opt_def.t() },
137        Def::Map(map_def) => {
138            return ShapeKind::Map {
139                key: map_def.k(),
140                value: map_def.v(),
141            };
142        }
143        Def::Set(set_def) => {
144            return ShapeKind::Set {
145                element: set_def.t(),
146            };
147        }
148        Def::Result(result_def) => {
149            return ShapeKind::Result {
150                ok: result_def.t(),
151                err: result_def.e(),
152            };
153        }
154        Def::Pointer(ptr_def) => {
155            if let Some(pointee) = ptr_def.pointee {
156                return ShapeKind::Pointer { pointee };
157            }
158        }
159        _ => {}
160    }
161
162    match shape.ty {
163        Type::User(UserType::Struct(struct_type)) => {
164            if struct_type.kind == StructKind::Tuple {
165                return ShapeKind::TupleStruct {
166                    fields: struct_type.fields,
167                };
168            }
169            return ShapeKind::Struct(StructInfo {
170                name: extract_type_name(shape.type_identifier),
171                kind: struct_type.kind,
172                fields: struct_type.fields,
173            });
174        }
175        Type::User(UserType::Enum(enum_type)) => {
176            return ShapeKind::Enum(EnumInfo {
177                name: extract_type_name(shape.type_identifier),
178                variants: enum_type.variants,
179            });
180        }
181        Type::Pointer(_) => {
182            if let Some(inner) = shape.type_params.first() {
183                return classify_shape(inner.shape);
184            }
185        }
186        _ => {}
187    }
188
189    ShapeKind::Opaque
190}
191
192/// Check if a shape represents bytes (`Vec<u8>` or `&[u8]`).
193pub fn is_bytes(shape: &Shape) -> bool {
194    match shape.def {
195        Def::List(list_def) => matches!(list_def.t().scalar_type(), Some(ScalarType::U8)),
196        Def::Slice(slice_def) => matches!(slice_def.t().scalar_type(), Some(ScalarType::U8)),
197        _ => false,
198    }
199}
200
201fn extract_type_name(type_identifier: &'static str) -> Option<&'static str> {
202    if type_identifier.is_empty()
203        || type_identifier.starts_with('(')
204        || type_identifier.starts_with('[')
205    {
206        return None;
207    }
208    Some(type_identifier)
209}