1use super::internal::{find_type, Field, Label, Type, TypeInner};
2use crate::types::TypeEnv;
3use crate::utils::RecursionDepth;
4use crate::{Error, Result};
5use anyhow::Context;
6use std::collections::{HashMap, HashSet};
7
8pub type Gamma = HashSet<(Type, Type)>;
9
10#[derive(Debug, Copy, Clone)]
12pub enum OptReport {
13 Silence,
14 Warning,
15 Error,
16}
17pub fn subtype(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result<()> {
19 subtype_(
20 OptReport::Warning,
21 gamma,
22 env,
23 t1,
24 t2,
25 &RecursionDepth::new(),
26 )
27}
28pub fn subtype_with_config(
30 report: OptReport,
31 gamma: &mut Gamma,
32 env: &TypeEnv,
33 t1: &Type,
34 t2: &Type,
35) -> Result<()> {
36 subtype_(report, gamma, env, t1, t2, &RecursionDepth::new())
37}
38
39fn subtype_(
40 report: OptReport,
41 gamma: &mut Gamma,
42 env: &TypeEnv,
43 t1: &Type,
44 t2: &Type,
45 depth: &RecursionDepth,
46) -> Result<()> {
47 let _guard = depth.guard()?;
48 use TypeInner::*;
49 if t1 == t2 {
50 return Ok(());
51 }
52 if matches!(t1.as_ref(), Var(_) | Knot(_)) || matches!(t2.as_ref(), Var(_) | Knot(_)) {
53 if !gamma.insert((t1.clone(), t2.clone())) {
54 return Ok(());
55 }
56 let res = match (t1.as_ref(), t2.as_ref()) {
57 (Var(id), _) => subtype_(
58 report,
59 gamma,
60 env,
61 env.rec_find_type_with_depth(id, depth).unwrap(),
62 t2,
63 depth,
64 ),
65 (_, Var(id)) => subtype_(
66 report,
67 gamma,
68 env,
69 t1,
70 env.rec_find_type_with_depth(id, depth).unwrap(),
71 depth,
72 ),
73 (Knot(id), _) => subtype_(report, gamma, env, &find_type(id).unwrap(), t2, depth),
74 (_, Knot(id)) => subtype_(report, gamma, env, t1, &find_type(id).unwrap(), depth),
75 (_, _) => unreachable!(),
76 };
77 if res.is_err() {
78 gamma.remove(&(t1.clone(), t2.clone()));
79 }
80 return res;
81 }
82 match (t1.as_ref(), t2.as_ref()) {
83 (_, Reserved) => Ok(()),
84 (Empty, _) => Ok(()),
85 (Nat, Int) => Ok(()),
86 (Vec(ty1), Vec(ty2)) => subtype_(report, gamma, env, ty1, ty2, depth),
87 (Null, Opt(_)) => Ok(()),
88 (Opt(ty1), Opt(ty2)) if subtype_(report, gamma, env, ty1, ty2, depth).is_ok() => Ok(()),
89 (_, Opt(ty2))
90 if subtype_(report, gamma, env, t1, ty2, depth).is_ok()
91 && !matches!(
92 env.trace_type_with_depth(ty2, depth)?.as_ref(),
93 Null | Reserved | Opt(_)
94 ) =>
95 {
96 Ok(())
97 }
98 (_, Opt(_)) => {
99 let msg = format!("WARNING: {t1} <: {t2} due to special subtyping rules involving optional types/fields (see https://github.com/dfinity/candid/blob/c7659ca/spec/Candid.md#upgrading-and-subtyping). This means the two interfaces have diverged, which could cause data loss.");
100 match report {
101 OptReport::Silence => (),
102 OptReport::Warning => eprintln!("{msg}"),
103 OptReport::Error => return Err(Error::msg(msg)),
104 };
105 Ok(())
106 }
107 (Record(fs1), Record(fs2)) => {
108 let fields: HashMap<_, _> = fs1.iter().map(|Field { id, ty }| (id, ty)).collect();
109 for Field { id, ty: ty2 } in fs2 {
110 match fields.get(id) {
111 Some(ty1) => {
112 subtype_(report, gamma, env, ty1, ty2, depth).with_context(|| {
113 format!("Record field {id}: {ty1} is not a subtype of {ty2}")
114 })?
115 }
116 None => {
117 if !matches!(
118 env.trace_type_with_depth(ty2, depth)?.as_ref(),
119 Null | Reserved | Opt(_)
120 ) {
121 return Err(Error::msg(format!("Record field {id}: {ty2} is only in the expected type and is not of type opt, null or reserved")));
122 }
123 }
124 }
125 }
126 Ok(())
127 }
128 (Variant(fs1), Variant(fs2)) => {
129 let fields: HashMap<_, _> = fs2.iter().map(|Field { id, ty }| (id, ty)).collect();
130 for Field { id, ty: ty1 } in fs1 {
131 match fields.get(id) {
132 Some(ty2) => {
133 subtype_(report, gamma, env, ty1, ty2, depth).with_context(|| {
134 format!("Variant field {id}: {ty1} is not a subtype of {ty2}")
135 })?
136 }
137 None => {
138 return Err(Error::msg(format!(
139 "Variant field {id} not found in the expected type"
140 )));
141 }
142 }
143 }
144 Ok(())
145 }
146 (Service(ms1), Service(ms2)) => {
147 let meths: HashMap<_, _> = ms1.iter().cloned().collect();
148 for (name, ty2) in ms2 {
149 match meths.get(name) {
150 Some(ty1) => {
151 subtype_(report, gamma, env, ty1, ty2, depth).with_context(|| {
152 format!("Method {name}: {ty1} is not a subtype of {ty2}")
153 })?
154 }
155 None => {
156 return Err(Error::msg(format!(
157 "Method {name} is only in the expected type"
158 )));
159 }
160 }
161 }
162 Ok(())
163 }
164 (Func(f1), Func(f2)) => {
165 if f1.modes != f2.modes {
166 return Err(Error::msg("Function mode mismatch"));
167 }
168 let args1 = to_tuple(&f1.args);
169 let args2 = to_tuple(&f2.args);
170 let rets1 = to_tuple(&f1.rets);
171 let rets2 = to_tuple(&f2.rets);
172 subtype_(report, gamma, env, &args2, &args1, depth)
173 .context("Subtype fails at function input type")?;
174 subtype_(report, gamma, env, &rets1, &rets2, depth)
175 .context("Subtype fails at function return type")?;
176 Ok(())
177 }
178 (Class(_, t), _) => subtype_(report, gamma, env, t, t2, depth),
180 (_, Class(_, t)) => subtype_(report, gamma, env, t1, t, depth),
181 (Unknown, _) => unreachable!(),
182 (_, Unknown) => unreachable!(),
183 (_, _) => Err(Error::msg(format!("{t1} is not a subtype of {t2}"))),
184 }
185}
186
187pub fn equal(gamma: &mut Gamma, env: &TypeEnv, t1: &Type, t2: &Type) -> Result<()> {
190 equal_impl(gamma, env, t1, t2, &RecursionDepth::new())
191}
192
193fn equal_impl(
194 gamma: &mut Gamma,
195 env: &TypeEnv,
196 t1: &Type,
197 t2: &Type,
198 depth: &RecursionDepth,
199) -> Result<()> {
200 let _guard = depth.guard()?;
201 use TypeInner::*;
202 if t1 == t2 {
203 return Ok(());
204 }
205 if matches!(t1.as_ref(), Var(_) | Knot(_)) || matches!(t2.as_ref(), Var(_) | Knot(_)) {
206 if !gamma.insert((t1.clone(), t2.clone())) {
207 return Ok(());
208 }
209 let res = match (t1.as_ref(), t2.as_ref()) {
210 (Var(id), _) => equal_impl(
211 gamma,
212 env,
213 env.rec_find_type_with_depth(id, depth).unwrap(),
214 t2,
215 depth,
216 ),
217 (_, Var(id)) => equal_impl(
218 gamma,
219 env,
220 t1,
221 env.rec_find_type_with_depth(id, depth).unwrap(),
222 depth,
223 ),
224 (Knot(id), _) => equal_impl(gamma, env, &find_type(id).unwrap(), t2, depth),
225 (_, Knot(id)) => equal_impl(gamma, env, t1, &find_type(id).unwrap(), depth),
226 (_, _) => unreachable!(),
227 };
228 if res.is_err() {
229 gamma.remove(&(t1.clone(), t2.clone()));
230 }
231 return res;
232 }
233 match (t1.as_ref(), t2.as_ref()) {
234 (Opt(ty1), Opt(ty2)) => equal_impl(gamma, env, ty1, ty2, depth),
235 (Vec(ty1), Vec(ty2)) => equal_impl(gamma, env, ty1, ty2, depth),
236 (Record(fs1), Record(fs2)) | (Variant(fs1), Variant(fs2)) => {
237 assert_length(fs1, fs2, |x| x.id.clone(), |x| x.to_string())
238 .context("Different field length")?;
239 for (f1, f2) in fs1.iter().zip(fs2.iter()) {
240 if f1.id != f2.id {
241 return Err(Error::msg(format!(
242 "Field name mismatch: {} and {}",
243 f1.id, f2.id
244 )));
245 }
246 equal_impl(gamma, env, &f1.ty, &f2.ty, depth).context(format!(
247 "Field {} has different types: {} and {}",
248 f1.id, f1.ty, f2.ty
249 ))?;
250 }
251 Ok(())
252 }
253 (Service(ms1), Service(ms2)) => {
254 assert_length(ms1, ms2, |x| x.0.clone(), |x| format!("method {x}"))
255 .context("Different method length")?;
256 for (m1, m2) in ms1.iter().zip(ms2.iter()) {
257 if m1.0 != m2.0 {
258 return Err(Error::msg(format!(
259 "Method name mismatch: {} and {}",
260 m1.0, m2.0
261 )));
262 }
263 equal_impl(gamma, env, &m1.1, &m2.1, depth).context(format!(
264 "Method {} has different types: {} and {}",
265 m1.0, m1.1, m2.1
266 ))?;
267 }
268 Ok(())
269 }
270 (Func(f1), Func(f2)) => {
271 if f1.modes != f2.modes {
272 return Err(Error::msg("Function mode mismatch"));
273 }
274 let args1 = to_tuple(&f1.args);
275 let args2 = to_tuple(&f2.args);
276 let rets1 = to_tuple(&f1.rets);
277 let rets2 = to_tuple(&f2.rets);
278 equal_impl(gamma, env, &args1, &args2, depth)
279 .context("Mismatch in function input type")?;
280 equal_impl(gamma, env, &rets1, &rets2, depth)
281 .context("Mismatch in function return type")?;
282 Ok(())
283 }
284 (Class(init1, ty1), Class(init2, ty2)) => {
285 let init_1 = to_tuple(init1);
286 let init_2 = to_tuple(init2);
287 equal_impl(gamma, env, &init_1, &init_2, depth).context(format!(
288 "Mismatch in init args: {} and {}",
289 pp_args(init1),
290 pp_args(init2)
291 ))?;
292 equal_impl(gamma, env, ty1, ty2, depth)
293 }
294 (Unknown, _) => unreachable!(),
295 (_, Unknown) => unreachable!(),
296 (_, _) => Err(Error::msg(format!("{t1} is not equal to {t2}"))),
297 }
298}
299
300fn assert_length<I, F, K, D>(left: &[I], right: &[I], get_key: F, display: D) -> Result<()>
301where
302 F: Fn(&I) -> K + Clone,
303 K: std::hash::Hash + std::cmp::Eq,
304 D: Fn(&K) -> String,
305{
306 let l = left.len();
307 let r = right.len();
308 if l == r {
309 return Ok(());
310 }
311 let left: HashSet<_> = left.iter().map(get_key.clone()).collect();
312 let right: HashSet<_> = right.iter().map(get_key).collect();
313 if l < r {
314 let mut diff = right.difference(&left);
315 Err(Error::msg(format!(
316 "Left side is missing {}",
317 display(diff.next().unwrap())
318 )))
319 } else {
320 let mut diff = left.difference(&right);
321 Err(Error::msg(format!(
322 "Right side is missing {}",
323 display(diff.next().unwrap())
324 )))
325 }
326}
327
328fn to_tuple(args: &[Type]) -> Type {
329 TypeInner::Record(
330 args.iter()
331 .enumerate()
332 .map(|(i, ty)| Field {
333 id: Label::Id(i as u32).into(),
334 ty: ty.clone(),
335 })
336 .collect(),
337 )
338 .into()
339}
340#[cfg(not(feature = "printer"))]
341fn pp_args(args: &[crate::types::Type]) -> String {
342 use std::fmt::Write;
343 let mut s = String::new();
344 write!(&mut s, "(").unwrap();
345 for arg in args.iter() {
346 write!(&mut s, "{:?}, ", arg).unwrap();
347 }
348 write!(&mut s, ")").unwrap();
349 s
350}
351#[cfg(feature = "printer")]
352fn pp_args(args: &[crate::types::Type]) -> String {
353 use crate::pretty::candid::pp_args;
354 pp_args(args).pretty(80).to_string()
355}