Skip to main content

candid/types/
subtype.rs

1use super::internal::{find_type, Field, Label, Type, TypeInner};
2use crate::types::TypeEnv;
3use crate::utils::RecursionDepth;
4use crate::{Error, Result};
5use anyhow::Context;
6use std::collections::{HashMap, HashSet};
7
8pub type Gamma = HashSet<(Type, Type)>;
9
10/// Error reporting style for the special opt rule
11#[derive(Debug, Copy, Clone)]
12pub enum OptReport {
13    Silence,
14    Warning,
15    Error,
16}
17/// Check if t1 <: t2
18pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result<()> {
19    subtype_(
20        OptReport::Warning,
21        gamma,
22        env,
23        t1,
24        t2,
25        &RecursionDepth::new(),
26    )
27}
28/// Check if t1 <: t2, and report the special opt rule as `Slience`, `Warning` or `Error`.
29pub fn subtype_with_config(
30    report: OptReport,
31    gamma: &mut Gamma,
32    env: &TypeEnv,
33    t1: &Type,
34    t2: &Type,
35) -> Result<()> {
36    subtype_(report, gamma, env, t1, t2, &RecursionDepth::new())
37}
38
39fn subtype_(
40    report: OptReport,
41    gamma: &mut Gamma,
42    env: &TypeEnv,
43    t1: &Type,
44    t2: &Type,
45    depth: &RecursionDepth,
46) -> Result<()> {
47    let _guard = depth.guard()?;
48    use TypeInner::*;
49    if t1 == t2 {
50        return Ok(());
51    }
52    if matches!(t1.as_ref(), Var(_) | Knot(_)) || matches!(t2.as_ref(), Var(_) | Knot(_)) {
53        if !gamma.insert((t1.clone(), t2.clone())) {
54            return Ok(());
55        }
56        let res = match (t1.as_ref(), t2.as_ref()) {
57            (Var(id), _) => subtype_(
58                report,
59                gamma,
60                env,
61                env.rec_find_type_with_depth(id, depth).unwrap(),
62                t2,
63                depth,
64            ),
65            (_, Var(id)) => subtype_(
66                report,
67                gamma,
68                env,
69                t1,
70                env.rec_find_type_with_depth(id, depth).unwrap(),
71                depth,
72            ),
73            (Knot(id), _) => subtype_(report, gamma, env, &find_type(id).unwrap(), t2, depth),
74            (_, Knot(id)) => subtype_(report, gamma, env, t1, &find_type(id).unwrap(), depth),
75            (_, _) => unreachable!(),
76        };
77        if res.is_err() {
78            gamma.remove(&(t1.clone(), t2.clone()));
79        }
80        return res;
81    }
82    match (t1.as_ref(), t2.as_ref()) {
83        (_, Reserved) => Ok(()),
84        (Empty, _) => Ok(()),
85        (Nat, Int) => Ok(()),
86        (Vec(ty1), Vec(ty2)) => subtype_(report, gamma, env, ty1, ty2, depth),
87        (Null, Opt(_)) => Ok(()),
88        (Opt(ty1), Opt(ty2)) if subtype_(report, gamma, env, ty1, ty2, depth).is_ok() => Ok(()),
89        (_, Opt(ty2))
90            if subtype_(report, gamma, env, t1, ty2, depth).is_ok()
91                && !matches!(
92                    env.trace_type_with_depth(ty2, depth)?.as_ref(),
93                    Null | Reserved | Opt(_)
94                ) =>
95        {
96            Ok(())
97        }
98        (_, Opt(_)) => {
99            let msg = format!("WARNING: {t1} <: {t2} due to special subtyping rules involving optional types/fields (see https://github.com/dfinity/candid/blob/c7659ca/spec/Candid.md#upgrading-and-subtyping). This means the two interfaces have diverged, which could cause data loss.");
100            match report {
101                OptReport::Silence => (),
102                OptReport::Warning => eprintln!("{msg}"),
103                OptReport::Error => return Err(Error::msg(msg)),
104            };
105            Ok(())
106        }
107        (Record(fs1), Record(fs2)) => {
108            let fields: HashMap<_, _> = fs1.iter().map(|Field { id, ty }| (id, ty)).collect();
109            for Field { id, ty: ty2 } in fs2 {
110                match fields.get(id) {
111                    Some(ty1) => {
112                        subtype_(report, gamma, env, ty1, ty2, depth).with_context(|| {
113                            format!("Record field {id}: {ty1} is not a subtype of {ty2}")
114                        })?
115                    }
116                    None => {
117                        if !matches!(
118                            env.trace_type_with_depth(ty2, depth)?.as_ref(),
119                            Null | Reserved | Opt(_)
120                        ) {
121                            return Err(Error::msg(format!("Record field {id}: {ty2} is only in the expected type and is not of type opt, null or reserved")));
122                        }
123                    }
124                }
125            }
126            Ok(())
127        }
128        (Variant(fs1), Variant(fs2)) => {
129            let fields: HashMap<_, _> = fs2.iter().map(|Field { id, ty }| (id, ty)).collect();
130            for Field { id, ty: ty1 } in fs1 {
131                match fields.get(id) {
132                    Some(ty2) => {
133                        subtype_(report, gamma, env, ty1, ty2, depth).with_context(|| {
134                            format!("Variant field {id}: {ty1} is not a subtype of {ty2}")
135                        })?
136                    }
137                    None => {
138                        return Err(Error::msg(format!(
139                            "Variant field {id} not found in the expected type"
140                        )));
141                    }
142                }
143            }
144            Ok(())
145        }
146        (Service(ms1), Service(ms2)) => {
147            let meths: HashMap<_, _> = ms1.iter().cloned().collect();
148            for (name, ty2) in ms2 {
149                match meths.get(name) {
150                    Some(ty1) => {
151                        subtype_(report, gamma, env, ty1, ty2, depth).with_context(|| {
152                            format!("Method {name}: {ty1} is not a subtype of {ty2}")
153                        })?
154                    }
155                    None => {
156                        return Err(Error::msg(format!(
157                            "Method {name} is only in the expected type"
158                        )));
159                    }
160                }
161            }
162            Ok(())
163        }
164        (Func(f1), Func(f2)) => {
165            if f1.modes != f2.modes {
166                return Err(Error::msg("Function mode mismatch"));
167            }
168            let args1 = to_tuple(&f1.args);
169            let args2 = to_tuple(&f2.args);
170            let rets1 = to_tuple(&f1.rets);
171            let rets2 = to_tuple(&f2.rets);
172            subtype_(report, gamma, env, &args2, &args1, depth)
173                .context("Subtype fails at function input type")?;
174            subtype_(report, gamma, env, &rets1, &rets2, depth)
175                .context("Subtype fails at function return type")?;
176            Ok(())
177        }
178        // This only works in the first order case, but service constructor only appears at the top level according to the spec.
179        (Class(_, t), _) => subtype_(report, gamma, env, t, t2, depth),
180        (_, Class(_, t)) => subtype_(report, gamma, env, t1, t, depth),
181        (Unknown, _) => unreachable!(),
182        (_, Unknown) => unreachable!(),
183        (_, _) => Err(Error::msg(format!("{t1} is not a subtype of {t2}"))),
184    }
185}
186
187/// Check if t1 and t2 are structurally equivalent, ignoring the variable naming differences.
188/// Note that this is more strict than `t1 <: t2` and `t2 <: t1`, because of the special opt rule.
189pub fn equal(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result<()> {
190    equal_impl(gamma, env, t1, t2, &RecursionDepth::new())
191}
192
193fn equal_impl(
194    gamma: &mut Gamma,
195    env: &TypeEnv,
196    t1: &Type,
197    t2: &Type,
198    depth: &RecursionDepth,
199) -> Result<()> {
200    let _guard = depth.guard()?;
201    use TypeInner::*;
202    if t1 == t2 {
203        return Ok(());
204    }
205    if matches!(t1.as_ref(), Var(_) | Knot(_)) || matches!(t2.as_ref(), Var(_) | Knot(_)) {
206        if !gamma.insert((t1.clone(), t2.clone())) {
207            return Ok(());
208        }
209        let res = match (t1.as_ref(), t2.as_ref()) {
210            (Var(id), _) => equal_impl(
211                gamma,
212                env,
213                env.rec_find_type_with_depth(id, depth).unwrap(),
214                t2,
215                depth,
216            ),
217            (_, Var(id)) => equal_impl(
218                gamma,
219                env,
220                t1,
221                env.rec_find_type_with_depth(id, depth).unwrap(),
222                depth,
223            ),
224            (Knot(id), _) => equal_impl(gamma, env, &find_type(id).unwrap(), t2, depth),
225            (_, Knot(id)) => equal_impl(gamma, env, t1, &find_type(id).unwrap(), depth),
226            (_, _) => unreachable!(),
227        };
228        if res.is_err() {
229            gamma.remove(&(t1.clone(), t2.clone()));
230        }
231        return res;
232    }
233    match (t1.as_ref(), t2.as_ref()) {
234        (Opt(ty1), Opt(ty2)) => equal_impl(gamma, env, ty1, ty2, depth),
235        (Vec(ty1), Vec(ty2)) => equal_impl(gamma, env, ty1, ty2, depth),
236        (Record(fs1), Record(fs2)) | (Variant(fs1), Variant(fs2)) => {
237            assert_length(fs1, fs2, |x| x.id.clone(), |x| x.to_string())
238                .context("Different field length")?;
239            for (f1, f2) in fs1.iter().zip(fs2.iter()) {
240                if f1.id != f2.id {
241                    return Err(Error::msg(format!(
242                        "Field name mismatch: {} and {}",
243                        f1.id, f2.id
244                    )));
245                }
246                equal_impl(gamma, env, &f1.ty, &f2.ty, depth).context(format!(
247                    "Field {} has different types: {} and {}",
248                    f1.id, f1.ty, f2.ty
249                ))?;
250            }
251            Ok(())
252        }
253        (Service(ms1), Service(ms2)) => {
254            assert_length(ms1, ms2, |x| x.0.clone(), |x| format!("method {x}"))
255                .context("Different method length")?;
256            for (m1, m2) in ms1.iter().zip(ms2.iter()) {
257                if m1.0 != m2.0 {
258                    return Err(Error::msg(format!(
259                        "Method name mismatch: {} and {}",
260                        m1.0, m2.0
261                    )));
262                }
263                equal_impl(gamma, env, &m1.1, &m2.1, depth).context(format!(
264                    "Method {} has different types: {} and {}",
265                    m1.0, m1.1, m2.1
266                ))?;
267            }
268            Ok(())
269        }
270        (Func(f1), Func(f2)) => {
271            if f1.modes != f2.modes {
272                return Err(Error::msg("Function mode mismatch"));
273            }
274            let args1 = to_tuple(&f1.args);
275            let args2 = to_tuple(&f2.args);
276            let rets1 = to_tuple(&f1.rets);
277            let rets2 = to_tuple(&f2.rets);
278            equal_impl(gamma, env, &args1, &args2, depth)
279                .context("Mismatch in function input type")?;
280            equal_impl(gamma, env, &rets1, &rets2, depth)
281                .context("Mismatch in function return type")?;
282            Ok(())
283        }
284        (Class(init1, ty1), Class(init2, ty2)) => {
285            let init_1 = to_tuple(init1);
286            let init_2 = to_tuple(init2);
287            equal_impl(gamma, env, &init_1, &init_2, depth).context(format!(
288                "Mismatch in init args: {} and {}",
289                pp_args(init1),
290                pp_args(init2)
291            ))?;
292            equal_impl(gamma, env, ty1, ty2, depth)
293        }
294        (Unknown, _) => unreachable!(),
295        (_, Unknown) => unreachable!(),
296        (_, _) => Err(Error::msg(format!("{t1} is not equal to {t2}"))),
297    }
298}
299
300fn assert_length<I, F, K, D>(left: &[I], right: &[I], get_key: F, display: D) -> Result<()>
301where
302    F: Fn(&I) -> K + Clone,
303    K: std::hash::Hash + std::cmp::Eq,
304    D: Fn(&K) -> String,
305{
306    let l = left.len();
307    let r = right.len();
308    if l == r {
309        return Ok(());
310    }
311    let left: HashSet<_> = left.iter().map(get_key.clone()).collect();
312    let right: HashSet<_> = right.iter().map(get_key).collect();
313    if l < r {
314        let mut diff = right.difference(&left);
315        Err(Error::msg(format!(
316            "Left side is missing {}",
317            display(diff.next().unwrap())
318        )))
319    } else {
320        let mut diff = left.difference(&right);
321        Err(Error::msg(format!(
322            "Right side is missing {}",
323            display(diff.next().unwrap())
324        )))
325    }
326}
327
328fn to_tuple(args: &[Type]) -> Type {
329    TypeInner::Record(
330        args.iter()
331            .enumerate()
332            .map(|(i, ty)| Field {
333                id: Label::Id(i as u32).into(),
334                ty: ty.clone(),
335            })
336            .collect(),
337    )
338    .into()
339}
340#[cfg(not(feature = "printer"))]
341fn pp_args(args: &[crate::types::Type]) -> String {
342    use std::fmt::Write;
343    let mut s = String::new();
344    write!(&mut s, "(").unwrap();
345    for arg in args.iter() {
346        write!(&mut s, "{:?}, ", arg).unwrap();
347    }
348    write!(&mut s, ")").unwrap();
349    s
350}
351#[cfg(feature = "printer")]
352fn pp_args(args: &[crate::types::Type]) -> String {
353    use crate::pretty::candid::pp_args;
354    pp_args(args).pretty(80).to_string()
355}