candid/types/
internal.rs

1use super::CandidType;
2use crate::idl_hash;
3use std::cell::RefCell;
4use std::cmp::Ordering;
5use std::collections::BTreeMap;
6use std::fmt;
7
8// This is a re-implementation of std::any::TypeId to get rid of 'static constraint.
9// The current TypeId doesn't consider lifetime while computing the hash, which is
10// totally fine for Candid type, as we don't care about lifetime at all.
11#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
12pub struct TypeId {
13    id: usize,
14    pub name: &'static str,
15}
16impl TypeId {
17    pub fn of<T: ?Sized>() -> Self {
18        let name = std::any::type_name::<T>();
19        TypeId {
20            id: TypeId::of::<T> as usize,
21            name,
22        }
23    }
24}
25impl std::fmt::Display for TypeId {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        let name = NAME.with(|n| n.borrow_mut().get(self));
28        write!(f, "{name}")
29    }
30}
31pub fn type_of<T>(_: &T) -> TypeId {
32    TypeId::of::<T>()
33}
34
35#[derive(Default)]
36struct TypeName {
37    type_name: BTreeMap<TypeId, String>,
38    name_index: BTreeMap<String, usize>,
39}
40impl TypeName {
41    fn get(&mut self, id: &TypeId) -> String {
42        match self.type_name.get(id) {
43            Some(n) => n.to_string(),
44            None => {
45                // The format of id.name is unspecified, and doesn't guarantee to be unique.
46                // Splitting by "::" is not ideal, as we can get types like std::Box<lib::List>, HashMap<lib::K, V>
47                // This is not a problem for correctness, but I may get misleading names.
48                let name = id.name.split('<').next().unwrap();
49                let name = name.rsplit("::").next().unwrap();
50                let name = name
51                    .chars()
52                    .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
53                    .collect::<String>()
54                    .trim_end_matches('_')
55                    .to_string();
56                let res = match self.name_index.get_mut(&name) {
57                    None => {
58                        self.name_index.insert(name.clone(), 0);
59                        name
60                    }
61                    Some(v) => {
62                        *v += 1;
63                        format!("{name}_{v}")
64                    }
65                };
66                self.type_name.insert(id.clone(), res.clone());
67                res
68            }
69        }
70    }
71}
72
73/// Used for `candid_derive::export_service` to generate `TypeEnv` from `Type`.
74///
75/// It performs a global rewriting of `Type` to resolve:
76/// * Duplicate type names in different modules/namespaces.
77/// * Give different names to instantiated polymorphic types.
78/// * Find the type name of a recursive node `Knot(TypeId)` and convert to `Var` node.
79///
80/// There are some drawbacks of this approach:
81/// * The type name is based on `type_name::<T>()`, whose format is unspecified and long. We use some regex to shorten the name.
82/// * Several Rust types can map to the same Candid type, and we only get to remember one name (currently we choose the shortest name). As a result, some of the type names in Rust is lost.
83/// * Unless we do equivalence checking, recursive types can be unrolled and assigned to multiple names.
84#[derive(Default)]
85pub struct TypeContainer {
86    pub env: crate::TypeEnv,
87}
88impl TypeContainer {
89    pub fn new() -> Self {
90        TypeContainer {
91            env: crate::TypeEnv::new(),
92        }
93    }
94    pub fn add<T: CandidType>(&mut self) -> Type {
95        let t = T::ty();
96        self.go(&t)
97    }
98    fn go(&mut self, t: &Type) -> Type {
99        match t.as_ref() {
100            TypeInner::Opt(t) => TypeInner::Opt(self.go(t)),
101            TypeInner::Vec(t) => TypeInner::Vec(self.go(t)),
102            TypeInner::Record(fs) => {
103                let res: Type = TypeInner::Record(
104                    fs.iter()
105                        .map(|Field { id, ty }| Field {
106                            id: id.clone(),
107                            ty: self.go(ty),
108                        })
109                        .collect(),
110                )
111                .into();
112                if t.is_tuple() {
113                    return res;
114                }
115                let id = ID.with(|n| n.borrow().get(t).cloned());
116                if let Some(id) = id {
117                    self.env.0.insert(id.to_string(), res);
118                    TypeInner::Var(id.to_string())
119                } else {
120                    // if the type is part of an enum, the id won't be recorded.
121                    // we want to inline the type in this case.
122                    return res;
123                }
124            }
125            TypeInner::Variant(fs) => {
126                let res: Type = TypeInner::Variant(
127                    fs.iter()
128                        .map(|Field { id, ty }| Field {
129                            id: id.clone(),
130                            ty: self.go(ty),
131                        })
132                        .collect(),
133                )
134                .into();
135                let id = ID.with(|n| n.borrow().get(t).cloned());
136                if let Some(id) = id {
137                    self.env.0.insert(id.to_string(), res);
138                    TypeInner::Var(id.to_string())
139                } else {
140                    return res;
141                }
142            }
143            TypeInner::Knot(id) => {
144                let name = id.to_string();
145                let ty = ENV.with(|e| e.borrow().get(id).unwrap().clone());
146                self.env.0.insert(id.to_string(), ty);
147                TypeInner::Var(name)
148            }
149            TypeInner::Func(func) => TypeInner::Func(Function {
150                modes: func.modes.clone(),
151                args: func.args.iter().map(|arg| self.go(arg)).collect(),
152                rets: func.rets.iter().map(|arg| self.go(arg)).collect(),
153            }),
154            TypeInner::Service(serv) => TypeInner::Service(
155                serv.iter()
156                    .map(|(id, t)| (id.clone(), self.go(t)))
157                    .collect(),
158            ),
159            TypeInner::Class(inits, ref ty) => {
160                TypeInner::Class(inits.iter().map(|t| self.go(t)).collect(), self.go(ty))
161            }
162            t => t.clone(),
163        }
164        .into()
165    }
166}
167
168#[derive(Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)]
169pub struct Type(pub std::rc::Rc<TypeInner>);
170
171#[derive(Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)]
172pub enum TypeInner {
173    Null,
174    Bool,
175    Nat,
176    Int,
177    Nat8,
178    Nat16,
179    Nat32,
180    Nat64,
181    Int8,
182    Int16,
183    Int32,
184    Int64,
185    Float32,
186    Float64,
187    Text,
188    Reserved,
189    Empty,
190    Knot(TypeId), // For recursive types from Rust
191    Var(String),  // For variables from Candid file
192    Unknown,
193    Opt(Type),
194    Vec(Type),
195    Record(Vec<Field>),
196    Variant(Vec<Field>),
197    Func(Function),
198    Service(Vec<(String, Type)>),
199    Class(Vec<Type>, Type),
200    Principal,
201    Future,
202}
203impl std::ops::Deref for Type {
204    type Target = TypeInner;
205    fn deref(&self) -> &TypeInner {
206        &self.0
207    }
208}
209impl AsRef<TypeInner> for Type {
210    fn as_ref(&self) -> &TypeInner {
211        self.0.as_ref()
212    }
213}
214impl From<TypeInner> for Type {
215    fn from(t: TypeInner) -> Self {
216        Type(t.into())
217    }
218}
219impl TypeInner {
220    pub fn is_tuple(&self) -> bool {
221        match self {
222            TypeInner::Record(ref fs) => {
223                for (i, field) in fs.iter().enumerate() {
224                    if field.id.get_id() != (i as u32) {
225                        return false;
226                    }
227                }
228                true
229            }
230            _ => false,
231        }
232    }
233    pub fn is_blob(&self, env: &crate::TypeEnv) -> bool {
234        match self {
235            TypeInner::Vec(t) => {
236                let Ok(t) = env.trace_type(t) else {
237                    return false;
238                };
239                matches!(*t, TypeInner::Nat8)
240            }
241            _ => false,
242        }
243    }
244}
245impl Type {
246    pub fn is_tuple(&self) -> bool {
247        self.as_ref().is_tuple()
248    }
249    pub fn is_blob(&self, env: &crate::TypeEnv) -> bool {
250        self.as_ref().is_blob(env)
251    }
252    pub fn subst(&self, tau: &std::collections::BTreeMap<String, String>) -> Self {
253        use TypeInner::*;
254        match self.as_ref() {
255            Var(id) => match tau.get(id) {
256                None => Var(id.to_string()),
257                Some(new_id) => Var(new_id.to_string()),
258            },
259            Opt(t) => Opt(t.subst(tau)),
260            Vec(t) => Vec(t.subst(tau)),
261            Record(fs) => Record(
262                fs.iter()
263                    .map(|Field { id, ty }| Field {
264                        id: id.clone(),
265                        ty: ty.subst(tau),
266                    })
267                    .collect(),
268            ),
269            Variant(fs) => Variant(
270                fs.iter()
271                    .map(|Field { id, ty }| Field {
272                        id: id.clone(),
273                        ty: ty.subst(tau),
274                    })
275                    .collect(),
276            ),
277            Func(func) => {
278                let func = func.clone();
279                Func(Function {
280                    modes: func.modes,
281                    args: func.args.into_iter().map(|t| t.subst(tau)).collect(),
282                    rets: func.rets.into_iter().map(|t| t.subst(tau)).collect(),
283                })
284            }
285            Service(serv) => Service(
286                serv.iter()
287                    .map(|(meth, ty)| (meth.clone(), ty.subst(tau)))
288                    .collect(),
289            ),
290            Class(args, ty) => Class(args.iter().map(|t| t.subst(tau)).collect(), ty.subst(tau)),
291            _ => return self.clone(),
292        }
293        .into()
294    }
295}
296#[cfg(feature = "printer")]
297impl fmt::Display for Type {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        write!(f, "{}", crate::pretty::candid::pp_ty(self).pretty(80))
300    }
301}
302#[cfg(feature = "printer")]
303impl fmt::Display for TypeInner {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        write!(f, "{}", crate::pretty::candid::pp_ty_inner(self).pretty(80))
306    }
307}
308#[cfg(not(feature = "printer"))]
309impl fmt::Display for Type {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        write!(f, "{:?}", self)
312    }
313}
314#[cfg(not(feature = "printer"))]
315impl fmt::Display for TypeInner {
316    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
317        write!(f, "{:?}", self)
318    }
319}
320#[allow(clippy::result_unit_err)]
321pub fn text_size(t: &Type, limit: i32) -> Result<i32, ()> {
322    use TypeInner::*;
323    if limit <= 1 {
324        return Err(());
325    }
326    let cost = match t.as_ref() {
327        Null | Bool | Text | Nat8 | Int8 => 4,
328        Nat | Int => 3,
329        Nat16 | Nat32 | Nat64 | Int16 | Int32 | Int64 | Empty => 5,
330        Float32 | Float64 => 7,
331        Reserved => 8,
332        Principal => 9,
333        Knot(_) => 10,
334        Var(id) => id.len() as i32,
335        Opt(t) => 4 + text_size(t, limit - 4)?,
336        Vec(t) => 4 + text_size(t, limit - 4)?,
337        Record(fs) | Variant(fs) => {
338            let mut cnt = 0;
339            let mut limit = limit;
340            for f in fs {
341                let id_size = match f.id.as_ref() {
342                    Label::Named(n) => n.len() as i32,
343                    Label::Id(_) => 4,
344                    Label::Unnamed(_) => 0,
345                };
346                cnt += id_size + text_size(&f.ty, limit - id_size - 3)? + 3;
347                limit -= cnt;
348            }
349            9 + cnt
350        }
351        Func(func) => {
352            let mode = if func.modes.is_empty() { 0 } else { 6 };
353            let mut cnt = mode + 6;
354            let mut limit = limit - cnt;
355            for t in &func.args {
356                cnt += text_size(t, limit)?;
357                limit -= cnt;
358            }
359            for t in &func.rets {
360                cnt += text_size(t, limit)?;
361                limit -= cnt;
362            }
363            cnt
364        }
365        Service(ms) => {
366            let mut cnt = 0;
367            let mut limit = limit;
368            for (name, f) in ms {
369                let len = name.len() as i32;
370                cnt += len + text_size(f, limit - len - 3)? + 3;
371                limit -= cnt;
372            }
373            10 + cnt
374        }
375        Future => 6,
376        Unknown => 7,
377        Class(..) => unreachable!(),
378    };
379    if cost > limit {
380        Err(())
381    } else {
382        Ok(cost)
383    }
384}
385
386#[derive(Debug, Clone)]
387pub enum Label {
388    Id(u32),
389    Named(String),
390    Unnamed(u32),
391}
392
393impl Label {
394    pub fn get_id(&self) -> u32 {
395        match *self {
396            Label::Id(n) | Label::Unnamed(n) => n,
397            Label::Named(ref n) => idl_hash(n),
398        }
399    }
400}
401
402impl std::fmt::Display for Label {
403    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
404        match self {
405            Label::Id(n) | Label::Unnamed(n) => {
406                write!(f, "{}", crate::utils::pp_num_str(&n.to_string()))
407            }
408            Label::Named(id) => write!(f, "{id}"),
409        }
410    }
411}
412
413impl PartialEq for Label {
414    fn eq(&self, other: &Self) -> bool {
415        self.get_id() == other.get_id()
416    }
417}
418
419impl Eq for Label {}
420
421impl PartialOrd for Label {
422    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
423        Some(self.cmp(other))
424    }
425}
426
427impl Ord for Label {
428    fn cmp(&self, other: &Self) -> Ordering {
429        self.get_id().cmp(&other.get_id())
430    }
431}
432
433impl std::hash::Hash for Label {
434    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
435        state.write_u32(self.get_id());
436    }
437}
438
439pub type SharedLabel = std::rc::Rc<Label>;
440
441#[derive(Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)]
442pub struct Field {
443    pub id: SharedLabel,
444    pub ty: Type,
445}
446#[cfg(feature = "printer")]
447impl fmt::Display for Field {
448    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449        write!(
450            f,
451            "{}",
452            crate::pretty::candid::pp_field(self, false).pretty(80)
453        )
454    }
455}
456#[cfg(not(feature = "printer"))]
457impl fmt::Display for Field {
458    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
459        write!(f, "{:?}", self)
460    }
461}
462
463#[macro_export]
464/// Construct a field type, which can be used in `TypeInner::Record` and `TypeInner::Variant`.
465///
466/// `field!{ a: TypeInner::Nat.into() }` expands to `Field { id: Label::Named("a"), ty: ... }`
467/// `field!{ 0: Nat::ty() }` expands to `Field { id: Label::Id(0), ty: ... }`
468macro_rules! field {
469    { $id:tt : $ty:expr } => {{
470        $crate::types::internal::Field {
471            id: match stringify!($id).parse::<u32>() {
472                Ok(id) => $crate::types::Label::Id(id),
473                Err(_) => $crate::types::Label::Named(stringify!($id).to_string()),
474            }.into(),
475            ty: $ty
476        }
477    }}
478}
479#[macro_export]
480/// Construct a record type, e.g., `record!{ label: Nat::ty(); 42: String::ty() }`.
481macro_rules! record {
482    { $($id:tt : $ty:expr);* $(;)? } => {{
483        let mut fs: Vec<$crate::types::internal::Field> = vec![ $($crate::field!{$id : $ty}),* ];
484        fs.sort_unstable_by_key(|f| f.id.get_id());
485        if let Err(e) = $crate::utils::check_unique(fs.iter().map(|f| &f.id)) {
486            panic!("{e}");
487        }
488        Into::<$crate::types::Type>::into($crate::types::TypeInner::Record(fs))
489    }}
490}
491#[macro_export]
492/// Construct a variant type, e.g., `variant!{ tag: <()>::ty() }`.
493macro_rules! variant {
494    { $($id:tt : $ty:expr);* $(;)? } => {{
495        let mut fs: Vec<$crate::types::internal::Field> = vec![ $($crate::field!{$id : $ty}),* ];
496        fs.sort_unstable_by_key(|f| f.id.get_id());
497        if let Err(e) = $crate::utils::check_unique(fs.iter().map(|f| &f.id)) {
498            panic!("{e}");
499        }
500        Into::<$crate::types::Type>::into($crate::types::TypeInner::Variant(fs))
501    }}
502}
503
504#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
505pub enum FuncMode {
506    Oneway,
507    Query,
508    CompositeQuery,
509}
510#[derive(Debug, PartialEq, Hash, Eq, Clone, PartialOrd, Ord)]
511pub struct Function {
512    pub modes: Vec<FuncMode>,
513    pub args: Vec<Type>,
514    pub rets: Vec<Type>,
515}
516
517#[cfg(feature = "printer")]
518impl fmt::Display for Function {
519    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
520        write!(f, "{}", crate::pretty::candid::pp_function(self).pretty(80))
521    }
522}
523#[cfg(not(feature = "printer"))]
524impl fmt::Display for Function {
525    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526        write!(f, "{:?}", self)
527    }
528}
529impl Function {
530    /// Check a function is a `query` or `composite_query` method
531    pub fn is_query(&self) -> bool {
532        self.modes
533            .iter()
534            .any(|m| matches!(m, FuncMode::Query | FuncMode::CompositeQuery))
535    }
536}
537#[macro_export]
538/// Construct a function type.
539///
540/// `func!((u8, &str) -> (Nat) query)` expands to `Type(Rc::new(TypeInner::Func(...)))`
541macro_rules! func {
542    ( ( $($arg:ty),* $(,)? ) -> ( $($ret:ty),* $(,)? ) ) => {
543        Into::<$crate::types::Type>::into($crate::types::TypeInner::Func($crate::types::Function { args: vec![$(<$arg>::ty()),*], rets: vec![$(<$ret>::ty()),*], modes: vec![] }))
544    };
545    ( ( $($arg:ty),* $(,)? ) -> ( $($ret:ty),* $(,)? ) query ) => {
546        Into::<$crate::types::Type>::into($crate::types::TypeInner::Func($crate::types::Function { args: vec![$(<$arg>::ty()),*], rets: vec![$(<$ret>::ty()),*], modes: vec![$crate::types::FuncMode::Query] }))
547    };
548    ( ( $($arg:ty),* $(,)? ) -> ( $($ret:ty),* $(,)? ) composite_query ) => {
549        Into::<$crate::types::Type>::into($crate::types::TypeInner::Func($crate::types::Function { args: vec![$(<$arg>::ty()),*], rets: vec![$(<$ret>::ty()),*], modes: vec![$crate::types::FuncMode::CompositeQuery] }))
550    };
551    ( ( $($arg:ty),* $(,)? ) -> ( $($ret:ty),* $(,)? ) oneway ) => {
552        Into::<$crate::types::Type>::into($crate::types::TypeInner::Func($crate::types::Function { args: vec![$(<$arg>::ty()),*], rets: vec![$(<$ret>::ty()),*], modes: vec![$crate::types::FuncMode::Oneway] }))
553    };
554}
555#[macro_export]
556/// Construct a service type.
557///
558/// `service!{ "f": func!((HttpRequest) -> ()) }` expands to `Type(Rc::new(TypeInner::Service(...)))`
559macro_rules! service {
560    { $($meth:tt : $ty:expr);* $(;)? } => {{
561        let mut ms: Vec<(String, $crate::types::Type)> = vec![ $(($meth.to_string(), $ty)),* ];
562        ms.sort_unstable_by(|a, b| a.0.as_str().partial_cmp(b.0.as_str()).unwrap());
563        if let Err(e) = $crate::utils::check_unique(ms.iter().map(|m| &m.0)) {
564            panic!("{e}");
565        }
566        Into::<$crate::types::Type>::into($crate::types::TypeInner::Service(ms))
567    }}
568}
569
570#[derive(Debug, PartialEq)]
571#[repr(i64)]
572pub enum Opcode {
573    Null = -1,
574    Bool = -2,
575    Nat = -3,
576    Int = -4,
577    Nat8 = -5,
578    Nat16 = -6,
579    Nat32 = -7,
580    Nat64 = -8,
581    Int8 = -9,
582    Int16 = -10,
583    Int32 = -11,
584    Int64 = -12,
585    Float32 = -13,
586    Float64 = -14,
587    Text = -15,
588    Reserved = -16,
589    Empty = -17,
590    Opt = -18,
591    Vec = -19,
592    Record = -20,
593    Variant = -21,
594    Func = -22,
595    Service = -23,
596    Principal = -24,
597}
598
599pub fn is_primitive(t: &Type) -> bool {
600    use self::TypeInner::*;
601    match t.as_ref() {
602        Null | Bool | Nat | Int | Text => true,
603        Nat8 | Nat16 | Nat32 | Nat64 => true,
604        Int8 | Int16 | Int32 | Int64 => true,
605        Float32 | Float64 => true,
606        Reserved | Empty => true,
607        Unknown => panic!("Unknown type"),
608        Future => panic!("Future type"),
609        Var(_) => panic!("Variable"), // Var may or may not be a primitive, so don't ask me
610        Knot(_) => true,
611        Opt(_) | Vec(_) | Record(_) | Variant(_) => false,
612        Func(_) | Service(_) | Class(_, _) => false,
613        Principal => true,
614    }
615}
616
617pub fn unroll(t: &Type) -> Type {
618    use self::TypeInner::*;
619    match t.as_ref() {
620        Knot(ref id) => return find_type(id).unwrap(),
621        Opt(ref t) => Opt(unroll(t)),
622        Vec(ref t) => Vec(unroll(t)),
623        Record(fs) => Record(
624            fs.iter()
625                .map(|Field { id, ty }| Field {
626                    id: id.clone(),
627                    ty: unroll(ty),
628                })
629                .collect(),
630        ),
631        Variant(fs) => Variant(
632            fs.iter()
633                .map(|Field { id, ty }| Field {
634                    id: id.clone(),
635                    ty: unroll(ty),
636                })
637                .collect(),
638        ),
639        t => t.clone(),
640    }
641    .into()
642}
643
644thread_local! {
645    static ENV: RefCell<BTreeMap<TypeId, Type>> = const { RefCell::new(BTreeMap::new()) };
646    // only used for TypeContainer
647    static ID: RefCell<BTreeMap<Type, TypeId>> = const { RefCell::new(BTreeMap::new()) };
648    static NAME: RefCell<TypeName> = RefCell::new(TypeName::default());
649}
650
651pub fn find_type(id: &TypeId) -> Option<Type> {
652    ENV.with(|e| e.borrow().get(id).cloned())
653}
654
655// only for debugging
656#[allow(dead_code)]
657pub(crate) fn show_env() {
658    ENV.with(|e| println!("{:?}", e.borrow()));
659}
660
661pub(crate) fn env_add(id: TypeId, t: Type) {
662    ENV.with(|e| e.borrow_mut().insert(id, t));
663}
664pub fn env_clear() {
665    ENV.with(|e| e.borrow_mut().clear());
666}
667
668pub(crate) fn env_id(id: TypeId, t: Type) {
669    // prefer shorter type names
670    let new_len = id.name.len();
671    ID.with(|n| {
672        let mut n = n.borrow_mut();
673        match n.get_mut(&t) {
674            None => {
675                n.insert(t, id);
676            }
677            Some(v) => {
678                if new_len < v.name.len() {
679                    *v = id;
680                }
681            }
682        }
683    });
684}
685
686pub fn get_type<T>(_v: &T) -> Type
687where
688    T: CandidType,
689{
690    T::ty()
691}