1use std::collections::HashSet;
9
10use facet_core::{ScalarType, Shape};
11use roam_types::{
12 EnumInfo, RpcPlan, ServiceDescriptor, ShapeKind, StructInfo, VariantKind, classify_shape,
13 classify_variant, is_bytes,
14};
15
16pub fn ts_field_access(expr: &str, field_name: &str) -> String {
19 if field_name
20 .chars()
21 .next()
22 .is_some_and(|c| c.is_ascii_digit())
23 {
24 format!("{expr}[{field_name}]")
25 } else {
26 format!("{expr}.{field_name}")
27 }
28}
29
30pub fn collect_named_types(service: &ServiceDescriptor) -> Vec<(String, &'static Shape)> {
33 let mut seen = HashSet::new();
34 let mut types = Vec::new();
35
36 fn visit(
37 shape: &'static Shape,
38 seen: &mut HashSet<String>,
39 types: &mut Vec<(String, &'static Shape)>,
40 ) {
41 match classify_shape(shape) {
42 ShapeKind::Struct(StructInfo {
43 name: Some(name),
44 fields,
45 ..
46 }) => {
47 if !seen.contains(name) {
48 seen.insert(name.to_string());
49 for field in fields {
51 visit(field.shape(), seen, types);
52 }
53 types.push((name.to_string(), shape));
54 }
55 }
56 ShapeKind::Enum(EnumInfo {
57 name: Some(name),
58 variants,
59 }) => {
60 if !seen.contains(name) {
61 seen.insert(name.to_string());
62 for variant in variants {
64 match classify_variant(variant) {
65 VariantKind::Newtype { inner } => visit(inner, seen, types),
66 VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
67 for field in fields {
68 visit(field.shape(), seen, types);
69 }
70 }
71 VariantKind::Unit => {}
72 }
73 }
74 types.push((name.to_string(), shape));
75 }
76 }
77 ShapeKind::List { element } => visit(element, seen, types),
78 ShapeKind::Option { inner } => visit(inner, seen, types),
79 ShapeKind::Array { element, .. } => visit(element, seen, types),
80 ShapeKind::Map { key, value } => {
81 visit(key, seen, types);
82 visit(value, seen, types);
83 }
84 ShapeKind::Set { element } => visit(element, seen, types),
85 ShapeKind::Tuple { elements } => {
86 for param in elements {
87 visit(param.shape, seen, types);
88 }
89 }
90 ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => visit(inner, seen, types),
91 ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
92 ShapeKind::Result { ok, err } => {
93 visit(ok, seen, types);
94 visit(err, seen, types);
95 }
96 _ => {}
98 }
99 }
100
101 for method in service.methods {
102 for arg in method.args {
103 visit(arg.shape, &mut seen, &mut types);
104 }
105 visit(method.return_shape, &mut seen, &mut types);
106 }
107
108 types
109}
110
111pub fn generate_named_types(named_types: &[(String, &'static Shape)]) -> String {
113 let mut out = String::new();
114
115 if named_types.is_empty() {
116 return out;
117 }
118
119 out.push_str("// Named type definitions\n");
120
121 for (name, shape) in named_types {
122 match classify_shape(shape) {
123 ShapeKind::Struct(StructInfo { fields, .. }) => {
124 out.push_str(&format!("export interface {} {{\n", name));
125 for field in fields {
126 out.push_str(&format!(
127 " {}: {};\n",
128 field.name,
129 ts_type_base_named(field.shape())
130 ));
131 }
132 out.push_str("}\n\n");
133 }
134 ShapeKind::Enum(EnumInfo { variants, .. }) => {
135 out.push_str(&format!("export type {} =\n", name));
136 for (i, variant) in variants.iter().enumerate() {
137 let variant_type = match classify_variant(variant) {
138 VariantKind::Unit => format!("{{ tag: '{}' }}", variant.name),
139 VariantKind::Newtype { inner } => {
140 format!(
141 "{{ tag: '{}'; value: {} }}",
142 variant.name,
143 ts_type_base_named(inner)
144 )
145 }
146 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
147 let field_strs = fields
148 .iter()
149 .map(|f| format!("{}: {}", f.name, ts_type_base_named(f.shape())))
150 .collect::<Vec<_>>()
151 .join("; ");
152 format!("{{ tag: '{}'; {} }}", variant.name, field_strs)
153 }
154 };
155 let sep = if i < variants.len() - 1 { "" } else { ";" };
156 out.push_str(&format!(" | {}{}\n", variant_type, sep));
157 }
158 out.push('\n');
159 }
160 _ => {}
161 }
162 }
163
164 out
165}
166
167pub fn ts_type_base_named(shape: &'static Shape) -> String {
170 match classify_shape(shape) {
171 ShapeKind::Struct(StructInfo {
173 name: Some(name), ..
174 }) => name.to_string(),
175 ShapeKind::Enum(EnumInfo {
176 name: Some(name), ..
177 }) => name.to_string(),
178
179 ShapeKind::List { element } => {
181 if is_bytes(shape) {
183 return "Uint8Array".into();
184 }
185 if matches!(
187 classify_shape(element),
188 ShapeKind::Enum(EnumInfo { name: None, .. })
189 ) {
190 format!("({})[]", ts_type_base_named(element))
191 } else {
192 format!("{}[]", ts_type_base_named(element))
193 }
194 }
195 ShapeKind::Option { inner } => format!("{} | null", ts_type_base_named(inner)),
196 ShapeKind::Array { element, len } => format!("[{}; {}]", ts_type_base_named(element), len),
197 ShapeKind::Map { key, value } => {
198 format!(
199 "Map<{}, {}>",
200 ts_type_base_named(key),
201 ts_type_base_named(value)
202 )
203 }
204 ShapeKind::Set { element } => format!("Set<{}>", ts_type_base_named(element)),
205 ShapeKind::Tuple { elements } => {
206 let inner = elements
207 .iter()
208 .map(|p| ts_type_base_named(p.shape))
209 .collect::<Vec<_>>()
210 .join(", ");
211 format!("[{inner}]")
212 }
213 ShapeKind::Tx { inner } => format!("Tx<{}>", ts_type_base_named(inner)),
214 ShapeKind::Rx { inner } => format!("Rx<{}>", ts_type_base_named(inner)),
215
216 ShapeKind::Struct(StructInfo {
218 name: None, fields, ..
219 }) => {
220 let inner = fields
221 .iter()
222 .map(|f| format!("{}: {}", f.name, ts_type_base_named(f.shape())))
223 .collect::<Vec<_>>()
224 .join("; ");
225 format!("{{ {inner} }}")
226 }
227
228 ShapeKind::Enum(EnumInfo {
230 name: None,
231 variants,
232 }) => variants
233 .iter()
234 .map(|v| match classify_variant(v) {
235 VariantKind::Unit => format!("{{ tag: '{}' }}", v.name),
236 VariantKind::Newtype { inner } => {
237 format!(
238 "{{ tag: '{}'; value: {} }}",
239 v.name,
240 ts_type_base_named(inner)
241 )
242 }
243 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
244 let field_strs = fields
245 .iter()
246 .map(|f| format!("{}: {}", f.name, ts_type_base_named(f.shape())))
247 .collect::<Vec<_>>()
248 .join("; ");
249 format!("{{ tag: '{}'; {} }}", v.name, field_strs)
250 }
251 })
252 .collect::<Vec<_>>()
253 .join(" | "),
254
255 ShapeKind::Scalar(scalar) => ts_scalar_type(scalar),
257 ShapeKind::Slice { element } => format!("{}[]", ts_type_base_named(element)),
258 ShapeKind::Pointer { pointee } => ts_type_base_named(pointee),
259 ShapeKind::Result { ok, err } => {
260 format!(
261 "{{ ok: true; value: {} }} | {{ ok: false; error: {} }}",
262 ts_type_base_named(ok),
263 ts_type_base_named(err)
264 )
265 }
266 ShapeKind::TupleStruct { fields } => {
267 let inner = fields
268 .iter()
269 .map(|f| ts_type_base_named(f.shape()))
270 .collect::<Vec<_>>()
271 .join(", ");
272 format!("[{inner}]")
273 }
274 ShapeKind::Opaque => "unknown".into(),
275 }
276}
277
278pub fn ts_scalar_type(scalar: ScalarType) -> String {
280 match scalar {
281 ScalarType::Bool => "boolean".into(),
282 ScalarType::U8
283 | ScalarType::U16
284 | ScalarType::U32
285 | ScalarType::I8
286 | ScalarType::I16
287 | ScalarType::I32
288 | ScalarType::F32
289 | ScalarType::F64 => "number".into(),
290 ScalarType::U64
291 | ScalarType::U128
292 | ScalarType::I64
293 | ScalarType::I128
294 | ScalarType::USize
295 | ScalarType::ISize => "bigint".into(),
296 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
297 "string".into()
298 }
299 ScalarType::Unit => "void".into(),
300 _ => "unknown".into(),
301 }
302}
303
304pub fn ts_type_client_arg(shape: &'static Shape) -> String {
308 match classify_shape(shape) {
309 ShapeKind::Tx { inner } => format!("Tx<{}>", ts_type_client_arg(inner)),
310 ShapeKind::Rx { inner } => format!("Rx<{}>", ts_type_client_arg(inner)),
311 _ => ts_type_base_named(shape),
312 }
313}
314
315pub fn ts_type_client_return(shape: &'static Shape) -> String {
318 assert_no_channels_in_return_shape(shape);
319 ts_type_base_named(shape)
320}
321
322pub fn ts_type_server_arg(shape: &'static Shape) -> String {
326 match classify_shape(shape) {
327 ShapeKind::Tx { inner } => format!("Tx<{}>", ts_type_server_arg(inner)),
328 ShapeKind::Rx { inner } => format!("Rx<{}>", ts_type_server_arg(inner)),
329 _ => ts_type_base_named(shape),
330 }
331}
332
333pub fn ts_type_server_return(shape: &'static Shape) -> String {
335 assert_no_channels_in_return_shape(shape);
336 ts_type_base_named(shape)
337}
338
339pub fn ts_type(shape: &'static Shape) -> String {
342 ts_type_base_named(shape)
343}
344
345pub fn is_fully_supported(shape: &'static Shape) -> bool {
348 match classify_shape(shape) {
349 ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => is_fully_supported(inner),
351 ShapeKind::List { element }
352 | ShapeKind::Option { inner: element }
353 | ShapeKind::Set { element }
354 | ShapeKind::Array { element, .. }
355 | ShapeKind::Slice { element } => is_fully_supported(element),
356 ShapeKind::Map { key, value } => is_fully_supported(key) && is_fully_supported(value),
357 ShapeKind::Tuple { elements } => elements.iter().all(|p| is_fully_supported(p.shape)),
358 ShapeKind::TupleStruct { fields } => fields.iter().all(|f| is_fully_supported(f.shape())),
359 ShapeKind::Struct(StructInfo { fields, .. }) => {
360 fields.iter().all(|f| is_fully_supported(f.shape()))
361 }
362 ShapeKind::Enum(EnumInfo { variants, .. }) => {
363 variants.iter().all(|v| match classify_variant(v) {
364 VariantKind::Unit => true,
365 VariantKind::Newtype { inner } => is_fully_supported(inner),
366 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
367 fields.iter().all(|f| is_fully_supported(f.shape()))
368 }
369 })
370 }
371 ShapeKind::Pointer { pointee } => is_fully_supported(pointee),
372 ShapeKind::Scalar(_) => true,
373 ShapeKind::Result { ok, err } => is_fully_supported(ok) && is_fully_supported(err),
374 ShapeKind::Opaque => false,
375 }
376}
377
378fn assert_no_channels_in_return_shape(shape: &'static Shape) {
379 assert!(
380 RpcPlan::for_shape(shape).channel_locations.is_empty(),
381 "channels are not allowed in return types"
382 );
383}