1use facet_core::{Field, ScalarType, Shape};
12use heck::ToLowerCamelCase;
13use roam_types::{
14 EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
15};
16
17pub struct WireType {
19 pub swift_name: String,
21 pub shape: &'static Shape,
23}
24
25pub fn generate_wire_types(types: &[WireType]) -> String {
27 let mut out = String::new();
28 out.push_str("// @generated by roam-codegen\n");
29 out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
30 out.push_str("import Foundation\n\n");
31
32 out.push_str(&generate_preamble());
34
35 out.push_str("public typealias MetadataV7 = [MetadataEntryV7]\n\n");
37
38 for wt in types {
40 out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
41 out.push('\n');
42 }
43
44 out.push_str(&generate_factory_methods(types));
46
47 out
48}
49
50fn generate_preamble() -> String {
52 let mut out = String::new();
53
54 out.push_str("public enum WireV7Error: Error, Equatable {\n");
56 out.push_str(" case truncated\n");
57 out.push_str(" case unknownVariant(UInt64)\n");
58 out.push_str(" case overflow\n");
59 out.push_str(" case invalidUtf8\n");
60 out.push_str(" case trailingBytes\n");
61 out.push_str("}\n\n");
62
63 out.push_str("public struct OpaquePayloadV7: Sendable, Equatable {\n");
65 out.push_str(" public var bytes: [UInt8]\n\n");
66 out.push_str(" public init(_ bytes: [UInt8]) {\n");
67 out.push_str(" self.bytes = bytes\n");
68 out.push_str(" }\n\n");
69 out.push_str(" func encode() -> [UInt8] {\n");
70 out.push_str(" encodeBytes(bytes)\n");
71 out.push_str(" }\n\n");
72 out.push_str(" static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
73 out.push_str(" .init(Array(try decodeBytesV7(from: data, offset: &offset)))\n");
74 out.push_str(" }\n\n");
75 out.push_str(" /// Encode without a length prefix — for trailing fields only.\n");
76 out.push_str(" func encodeTrailing() -> [UInt8] {\n");
77 out.push_str(" bytes\n");
78 out.push_str(" }\n\n");
79 out.push_str(" /// Decode by consuming all remaining bytes — for trailing fields only.\n");
80 out.push_str(" static func decodeTrailing(from data: Data, offset: inout Int) -> Self {\n");
81 out.push_str(" let start = data.startIndex + offset\n");
82 out.push_str(" let remaining = Array(data[start...])\n");
83 out.push_str(" offset = data.count\n");
84 out.push_str(" return .init(remaining)\n");
85 out.push_str(" }\n");
86 out.push_str("}\n\n");
87
88 out.push_str("@inline(__always)\n");
90 out.push_str(
91 "private func decodeVarintU32V7(from data: Data, offset: inout Int) throws -> UInt32 {\n",
92 );
93 out.push_str(" let value = try decodeVarint(from: data, offset: &offset)\n");
94 out.push_str(" guard value <= UInt64(UInt32.max) else {\n");
95 out.push_str(" throw WireV7Error.overflow\n");
96 out.push_str(" }\n");
97 out.push_str(" return UInt32(value)\n");
98 out.push_str("}\n\n");
99
100 out.push_str("@inline(__always)\n");
101 out.push_str(
102 "private func decodeStringV7(from data: Data, offset: inout Int) throws -> String {\n",
103 );
104 out.push_str(" do {\n");
105 out.push_str(" return try decodeString(from: data, offset: &offset)\n");
106 out.push_str(" } catch PostcardError.invalidUtf8 {\n");
107 out.push_str(" throw WireV7Error.invalidUtf8\n");
108 out.push_str(" } catch PostcardError.truncated {\n");
109 out.push_str(" throw WireV7Error.truncated\n");
110 out.push_str(" } catch {\n");
111 out.push_str(" throw error\n");
112 out.push_str(" }\n");
113 out.push_str("}\n\n");
114
115 out.push_str("@inline(__always)\n");
116 out.push_str(
117 "private func decodeBytesV7(from data: Data, offset: inout Int) throws -> Data {\n",
118 );
119 out.push_str(" do {\n");
120 out.push_str(" return try decodeBytes(from: data, offset: &offset)\n");
121 out.push_str(" } catch PostcardError.truncated {\n");
122 out.push_str(" throw WireV7Error.truncated\n");
123 out.push_str(" } catch {\n");
124 out.push_str(" throw error\n");
125 out.push_str(" }\n");
126 out.push_str("}\n\n");
127
128 out
129}
130
131fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
133 if is_bytes(shape) {
134 return "[UInt8]".into();
135 }
136
137 match classify_shape(shape) {
138 ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
139 ShapeKind::List { element } | ShapeKind::Slice { element } => {
140 format!("[{}]", swift_wire_type(element, None, types))
141 }
142 ShapeKind::Option { inner } => {
143 format!("{}?", swift_wire_type(inner, None, types))
144 }
145 ShapeKind::Array { element, .. } => {
146 format!("[{}]", swift_wire_type(element, None, types))
147 }
148 ShapeKind::Struct(StructInfo {
149 name: Some(name), ..
150 }) => lookup_wire_name(name, types),
151 ShapeKind::Enum(EnumInfo {
152 name: Some(name), ..
153 }) => lookup_wire_name(name, types),
154 ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
155 ShapeKind::Opaque => "OpaquePayloadV7".into(),
156 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
157 swift_wire_type(fields[0].shape(), None, types)
158 }
159 _ => "Any /* unsupported */".into(),
160 }
161}
162
163fn swift_scalar_type(scalar: ScalarType) -> String {
164 match scalar {
165 ScalarType::Bool => "Bool".into(),
166 ScalarType::U8 => "UInt8".into(),
167 ScalarType::U16 => "UInt16".into(),
168 ScalarType::U32 => "UInt32".into(),
169 ScalarType::U64 | ScalarType::USize => "UInt64".into(),
170 ScalarType::I8 => "Int8".into(),
171 ScalarType::I16 => "Int16".into(),
172 ScalarType::I32 => "Int32".into(),
173 ScalarType::I64 | ScalarType::ISize => "Int64".into(),
174 ScalarType::F32 => "Float".into(),
175 ScalarType::F64 => "Double".into(),
176 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
177 "String".into()
178 }
179 ScalarType::Unit => "Void".into(),
180 _ => "Data".into(),
181 }
182}
183
184fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
187 for wt in types {
188 if wt.shape.type_identifier.ends_with(rust_name) {
192 return wt.swift_name.clone();
193 }
194 }
195 rust_name.to_string()
196}
197
198fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
200 match classify_shape(shape) {
201 ShapeKind::Struct(StructInfo { fields, .. }) => {
202 generate_struct(swift_name, fields, types, swift_name == "MessageV7")
203 }
204 ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
205 _ => format!("// Unsupported shape for {swift_name}\n"),
206 }
207}
208
209fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
211 let mut out = String::new();
212
213 out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
215 for f in fields {
216 let field_name = f.name.to_lower_camel_case();
217 let field_type = swift_wire_type(f.shape(), Some(f), types);
218 out.push_str(&format!(" public var {field_name}: {field_type}\n"));
219 }
220
221 if !fields.is_empty() {
223 out.push_str("\n public init(");
224 for (i, f) in fields.iter().enumerate() {
225 if i > 0 {
226 out.push_str(", ");
227 }
228 let field_name = f.name.to_lower_camel_case();
229 let field_type = swift_wire_type(f.shape(), Some(f), types);
230 out.push_str(&format!("{field_name}: {field_type}"));
231 }
232 out.push_str(") {\n");
233 for f in fields {
234 let field_name = f.name.to_lower_camel_case();
235 out.push_str(&format!(" self.{field_name} = {field_name}\n"));
236 }
237 out.push_str(" }\n");
238 }
239
240 let vis = if is_top_level { "public " } else { "" };
242 out.push_str(&format!("\n {vis}func encode() -> [UInt8] {{\n"));
243 if fields.is_empty() {
244 out.push_str(" []\n");
245 } else if fields.len() == 1 {
246 let f = &fields[0];
247 let expr = encode_field_expr(f, types);
248 out.push_str(&format!(" {expr}\n"));
249 } else {
250 out.push_str(" var out: [UInt8] = []\n");
251 for f in fields {
252 let expr = encode_field_expr(f, types);
253 out.push_str(&format!(" out += {expr}\n"));
254 }
255 out.push_str(" return out\n");
256 }
257 out.push_str(" }\n");
258
259 out.push_str(&format!(
261 "\n {vis}static func decode(from data: Data, offset: inout Int) throws -> Self {{\n"
262 ));
263 for f in fields {
264 let field_name = f.name.to_lower_camel_case();
265 let decode_expr = decode_field_expr(f, types);
266 out.push_str(&format!(" let {field_name} = {decode_expr}\n"));
267 }
268 let field_names: Vec<String> = fields
269 .iter()
270 .map(|f| {
271 let n = f.name.to_lower_camel_case();
272 format!("{n}: {n}")
273 })
274 .collect();
275 out.push_str(&format!(
276 " return .init({})\n",
277 field_names.join(", ")
278 ));
279 out.push_str(" }\n");
280
281 if is_top_level {
283 out.push_str("\n public static func decode(from data: Data) throws -> Self {\n");
284 out.push_str(" var offset = 0\n");
285 out.push_str(" let result = try decode(from: data, offset: &offset)\n");
286 out.push_str(" guard offset == data.count else {\n");
287 out.push_str(" throw WireV7Error.trailingBytes\n");
288 out.push_str(" }\n");
289 out.push_str(" return result\n");
290 out.push_str(" }\n");
291 }
292
293 out.push_str("}\n");
294 out
295}
296
297fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
299 let mut out = String::new();
300
301 out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
303 for v in variants {
304 let variant_name = v.name.to_lower_camel_case();
305 match classify_variant(v) {
306 VariantKind::Unit => {
307 out.push_str(&format!(" case {variant_name}\n"));
308 }
309 VariantKind::Newtype { inner } => {
310 let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
311 out.push_str(&format!(" case {variant_name}({inner_type})\n"));
312 }
313 VariantKind::Tuple { fields } => {
314 let field_types: Vec<String> = fields
315 .iter()
316 .map(|f| swift_wire_type(f.shape(), Some(f), types))
317 .collect();
318 out.push_str(&format!(
319 " case {variant_name}({})\n",
320 field_types.join(", ")
321 ));
322 }
323 VariantKind::Struct { fields } => {
324 let field_decls: Vec<String> = fields
325 .iter()
326 .map(|f| {
327 format!(
328 "{}: {}",
329 f.name.to_lower_camel_case(),
330 swift_wire_type(f.shape(), Some(f), types)
331 )
332 })
333 .collect();
334 out.push_str(&format!(
335 " case {variant_name}({})\n",
336 field_decls.join(", ")
337 ));
338 }
339 }
340 }
341
342 out.push_str("\n func encode() -> [UInt8] {\n");
344 out.push_str(" switch self {\n");
345 for (i, v) in variants.iter().enumerate() {
346 let variant_name = v.name.to_lower_camel_case();
347 match classify_variant(v) {
348 VariantKind::Unit => {
349 out.push_str(&format!(
350 " case .{variant_name}:\n return encodeVarint(UInt64({i}))\n"
351 ));
352 }
353 VariantKind::Newtype { inner } => {
354 let encode = encode_shape_expr(inner, "val", v.data.fields.first(), types);
355 out.push_str(&format!(
356 " case .{variant_name}(let val):\n return encodeVarint(UInt64({i})) + {encode}\n"
357 ));
358 }
359 VariantKind::Tuple { fields } => {
360 let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
361 let binding_str = bindings
362 .iter()
363 .map(|b| format!("let {b}"))
364 .collect::<Vec<_>>()
365 .join(", ");
366 let encodes: Vec<String> = fields
367 .iter()
368 .enumerate()
369 .map(|(j, f)| encode_shape_expr(f.shape(), &format!("f{j}"), Some(f), types))
370 .collect();
371 out.push_str(&format!(
372 " case .{variant_name}({binding_str}):\n return encodeVarint(UInt64({i})) + {}\n",
373 encodes.join(" + ")
374 ));
375 }
376 VariantKind::Struct { fields } => {
377 let bindings: Vec<String> = fields
378 .iter()
379 .map(|f| f.name.to_lower_camel_case())
380 .collect();
381 let binding_str = bindings
382 .iter()
383 .map(|b| format!("let {b}"))
384 .collect::<Vec<_>>()
385 .join(", ");
386 let encodes: Vec<String> = fields
387 .iter()
388 .map(|f| {
389 encode_shape_expr(f.shape(), &f.name.to_lower_camel_case(), Some(f), types)
390 })
391 .collect();
392 out.push_str(&format!(
393 " case .{variant_name}({binding_str}):\n return encodeVarint(UInt64({i})) + {}\n",
394 encodes.join(" + ")
395 ));
396 }
397 }
398 }
399 out.push_str(" }\n");
400 out.push_str(" }\n");
401
402 out.push_str("\n static func decode(from data: Data, offset: inout Int) throws -> Self {\n");
404 out.push_str(" let disc = try decodeVarint(from: data, offset: &offset)\n");
405 out.push_str(" switch disc {\n");
406 for (i, v) in variants.iter().enumerate() {
407 let variant_name = v.name.to_lower_camel_case();
408 out.push_str(&format!(" case {i}:\n"));
409 match classify_variant(v) {
410 VariantKind::Unit => {
411 out.push_str(&format!(" return .{variant_name}\n"));
412 }
413 VariantKind::Newtype { inner } => {
414 let decode = decode_shape_expr(inner, v.data.fields.first(), types);
415 out.push_str(&format!(" return .{variant_name}({decode})\n"));
416 }
417 VariantKind::Tuple { fields } => {
418 for (j, f) in fields.iter().enumerate() {
419 let decode = decode_shape_expr(f.shape(), Some(f), types);
420 out.push_str(&format!(" let f{j} = {decode}\n"));
421 }
422 let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
423 out.push_str(&format!(
424 " return .{variant_name}({})\n",
425 args.join(", ")
426 ));
427 }
428 VariantKind::Struct { fields } => {
429 for f in fields {
430 let field_name = f.name.to_lower_camel_case();
431 let decode = decode_shape_expr(f.shape(), Some(f), types);
432 out.push_str(&format!(" let {field_name} = {decode}\n"));
433 }
434 let args: Vec<String> = fields
435 .iter()
436 .map(|f| {
437 let n = f.name.to_lower_camel_case();
438 format!("{n}: {n}")
439 })
440 .collect();
441 out.push_str(&format!(
442 " return .{variant_name}({})\n",
443 args.join(", ")
444 ));
445 }
446 }
447 }
448 out.push_str(" default:\n");
449 out.push_str(" throw WireV7Error.unknownVariant(disc)\n");
450 out.push_str(" }\n");
451 out.push_str(" }\n");
452
453 out.push_str("}\n");
454 out
455}
456
457fn encode_field_expr(field: &Field, types: &[WireType]) -> String {
459 let field_name = field.name.to_lower_camel_case();
460 encode_shape_expr(field.shape(), &field_name, Some(field), types)
461}
462
463fn encode_shape_expr(
465 shape: &'static Shape,
466 value: &str,
467 field: Option<&Field>,
468 types: &[WireType],
469) -> String {
470 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
471
472 if matches!(classify_shape(shape), ShapeKind::Opaque) {
474 return if is_trailing {
475 format!("{value}.encodeTrailing()")
476 } else {
477 format!("{value}.encode()")
478 };
479 }
480
481 if is_bytes(shape) {
482 return format!("encodeBytes({value})");
483 }
484
485 match classify_shape(shape) {
486 ShapeKind::Scalar(scalar) => encode_scalar(scalar, value),
487 ShapeKind::List { element } | ShapeKind::Slice { element } => {
488 let inner = encode_element_closure(element, types);
489 format!("encodeVec({value}, encoder: {inner})")
490 }
491 ShapeKind::Option { inner } => {
492 let inner_closure = encode_element_closure(inner, types);
493 format!("encodeOption({value}, encoder: {inner_closure})")
494 }
495 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
496 format!("{value}.encode()")
497 }
498 ShapeKind::Pointer { pointee } => encode_shape_expr(pointee, value, field, types),
499 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
500 encode_shape_expr(fields[0].shape(), value, field, types)
501 }
502 _ => format!("[] /* unsupported encode for {value} */"),
503 }
504}
505
506fn encode_scalar(scalar: ScalarType, value: &str) -> String {
507 match scalar {
508 ScalarType::Bool => format!("encodeBool({value})"),
509 ScalarType::U8 => format!("encodeU8({value})"),
510 ScalarType::I8 => format!("encodeI8({value})"),
511 ScalarType::U16 => format!("encodeU16({value})"),
512 ScalarType::I16 => format!("encodeI16({value})"),
513 ScalarType::U32 => format!("encodeVarint(UInt64({value}))"),
514 ScalarType::I32 => format!("encodeI32({value})"),
515 ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value})"),
516 ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value})"),
517 ScalarType::F32 => format!("encodeF32({value})"),
518 ScalarType::F64 => format!("encodeF64({value})"),
519 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
520 format!("encodeString({value})")
521 }
522 _ => "[] /* unsupported scalar */".to_string(),
523 }
524}
525
526fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
528 if is_bytes(shape) {
529 return "{ encodeBytes($0) }".into();
530 }
531
532 match classify_shape(shape) {
533 ShapeKind::Scalar(scalar) => {
534 let expr = encode_scalar(scalar, "$0");
535 format!("{{ {expr} }}")
536 }
537 ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
538 "{ $0.encode() }".into()
539 }
540 ShapeKind::List { element } | ShapeKind::Slice { element } => {
541 let inner = encode_element_closure(element, _types);
542 format!("{{ encodeVec($0, encoder: {inner}) }}")
543 }
544 ShapeKind::Opaque => "{ $0.encode() }".into(),
545 ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
546 _ => "{ _ in [] }".into(),
547 }
548}
549
550fn decode_field_expr(field: &Field, types: &[WireType]) -> String {
552 decode_shape_expr(field.shape(), Some(field), types)
553}
554
555fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
557 let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
558
559 if matches!(classify_shape(shape), ShapeKind::Opaque) {
561 return if is_trailing {
562 "OpaquePayloadV7.decodeTrailing(from: data, offset: &offset)".into()
563 } else {
564 "try OpaquePayloadV7.decode(from: data, offset: &offset)".into()
565 };
566 }
567
568 if is_bytes(shape) {
569 return "Array(try decodeBytesV7(from: data, offset: &offset))".into();
570 }
571
572 match classify_shape(shape) {
573 ShapeKind::Scalar(scalar) => decode_scalar(scalar),
574 ShapeKind::List { element } | ShapeKind::Slice { element } => {
575 let inner = decode_element_closure(element, types);
576 format!("try decodeVec(from: data, offset: &offset, decoder: {inner})")
577 }
578 ShapeKind::Option { inner } => {
579 let inner_closure = decode_element_closure(inner, types);
580 format!("try decodeOption(from: data, offset: &offset, decoder: {inner_closure})")
581 }
582 ShapeKind::Struct(StructInfo {
583 name: Some(name), ..
584 }) => {
585 let swift_name = lookup_wire_name(name, types);
586 format!("try {swift_name}.decode(from: data, offset: &offset)")
587 }
588 ShapeKind::Enum(EnumInfo {
589 name: Some(name), ..
590 }) => {
591 let swift_name = lookup_wire_name(name, types);
592 format!("try {swift_name}.decode(from: data, offset: &offset)")
593 }
594 ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
595 ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
596 decode_shape_expr(fields[0].shape(), field, types)
597 }
598 _ => "nil /* unsupported decode */".into(),
599 }
600}
601
602fn decode_scalar(scalar: ScalarType) -> String {
603 match scalar {
604 ScalarType::Bool => "try decodeBool(from: data, offset: &offset)".into(),
605 ScalarType::U8 => "try decodeU8(from: data, offset: &offset)".into(),
606 ScalarType::I8 => "try decodeI8(from: data, offset: &offset)".into(),
607 ScalarType::U16 => "try decodeU16(from: data, offset: &offset)".into(),
608 ScalarType::I16 => "try decodeI16(from: data, offset: &offset)".into(),
609 ScalarType::U32 => "try decodeVarintU32V7(from: data, offset: &offset)".into(),
610 ScalarType::I32 => "try decodeI32(from: data, offset: &offset)".into(),
611 ScalarType::U64 | ScalarType::USize => {
612 "try decodeVarint(from: data, offset: &offset)".into()
613 }
614 ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: data, offset: &offset)".into(),
615 ScalarType::F32 => "try decodeF32(from: data, offset: &offset)".into(),
616 ScalarType::F64 => "try decodeF64(from: data, offset: &offset)".into(),
617 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
618 "try decodeStringV7(from: data, offset: &offset)".into()
619 }
620 _ => "nil /* unsupported scalar decode */".into(),
621 }
622}
623
624fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
626 if is_bytes(shape) {
627 return "{ data, off in Array(try decodeBytesV7(from: data, offset: &off)) }".into();
628 }
629
630 match classify_shape(shape) {
631 ShapeKind::Scalar(scalar) => {
632 let expr = decode_scalar_with_params(scalar, "data", "off");
633 format!("{{ data, off in {expr} }}")
634 }
635 ShapeKind::Struct(StructInfo {
636 name: Some(name), ..
637 }) => {
638 let swift_name = lookup_wire_name(name, types);
639 format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
640 }
641 ShapeKind::Enum(EnumInfo {
642 name: Some(name), ..
643 }) => {
644 let swift_name = lookup_wire_name(name, types);
645 format!("{{ data, off in try {swift_name}.decode(from: data, offset: &off) }}")
646 }
647 ShapeKind::List { element } | ShapeKind::Slice { element } => {
648 let inner = decode_element_closure(element, types);
649 format!("{{ data, off in try decodeVec(from: data, offset: &off, decoder: {inner}) }}")
650 }
651 ShapeKind::Opaque => {
652 "{ data, off in try OpaquePayloadV7.decode(from: data, offset: &off) }".into()
653 }
654 ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
655 _ => "{ _, _ in throw WireV7Error.truncated }".into(),
656 }
657}
658
659fn decode_scalar_with_params(scalar: ScalarType, data: &str, offset: &str) -> String {
660 match scalar {
661 ScalarType::Bool => format!("try decodeBool(from: {data}, offset: &{offset})"),
662 ScalarType::U8 => format!("try decodeU8(from: {data}, offset: &{offset})"),
663 ScalarType::I8 => format!("try decodeI8(from: {data}, offset: &{offset})"),
664 ScalarType::U16 => format!("try decodeU16(from: {data}, offset: &{offset})"),
665 ScalarType::I16 => format!("try decodeI16(from: {data}, offset: &{offset})"),
666 ScalarType::U32 => format!("try decodeVarintU32V7(from: {data}, offset: &{offset})"),
667 ScalarType::I32 => format!("try decodeI32(from: {data}, offset: &{offset})"),
668 ScalarType::U64 | ScalarType::USize => {
669 format!("try decodeVarint(from: {data}, offset: &{offset})")
670 }
671 ScalarType::I64 | ScalarType::ISize => {
672 format!("try decodeI64(from: {data}, offset: &{offset})")
673 }
674 ScalarType::F32 => format!("try decodeF32(from: {data}, offset: &{offset})"),
675 ScalarType::F64 => format!("try decodeF64(from: {data}, offset: &{offset})"),
676 ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
677 format!("try decodeStringV7(from: {data}, offset: &{offset})")
678 }
679 _ => "nil /* unsupported scalar */".to_string(),
680 }
681}
682
683fn generate_factory_methods(types: &[WireType]) -> String {
685 let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayloadV7");
687 let payload_wt = match payload_wt {
688 Some(wt) => wt,
689 None => return String::new(),
690 };
691
692 let variants = match classify_shape(payload_wt.shape) {
693 ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
694 _ => return String::new(),
695 };
696
697 let mut out = String::new();
698 out.push_str("public extension MessageV7 {\n");
699
700 for v in variants {
701 let variant_name = v.name.to_lower_camel_case();
702 if let VariantKind::Newtype { inner } = classify_variant(v) {
703 let inner_swift = swift_wire_type(inner, None, types);
704 let is_control = matches!(
706 v.name,
707 "Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
708 );
709
710 if is_control {
711 out.push_str(&format!(
712 " static func {variant_name}(_ value: {inner_swift}) -> MessageV7 {{\n"
713 ));
714 out.push_str(&format!(
715 " MessageV7(connectionId: 0, payload: .{variant_name}(value))\n"
716 ));
717 out.push_str(" }\n\n");
718 } else {
719 out.push_str(&format!(
720 " static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> MessageV7 {{\n"
721 ));
722 out.push_str(&format!(
723 " MessageV7(connectionId: connId, payload: .{variant_name}(value))\n"
724 ));
725 out.push_str(" }\n\n");
726 }
727 }
728 }
729
730 out.push_str(" static func protocolError(description: String) -> MessageV7 {\n");
733 out.push_str(" MessageV7(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
734 out.push_str(" }\n\n");
735
736 out.push_str(" static func connectionOpen(connId: UInt64, settings: ConnectionSettingsV7, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
737 out.push_str(" MessageV7(connectionId: connId, payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
738 out.push_str(" }\n\n");
739
740 out.push_str(" static func connectionAccept(connId: UInt64, settings: ConnectionSettingsV7, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
741 out.push_str(" MessageV7(connectionId: connId, payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
742 out.push_str(" }\n\n");
743
744 out.push_str(" static func connectionReject(connId: UInt64, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
745 out.push_str(" MessageV7(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
746 out.push_str(" }\n\n");
747
748 out.push_str(" static func connectionClose(connId: UInt64, metadata: [MetadataEntryV7]) -> MessageV7 {\n");
749 out.push_str(" MessageV7(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
750 out.push_str(" }\n\n");
751
752 out.push_str(" static func request(\n");
753 out.push_str(" connId: UInt64,\n");
754 out.push_str(" requestId: UInt64,\n");
755 out.push_str(" methodId: UInt64,\n");
756 out.push_str(" metadata: [MetadataEntryV7],\n");
757 out.push_str(" channels: [UInt64],\n");
758 out.push_str(" payload: [UInt8]\n");
759 out.push_str(" ) -> MessageV7 {\n");
760 out.push_str(" MessageV7(\n");
761 out.push_str(" connectionId: connId,\n");
762 out.push_str(" payload: .requestMessage(\n");
763 out.push_str(" .init(\n");
764 out.push_str(" id: requestId,\n");
765 out.push_str(" body: .call(.init(methodId: methodId, channels: channels, metadata: metadata, args: .init(payload)))\n");
766 out.push_str(" ))\n");
767 out.push_str(" )\n");
768 out.push_str(" }\n\n");
769
770 out.push_str(" static func response(\n");
771 out.push_str(" connId: UInt64,\n");
772 out.push_str(" requestId: UInt64,\n");
773 out.push_str(" metadata: [MetadataEntryV7],\n");
774 out.push_str(" channels: [UInt64],\n");
775 out.push_str(" payload: [UInt8]\n");
776 out.push_str(" ) -> MessageV7 {\n");
777 out.push_str(" MessageV7(\n");
778 out.push_str(" connectionId: connId,\n");
779 out.push_str(" payload: .requestMessage(\n");
780 out.push_str(" .init(\n");
781 out.push_str(" id: requestId,\n");
782 out.push_str(" body: .response(.init(channels: channels, metadata: metadata, ret: .init(payload)))\n");
783 out.push_str(" ))\n");
784 out.push_str(" )\n");
785 out.push_str(" }\n\n");
786
787 out.push_str(" static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntryV7] = []) -> MessageV7 {\n");
788 out.push_str(" MessageV7(\n");
789 out.push_str(" connectionId: connId,\n");
790 out.push_str(" payload: .requestMessage(\n");
791 out.push_str(" .init(\n");
792 out.push_str(" id: requestId,\n");
793 out.push_str(" body: .cancel(.init(metadata: metadata))\n");
794 out.push_str(" ))\n");
795 out.push_str(" )\n");
796 out.push_str(" }\n\n");
797
798 out.push_str(" static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> MessageV7 {\n");
799 out.push_str(" MessageV7(\n");
800 out.push_str(" connectionId: connId,\n");
801 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
802 out.push_str(" )\n");
803 out.push_str(" }\n\n");
804
805 out.push_str(" static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntryV7] = []) -> MessageV7 {\n");
806 out.push_str(" MessageV7(\n");
807 out.push_str(" connectionId: connId,\n");
808 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
809 out.push_str(" )\n");
810 out.push_str(" }\n\n");
811
812 out.push_str(" static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntryV7] = []) -> MessageV7 {\n");
813 out.push_str(" MessageV7(\n");
814 out.push_str(" connectionId: connId,\n");
815 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
816 out.push_str(" )\n");
817 out.push_str(" }\n\n");
818
819 out.push_str(
820 " static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> MessageV7 {\n",
821 );
822 out.push_str(" MessageV7(\n");
823 out.push_str(" connectionId: connId,\n");
824 out.push_str(" payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
825 out.push_str(" )\n");
826 out.push_str(" }\n");
827
828 out.push_str("}\n");
829 out
830}