Skip to main content

roam_hash/
lib.rs

1#![deny(unsafe_code)]
2
3//! Hashing and method identity for roam.
4//!
5//! Encodes types using `facet::Shape` for signature hashing, following
6//! `docs/content/spec-sig.md`.
7
8use facet_core::{Def, Facet, ScalarType, Shape, StructKind, Type, UserType};
9use heck::ToKebabCase;
10use roam_types::{ArgDescriptor, MethodDescriptor, MethodId};
11use roam_types::{is_rx, is_tx};
12use std::collections::HashSet;
13
14/// Signature encoding tags for type serialization.
15mod sig {
16    // Primitives (0x01-0x11)
17    pub const BOOL: u8 = 0x01;
18    pub const U8: u8 = 0x02;
19    pub const U16: u8 = 0x03;
20    pub const U32: u8 = 0x04;
21    pub const U64: u8 = 0x05;
22    pub const U128: u8 = 0x06;
23    pub const I8: u8 = 0x07;
24    pub const I16: u8 = 0x08;
25    pub const I32: u8 = 0x09;
26    pub const I64: u8 = 0x0A;
27    pub const I128: u8 = 0x0B;
28    pub const F32: u8 = 0x0C;
29    pub const F64: u8 = 0x0D;
30    pub const CHAR: u8 = 0x0E;
31    pub const STRING: u8 = 0x0F;
32    pub const UNIT: u8 = 0x10;
33    pub const BYTES: u8 = 0x11;
34
35    // Containers (0x20-0x27)
36    pub const LIST: u8 = 0x20;
37    pub const OPTION: u8 = 0x21;
38    pub const ARRAY: u8 = 0x22;
39    pub const MAP: u8 = 0x23;
40    pub const SET: u8 = 0x24;
41    pub const TUPLE: u8 = 0x25;
42    pub const TX: u8 = 0x26;
43    pub const RX: u8 = 0x27;
44
45    // Composite (0x30-0x32)
46    pub const STRUCT: u8 = 0x30;
47    pub const ENUM: u8 = 0x31;
48    pub const BACKREF: u8 = 0x32;
49
50    // Variant payloads
51    pub const VARIANT_UNIT: u8 = 0x00;
52    pub const VARIANT_NEWTYPE: u8 = 0x01;
53    pub const VARIANT_STRUCT: u8 = 0x02;
54}
55
56// r[impl signature.varint]
57fn encode_varint_u64(mut value: u64, out: &mut Vec<u8>) {
58    while value >= 0x80 {
59        out.push((value as u8) | 0x80);
60        value >>= 7;
61    }
62    out.push(value as u8);
63}
64
65fn encode_str(s: &str, out: &mut Vec<u8>) {
66    encode_varint_u64(s.len() as u64, out);
67    out.extend_from_slice(s.as_bytes());
68}
69
70/// Encode a `Shape` into its canonical signature byte representation.
71// r[impl signature.primitive]
72// r[impl signature.container]
73// r[impl signature.struct]
74// r[impl signature.enum]
75// r[impl signature.recursive]
76// r[impl signature.recursive.encoding]
77// r[impl signature.recursive.stack]
78fn encode_shape(shape: &'static Shape, out: &mut Vec<u8>) {
79    let mut stack: Vec<&'static Shape> = Vec::new();
80    encode_shape_inner(shape, out, &mut stack);
81}
82
83fn encode_shape_inner(shape: &'static Shape, out: &mut Vec<u8>, stack: &mut Vec<&'static Shape>) {
84    // Channel types
85    if is_tx(shape) {
86        out.push(sig::TX);
87        if let Some(inner) = shape.type_params.first() {
88            encode_shape_inner(inner.shape, out, stack);
89        }
90        return;
91    }
92    if is_rx(shape) {
93        out.push(sig::RX);
94        if let Some(inner) = shape.type_params.first() {
95            encode_shape_inner(inner.shape, out, stack);
96        }
97        return;
98    }
99
100    // Transparent wrappers
101    if shape.is_transparent()
102        && let Some(inner) = shape.inner
103    {
104        encode_shape_inner(inner, out, stack);
105        return;
106    }
107
108    // Scalars
109    if let Some(scalar) = shape.scalar_type() {
110        encode_scalar(scalar, out);
111        return;
112    }
113
114    // Semantic definitions
115    match shape.def {
116        Def::List(list_def) => {
117            if let Some(ScalarType::U8) = list_def.t().scalar_type() {
118                // r[impl signature.bytes.equivalence]
119                out.push(sig::BYTES);
120            } else {
121                out.push(sig::LIST);
122                encode_shape_inner(list_def.t(), out, stack);
123            }
124            return;
125        }
126        Def::Array(array_def) => {
127            out.push(sig::ARRAY);
128            encode_varint_u64(array_def.n as u64, out);
129            encode_shape_inner(array_def.t(), out, stack);
130            return;
131        }
132        Def::Slice(slice_def) => {
133            out.push(sig::LIST);
134            encode_shape_inner(slice_def.t(), out, stack);
135            return;
136        }
137        Def::Map(map_def) => {
138            out.push(sig::MAP);
139            encode_shape_inner(map_def.k(), out, stack);
140            encode_shape_inner(map_def.v(), out, stack);
141            return;
142        }
143        Def::Set(set_def) => {
144            out.push(sig::SET);
145            encode_shape_inner(set_def.t(), out, stack);
146            return;
147        }
148        Def::Option(opt_def) => {
149            out.push(sig::OPTION);
150            encode_shape_inner(opt_def.t(), out, stack);
151            return;
152        }
153        Def::Pointer(ptr_def) => {
154            if let Some(pointee) = ptr_def.pointee {
155                encode_shape_inner(pointee, out, stack);
156                return;
157            }
158        }
159        _ => {}
160    }
161
162    // Cycle detection for user-defined types: check if this shape is
163    // already on the encoding stack (indicates recursion).
164    if let Some(pos) = stack.iter().rposition(|&s| s == shape) {
165        // Depth = distance from top of stack (0 = immediate parent)
166        let depth = stack.len() - 1 - pos;
167        out.push(sig::BACKREF);
168        encode_varint_u64(depth as u64, out);
169        return;
170    }
171
172    // Push onto stack before encoding children, pop after.
173    stack.push(shape);
174
175    match shape.ty {
176        Type::User(UserType::Struct(struct_type)) => match struct_type.kind {
177            StructKind::Unit => {
178                out.push(sig::UNIT);
179            }
180            StructKind::TupleStruct | StructKind::Tuple => {
181                out.push(sig::TUPLE);
182                encode_varint_u64(struct_type.fields.len() as u64, out);
183                for field in struct_type.fields {
184                    encode_shape_inner(field.shape(), out, stack);
185                }
186            }
187            StructKind::Struct => {
188                out.push(sig::STRUCT);
189                encode_varint_u64(struct_type.fields.len() as u64, out);
190                for field in struct_type.fields {
191                    encode_str(field.name, out);
192                    encode_shape_inner(field.shape(), out, stack);
193                }
194            }
195        },
196        Type::User(UserType::Enum(enum_type)) => {
197            out.push(sig::ENUM);
198            encode_varint_u64(enum_type.variants.len() as u64, out);
199            for variant in enum_type.variants {
200                encode_str(variant.name, out);
201                match variant.data.kind {
202                    StructKind::Unit => {
203                        out.push(sig::VARIANT_UNIT);
204                    }
205                    StructKind::TupleStruct | StructKind::Tuple => {
206                        if variant.data.fields.len() == 1 {
207                            out.push(sig::VARIANT_NEWTYPE);
208                            encode_shape_inner(variant.data.fields[0].shape(), out, stack);
209                        } else {
210                            out.push(sig::VARIANT_STRUCT);
211                            encode_varint_u64(variant.data.fields.len() as u64, out);
212                            for (i, field) in variant.data.fields.iter().enumerate() {
213                                encode_str(&i.to_string(), out);
214                                encode_shape_inner(field.shape(), out, stack);
215                            }
216                        }
217                    }
218                    StructKind::Struct => {
219                        out.push(sig::VARIANT_STRUCT);
220                        encode_varint_u64(variant.data.fields.len() as u64, out);
221                        for field in variant.data.fields {
222                            encode_str(field.name, out);
223                            encode_shape_inner(field.shape(), out, stack);
224                        }
225                    }
226                }
227            }
228        }
229        Type::Pointer(_) => {
230            if let Some(inner) = shape.type_params.first() {
231                encode_shape_inner(inner.shape, out, stack);
232            } else {
233                out.push(sig::UNIT);
234            }
235        }
236        _ => {
237            out.push(sig::UNIT);
238        }
239    }
240
241    stack.pop();
242}
243
244fn encode_scalar(scalar: ScalarType, out: &mut Vec<u8>) {
245    match scalar {
246        ScalarType::Unit => out.push(sig::UNIT),
247        ScalarType::Bool => out.push(sig::BOOL),
248        ScalarType::Char => out.push(sig::CHAR),
249        ScalarType::Str | ScalarType::String | ScalarType::CowStr => out.push(sig::STRING),
250        ScalarType::F32 => out.push(sig::F32),
251        ScalarType::F64 => out.push(sig::F64),
252        ScalarType::U8 => out.push(sig::U8),
253        ScalarType::U16 => out.push(sig::U16),
254        ScalarType::U32 => out.push(sig::U32),
255        ScalarType::U64 => out.push(sig::U64),
256        ScalarType::U128 => out.push(sig::U128),
257        ScalarType::USize => out.push(sig::U64), // portable: usize → u64
258        ScalarType::I8 => out.push(sig::I8),
259        ScalarType::I16 => out.push(sig::I16),
260        ScalarType::I32 => out.push(sig::I32),
261        ScalarType::I64 => out.push(sig::I64),
262        ScalarType::I128 => out.push(sig::I128),
263        ScalarType::ISize => out.push(sig::I64), // portable: isize → i64
264        ScalarType::ConstTypeId => out.push(sig::U64),
265        _ => out.push(sig::UNIT),
266    }
267}
268
269/// Encode a method signature: args tuple type followed by return type.
270// r[impl rpc.schema-evolution]
271// r[impl signature.method]
272// r[impl signature.hash.algorithm]
273fn encode_method_signature(args: &'static Shape, return_type: &'static Shape, out: &mut Vec<u8>) {
274    encode_shape(args, out);
275    encode_shape(return_type, out);
276}
277
278/// Compute the final method ID from type parameters.
279///
280/// `A` is the args tuple type (e.g. `(f64, f64)`), `R` is the return type.
281// r[impl rpc.method-id]
282// r[impl rpc.method-id.algorithm]
283// r[impl rpc.method-id.no-collisions]
284// r[impl method.identity.computation]
285// r[impl signature.endianness]
286pub fn method_id<'a, 'r, A: Facet<'a>, R: Facet<'r>>(
287    service_name: &str,
288    method_name: &str,
289) -> MethodId {
290    let mut sig_bytes = Vec::new();
291    encode_method_signature(A::SHAPE, R::SHAPE, &mut sig_bytes);
292    let sig_hash = blake3::hash(&sig_bytes);
293
294    let mut input = Vec::new();
295    input.extend_from_slice(service_name.to_kebab_case().as_bytes());
296    input.push(b'.');
297    input.extend_from_slice(method_name.to_kebab_case().as_bytes());
298    input.extend_from_slice(sig_hash.as_bytes());
299    let h = blake3::hash(&input);
300    let first8: [u8; 8] = h.as_bytes()[0..8].try_into().expect("slice len");
301    MethodId(u64::from_le_bytes(first8))
302}
303
304/// Build and leak a `MethodDescriptor` from type parameters and arg names.
305///
306/// Called once per method inside a `OnceLock::get_or_init` in macro-generated code.
307/// `A` is the args tuple type, `R` is the return type.
308pub fn method_descriptor<'a, 'r, A: Facet<'a>, R: Facet<'r>>(
309    service_name: &'static str,
310    method_name: &'static str,
311    arg_names: &[&'static str],
312    doc: Option<&'static str>,
313) -> &'static MethodDescriptor {
314    assert!(
315        !shape_contains_channel(R::SHAPE),
316        "channels are not allowed in return types: {service_name}.{method_name}"
317    );
318
319    let id = method_id::<A, R>(service_name, method_name);
320
321    // Extract per-arg shapes from the tuple fields of A::SHAPE.
322    let arg_shapes: &[&'static Shape] = match A::SHAPE.ty {
323        Type::User(UserType::Struct(s)) => {
324            let fields: Vec<&'static Shape> = s.fields.iter().map(|f| f.shape()).collect();
325            Box::leak(fields.into_boxed_slice())
326        }
327        _ => &[],
328    };
329
330    assert_eq!(
331        arg_names.len(),
332        arg_shapes.len(),
333        "arg_names length mismatch for {service_name}.{method_name}"
334    );
335
336    let args: &'static [ArgDescriptor] = Box::leak(
337        arg_names
338            .iter()
339            .zip(arg_shapes.iter())
340            .map(|(&name, &shape)| ArgDescriptor { name, shape })
341            .collect::<Vec<_>>()
342            .into_boxed_slice(),
343    );
344
345    Box::leak(Box::new(MethodDescriptor {
346        id,
347        service_name,
348        method_name,
349        args,
350        return_shape: R::SHAPE,
351        doc,
352    }))
353}
354
355fn shape_contains_channel(shape: &'static Shape) -> bool {
356    fn visit(shape: &'static Shape, seen: &mut HashSet<usize>) -> bool {
357        if is_tx(shape) || is_rx(shape) {
358            return true;
359        }
360
361        let key = shape as *const Shape as usize;
362        if !seen.insert(key) {
363            return false;
364        }
365
366        if let Some(inner) = shape.inner
367            && visit(inner, seen)
368        {
369            return true;
370        }
371
372        if shape.type_params.iter().any(|t| visit(t.shape, seen)) {
373            return true;
374        }
375
376        match shape.def {
377            Def::List(list_def) => visit(list_def.t(), seen),
378            Def::Array(array_def) => visit(array_def.t(), seen),
379            Def::Slice(slice_def) => visit(slice_def.t(), seen),
380            Def::Map(map_def) => visit(map_def.k(), seen) || visit(map_def.v(), seen),
381            Def::Set(set_def) => visit(set_def.t(), seen),
382            Def::Option(opt_def) => visit(opt_def.t(), seen),
383            Def::Result(result_def) => visit(result_def.t(), seen) || visit(result_def.e(), seen),
384            Def::Pointer(ptr_def) => ptr_def.pointee.is_some_and(|p| visit(p, seen)),
385            _ => match shape.ty {
386                Type::User(UserType::Struct(s)) => s.fields.iter().any(|f| visit(f.shape(), seen)),
387                Type::User(UserType::Enum(e)) => e
388                    .variants
389                    .iter()
390                    .any(|v| v.data.fields.iter().any(|f| visit(f.shape(), seen))),
391                _ => false,
392            },
393        }
394    }
395
396    let mut seen = HashSet::new();
397    visit(shape, &mut seen)
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use facet::Facet;
404    use roam_types::{Rx, Tx};
405
406    #[derive(Facet)]
407    struct PlainRet {
408        value: u64,
409    }
410
411    #[derive(Facet)]
412    struct NestedRet {
413        nested: Option<Result<Rx<u8>, u32>>,
414    }
415
416    #[test]
417    fn allows_non_channel_return_types() {
418        let _ = method_descriptor::<(), PlainRet>("TestSvc", "plain", &[], None);
419    }
420
421    #[test]
422    #[should_panic(expected = "channels are not allowed in return types: TestSvc.nested")]
423    fn rejects_nested_channel_in_return_types() {
424        let _ = method_descriptor::<(Tx<u8>,), NestedRet>("TestSvc", "nested", &["input"], None);
425    }
426
427    #[test]
428    fn encode_varint_encodes_expected_boundaries() {
429        let mut out = Vec::new();
430        encode_varint_u64(0, &mut out);
431        assert_eq!(out, vec![0x00]);
432
433        out.clear();
434        encode_varint_u64(127, &mut out);
435        assert_eq!(out, vec![0x7F]);
436
437        out.clear();
438        encode_varint_u64(128, &mut out);
439        assert_eq!(out, vec![0x80, 0x01]);
440
441        out.clear();
442        encode_varint_u64(300, &mut out);
443        assert_eq!(out, vec![0xAC, 0x02]);
444    }
445
446    #[test]
447    fn method_id_is_stable_and_uses_kebab_case_names() {
448        let a = method_id::<(u32,), u64>("MyService", "DoThingFast");
449        let b = method_id::<(u32,), u64>("my-service", "do-thing-fast");
450        let c = method_id::<(u32,), u64>("MY_SERVICE", "DO_THING_FAST");
451        assert_eq!(a, b);
452        assert_eq!(b, c);
453    }
454
455    #[test]
456    fn method_id_changes_when_signature_changes() {
457        let a = method_id::<(u32,), u64>("svc", "m");
458        let b = method_id::<(u64,), u64>("svc", "m");
459        let c = method_id::<(u32,), u32>("svc", "m");
460        assert_ne!(a, b);
461        assert_ne!(a, c);
462    }
463
464    #[test]
465    fn method_descriptor_populates_args_and_doc() {
466        let descriptor = method_descriptor::<(u32, String), PlainRet>(
467            "Svc",
468            "do_it",
469            &["count", "name"],
470            Some("doc"),
471        );
472        assert_eq!(descriptor.service_name, "Svc");
473        assert_eq!(descriptor.method_name, "do_it");
474        assert_eq!(descriptor.args.len(), 2);
475        assert_eq!(descriptor.args[0].name, "count");
476        assert_eq!(descriptor.args[1].name, "name");
477        assert_eq!(descriptor.doc, Some("doc"));
478    }
479
480    #[test]
481    #[should_panic(expected = "arg_names length mismatch for Svc.bad")]
482    fn method_descriptor_panics_when_arg_names_length_mismatches_shape() {
483        let _ = method_descriptor::<(u32, u64), PlainRet>("Svc", "bad", &["only_one"], None);
484    }
485
486    #[test]
487    fn list_of_u8_uses_bytes_tag_while_other_lists_do_not() {
488        let mut vec_u8_sig = Vec::new();
489        encode_shape(<Vec<u8> as Facet>::SHAPE, &mut vec_u8_sig);
490        assert_eq!(vec_u8_sig, vec![sig::BYTES]);
491
492        let mut vec_u16_sig = Vec::new();
493        encode_shape(<Vec<u16> as Facet>::SHAPE, &mut vec_u16_sig);
494
495        assert_ne!(vec_u8_sig, vec_u16_sig);
496        assert_eq!(vec_u16_sig[0], sig::LIST);
497    }
498
499    #[test]
500    fn shape_contains_channel_handles_recursive_and_non_recursive_shapes() {
501        #[derive(Facet)]
502        struct Recursive {
503            next: Option<Box<Recursive>>,
504        }
505
506        #[derive(Facet)]
507        struct ChannelNested {
508            inner: Option<Result<Tx<u16>, u8>>,
509        }
510
511        assert!(!shape_contains_channel(Recursive::SHAPE));
512        assert!(shape_contains_channel(ChannelNested::SHAPE));
513    }
514
515    #[test]
516    fn encode_shape_emits_expected_scalar_and_container_tags() {
517        fn head(shape: &'static facet_core::Shape) -> u8 {
518            let mut out = Vec::new();
519            encode_shape(shape, &mut out);
520            out[0]
521        }
522
523        assert_eq!(head(<bool as Facet>::SHAPE), sig::BOOL);
524        assert_eq!(head(<u64 as Facet>::SHAPE), sig::U64);
525        assert_eq!(head(<i32 as Facet>::SHAPE), sig::I32);
526        assert_eq!(head(<String as Facet>::SHAPE), sig::STRING);
527        assert_eq!(head(<Option<u8> as Facet>::SHAPE), sig::OPTION);
528        assert_eq!(head(<Vec<u16> as Facet>::SHAPE), sig::LIST);
529        assert_eq!(head(<[u16; 4] as Facet>::SHAPE), sig::ARRAY);
530        assert_eq!(
531            head(<std::collections::BTreeMap<u8, u16> as Facet>::SHAPE),
532            sig::MAP
533        );
534        assert_eq!(
535            head(<std::collections::BTreeSet<u8> as Facet>::SHAPE),
536            sig::SET
537        );
538        assert_eq!(head(<(u8, u16) as Facet>::SHAPE), sig::TUPLE);
539    }
540
541    #[test]
542    fn encode_shape_marks_recursive_types_with_backref() {
543        #[derive(Facet)]
544        struct Node {
545            next: Option<Box<Node>>,
546        }
547
548        let mut out = Vec::new();
549        encode_shape(Node::SHAPE, &mut out);
550        assert!(
551            out.contains(&sig::BACKREF),
552            "recursive encoding should contain BACKREF marker"
553        );
554    }
555}