1#![deny(unsafe_code)]
2
3use facet_core::{Def, Facet, ScalarType, Shape, StructKind, Type, UserType};
9use heck::ToKebabCase;
10use roam_types::{ArgDescriptor, MethodDescriptor, MethodId};
11use roam_types::{is_rx, is_tx};
12use std::collections::HashSet;
13
14mod sig {
16 pub const BOOL: u8 = 0x01;
18 pub const U8: u8 = 0x02;
19 pub const U16: u8 = 0x03;
20 pub const U32: u8 = 0x04;
21 pub const U64: u8 = 0x05;
22 pub const U128: u8 = 0x06;
23 pub const I8: u8 = 0x07;
24 pub const I16: u8 = 0x08;
25 pub const I32: u8 = 0x09;
26 pub const I64: u8 = 0x0A;
27 pub const I128: u8 = 0x0B;
28 pub const F32: u8 = 0x0C;
29 pub const F64: u8 = 0x0D;
30 pub const CHAR: u8 = 0x0E;
31 pub const STRING: u8 = 0x0F;
32 pub const UNIT: u8 = 0x10;
33 pub const BYTES: u8 = 0x11;
34
35 pub const LIST: u8 = 0x20;
37 pub const OPTION: u8 = 0x21;
38 pub const ARRAY: u8 = 0x22;
39 pub const MAP: u8 = 0x23;
40 pub const SET: u8 = 0x24;
41 pub const TUPLE: u8 = 0x25;
42 pub const TX: u8 = 0x26;
43 pub const RX: u8 = 0x27;
44
45 pub const STRUCT: u8 = 0x30;
47 pub const ENUM: u8 = 0x31;
48 pub const BACKREF: u8 = 0x32;
49
50 pub const VARIANT_UNIT: u8 = 0x00;
52 pub const VARIANT_NEWTYPE: u8 = 0x01;
53 pub const VARIANT_STRUCT: u8 = 0x02;
54}
55
56fn encode_varint_u64(mut value: u64, out: &mut Vec<u8>) {
58 while value >= 0x80 {
59 out.push((value as u8) | 0x80);
60 value >>= 7;
61 }
62 out.push(value as u8);
63}
64
65fn encode_str(s: &str, out: &mut Vec<u8>) {
66 encode_varint_u64(s.len() as u64, out);
67 out.extend_from_slice(s.as_bytes());
68}
69
70fn encode_shape(shape: &'static Shape, out: &mut Vec<u8>) {
79 let mut stack: Vec<&'static Shape> = Vec::new();
80 encode_shape_inner(shape, out, &mut stack);
81}
82
83fn encode_shape_inner(shape: &'static Shape, out: &mut Vec<u8>, stack: &mut Vec<&'static Shape>) {
84 if is_tx(shape) {
86 out.push(sig::TX);
87 if let Some(inner) = shape.type_params.first() {
88 encode_shape_inner(inner.shape, out, stack);
89 }
90 return;
91 }
92 if is_rx(shape) {
93 out.push(sig::RX);
94 if let Some(inner) = shape.type_params.first() {
95 encode_shape_inner(inner.shape, out, stack);
96 }
97 return;
98 }
99
100 if shape.is_transparent()
102 && let Some(inner) = shape.inner
103 {
104 encode_shape_inner(inner, out, stack);
105 return;
106 }
107
108 if let Some(scalar) = shape.scalar_type() {
110 encode_scalar(scalar, out);
111 return;
112 }
113
114 match shape.def {
116 Def::List(list_def) => {
117 if let Some(ScalarType::U8) = list_def.t().scalar_type() {
118 out.push(sig::BYTES);
120 } else {
121 out.push(sig::LIST);
122 encode_shape_inner(list_def.t(), out, stack);
123 }
124 return;
125 }
126 Def::Array(array_def) => {
127 out.push(sig::ARRAY);
128 encode_varint_u64(array_def.n as u64, out);
129 encode_shape_inner(array_def.t(), out, stack);
130 return;
131 }
132 Def::Slice(slice_def) => {
133 out.push(sig::LIST);
134 encode_shape_inner(slice_def.t(), out, stack);
135 return;
136 }
137 Def::Map(map_def) => {
138 out.push(sig::MAP);
139 encode_shape_inner(map_def.k(), out, stack);
140 encode_shape_inner(map_def.v(), out, stack);
141 return;
142 }
143 Def::Set(set_def) => {
144 out.push(sig::SET);
145 encode_shape_inner(set_def.t(), out, stack);
146 return;
147 }
148 Def::Option(opt_def) => {
149 out.push(sig::OPTION);
150 encode_shape_inner(opt_def.t(), out, stack);
151 return;
152 }
153 Def::Pointer(ptr_def) => {
154 if let Some(pointee) = ptr_def.pointee {
155 encode_shape_inner(pointee, out, stack);
156 return;
157 }
158 }
159 _ => {}
160 }
161
162 if let Some(pos) = stack.iter().rposition(|&s| s == shape) {
165 let depth = stack.len() - 1 - pos;
167 out.push(sig::BACKREF);
168 encode_varint_u64(depth as u64, out);
169 return;
170 }
171
172 stack.push(shape);
174
175 match shape.ty {
176 Type::User(UserType::Struct(struct_type)) => match struct_type.kind {
177 StructKind::Unit => {
178 out.push(sig::UNIT);
179 }
180 StructKind::TupleStruct | StructKind::Tuple => {
181 out.push(sig::TUPLE);
182 encode_varint_u64(struct_type.fields.len() as u64, out);
183 for field in struct_type.fields {
184 encode_shape_inner(field.shape(), out, stack);
185 }
186 }
187 StructKind::Struct => {
188 out.push(sig::STRUCT);
189 encode_varint_u64(struct_type.fields.len() as u64, out);
190 for field in struct_type.fields {
191 encode_str(field.name, out);
192 encode_shape_inner(field.shape(), out, stack);
193 }
194 }
195 },
196 Type::User(UserType::Enum(enum_type)) => {
197 out.push(sig::ENUM);
198 encode_varint_u64(enum_type.variants.len() as u64, out);
199 for variant in enum_type.variants {
200 encode_str(variant.name, out);
201 match variant.data.kind {
202 StructKind::Unit => {
203 out.push(sig::VARIANT_UNIT);
204 }
205 StructKind::TupleStruct | StructKind::Tuple => {
206 if variant.data.fields.len() == 1 {
207 out.push(sig::VARIANT_NEWTYPE);
208 encode_shape_inner(variant.data.fields[0].shape(), out, stack);
209 } else {
210 out.push(sig::VARIANT_STRUCT);
211 encode_varint_u64(variant.data.fields.len() as u64, out);
212 for (i, field) in variant.data.fields.iter().enumerate() {
213 encode_str(&i.to_string(), out);
214 encode_shape_inner(field.shape(), out, stack);
215 }
216 }
217 }
218 StructKind::Struct => {
219 out.push(sig::VARIANT_STRUCT);
220 encode_varint_u64(variant.data.fields.len() as u64, out);
221 for field in variant.data.fields {
222 encode_str(field.name, out);
223 encode_shape_inner(field.shape(), out, stack);
224 }
225 }
226 }
227 }
228 }
229 Type::Pointer(_) => {
230 if let Some(inner) = shape.type_params.first() {
231 encode_shape_inner(inner.shape, out, stack);
232 } else {
233 out.push(sig::UNIT);
234 }
235 }
236 _ => {
237 out.push(sig::UNIT);
238 }
239 }
240
241 stack.pop();
242}
243
244fn encode_scalar(scalar: ScalarType, out: &mut Vec<u8>) {
245 match scalar {
246 ScalarType::Unit => out.push(sig::UNIT),
247 ScalarType::Bool => out.push(sig::BOOL),
248 ScalarType::Char => out.push(sig::CHAR),
249 ScalarType::Str | ScalarType::String | ScalarType::CowStr => out.push(sig::STRING),
250 ScalarType::F32 => out.push(sig::F32),
251 ScalarType::F64 => out.push(sig::F64),
252 ScalarType::U8 => out.push(sig::U8),
253 ScalarType::U16 => out.push(sig::U16),
254 ScalarType::U32 => out.push(sig::U32),
255 ScalarType::U64 => out.push(sig::U64),
256 ScalarType::U128 => out.push(sig::U128),
257 ScalarType::USize => out.push(sig::U64), ScalarType::I8 => out.push(sig::I8),
259 ScalarType::I16 => out.push(sig::I16),
260 ScalarType::I32 => out.push(sig::I32),
261 ScalarType::I64 => out.push(sig::I64),
262 ScalarType::I128 => out.push(sig::I128),
263 ScalarType::ISize => out.push(sig::I64), ScalarType::ConstTypeId => out.push(sig::U64),
265 _ => out.push(sig::UNIT),
266 }
267}
268
269fn encode_method_signature(args: &'static Shape, return_type: &'static Shape, out: &mut Vec<u8>) {
274 encode_shape(args, out);
275 encode_shape(return_type, out);
276}
277
278pub fn method_id<'a, 'r, A: Facet<'a>, R: Facet<'r>>(
287 service_name: &str,
288 method_name: &str,
289) -> MethodId {
290 let mut sig_bytes = Vec::new();
291 encode_method_signature(A::SHAPE, R::SHAPE, &mut sig_bytes);
292 let sig_hash = blake3::hash(&sig_bytes);
293
294 let mut input = Vec::new();
295 input.extend_from_slice(service_name.to_kebab_case().as_bytes());
296 input.push(b'.');
297 input.extend_from_slice(method_name.to_kebab_case().as_bytes());
298 input.extend_from_slice(sig_hash.as_bytes());
299 let h = blake3::hash(&input);
300 let first8: [u8; 8] = h.as_bytes()[0..8].try_into().expect("slice len");
301 MethodId(u64::from_le_bytes(first8))
302}
303
304pub fn method_descriptor<'a, 'r, A: Facet<'a>, R: Facet<'r>>(
309 service_name: &'static str,
310 method_name: &'static str,
311 arg_names: &[&'static str],
312 doc: Option<&'static str>,
313) -> &'static MethodDescriptor {
314 assert!(
315 !shape_contains_channel(R::SHAPE),
316 "channels are not allowed in return types: {service_name}.{method_name}"
317 );
318
319 let id = method_id::<A, R>(service_name, method_name);
320
321 let arg_shapes: &[&'static Shape] = match A::SHAPE.ty {
323 Type::User(UserType::Struct(s)) => {
324 let fields: Vec<&'static Shape> = s.fields.iter().map(|f| f.shape()).collect();
325 Box::leak(fields.into_boxed_slice())
326 }
327 _ => &[],
328 };
329
330 assert_eq!(
331 arg_names.len(),
332 arg_shapes.len(),
333 "arg_names length mismatch for {service_name}.{method_name}"
334 );
335
336 let args: &'static [ArgDescriptor] = Box::leak(
337 arg_names
338 .iter()
339 .zip(arg_shapes.iter())
340 .map(|(&name, &shape)| ArgDescriptor { name, shape })
341 .collect::<Vec<_>>()
342 .into_boxed_slice(),
343 );
344
345 Box::leak(Box::new(MethodDescriptor {
346 id,
347 service_name,
348 method_name,
349 args,
350 return_shape: R::SHAPE,
351 doc,
352 }))
353}
354
355fn shape_contains_channel(shape: &'static Shape) -> bool {
356 fn visit(shape: &'static Shape, seen: &mut HashSet<usize>) -> bool {
357 if is_tx(shape) || is_rx(shape) {
358 return true;
359 }
360
361 let key = shape as *const Shape as usize;
362 if !seen.insert(key) {
363 return false;
364 }
365
366 if let Some(inner) = shape.inner
367 && visit(inner, seen)
368 {
369 return true;
370 }
371
372 if shape.type_params.iter().any(|t| visit(t.shape, seen)) {
373 return true;
374 }
375
376 match shape.def {
377 Def::List(list_def) => visit(list_def.t(), seen),
378 Def::Array(array_def) => visit(array_def.t(), seen),
379 Def::Slice(slice_def) => visit(slice_def.t(), seen),
380 Def::Map(map_def) => visit(map_def.k(), seen) || visit(map_def.v(), seen),
381 Def::Set(set_def) => visit(set_def.t(), seen),
382 Def::Option(opt_def) => visit(opt_def.t(), seen),
383 Def::Result(result_def) => visit(result_def.t(), seen) || visit(result_def.e(), seen),
384 Def::Pointer(ptr_def) => ptr_def.pointee.is_some_and(|p| visit(p, seen)),
385 _ => match shape.ty {
386 Type::User(UserType::Struct(s)) => s.fields.iter().any(|f| visit(f.shape(), seen)),
387 Type::User(UserType::Enum(e)) => e
388 .variants
389 .iter()
390 .any(|v| v.data.fields.iter().any(|f| visit(f.shape(), seen))),
391 _ => false,
392 },
393 }
394 }
395
396 let mut seen = HashSet::new();
397 visit(shape, &mut seen)
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use facet::Facet;
404 use roam_types::{Rx, Tx};
405
406 #[derive(Facet)]
407 struct PlainRet {
408 value: u64,
409 }
410
411 #[derive(Facet)]
412 struct NestedRet {
413 nested: Option<Result<Rx<u8>, u32>>,
414 }
415
416 #[test]
417 fn allows_non_channel_return_types() {
418 let _ = method_descriptor::<(), PlainRet>("TestSvc", "plain", &[], None);
419 }
420
421 #[test]
422 #[should_panic(expected = "channels are not allowed in return types: TestSvc.nested")]
423 fn rejects_nested_channel_in_return_types() {
424 let _ = method_descriptor::<(Tx<u8>,), NestedRet>("TestSvc", "nested", &["input"], None);
425 }
426
427 #[test]
428 fn encode_varint_encodes_expected_boundaries() {
429 let mut out = Vec::new();
430 encode_varint_u64(0, &mut out);
431 assert_eq!(out, vec![0x00]);
432
433 out.clear();
434 encode_varint_u64(127, &mut out);
435 assert_eq!(out, vec![0x7F]);
436
437 out.clear();
438 encode_varint_u64(128, &mut out);
439 assert_eq!(out, vec![0x80, 0x01]);
440
441 out.clear();
442 encode_varint_u64(300, &mut out);
443 assert_eq!(out, vec![0xAC, 0x02]);
444 }
445
446 #[test]
447 fn method_id_is_stable_and_uses_kebab_case_names() {
448 let a = method_id::<(u32,), u64>("MyService", "DoThingFast");
449 let b = method_id::<(u32,), u64>("my-service", "do-thing-fast");
450 let c = method_id::<(u32,), u64>("MY_SERVICE", "DO_THING_FAST");
451 assert_eq!(a, b);
452 assert_eq!(b, c);
453 }
454
455 #[test]
456 fn method_id_changes_when_signature_changes() {
457 let a = method_id::<(u32,), u64>("svc", "m");
458 let b = method_id::<(u64,), u64>("svc", "m");
459 let c = method_id::<(u32,), u32>("svc", "m");
460 assert_ne!(a, b);
461 assert_ne!(a, c);
462 }
463
464 #[test]
465 fn method_descriptor_populates_args_and_doc() {
466 let descriptor = method_descriptor::<(u32, String), PlainRet>(
467 "Svc",
468 "do_it",
469 &["count", "name"],
470 Some("doc"),
471 );
472 assert_eq!(descriptor.service_name, "Svc");
473 assert_eq!(descriptor.method_name, "do_it");
474 assert_eq!(descriptor.args.len(), 2);
475 assert_eq!(descriptor.args[0].name, "count");
476 assert_eq!(descriptor.args[1].name, "name");
477 assert_eq!(descriptor.doc, Some("doc"));
478 }
479
480 #[test]
481 #[should_panic(expected = "arg_names length mismatch for Svc.bad")]
482 fn method_descriptor_panics_when_arg_names_length_mismatches_shape() {
483 let _ = method_descriptor::<(u32, u64), PlainRet>("Svc", "bad", &["only_one"], None);
484 }
485
486 #[test]
487 fn list_of_u8_uses_bytes_tag_while_other_lists_do_not() {
488 let mut vec_u8_sig = Vec::new();
489 encode_shape(<Vec<u8> as Facet>::SHAPE, &mut vec_u8_sig);
490 assert_eq!(vec_u8_sig, vec![sig::BYTES]);
491
492 let mut vec_u16_sig = Vec::new();
493 encode_shape(<Vec<u16> as Facet>::SHAPE, &mut vec_u16_sig);
494
495 assert_ne!(vec_u8_sig, vec_u16_sig);
496 assert_eq!(vec_u16_sig[0], sig::LIST);
497 }
498
499 #[test]
500 fn shape_contains_channel_handles_recursive_and_non_recursive_shapes() {
501 #[derive(Facet)]
502 struct Recursive {
503 next: Option<Box<Recursive>>,
504 }
505
506 #[derive(Facet)]
507 struct ChannelNested {
508 inner: Option<Result<Tx<u16>, u8>>,
509 }
510
511 assert!(!shape_contains_channel(Recursive::SHAPE));
512 assert!(shape_contains_channel(ChannelNested::SHAPE));
513 }
514
515 #[test]
516 fn encode_shape_emits_expected_scalar_and_container_tags() {
517 fn head(shape: &'static facet_core::Shape) -> u8 {
518 let mut out = Vec::new();
519 encode_shape(shape, &mut out);
520 out[0]
521 }
522
523 assert_eq!(head(<bool as Facet>::SHAPE), sig::BOOL);
524 assert_eq!(head(<u64 as Facet>::SHAPE), sig::U64);
525 assert_eq!(head(<i32 as Facet>::SHAPE), sig::I32);
526 assert_eq!(head(<String as Facet>::SHAPE), sig::STRING);
527 assert_eq!(head(<Option<u8> as Facet>::SHAPE), sig::OPTION);
528 assert_eq!(head(<Vec<u16> as Facet>::SHAPE), sig::LIST);
529 assert_eq!(head(<[u16; 4] as Facet>::SHAPE), sig::ARRAY);
530 assert_eq!(
531 head(<std::collections::BTreeMap<u8, u16> as Facet>::SHAPE),
532 sig::MAP
533 );
534 assert_eq!(
535 head(<std::collections::BTreeSet<u8> as Facet>::SHAPE),
536 sig::SET
537 );
538 assert_eq!(head(<(u8, u16) as Facet>::SHAPE), sig::TUPLE);
539 }
540
541 #[test]
542 fn encode_shape_marks_recursive_types_with_backref() {
543 #[derive(Facet)]
544 struct Node {
545 next: Option<Box<Node>>,
546 }
547
548 let mut out = Vec::new();
549 encode_shape(Node::SHAPE, &mut out);
550 assert!(
551 out.contains(&sig::BACKREF),
552 "recursive encoding should contain BACKREF marker"
553 );
554 }
555}