1use facet_core::{ScalarType, Shape};
6use heck::ToLowerCamelCase;
7use roam_schema::{
8 EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
9};
10
11pub fn generate_decode_stmt(shape: &'static Shape, var_name: &str, indent: &str) -> String {
14 generate_decode_stmt_from(shape, var_name, indent, "payload")
15}
16
17pub fn generate_decode_stmt_from(
20 shape: &'static Shape,
21 var_name: &str,
22 indent: &str,
23 data_var: &str,
24) -> String {
25 if is_bytes(shape) {
27 return format!(
28 "{indent}let {var_name} = try decodeBytes(from: {data_var}, offset: &offset)\n"
29 );
30 }
31
32 match classify_shape(shape) {
33 ShapeKind::Scalar(scalar) => {
34 let decode_fn = swift_decode_fn(scalar);
35 format!("{indent}let {var_name} = try {decode_fn}(from: {data_var}, offset: &offset)\n")
36 }
37 ShapeKind::List { element }
38 | ShapeKind::Slice { element }
39 | ShapeKind::Array { element, .. } => {
40 let inner_decode = generate_decode_closure(element);
41 format!(
42 "{indent}let {var_name} = try decodeVec(from: {data_var}, offset: &offset, decoder: {inner_decode})\n"
43 )
44 }
45 ShapeKind::Option { inner } => {
46 let inner_decode = generate_decode_closure(inner);
47 format!(
48 "{indent}let {var_name} = try decodeOption(from: {data_var}, offset: &offset, decoder: {inner_decode})\n"
49 )
50 }
51 ShapeKind::Tuple { elements } if elements.len() == 2 => {
52 let a_decode = generate_decode_closure(elements[0].shape);
53 let b_decode = generate_decode_closure(elements[1].shape);
54 format!(
55 "{indent}let {var_name} = try decodeTuple2(from: {data_var}, offset: &offset, decoderA: {a_decode}, decoderB: {b_decode})\n"
56 )
57 }
58 ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
59 let a_decode = generate_decode_closure(fields[0].shape());
60 let b_decode = generate_decode_closure(fields[1].shape());
61 format!(
62 "{indent}let {var_name} = try decodeTuple2(from: {data_var}, offset: &offset, decoderA: {a_decode}, decoderB: {b_decode})\n"
63 )
64 }
65 ShapeKind::Struct(StructInfo {
66 name: Some(name),
67 fields,
68 ..
69 }) => {
70 let mut out = String::new();
72 for f in fields.iter() {
73 let field_name = f.name.to_lower_camel_case();
74 out.push_str(&generate_decode_stmt_from(
75 f.shape(),
76 &format!("_{var_name}_{field_name}"),
77 indent,
78 data_var,
79 ));
80 }
81 let field_inits: Vec<String> = fields
82 .iter()
83 .map(|f| {
84 let field_name = f.name.to_lower_camel_case();
85 format!("{field_name}: _{var_name}_{field_name}")
86 })
87 .collect();
88 out.push_str(&format!(
89 "{indent}let {var_name} = {name}({})\n",
90 field_inits.join(", ")
91 ));
92 out
93 }
94 ShapeKind::Enum(EnumInfo {
95 name: Some(name),
96 variants,
97 ..
98 }) => {
99 let mut out = String::new();
101 out.push_str(&format!(
102 "{indent}let _{var_name}_disc = try decodeU8(from: {data_var}, offset: &offset)\n"
103 ));
104 out.push_str(&format!("{indent}let {var_name}: {name}\n"));
105 out.push_str(&format!("{indent}switch _{var_name}_disc {{\n"));
106 for (i, v) in variants.iter().enumerate() {
107 out.push_str(&format!("{indent}case {i}:\n"));
108 match classify_variant(v) {
109 VariantKind::Unit => {
110 out.push_str(&format!(
111 "{indent} {var_name} = .{}\n",
112 v.name.to_lower_camel_case()
113 ));
114 }
115 VariantKind::Newtype { inner } => {
116 out.push_str(&generate_decode_stmt_from(
117 inner,
118 &format!("_{var_name}_val"),
119 &format!("{indent} "),
120 data_var,
121 ));
122 out.push_str(&format!(
123 "{indent} {var_name} = .{}(_{var_name}_val)\n",
124 v.name.to_lower_camel_case()
125 ));
126 }
127 VariantKind::Tuple { fields } => {
128 for (j, f) in fields.iter().enumerate() {
129 out.push_str(&generate_decode_stmt_from(
130 f.shape(),
131 &format!("_{var_name}_f{j}"),
132 &format!("{indent} "),
133 data_var,
134 ));
135 }
136 let args: Vec<String> = (0..fields.len())
137 .map(|j| format!("_{var_name}_f{j}"))
138 .collect();
139 out.push_str(&format!(
140 "{indent} {var_name} = .{}({})\n",
141 v.name.to_lower_camel_case(),
142 args.join(", ")
143 ));
144 }
145 VariantKind::Struct { fields } => {
146 for f in fields.iter() {
147 let field_name = f.name.to_lower_camel_case();
148 out.push_str(&generate_decode_stmt_from(
149 f.shape(),
150 &format!("_{var_name}_{field_name}"),
151 &format!("{indent} "),
152 data_var,
153 ));
154 }
155 let args: Vec<String> = fields
156 .iter()
157 .map(|f| {
158 let field_name = f.name.to_lower_camel_case();
159 format!("{field_name}: _{var_name}_{field_name}")
160 })
161 .collect();
162 out.push_str(&format!(
163 "{indent} {var_name} = .{}({})\n",
164 v.name.to_lower_camel_case(),
165 args.join(", ")
166 ));
167 }
168 }
169 }
170 out.push_str(&format!("{indent}default:\n"));
171 out.push_str(&format!(
172 "{indent} throw RoamError.decodeError(\"unknown enum variant\")\n"
173 ));
174 out.push_str(&format!("{indent}}}\n"));
175 out
176 }
177 ShapeKind::Pointer { pointee } => generate_decode_stmt(pointee, var_name, indent),
178 ShapeKind::Result { ok, err } => {
179 let ok_type = super::types::swift_type_base(ok);
181 let err_type = super::types::swift_type_base(err);
182 let mut out = String::new();
183 out.push_str(&format!(
184 "{indent}let _{var_name}_disc = try decodeU8(from: {data_var}, offset: &offset)\n"
185 ));
186 out.push_str(&format!(
187 "{indent}let {var_name}: Result<{ok_type}, {err_type}>\n"
188 ));
189 out.push_str(&format!("{indent}switch _{var_name}_disc {{\n"));
190 out.push_str(&format!("{indent}case 0:\n"));
191 out.push_str(&generate_decode_stmt_from(
192 ok,
193 &format!("_{var_name}_ok"),
194 &format!("{indent} "),
195 data_var,
196 ));
197 out.push_str(&format!(
198 "{indent} {var_name} = .success(_{var_name}_ok)\n"
199 ));
200 out.push_str(&format!("{indent}case 1:\n"));
201 out.push_str(&generate_decode_stmt_from(
202 err,
203 &format!("_{var_name}_err"),
204 &format!("{indent} "),
205 data_var,
206 ));
207 out.push_str(&format!(
208 "{indent} {var_name} = .failure(_{var_name}_err)\n"
209 ));
210 out.push_str(&format!("{indent}default:\n"));
211 out.push_str(&format!(
212 "{indent} throw RoamError.decodeError(\"invalid Result discriminant\")\n"
213 ));
214 out.push_str(&format!("{indent}}}\n"));
215 out
216 }
217 _ => {
218 format!("{indent}let {var_name}: Any = () // unsupported type\n")
220 }
221 }
222}
223
224pub fn generate_decode_closure(shape: &'static Shape) -> String {
226 if is_bytes(shape) {
227 return "{ data, off in try decodeBytes(from: data, offset: &off) }".into();
228 }
229
230 match classify_shape(shape) {
231 ShapeKind::Scalar(scalar) => {
232 let decode_fn = swift_decode_fn(scalar);
233 format!("{{ data, off in try {decode_fn}(from: data, offset: &off) }}")
234 }
235 ShapeKind::List { element } | ShapeKind::Slice { element } => {
236 let inner = generate_decode_closure(element);
237 format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
238 }
239 ShapeKind::Option { inner } => {
240 let inner_closure = generate_decode_closure(inner);
241 format!(
242 "{{ data, off in try decodeOption(from: data, offset: &off, decoder: {inner_closure}) }}"
243 )
244 }
245 ShapeKind::Tuple { elements } if elements.len() == 2 => {
246 let a_decode = generate_decode_closure(elements[0].shape);
247 let b_decode = generate_decode_closure(elements[1].shape);
248 format!(
249 "{{ data, off in try decodeTuple2(from: data, offset: &off, decoderA: {a_decode}, decoderB: {b_decode}) }}"
250 )
251 }
252 ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
253 let a_decode = generate_decode_closure(fields[0].shape());
254 let b_decode = generate_decode_closure(fields[1].shape());
255 format!(
256 "{{ data, off in try decodeTuple2(from: data, offset: &off, decoderA: {a_decode}, decoderB: {b_decode}) }}"
257 )
258 }
259 ShapeKind::Struct(StructInfo {
260 name: Some(name),
261 fields,
262 ..
263 }) => {
264 let mut code = "{ data, off in\n".to_string();
266 for f in fields.iter() {
267 let field_name = f.name.to_lower_camel_case();
268 let decode_call = generate_inline_decode(f.shape(), "data", "off");
269 code.push_str(&format!(" let _{field_name} = try {decode_call}\n"));
270 }
271 let field_inits: Vec<String> = fields
272 .iter()
273 .map(|f| {
274 let field_name = f.name.to_lower_camel_case();
275 format!("{field_name}: _{field_name}")
276 })
277 .collect();
278 code.push_str(&format!(
279 " return {name}({})\n}}",
280 field_inits.join(", ")
281 ));
282 code
283 }
284 ShapeKind::Enum(EnumInfo {
285 name: Some(name),
286 variants,
287 ..
288 }) => {
289 let mut code = format!(
291 "{{ data, off in\n let disc = try decodeU8(from: data, offset: &off)\n let result: {name}\n switch disc {{\n"
292 );
293 for (i, v) in variants.iter().enumerate() {
294 code.push_str(&format!(" case {i}:\n"));
295 match classify_variant(v) {
296 VariantKind::Unit => {
297 code.push_str(&format!(
298 " result = .{}\n",
299 v.name.to_lower_camel_case()
300 ));
301 }
302 VariantKind::Newtype { inner } => {
303 let inner_decode = generate_inline_decode(inner, "data", "off");
304 code.push_str(&format!(
305 " let val = try {inner_decode}\n result = .{}(val)\n",
306 v.name.to_lower_camel_case()
307 ));
308 }
309 VariantKind::Tuple { fields } => {
310 for (j, f) in fields.iter().enumerate() {
311 let inner_decode = generate_inline_decode(f.shape(), "data", "off");
312 code.push_str(&format!(" let f{j} = try {inner_decode}\n"));
313 }
314 let args: Vec<String> =
315 (0..fields.len()).map(|j| format!("f{j}")).collect();
316 code.push_str(&format!(
317 " result = .{}({})\n",
318 v.name.to_lower_camel_case(),
319 args.join(", ")
320 ));
321 }
322 VariantKind::Struct { fields } => {
323 for f in fields.iter() {
324 let field_name = f.name.to_lower_camel_case();
325 let inner_decode = generate_inline_decode(f.shape(), "data", "off");
326 code.push_str(&format!(
327 " let _{field_name} = try {inner_decode}\n"
328 ));
329 }
330 let args: Vec<String> = fields
331 .iter()
332 .map(|f| {
333 let field_name = f.name.to_lower_camel_case();
334 format!("{field_name}: _{field_name}")
335 })
336 .collect();
337 code.push_str(&format!(
338 " result = .{}({})\n",
339 v.name.to_lower_camel_case(),
340 args.join(", ")
341 ));
342 }
343 }
344 }
345 code.push_str(" default:\n throw RoamError.decodeError(\"unknown enum variant\")\n }\n return result\n}");
346 code
347 }
348 ShapeKind::Pointer { pointee } => generate_decode_closure(pointee),
349 _ => "{ _, _ in throw RoamError.decodeError(\"unsupported type\") }".into(),
350 }
351}
352
353pub fn generate_inline_decode(shape: &'static Shape, data_var: &str, offset_var: &str) -> String {
355 if is_bytes(shape) {
356 return format!("decodeBytes(from: {data_var}, offset: &{offset_var})");
357 }
358
359 match classify_shape(shape) {
360 ShapeKind::Scalar(scalar) => {
361 let decode_fn = swift_decode_fn(scalar);
362 format!("{decode_fn}(from: {data_var}, offset: &{offset_var})")
363 }
364 ShapeKind::List { element } | ShapeKind::Slice { element } => {
365 let inner = generate_decode_closure(element);
366 format!("decodeVec(from: {data_var}, offset: &{offset_var}, decoder: {inner})")
367 }
368 ShapeKind::Option { inner } => {
369 let inner_closure = generate_decode_closure(inner);
370 format!(
371 "decodeOption(from: {data_var}, offset: &{offset_var}, decoder: {inner_closure})"
372 )
373 }
374 ShapeKind::Pointer { pointee } => generate_inline_decode(pointee, data_var, offset_var),
375 _ => "{ throw RoamError.decodeError(\"unsupported\") }()".to_string(),
376 }
377}
378
379pub fn swift_decode_fn(scalar: ScalarType) -> &'static str {
381 match scalar {
382 ScalarType::Bool => "decodeBool",
383 ScalarType::U8 => "decodeU8",
384 ScalarType::I8 => "decodeI8",
385 ScalarType::U16 => "decodeU16",
386 ScalarType::I16 => "decodeI16",
387 ScalarType::U32 => "decodeU32",
388 ScalarType::I32 => "decodeI32",
389 ScalarType::U64 | ScalarType::USize => "decodeVarint",
390 ScalarType::I64 | ScalarType::ISize => "decodeI64",
391 ScalarType::F32 => "decodeF32",
392 ScalarType::F64 => "decodeF64",
393 ScalarType::Char | ScalarType::Str | ScalarType::CowStr | ScalarType::String => {
394 "decodeString"
395 }
396 ScalarType::Unit => "{ _, _ in () }",
397 _ => "decodeBytes", }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use facet::Facet;
405
406 #[test]
407 fn test_decode_primitives() {
408 let result = generate_decode_stmt(<bool as Facet>::SHAPE, "x", " ");
409 assert!(result.contains("decodeBool"));
410 assert!(result.contains("let x"));
411
412 let result = generate_decode_stmt(<String as Facet>::SHAPE, "msg", " ");
413 assert!(result.contains("decodeString"));
414 assert!(result.contains("let msg"));
415 }
416
417 #[test]
418 fn test_decode_vec() {
419 let result = generate_decode_stmt(<Vec<i32> as Facet>::SHAPE, "items", " ");
420 assert!(result.contains("decodeVec"));
421 assert!(result.contains("decodeI32"));
422 }
423
424 #[test]
425 fn test_decode_option() {
426 let result = generate_decode_stmt(<Option<String> as Facet>::SHAPE, "val", " ");
427 assert!(result.contains("decodeOption"));
428 assert!(result.contains("decodeString"));
429 }
430
431 #[test]
432 fn test_decode_bytes() {
433 let result = generate_decode_stmt(<Vec<u8> as Facet>::SHAPE, "data", " ");
434 assert!(result.contains("decodeBytes"));
435 }
436
437 #[test]
438 fn test_inline_decode() {
439 let result = generate_inline_decode(<u32 as Facet>::SHAPE, "buf", "pos");
440 assert_eq!(result, "decodeU32(from: buf, offset: &pos)");
441 }
442}