1use std::collections::HashSet;
9
10use facet_core::{ScalarType, Shape};
11use roam_types::{
12 EnumInfo, RpcPlan, ServiceDescriptor, ShapeKind, StructInfo, VariantKind, classify_shape,
13 classify_variant, is_bytes,
14};
15
16fn transparent_named_alias(shape: &'static Shape) -> Option<(&'static str, &'static Shape)> {
17 if !shape.is_transparent() {
18 return None;
19 }
20 let name = extract_type_name(shape.type_identifier)?;
21 let inner = shape.inner?;
22 Some((name, inner))
23}
24
25fn extract_type_name(type_identifier: &'static str) -> Option<&'static str> {
26 if type_identifier.is_empty()
27 || type_identifier.starts_with('(')
28 || type_identifier.starts_with('[')
29 {
30 return None;
31 }
32 Some(type_identifier)
33}
34
35pub fn ts_field_access(expr: &str, field_name: &str) -> String {
38 if field_name
39 .chars()
40 .next()
41 .is_some_and(|c| c.is_ascii_digit())
42 {
43 format!("{expr}[{field_name}]")
44 } else {
45 format!("{expr}.{field_name}")
46 }
47}
48
49pub fn collect_named_types(service: &ServiceDescriptor) -> Vec<(String, &'static Shape)> {
52 let mut seen = HashSet::new();
53 let mut types = Vec::new();
54
55 fn visit(
56 shape: &'static Shape,
57 seen: &mut HashSet<String>,
58 types: &mut Vec<(String, &'static Shape)>,
59 ) {
60 if let Some((name, inner)) = transparent_named_alias(shape) {
61 if !seen.contains(name) {
62 seen.insert(name.to_string());
63 visit(inner, seen, types);
64 types.push((name.to_string(), shape));
65 }
66 return;
67 }
68
69 match classify_shape(shape) {
70 ShapeKind::Struct(StructInfo {
71 name: Some(name),
72 fields,
73 ..
74 }) => {
75 if !seen.contains(name) {
76 seen.insert(name.to_string());
77 for field in fields {
79 visit(field.shape(), seen, types);
80 }
81 types.push((name.to_string(), shape));
82 }
83 }
84 ShapeKind::Enum(EnumInfo {
85 name: Some(name),
86 variants,
87 }) => {
88 if !seen.contains(name) {
89 seen.insert(name.to_string());
90 for variant in variants {
92 match classify_variant(variant) {
93 VariantKind::Newtype { inner } => visit(inner, seen, types),
94 VariantKind::Struct { fields } | VariantKind::Tuple { fields } => {
95 for field in fields {
96 visit(field.shape(), seen, types);
97 }
98 }
99 VariantKind::Unit => {}
100 }
101 }
102 types.push((name.to_string(), shape));
103 }
104 }
105 ShapeKind::List { element } => visit(element, seen, types),
106 ShapeKind::Option { inner } => visit(inner, seen, types),
107 ShapeKind::Array { element, .. } => visit(element, seen, types),
108 ShapeKind::Map { key, value } => {
109 visit(key, seen, types);
110 visit(value, seen, types);
111 }
112 ShapeKind::Set { element } => visit(element, seen, types),
113 ShapeKind::Tuple { elements } => {
114 for param in elements {
115 visit(param.shape, seen, types);
116 }
117 }
118 ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => visit(inner, seen, types),
119 ShapeKind::Pointer { pointee } => visit(pointee, seen, types),
120 ShapeKind::Result { ok, err } => {
121 visit(ok, seen, types);
122 visit(err, seen, types);
123 }
124 _ => {}
126 }
127 }
128
129 for method in service.methods {
130 for arg in method.args {
131 visit(arg.shape, &mut seen, &mut types);
132 }
133 visit(method.return_shape, &mut seen, &mut types);
134 }
135
136 types
137}
138
139pub fn generate_named_types(named_types: &[(String, &'static Shape)]) -> String {
141 let mut out = String::new();
142
143 if named_types.is_empty() {
144 return out;
145 }
146
147 out.push_str("// Named type definitions\n");
148
149 for (name, shape) in named_types {
150 if let Some((_, inner)) = transparent_named_alias(shape) {
151 out.push_str(&format!(
152 "export type {name} = {};\n\n",
153 ts_type_base_named(inner)
154 ));
155 continue;
156 }
157
158 match classify_shape(shape) {
159 ShapeKind::Struct(StructInfo { fields, .. }) => {
160 out.push_str(&format!("export interface {} {{\n", name));
161 for field in fields {
162 out.push_str(&format!(
163 " {}: {};\n",
164 field.name,
165 ts_type_base_named(field.shape())
166 ));
167 }
168 out.push_str("}\n\n");
169 }
170 ShapeKind::Enum(EnumInfo { variants, .. }) => {
171 out.push_str(&format!("export type {} =\n", name));
172 for (i, variant) in variants.iter().enumerate() {
173 let variant_type = match classify_variant(variant) {
174 VariantKind::Unit => format!("{{ tag: '{}' }}", variant.name),
175 VariantKind::Newtype { inner } => {
176 format!(
177 "{{ tag: '{}'; value: {} }}",
178 variant.name,
179 ts_type_base_named(inner)
180 )
181 }
182 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
183 let field_strs = fields
184 .iter()
185 .map(|f| format!("{}: {}", f.name, ts_type_base_named(f.shape())))
186 .collect::<Vec<_>>()
187 .join("; ");
188 format!("{{ tag: '{}'; {} }}", variant.name, field_strs)
189 }
190 };
191 let sep = if i < variants.len() - 1 { "" } else { ";" };
192 out.push_str(&format!(" | {}{}\n", variant_type, sep));
193 }
194 out.push('\n');
195 }
196 _ => {}
197 }
198 }
199
200 out
201}
202
203pub fn ts_type_base_named(shape: &'static Shape) -> String {
206 if let Some((name, _)) = transparent_named_alias(shape) {
207 return name.to_string();
208 }
209
210 match classify_shape(shape) {
211 ShapeKind::Struct(StructInfo {
213 name: Some(name), ..
214 }) => name.to_string(),
215 ShapeKind::Enum(EnumInfo {
216 name: Some(name), ..
217 }) => name.to_string(),
218
219 ShapeKind::List { element } => {
221 if is_bytes(shape) {
223 return "Uint8Array".into();
224 }
225 if matches!(
227 classify_shape(element),
228 ShapeKind::Enum(EnumInfo { name: None, .. })
229 ) {
230 format!("({})[]", ts_type_base_named(element))
231 } else {
232 format!("{}[]", ts_type_base_named(element))
233 }
234 }
235 ShapeKind::Option { inner } => format!("{} | null", ts_type_base_named(inner)),
236 ShapeKind::Array { element, len } => format!("[{}; {}]", ts_type_base_named(element), len),
237 ShapeKind::Map { key, value } => {
238 format!(
239 "Map<{}, {}>",
240 ts_type_base_named(key),
241 ts_type_base_named(value)
242 )
243 }
244 ShapeKind::Set { element } => format!("Set<{}>", ts_type_base_named(element)),
245 ShapeKind::Tuple { elements } => {
246 let inner = elements
247 .iter()
248 .map(|p| ts_type_base_named(p.shape))
249 .collect::<Vec<_>>()
250 .join(", ");
251 format!("[{inner}]")
252 }
253 ShapeKind::Tx { inner } => format!("Tx<{}>", ts_type_base_named(inner)),
254 ShapeKind::Rx { inner } => format!("Rx<{}>", ts_type_base_named(inner)),
255
256 ShapeKind::Struct(StructInfo {
258 name: None, fields, ..
259 }) => {
260 let inner = fields
261 .iter()
262 .map(|f| format!("{}: {}", f.name, ts_type_base_named(f.shape())))
263 .collect::<Vec<_>>()
264 .join("; ");
265 format!("{{ {inner} }}")
266 }
267
268 ShapeKind::Enum(EnumInfo {
270 name: None,
271 variants,
272 }) => variants
273 .iter()
274 .map(|v| match classify_variant(v) {
275 VariantKind::Unit => format!("{{ tag: '{}' }}", v.name),
276 VariantKind::Newtype { inner } => {
277 format!(
278 "{{ tag: '{}'; value: {} }}",
279 v.name,
280 ts_type_base_named(inner)
281 )
282 }
283 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
284 let field_strs = fields
285 .iter()
286 .map(|f| format!("{}: {}", f.name, ts_type_base_named(f.shape())))
287 .collect::<Vec<_>>()
288 .join("; ");
289 format!("{{ tag: '{}'; {} }}", v.name, field_strs)
290 }
291 })
292 .collect::<Vec<_>>()
293 .join(" | "),
294
295 ShapeKind::Scalar(scalar) => ts_scalar_type(scalar),
297 ShapeKind::Slice { element } => format!("{}[]", ts_type_base_named(element)),
298 ShapeKind::Pointer { pointee } => ts_type_base_named(pointee),
299 ShapeKind::Result { ok, err } => {
300 format!(
301 "{{ ok: true; value: {} }} | {{ ok: false; error: {} }}",
302 ts_type_base_named(ok),
303 ts_type_base_named(err)
304 )
305 }
306 ShapeKind::TupleStruct { fields } => {
307 let inner = fields
308 .iter()
309 .map(|f| ts_type_base_named(f.shape()))
310 .collect::<Vec<_>>()
311 .join(", ");
312 format!("[{inner}]")
313 }
314 ShapeKind::Opaque => "unknown".into(),
315 }
316}
317
318pub fn ts_scalar_type(scalar: ScalarType) -> String {
320 match scalar {
321 ScalarType::Bool => "boolean".into(),
322 ScalarType::U8
323 | ScalarType::U16
324 | ScalarType::U32
325 | ScalarType::I8
326 | ScalarType::I16
327 | ScalarType::I32
328 | ScalarType::F32
329 | ScalarType::F64 => "number".into(),
330 ScalarType::U64
331 | ScalarType::U128
332 | ScalarType::I64
333 | ScalarType::I128
334 | ScalarType::USize
335 | ScalarType::ISize => "bigint".into(),
336 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
337 "string".into()
338 }
339 ScalarType::Unit => "void".into(),
340 _ => "unknown".into(),
341 }
342}
343
344pub fn ts_type_client_arg(shape: &'static Shape) -> String {
348 match classify_shape(shape) {
349 ShapeKind::Tx { inner } => format!("Tx<{}>", ts_type_client_arg(inner)),
350 ShapeKind::Rx { inner } => format!("Rx<{}>", ts_type_client_arg(inner)),
351 _ => ts_type_base_named(shape),
352 }
353}
354
355pub fn ts_type_client_return(shape: &'static Shape) -> String {
358 assert_no_channels_in_return_shape(shape);
359 ts_type_base_named(shape)
360}
361
362pub fn ts_type_server_arg(shape: &'static Shape) -> String {
366 match classify_shape(shape) {
367 ShapeKind::Tx { inner } => format!("Tx<{}>", ts_type_server_arg(inner)),
368 ShapeKind::Rx { inner } => format!("Rx<{}>", ts_type_server_arg(inner)),
369 _ => ts_type_base_named(shape),
370 }
371}
372
373pub fn ts_type_server_return(shape: &'static Shape) -> String {
375 assert_no_channels_in_return_shape(shape);
376 ts_type_base_named(shape)
377}
378
379pub fn ts_type(shape: &'static Shape) -> String {
382 ts_type_base_named(shape)
383}
384
385pub fn is_fully_supported(shape: &'static Shape) -> bool {
388 match classify_shape(shape) {
389 ShapeKind::Tx { inner } | ShapeKind::Rx { inner } => is_fully_supported(inner),
391 ShapeKind::List { element }
392 | ShapeKind::Option { inner: element }
393 | ShapeKind::Set { element }
394 | ShapeKind::Array { element, .. }
395 | ShapeKind::Slice { element } => is_fully_supported(element),
396 ShapeKind::Map { key, value } => is_fully_supported(key) && is_fully_supported(value),
397 ShapeKind::Tuple { elements } => elements.iter().all(|p| is_fully_supported(p.shape)),
398 ShapeKind::TupleStruct { fields } => fields.iter().all(|f| is_fully_supported(f.shape())),
399 ShapeKind::Struct(StructInfo { fields, .. }) => {
400 fields.iter().all(|f| is_fully_supported(f.shape()))
401 }
402 ShapeKind::Enum(EnumInfo { variants, .. }) => {
403 variants.iter().all(|v| match classify_variant(v) {
404 VariantKind::Unit => true,
405 VariantKind::Newtype { inner } => is_fully_supported(inner),
406 VariantKind::Tuple { fields } | VariantKind::Struct { fields } => {
407 fields.iter().all(|f| is_fully_supported(f.shape()))
408 }
409 })
410 }
411 ShapeKind::Pointer { pointee } => is_fully_supported(pointee),
412 ShapeKind::Scalar(_) => true,
413 ShapeKind::Result { ok, err } => is_fully_supported(ok) && is_fully_supported(err),
414 ShapeKind::Opaque => false,
415 }
416}
417
418fn assert_no_channels_in_return_shape(shape: &'static Shape) {
419 assert!(
420 RpcPlan::for_shape(shape).channel_locations.is_empty(),
421 "channels are not allowed in return types"
422 );
423}