use super::types::swift_field_name;
use facet_core::{Field, ScalarType, Shape};
use vox_types::{
DEFAULT_INITIAL_CHANNEL_CREDIT, EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape,
classify_variant, extract_schemas, is_bytes,
};
pub struct WireType {
pub swift_name: String,
pub shape: &'static Shape,
}
pub fn generate_wire_types(types: &[WireType]) -> (String, Vec<u8>) {
let mut out = String::new();
out.push_str("// @generated by vox-codegen\n");
out.push_str("// DO NOT EDIT — regenerate with `cargo xtask codegen --swift-wire`\n\n");
out.push_str("import Foundation\n");
out.push_str("@preconcurrency import NIOCore\n\n");
out.push_str(&generate_preamble());
out.push_str("public typealias Metadata = [MetadataEntry]\n\n");
for wt in types {
out.push_str(&generate_one_type(&wt.swift_name, wt.shape, types));
out.push('\n');
}
out.push_str(&generate_factory_methods(types));
out.push_str(
"\n/// CBOR-encoded wire message schemas, loaded from the bundled binary resource.\n",
);
out.push_str("public let wireMessageSchemasCbor: [UInt8] = {\n");
out.push_str(
" guard let url = Bundle.module.url(forResource: \"wireMessageSchemas\", withExtension: \"bin\"),\n",
);
out.push_str(" let data = try? Data(contentsOf: url) else {\n");
out.push_str(
" preconditionFailure(\"wireMessageSchemas.bin resource not found in Bundle.module\")\n",
);
out.push_str(" }\n");
out.push_str(" return Array(data)\n");
out.push_str("}()\n");
let cbor_bytes = types
.iter()
.find(|wt| wt.swift_name == "Message")
.or_else(|| types.last())
.map(|root| {
let extracted =
extract_schemas(root.shape).expect("wire schema extraction should succeed");
facet_cbor::to_vec(&extracted.schemas)
.expect("wire schema CBOR serialization should succeed")
})
.unwrap_or_default();
(out, cbor_bytes)
}
fn generate_preamble() -> String {
let mut out = String::new();
out.push_str("public enum WireError: Error, Equatable {\n");
out.push_str(" case truncated\n");
out.push_str(" case unknownVariant(UInt64)\n");
out.push_str(" case overflow\n");
out.push_str(" case invalidUtf8\n");
out.push_str(" case trailingBytes\n");
out.push_str("}\n\n");
out.push_str("public struct OpaquePayload: Sendable, Equatable {\n");
out.push_str(" public var bytes: ByteBuffer\n\n");
out.push_str(" public init(_ bytes: ByteBuffer) {\n");
out.push_str(" self.bytes = bytes\n");
out.push_str(" }\n\n");
out.push_str(" /// Init from a [UInt8] for convenience (e.g. from legacy code)\n");
out.push_str(" public init(_ bytes: [UInt8]) {\n");
out.push_str(" var buf = ByteBufferAllocator().buffer(capacity: bytes.count)\n");
out.push_str(" buf.writeBytes(bytes)\n");
out.push_str(" self.bytes = buf\n");
out.push_str(" }\n\n");
out.push_str(" func encode(into buffer: inout ByteBuffer) {\n");
out.push_str(" let len = UInt32(bytes.readableBytes)\n");
out.push_str(" buffer.writeInteger(len, endianness: .little)\n");
out.push_str(" var copy = bytes\n");
out.push_str(" buffer.writeBuffer(©)\n");
out.push_str(" }\n\n");
out.push_str(" static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
out.push_str(
" guard let len: UInt32 = buffer.readInteger(endianness: .little) else {\n",
);
out.push_str(" throw WireError.truncated\n");
out.push_str(" }\n");
out.push_str(" guard let slice = buffer.readSlice(length: Int(len)) else {\n");
out.push_str(" throw WireError.truncated\n");
out.push_str(" }\n");
out.push_str(" return .init(slice)\n");
out.push_str(" }\n\n");
out.push_str(" /// Encode without a length prefix — for trailing fields only.\n");
out.push_str(" func encodeTrailing(into buffer: inout ByteBuffer) {\n");
out.push_str(" var copy = bytes\n");
out.push_str(" buffer.writeBuffer(©)\n");
out.push_str(" }\n\n");
out.push_str(" /// Decode by consuming all remaining bytes — for trailing fields only.\n");
out.push_str(" static func decodeTrailing(from buffer: inout ByteBuffer) -> Self {\n");
out.push_str(
" let slice = buffer.readSlice(length: buffer.readableBytes) ?? ByteBuffer()\n",
);
out.push_str(" return .init(slice)\n");
out.push_str(" }\n");
out.push_str("}\n\n");
out.push_str("@inline(__always)\n");
out.push_str(
"private func decodeWireVarintU32(from buffer: inout ByteBuffer) throws -> UInt32 {\n",
);
out.push_str(" let value = try decodeVarint(from: &buffer)\n");
out.push_str(" guard value <= UInt64(UInt32.max) else {\n");
out.push_str(" throw WireError.overflow\n");
out.push_str(" }\n");
out.push_str(" return UInt32(value)\n");
out.push_str("}\n\n");
out.push_str("@inline(__always)\n");
out.push_str(
"private func decodeWireString(from buffer: inout ByteBuffer) throws -> String {\n",
);
out.push_str(" do {\n");
out.push_str(" return try decodeString(from: &buffer)\n");
out.push_str(" } catch PostcardError.invalidUtf8 {\n");
out.push_str(" throw WireError.invalidUtf8\n");
out.push_str(" } catch PostcardError.truncated {\n");
out.push_str(" throw WireError.truncated\n");
out.push_str(" } catch {\n");
out.push_str(" throw error\n");
out.push_str(" }\n");
out.push_str("}\n\n");
out.push_str("@inline(__always)\n");
out.push_str(
"private func decodeWireBytes(from buffer: inout ByteBuffer) throws -> ByteBuffer {\n",
);
out.push_str(" do {\n");
out.push_str(" return try decodeBytes(from: &buffer)\n");
out.push_str(" } catch PostcardError.truncated {\n");
out.push_str(" throw WireError.truncated\n");
out.push_str(" } catch {\n");
out.push_str(" throw error\n");
out.push_str(" }\n");
out.push_str("}\n\n");
out
}
fn swift_wire_type(shape: &'static Shape, _field: Option<&Field>, types: &[WireType]) -> String {
if is_bytes(shape) {
return "[UInt8]".into();
}
match classify_shape(shape) {
ShapeKind::Scalar(scalar) => swift_scalar_type(scalar),
ShapeKind::List { element } | ShapeKind::Slice { element } => {
format!("[{}]", swift_wire_type(element, None, types))
}
ShapeKind::Option { inner } => {
format!("{}?", swift_wire_type(inner, None, types))
}
ShapeKind::Array { element, .. } => {
format!("[{}]", swift_wire_type(element, None, types))
}
ShapeKind::Struct(StructInfo {
name: Some(name), ..
}) => lookup_wire_name(name, types),
ShapeKind::Enum(EnumInfo {
name: Some(name), ..
}) => lookup_wire_name(name, types),
ShapeKind::Pointer { pointee } => swift_wire_type(pointee, _field, types),
ShapeKind::Opaque => "OpaquePayload".into(),
ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
swift_wire_type(fields[0].shape(), None, types)
}
_ => "Any /* unsupported */".into(),
}
}
fn swift_scalar_type(scalar: ScalarType) -> String {
match scalar {
ScalarType::Bool => "Bool".into(),
ScalarType::U8 => "UInt8".into(),
ScalarType::U16 => "UInt16".into(),
ScalarType::U32 => "UInt32".into(),
ScalarType::U64 | ScalarType::USize => "UInt64".into(),
ScalarType::I8 => "Int8".into(),
ScalarType::I16 => "Int16".into(),
ScalarType::I32 => "Int32".into(),
ScalarType::I64 | ScalarType::ISize => "Int64".into(),
ScalarType::F32 => "Float".into(),
ScalarType::F64 => "Double".into(),
ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
"String".into()
}
ScalarType::Unit => "Void".into(),
_ => "Any".into(),
}
}
fn lookup_wire_name(rust_name: &str, types: &[WireType]) -> String {
for wt in types {
if wt.shape.type_identifier.ends_with(rust_name) {
return wt.swift_name.clone();
}
}
rust_name.to_string()
}
fn generate_one_type(swift_name: &str, shape: &'static Shape, types: &[WireType]) -> String {
match classify_shape(shape) {
ShapeKind::Struct(StructInfo { fields, .. }) => {
generate_struct(swift_name, fields, types, swift_name == "Message")
}
ShapeKind::Enum(EnumInfo { variants, .. }) => generate_enum(swift_name, variants, types),
_ => format!("// Unsupported shape for {swift_name}\n"),
}
}
fn generate_struct(name: &str, fields: &[Field], types: &[WireType], is_top_level: bool) -> String {
let mut out = String::new();
out.push_str(&format!("public struct {name}: Sendable, Equatable {{\n"));
for f in fields {
let field_name = swift_field_name(f.name);
let field_type = swift_wire_type(f.shape(), Some(f), types);
out.push_str(&format!(" public var {field_name}: {field_type}\n"));
}
if !fields.is_empty() {
out.push_str("\n public init(");
for (i, f) in fields.iter().enumerate() {
if i > 0 {
out.push_str(", ");
}
let field_name = swift_field_name(f.name);
let field_type = swift_wire_type(f.shape(), Some(f), types);
out.push_str(&format!("{field_name}: {field_type}"));
if let Some(default_value) = swift_default_argument(f) {
out.push_str(&format!(" = {default_value}"));
}
}
out.push_str(") {\n");
for f in fields {
let field_name = swift_field_name(f.name);
out.push_str(&format!(" self.{field_name} = {field_name}\n"));
}
out.push_str(" }\n");
}
let vis = if is_top_level { "public " } else { "" };
out.push_str(&format!(
"\n {vis}func encode(into buffer: inout ByteBuffer) {{\n"
));
if fields.is_empty() {
} else {
for f in fields {
let stmt = encode_field_stmt(f, types);
out.push_str(&format!(" {stmt}\n"));
}
}
out.push_str(" }\n");
if is_top_level {
out.push_str(
"\n /// Encode to a `[UInt8]` array (bridge for callers that need bytes).\n",
);
out.push_str(" public func encode() -> [UInt8] {\n");
out.push_str(" var buffer = ByteBufferAllocator().buffer(capacity: 64)\n");
out.push_str(" encode(into: &buffer)\n");
out.push_str(" return buffer.readBytes(length: buffer.readableBytes) ?? []\n");
out.push_str(" }\n");
}
out.push_str(&format!(
"\n {vis}static func decode(from buffer: inout ByteBuffer) throws -> Self {{\n"
));
for f in fields {
for stmt in decode_field_stmts(f, types) {
out.push_str(&format!(" {stmt}\n"));
}
}
let field_names: Vec<String> = fields
.iter()
.map(|f| {
let n = swift_field_name(f.name);
format!("{n}: {n}")
})
.collect();
out.push_str(&format!(
" return .init({})\n",
field_names.join(", ")
));
out.push_str(" }\n");
if is_top_level {
out.push_str(
"\n /// Decode from a `[UInt8]` array (bridge for callers that have raw bytes).\n",
);
out.push_str(" public static func decode(fromBytes data: [UInt8]) throws -> Self {\n");
out.push_str(" var buffer = ByteBufferAllocator().buffer(capacity: data.count)\n");
out.push_str(" buffer.writeBytes(data)\n");
out.push_str(" let result = try decode(from: &buffer)\n");
out.push_str(" guard buffer.readableBytes == 0 else {\n");
out.push_str(" throw WireError.trailingBytes\n");
out.push_str(" }\n");
out.push_str(" return result\n");
out.push_str(" }\n");
}
out.push_str("}\n");
out
}
fn swift_default_argument(field: &Field) -> Option<String> {
if field.name == "initial_channel_credit" && field.has_default() {
return Some(DEFAULT_INITIAL_CHANNEL_CREDIT.to_string());
}
None
}
fn generate_enum(name: &str, variants: &[facet_core::Variant], types: &[WireType]) -> String {
let mut out = String::new();
out.push_str(&format!("public enum {name}: Sendable, Equatable {{\n"));
for v in variants {
let variant_name = swift_field_name(v.name);
match classify_variant(v) {
VariantKind::Unit => {
out.push_str(&format!(" case {variant_name}\n"));
}
VariantKind::Newtype { inner } => {
let inner_type = swift_wire_type(inner, v.data.fields.first(), types);
out.push_str(&format!(" case {variant_name}({inner_type})\n"));
}
VariantKind::Tuple { fields } => {
let field_types: Vec<String> = fields
.iter()
.map(|f| swift_wire_type(f.shape(), Some(f), types))
.collect();
out.push_str(&format!(
" case {variant_name}({})\n",
field_types.join(", ")
));
}
VariantKind::Struct { fields } => {
let field_decls: Vec<String> = fields
.iter()
.map(|f| {
format!(
"{}: {}",
swift_field_name(f.name),
swift_wire_type(f.shape(), Some(f), types)
)
})
.collect();
out.push_str(&format!(
" case {variant_name}({})\n",
field_decls.join(", ")
));
}
}
}
out.push_str("\n func encode(into buffer: inout ByteBuffer) {\n");
out.push_str(" switch self {\n");
for (i, v) in variants.iter().enumerate() {
let variant_name = swift_field_name(v.name);
match classify_variant(v) {
VariantKind::Unit => {
out.push_str(&format!(
" case .{variant_name}:\n encodeVarint(UInt64({i}), into: &buffer)\n"
));
}
VariantKind::Newtype { inner } => {
let stmt = encode_shape_stmt(inner, "val", v.data.fields.first(), types);
out.push_str(&format!(
" case .{variant_name}(let val):\n encodeVarint(UInt64({i}), into: &buffer)\n {stmt}\n"
));
}
VariantKind::Tuple { fields } => {
let bindings: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
let binding_str = bindings
.iter()
.map(|b| format!("let {b}"))
.collect::<Vec<_>>()
.join(", ");
let stmts: Vec<String> = fields
.iter()
.enumerate()
.map(|(j, f)| encode_shape_stmt(f.shape(), &format!("f{j}"), Some(f), types))
.collect();
out.push_str(&format!(
" case .{variant_name}({binding_str}):\n encodeVarint(UInt64({i}), into: &buffer)\n"
));
for stmt in &stmts {
out.push_str(&format!(" {stmt}\n"));
}
}
VariantKind::Struct { fields } => {
let bindings: Vec<String> =
fields.iter().map(|f| swift_field_name(f.name)).collect();
let binding_str = bindings
.iter()
.map(|b| format!("let {b}"))
.collect::<Vec<_>>()
.join(", ");
let stmts: Vec<String> = fields
.iter()
.map(|f| {
encode_shape_stmt(f.shape(), &swift_field_name(f.name), Some(f), types)
})
.collect();
out.push_str(&format!(
" case .{variant_name}({binding_str}):\n encodeVarint(UInt64({i}), into: &buffer)\n"
));
for stmt in &stmts {
out.push_str(&format!(" {stmt}\n"));
}
}
}
}
out.push_str(" }\n");
out.push_str(" }\n");
out.push_str("\n static func decode(from buffer: inout ByteBuffer) throws -> Self {\n");
out.push_str(" let disc = try decodeVarint(from: &buffer)\n");
out.push_str(" switch disc {\n");
for (i, v) in variants.iter().enumerate() {
let variant_name = swift_field_name(v.name);
out.push_str(&format!(" case {i}:\n"));
match classify_variant(v) {
VariantKind::Unit => {
out.push_str(&format!(" return .{variant_name}\n"));
}
VariantKind::Newtype { inner } => {
for stmt in decode_stmts_for(inner, v.data.fields.first(), "_newtype_val", types) {
out.push_str(&format!(" {stmt}\n"));
}
out.push_str(&format!(
" return .{variant_name}(_newtype_val)\n"
));
}
VariantKind::Tuple { fields } => {
for (j, f) in fields.iter().enumerate() {
for stmt in decode_stmts_for(f.shape(), Some(f), &format!("f{j}"), types) {
out.push_str(&format!(" {stmt}\n"));
}
}
let args: Vec<String> = (0..fields.len()).map(|j| format!("f{j}")).collect();
out.push_str(&format!(
" return .{variant_name}({})\n",
args.join(", ")
));
}
VariantKind::Struct { fields } => {
for f in fields {
let field_name = swift_field_name(f.name);
for stmt in decode_stmts_for(f.shape(), Some(f), &field_name, types) {
out.push_str(&format!(" {stmt}\n"));
}
}
let args: Vec<String> = fields
.iter()
.map(|f| {
let n = swift_field_name(f.name);
format!("{n}: {n}")
})
.collect();
out.push_str(&format!(
" return .{variant_name}({})\n",
args.join(", ")
));
}
}
}
out.push_str(" default:\n");
out.push_str(" throw WireError.unknownVariant(disc)\n");
out.push_str(" }\n");
out.push_str(" }\n");
out.push_str("}\n");
out
}
fn encode_field_stmt(field: &Field, types: &[WireType]) -> String {
let field_name = swift_field_name(field.name);
encode_shape_stmt(field.shape(), &field_name, Some(field), types)
}
fn encode_shape_stmt(
shape: &'static Shape,
value: &str,
field: Option<&Field>,
types: &[WireType],
) -> String {
let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
if matches!(classify_shape(shape), ShapeKind::Opaque) {
return if is_trailing {
format!("{value}.encodeTrailing(into: &buffer)")
} else {
format!("{value}.encode(into: &buffer)")
};
}
if is_bytes(shape) {
return format!("encodeByteSeq({value}, into: &buffer)");
}
match classify_shape(shape) {
ShapeKind::Scalar(scalar) => encode_scalar_stmt(scalar, value),
ShapeKind::List { element } | ShapeKind::Slice { element } => {
let inner = encode_element_closure(element, types);
format!("encodeVec({value}, into: &buffer, encoder: {inner})")
}
ShapeKind::Option { inner } => {
let inner_closure = encode_element_closure(inner, types);
format!("encodeOption({value}, into: &buffer, encoder: {inner_closure})")
}
ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
format!("{value}.encode(into: &buffer)")
}
ShapeKind::Pointer { pointee } => encode_shape_stmt(pointee, value, field, types),
ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
encode_shape_stmt(fields[0].shape(), value, field, types)
}
_ => format!("/* unsupported encode for {value} */"),
}
}
fn encode_scalar_stmt(scalar: ScalarType, value: &str) -> String {
match scalar {
ScalarType::Bool => format!("encodeBool({value}, into: &buffer)"),
ScalarType::U8 => format!("encodeU8({value}, into: &buffer)"),
ScalarType::I8 => format!("encodeI8({value}, into: &buffer)"),
ScalarType::U16 => format!("encodeU16({value}, into: &buffer)"),
ScalarType::I16 => format!("encodeI16({value}, into: &buffer)"),
ScalarType::U32 => format!("encodeVarint(UInt64({value}), into: &buffer)"),
ScalarType::I32 => format!("encodeI32({value}, into: &buffer)"),
ScalarType::U64 | ScalarType::USize => format!("encodeVarint({value}, into: &buffer)"),
ScalarType::I64 | ScalarType::ISize => format!("encodeI64({value}, into: &buffer)"),
ScalarType::F32 => format!("encodeF32({value}, into: &buffer)"),
ScalarType::F64 => format!("encodeF64({value}, into: &buffer)"),
ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
format!("encodeString({value}, into: &buffer)")
}
_ => format!("/* unsupported scalar encode for {value} */"),
}
}
fn encode_element_closure(shape: &'static Shape, _types: &[WireType]) -> String {
if is_bytes(shape) {
return "{ val, buf in encodeByteSeq(val, into: &buf) }".into();
}
match classify_shape(shape) {
ShapeKind::Scalar(scalar) => {
let stmt = encode_scalar_stmt(scalar, "val").replace("into: &buffer", "into: &buf");
format!("{{ val, buf in {stmt} }}")
}
ShapeKind::Struct(StructInfo { .. }) | ShapeKind::Enum(EnumInfo { .. }) => {
"{ val, buf in val.encode(into: &buf) }".into()
}
ShapeKind::List { element } | ShapeKind::Slice { element } => {
let inner = encode_element_closure(element, _types);
format!("{{ val, buf in encodeVec(val, into: &buf, encoder: {inner}) }}")
}
ShapeKind::Opaque => "{ val, buf in val.encode(into: &buf) }".into(),
ShapeKind::Pointer { pointee } => encode_element_closure(pointee, _types),
_ => "{ _, _ in /* unsupported */ }".into(),
}
}
fn decode_field_stmts(field: &Field, types: &[WireType]) -> Vec<String> {
let field_name = swift_field_name(field.name);
decode_stmts_for(field.shape(), Some(field), &field_name, types)
}
fn decode_stmts_for(
shape: &'static Shape,
field: Option<&Field>,
var_name: &str,
types: &[WireType],
) -> Vec<String> {
if let ShapeKind::Pointer { pointee } = classify_shape(shape) {
return decode_stmts_for(pointee, field, var_name, types);
}
if is_bytes(shape) {
return vec![
format!("var _{var_name}Buf = try decodeWireBytes(from: &buffer)"),
format!(
"let {var_name} = _{var_name}Buf.readBytes(length: _{var_name}Buf.readableBytes) ?? []"
),
];
}
vec![format!(
"let {var_name} = {}",
decode_shape_expr(shape, field, types)
)]
}
fn decode_shape_expr(shape: &'static Shape, field: Option<&Field>, types: &[WireType]) -> String {
let is_trailing = field.is_some_and(|f| f.has_builtin_attr("trailing"));
if matches!(classify_shape(shape), ShapeKind::Opaque) {
return if is_trailing {
"OpaquePayload.decodeTrailing(from: &buffer)".into()
} else {
"try OpaquePayload.decode(from: &buffer)".into()
};
}
match classify_shape(shape) {
ShapeKind::Scalar(scalar) => decode_scalar(scalar),
ShapeKind::List { element } | ShapeKind::Slice { element } => {
let inner = decode_element_closure(element, types);
format!("try decodeVec(from: &buffer, decoder: {inner})")
}
ShapeKind::Option { inner } => {
let inner_closure = decode_element_closure(inner, types);
format!("try decodeOption(from: &buffer, decoder: {inner_closure})")
}
ShapeKind::Struct(StructInfo {
name: Some(name), ..
}) => {
let swift_name = lookup_wire_name(name, types);
format!("try {swift_name}.decode(from: &buffer)")
}
ShapeKind::Enum(EnumInfo {
name: Some(name), ..
}) => {
let swift_name = lookup_wire_name(name, types);
format!("try {swift_name}.decode(from: &buffer)")
}
ShapeKind::Pointer { pointee } => decode_shape_expr(pointee, field, types),
ShapeKind::TupleStruct { fields } if fields.len() == 1 => {
decode_shape_expr(fields[0].shape(), field, types)
}
_ => "nil /* unsupported decode */".into(),
}
}
fn decode_scalar(scalar: ScalarType) -> String {
match scalar {
ScalarType::Bool => "try decodeBool(from: &buffer)".into(),
ScalarType::U8 => "try decodeU8(from: &buffer)".into(),
ScalarType::I8 => "try decodeI8(from: &buffer)".into(),
ScalarType::U16 => "try decodeU16(from: &buffer)".into(),
ScalarType::I16 => "try decodeI16(from: &buffer)".into(),
ScalarType::U32 => "try decodeWireVarintU32(from: &buffer)".into(),
ScalarType::I32 => "try decodeI32(from: &buffer)".into(),
ScalarType::U64 | ScalarType::USize => "try decodeVarint(from: &buffer)".into(),
ScalarType::I64 | ScalarType::ISize => "try decodeI64(from: &buffer)".into(),
ScalarType::F32 => "try decodeF32(from: &buffer)".into(),
ScalarType::F64 => "try decodeF64(from: &buffer)".into(),
ScalarType::Char | ScalarType::Str | ScalarType::String | ScalarType::CowStr => {
"try decodeWireString(from: &buffer)".into()
}
_ => "nil /* unsupported scalar decode */".into(),
}
}
fn decode_element_closure(shape: &'static Shape, types: &[WireType]) -> String {
if is_bytes(shape) {
return "{ buf in var s = try decodeWireBytes(from: &buf); return s.readBytes(length: s.readableBytes) ?? [] }".into();
}
match classify_shape(shape) {
ShapeKind::Scalar(scalar) => {
let expr = decode_scalar(scalar).replace("from: &buffer", "from: &buf");
format!("{{ buf in {expr} }}")
}
ShapeKind::Struct(StructInfo {
name: Some(name), ..
}) => {
let swift_name = lookup_wire_name(name, types);
format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
}
ShapeKind::Enum(EnumInfo {
name: Some(name), ..
}) => {
let swift_name = lookup_wire_name(name, types);
format!("{{ buf in try {swift_name}.decode(from: &buf) }}")
}
ShapeKind::List { element } | ShapeKind::Slice { element } => {
let inner = decode_element_closure(element, types);
format!("{{ buf in try decodeVec(from: &buf, decoder: {inner}) }}")
}
ShapeKind::Opaque => "{ buf in try OpaquePayload.decode(from: &buf) }".into(),
ShapeKind::Pointer { pointee } => decode_element_closure(pointee, types),
_ => "{ _ in throw WireError.truncated }".into(),
}
}
fn generate_factory_methods(types: &[WireType]) -> String {
let payload_wt = types.iter().find(|wt| wt.swift_name == "MessagePayload");
let payload_wt = match payload_wt {
Some(wt) => wt,
None => return String::new(),
};
let variants = match classify_shape(payload_wt.shape) {
ShapeKind::Enum(EnumInfo { variants, .. }) => variants,
_ => return String::new(),
};
let mut out = String::new();
out.push_str("public extension Message {\n");
for v in variants {
let variant_name = swift_field_name(v.name);
if let VariantKind::Newtype { inner } = classify_variant(v) {
let inner_swift = swift_wire_type(inner, None, types);
let is_control = matches!(
v.name,
"Hello" | "HelloYourself" | "ProtocolError" | "Ping" | "Pong"
);
if is_control {
out.push_str(&format!(
" static func {variant_name}(_ value: {inner_swift}) -> Message {{\n"
));
out.push_str(&format!(
" Message(connectionId: 0, payload: .{variant_name}(value))\n"
));
out.push_str(" }\n\n");
} else {
out.push_str(&format!(
" static func {variant_name}(connId: UInt64, _ value: {inner_swift}) -> Message {{\n"
));
out.push_str(&format!(
" Message(connectionId: connId, payload: .{variant_name}(value))\n"
));
out.push_str(" }\n\n");
}
}
}
out.push_str(" static func protocolError(description: String) -> Message {\n");
out.push_str(" Message(connectionId: 0, payload: .protocolError(.init(description: description)))\n");
out.push_str(" }\n\n");
out.push_str(" static func connectionOpen(\n");
out.push_str(
" connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
);
out.push_str(" ) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .connectionOpen(.init(connectionSettings: settings, metadata: metadata)))\n");
out.push_str(" }\n\n");
out.push_str(" static func connectionAccept(\n");
out.push_str(
" connId: UInt64, settings: ConnectionSettings, metadata: [MetadataEntry]\n",
);
out.push_str(" ) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .connectionAccept(.init(connectionSettings: settings, metadata: metadata)))\n");
out.push_str(" }\n\n");
out.push_str(" static func connectionReject(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n");
out.push_str(" Message(connectionId: connId, payload: .connectionReject(.init(metadata: metadata)))\n");
out.push_str(" }\n\n");
out.push_str(
" static func connectionClose(connId: UInt64, metadata: [MetadataEntry]) -> Message {\n",
);
out.push_str(" Message(connectionId: connId, payload: .connectionClose(.init(metadata: metadata)))\n");
out.push_str(" }\n\n");
out.push_str(" static func request(\n");
out.push_str(" connId: UInt64,\n");
out.push_str(" requestId: UInt64,\n");
out.push_str(" methodId: UInt64,\n");
out.push_str(" metadata: [MetadataEntry],\n");
out.push_str(" schemas: [UInt8] = [],\n");
out.push_str(" payload: [UInt8]\n");
out.push_str(" ) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .requestMessage(\n");
out.push_str(" .init(\n");
out.push_str(" id: requestId,\n");
out.push_str(" body: .call(.init(methodId: methodId, metadata: metadata, args: .init(payload), schemas: schemas))\n");
out.push_str(" ))\n");
out.push_str(" )\n");
out.push_str(" }\n\n");
out.push_str(" static func response(\n");
out.push_str(" connId: UInt64,\n");
out.push_str(" requestId: UInt64,\n");
out.push_str(" metadata: [MetadataEntry],\n");
out.push_str(" schemas: [UInt8] = [],\n");
out.push_str(" payload: [UInt8]\n");
out.push_str(" ) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .requestMessage(\n");
out.push_str(" .init(\n");
out.push_str(" id: requestId,\n");
out.push_str(" body: .response(.init(metadata: metadata, ret: .init(payload), schemas: schemas))\n");
out.push_str(" ))\n");
out.push_str(" )\n");
out.push_str(" }\n\n");
out.push_str(" static func cancel(connId: UInt64, requestId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .requestMessage(\n");
out.push_str(" .init(\n");
out.push_str(" id: requestId,\n");
out.push_str(" body: .cancel(.init(metadata: metadata))\n");
out.push_str(" ))\n");
out.push_str(" )\n");
out.push_str(" }\n\n");
out.push_str(
" static func data(connId: UInt64, channelId: UInt64, payload: [UInt8]) -> Message {\n",
);
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .channelMessage(.init(id: channelId, body: .item(.init(item: .init(payload)))))\n");
out.push_str(" )\n");
out.push_str(" }\n\n");
out.push_str(" static func close(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .channelMessage(.init(id: channelId, body: .close(.init(metadata: metadata))))\n");
out.push_str(" )\n");
out.push_str(" }\n\n");
out.push_str(" static func reset(connId: UInt64, channelId: UInt64, metadata: [MetadataEntry] = []) -> Message {\n");
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .channelMessage(.init(id: channelId, body: .reset(.init(metadata: metadata))))\n");
out.push_str(" )\n");
out.push_str(" }\n\n");
out.push_str(
" static func credit(connId: UInt64, channelId: UInt64, bytes: UInt32) -> Message {\n",
);
out.push_str(" Message(\n");
out.push_str(" connectionId: connId,\n");
out.push_str(" payload: .channelMessage(.init(id: channelId, body: .grantCredit(.init(additional: bytes))))\n");
out.push_str(" )\n");
out.push_str(" }\n");
out.push_str("}\n");
out
}