roam_codegen/targets/swift/
encode.rs1use facet_core::{ScalarType, Shape};
6use heck::ToLowerCamelCase;
7use roam_schema::{
8 EnumInfo, ShapeKind, StructInfo, VariantKind, classify_shape, classify_variant, is_bytes,
9};
10
11pub fn generate_encode_expr(shape: &'static Shape, value: &str) -> String {
13 if is_bytes(shape) {
14 return format!("encodeBytes(Array({value}))");
15 }
16
17 match classify_shape(shape) {
18 ShapeKind::Scalar(scalar) => {
19 let encode_fn = swift_encode_fn(scalar);
20 format!("{encode_fn}({value})")
21 }
22 ShapeKind::List { element }
23 | ShapeKind::Slice { element }
24 | ShapeKind::Array { element, .. } => {
25 let inner_encode = generate_encode_closure(element);
26 format!("encodeVec({value}, encoder: {inner_encode})")
27 }
28 ShapeKind::Option { inner } => {
29 let inner_encode = generate_encode_closure(inner);
30 format!("encodeOption({value}, encoder: {inner_encode})")
31 }
32 ShapeKind::Tuple { elements } if elements.len() == 2 => {
33 let a_encode = generate_encode_closure(elements[0].shape);
34 let b_encode = generate_encode_closure(elements[1].shape);
35 format!("{a_encode}({value}.0) + {b_encode}({value}.1)")
36 }
37 ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
38 let a_encode = generate_encode_closure(fields[0].shape());
39 let b_encode = generate_encode_closure(fields[1].shape());
40 format!("{a_encode}({value}.0) + {b_encode}({value}.1)")
41 }
42 ShapeKind::Struct(StructInfo { fields, .. }) => {
43 let field_encodes: Vec<String> = fields
45 .iter()
46 .map(|f| {
47 let field_name = f.name.to_lower_camel_case();
48 generate_encode_expr(f.shape(), &format!("{value}.{field_name}"))
49 })
50 .collect();
51 if field_encodes.is_empty() {
52 "[]".into()
53 } else {
54 field_encodes.join(" + ")
55 }
56 }
57 ShapeKind::Pointer { pointee } => generate_encode_expr(pointee, value),
58 ShapeKind::Result { ok, err } => {
59 let ok_encode = generate_encode_closure(ok);
61 let err_encode = generate_encode_closure(err);
62 format!(
63 "{{ switch {value} {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1)] + {err_encode}(e) }} }}()"
64 )
65 }
66 _ => "[]".into(), }
68}
69
70pub fn generate_encode_closure(shape: &'static Shape) -> String {
72 if is_bytes(shape) {
73 return "{ encodeBytes(Array($0)) }".into();
74 }
75
76 match classify_shape(shape) {
77 ShapeKind::Scalar(scalar) => {
78 let encode_fn = swift_encode_fn(scalar);
79 format!("{{ {encode_fn}($0) }}")
80 }
81 ShapeKind::List { element } | ShapeKind::Slice { element } => {
82 let inner = generate_encode_closure(element);
83 format!("{{ encodeVec($0, encoder: {inner}) }}")
84 }
85 ShapeKind::Option { inner } => {
86 let inner_closure = generate_encode_closure(inner);
87 format!("{{ encodeOption($0, encoder: {inner_closure}) }}")
88 }
89 ShapeKind::Tuple { elements } if elements.len() == 2 => {
90 let a_encode = generate_encode_closure(elements[0].shape);
91 let b_encode = generate_encode_closure(elements[1].shape);
92 format!("{{ {a_encode}($0.0) + {b_encode}($0.1) }}")
93 }
94 ShapeKind::TupleStruct { fields } if fields.len() == 2 => {
95 let a_encode = generate_encode_closure(fields[0].shape());
96 let b_encode = generate_encode_closure(fields[1].shape());
97 format!("{{ {a_encode}($0.0) + {b_encode}($0.1) }}")
98 }
99 ShapeKind::Struct(StructInfo { fields, .. }) => {
100 let field_encodes: Vec<String> = fields
102 .iter()
103 .map(|f| {
104 let field_name = f.name.to_lower_camel_case();
105 generate_encode_expr(f.shape(), &format!("$0.{field_name}"))
106 })
107 .collect();
108 if field_encodes.is_empty() {
109 "{ _ in [] }".into()
110 } else {
111 format!("{{ {} }}", field_encodes.join(" + "))
112 }
113 }
114 ShapeKind::Enum(EnumInfo {
115 name: Some(_name),
116 variants,
117 ..
118 }) => {
119 let mut code = "{ v in\n switch v {\n".to_string();
121 for (i, v) in variants.iter().enumerate() {
122 let variant_name = v.name.to_lower_camel_case();
123 match classify_variant(v) {
124 VariantKind::Unit => {
125 code.push_str(&format!(
126 " case .{variant_name}:\n return [UInt8({i})]\n"
127 ));
128 }
129 VariantKind::Newtype { inner } => {
130 let inner_encode = generate_encode_expr(inner, "val");
131 code.push_str(&format!(
132 " case .{variant_name}(let val):\n return [UInt8({i})] + {inner_encode}\n"
133 ));
134 }
135 VariantKind::Tuple { fields } => {
136 let bindings: Vec<String> =
137 (0..fields.len()).map(|j| format!("f{j}")).collect();
138 let field_encodes: Vec<String> = fields
139 .iter()
140 .enumerate()
141 .map(|(j, f)| generate_encode_expr(f.shape(), &format!("f{j}")))
142 .collect();
143 code.push_str(&format!(
144 " case .{variant_name}({}):\n return [UInt8({i})] + {}\n",
145 bindings
146 .iter()
147 .map(|b| format!("let {b}"))
148 .collect::<Vec<_>>()
149 .join(", "),
150 field_encodes.join(" + ")
151 ));
152 }
153 VariantKind::Struct { fields } => {
154 let bindings: Vec<String> = fields
155 .iter()
156 .map(|f| f.name.to_lower_camel_case())
157 .collect();
158 let field_encodes: Vec<String> = fields
159 .iter()
160 .map(|f| {
161 let field_name = f.name.to_lower_camel_case();
162 generate_encode_expr(f.shape(), &field_name)
163 })
164 .collect();
165 code.push_str(&format!(
166 " case .{variant_name}({}):\n return [UInt8({i})] + {}\n",
167 bindings
168 .iter()
169 .map(|b| format!("let {b}"))
170 .collect::<Vec<_>>()
171 .join(", "),
172 field_encodes.join(" + ")
173 ));
174 }
175 }
176 }
177 code.push_str(" }\n}");
178 code
179 }
180 ShapeKind::Pointer { pointee } => generate_encode_closure(pointee),
181 ShapeKind::Result { ok, err } => {
182 let ok_encode = generate_encode_closure(ok);
183 let err_encode = generate_encode_closure(err);
184 format!(
185 "{{ switch $0 {{ case .success(let v): return [UInt8(0)] + {ok_encode}(v); case .failure(let e): return [UInt8(1)] + {err_encode}(e) }} }}"
186 )
187 }
188 _ => "{ _ in [] }".into(), }
190}
191
192pub fn swift_encode_fn(scalar: ScalarType) -> &'static str {
194 match scalar {
195 ScalarType::Bool => "encodeBool",
196 ScalarType::U8 => "encodeU8",
197 ScalarType::I8 => "encodeI8",
198 ScalarType::U16 => "encodeU16",
199 ScalarType::I16 => "encodeI16",
200 ScalarType::U32 => "encodeU32",
201 ScalarType::I32 => "encodeI32",
202 ScalarType::U64 | ScalarType::USize => "encodeVarint",
203 ScalarType::I64 | ScalarType::ISize => "encodeI64",
204 ScalarType::F32 => "encodeF32",
205 ScalarType::F64 => "encodeF64",
206 ScalarType::Char | ScalarType::Str | ScalarType::CowStr | ScalarType::String => {
207 "encodeString"
208 }
209 ScalarType::Unit => "{ _ in [] }",
210 _ => "encodeBytes", }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use facet::Facet;
218
219 #[test]
220 fn test_encode_primitives() {
221 assert_eq!(
222 generate_encode_expr(<bool as Facet>::SHAPE, "x"),
223 "encodeBool(x)"
224 );
225 assert_eq!(
226 generate_encode_expr(<u32 as Facet>::SHAPE, "x"),
227 "encodeU32(x)"
228 );
229 assert_eq!(
230 generate_encode_expr(<String as Facet>::SHAPE, "x"),
231 "encodeString(x)"
232 );
233 }
234
235 #[test]
236 fn test_encode_vec() {
237 let result = generate_encode_expr(<Vec<i32> as Facet>::SHAPE, "items");
238 assert!(result.contains("encodeVec"));
239 assert!(result.contains("encodeI32"));
240 }
241
242 #[test]
243 fn test_encode_option() {
244 let result = generate_encode_expr(<Option<String> as Facet>::SHAPE, "val");
245 assert!(result.contains("encodeOption"));
246 assert!(result.contains("encodeString"));
247 }
248
249 #[test]
250 fn test_encode_bytes() {
251 let result = generate_encode_expr(<Vec<u8> as Facet>::SHAPE, "data");
252 assert_eq!(result, "encodeBytes(Array(data))");
253 }
254}