assert_call/
lib.rs

1//! A tool for testing that ensures code parts are called as expected.
2//!
3//! By creating an instance of [`CallRecorder`],
4//! it starts recording the calls to [`call`], and then [`CallRecorder::verify`] verifies that the calls to [`call`] are as expected.
5//!
6//! The pattern of expected calls specified in [`CallRecorder::verify`] uses [`Call`].
7//!
8//! ## Examples
9//!
10//! ```should_panic
11//! use assert_call::{call, CallRecorder};
12//!
13//! let mut c = CallRecorder::new();
14//!
15//! call!("1");
16//! call!("2");
17//!
18//! c.verify(["1", "3"]);
19//! ```
20//!
21//! The above code panics and outputs the following message
22//! because the call to [`call`] macro is different from what is specified in [`CallRecorder::verify`].
23//!
24//! ```txt
25//! actual calls :
26//!   1
27//! * 2
28//!   (end)
29//!
30//! mismatch call
31//! src\lib.rs:10
32//! actual : 2
33//! expect : 3
34//! ```
35//!
36//! # Backtrace support
37//!
38//! If backtrace capture is enabled at [`Backtrace::capture`],
39//! [`CallRecorder::verify`] outputs detailed information including the backtrace for each [`call!`] call.
40//!
41use std::{
42    backtrace::{Backtrace, BacktraceStatus},
43    collections::VecDeque,
44    error::Error,
45    fmt::Display,
46    thread::{self, ThreadId},
47};
48
49use records::{Global, Local, Records, Thread};
50use yansi::Condition;
51
52pub mod records;
53
54#[cfg(test)]
55mod tests;
56
57/// Record the call.
58///
59/// The argument is the call ID with the same format as [`std::format`].
60///
61/// # Panics
62///
63/// Panics if [`CallRecorder`] is not initialized.
64///
65/// If `call!()` is allowed to be called while `CallRecorder` is not initialized,
66/// the test result will be wrong
67/// if a test that initializes `CallRecorder` and a test in which `CallRecorder` is not initialized are performed at the same time,
68/// so calling `call!()` without initializing `CallRecorder` is not allowed.
69///
70/// # Examples
71///
72/// ```
73/// use assert_call::call;
74/// let c = assert_call::CallRecorder::new_local();
75///
76/// call!("1");
77/// call!("{}-{}", 1, 2);
78/// ```
79#[macro_export]
80macro_rules! call {
81    ($($id:tt)*) => {
82        $crate::records::Records::push(::std::format!($($id)*), ::std::file!(), ::std::line!());
83    };
84}
85
86/// Records and verifies calls to [`call`].
87pub struct CallRecorder<T: Thread = Global> {
88    thread: T,
89}
90impl CallRecorder {
91    /// Start recording [`call`] macro calls in all threads.
92    ///
93    /// If there are other instances of `CallRecorder` created by this function,
94    /// wait until the other instances are dropped.
95    pub fn new() -> Self {
96        Self::new_raw()
97    }
98}
99impl CallRecorder<Local> {
100    /// Start recording [`call`] macro calls in current thread.
101    ///
102    /// # Panics
103    ///
104    /// Panics if an instance of `CallRecorder` created by `new_local` already exists in this thread.
105    pub fn new_local() -> Self {
106        Self::new_raw()
107    }
108}
109impl<T: Thread> CallRecorder<T> {
110    fn new_raw() -> Self {
111        Self { thread: T::init() }
112    }
113
114    /// Panic if [`call`] call does not match the expected pattern.
115    ///
116    /// Calling this method clears the recorded [`call`] calls.
117    #[track_caller]
118    pub fn verify(&mut self, expect: impl ToCall) {
119        self.verify_with_msg(expect, "mismatch call");
120    }
121
122    /// Panic with specified message if [`call`] call does not match the expected pattern.
123    ///
124    /// Calling this method clears the recorded [`call`] calls.
125    #[track_caller]
126    pub fn verify_with_msg(&mut self, expect: impl ToCall, msg: &str) {
127        match self.result_with_msg(expect, msg) {
128            Ok(_) => {}
129            Err(e) => {
130                panic!("{:#}", e.display(true, Condition::tty_and_color()));
131            }
132        }
133    }
134
135    /// Return `Err` with specified message if [`call`] call does not match the expected pattern.
136    ///
137    /// Calling this method clears the recorded [`call`] calls.
138    fn result_with_msg(&mut self, expect: impl ToCall, msg: &str) -> Result<(), CallMismatchError> {
139        let expect: Call = expect.to_call();
140        let actual = self.thread.take_actual();
141        expect.verify(actual, msg)
142    }
143}
144impl<T: Thread> Default for CallRecorder<T> {
145    fn default() -> Self {
146        Self::new_raw()
147    }
148}
149impl<T: Thread> Drop for CallRecorder<T> {
150    fn drop(&mut self) {}
151}
152
153/// Pattern of expected [`call`] calls.
154///
155/// To create a value of this type, call a method of this type or use [`ToCall`].
156#[derive(Clone, Debug, Eq, PartialEq, Hash)]
157pub enum Call {
158    Id(String),
159    Seq(VecDeque<Call>),
160    Par(Vec<Call>),
161    Any(Vec<Call>),
162}
163
164impl Call {
165    /// Create `Call` to represent a single [`call`] call.
166    ///
167    /// # Examples
168    ///
169    /// ```
170    /// use assert_call::{call, Call, CallRecorder};
171    ///
172    /// let mut c = CallRecorder::new();
173    /// call!("1");
174    /// c.verify(Call::id("1"));
175    /// ```
176    pub fn id(id: impl Display) -> Self {
177        Self::Id(id.to_string())
178    }
179
180    /// Create `Call` to represent no [`call`] call.
181    ///
182    /// # Examples
183    ///
184    /// ```
185    /// use assert_call::{Call, CallRecorder};
186    ///
187    /// let mut c = CallRecorder::new();
188    /// c.verify(Call::empty());
189    /// ```
190    pub fn empty() -> Self {
191        Self::Seq(VecDeque::new())
192    }
193
194    /// Create `Call` to represent all specified `Call`s will be called in sequence.
195    ///
196    /// # Examples
197    ///
198    /// ```
199    /// use assert_call::{call, Call, CallRecorder};
200    ///
201    /// let mut c = CallRecorder::new();
202    /// call!("1");
203    /// call!("2");
204    /// c.verify(Call::seq(["1", "2"]));
205    /// ```
206    pub fn seq(p: impl IntoIterator<Item = impl ToCall>) -> Self {
207        Self::Seq(p.into_iter().map(|x| x.to_call()).collect())
208    }
209
210    /// Create `Call` to represent all specified `Call`s will be called in parallel.
211    ///
212    /// # Examples
213    ///
214    /// ```
215    /// use assert_call::{call, Call, CallRecorder};
216    ///
217    /// let mut c = CallRecorder::new();
218    /// call!("a-1");
219    /// call!("b-1");
220    /// call!("b-2");
221    /// call!("a-2");
222    /// c.verify(Call::par([["a-1", "a-2"], ["b-1", "b-2"]]));
223    /// ```
224    pub fn par(p: impl IntoIterator<Item = impl ToCall>) -> Self {
225        Self::Par(p.into_iter().map(|x| x.to_call()).collect())
226    }
227
228    /// Create `Call` to represent one of the specified `Call`s will be called.
229    ///
230    /// # Examples
231    ///
232    /// ```
233    /// use assert_call::{call, Call, CallRecorder};
234    ///
235    /// let mut c = CallRecorder::new();
236    /// call!("1");
237    /// c.verify(Call::any(["1", "2"]));
238    /// call!("4");
239    /// c.verify(Call::any(["3", "4"]));
240    /// ```
241    pub fn any(p: impl IntoIterator<Item = impl ToCall>) -> Self {
242        Self::Any(p.into_iter().map(|x| x.to_call()).collect())
243    }
244
245    fn verify(mut self, actual: Records, msg: &str) -> Result<(), CallMismatchError> {
246        match self.verify_nexts(&actual.0) {
247            Ok(_) => Ok(()),
248            Err(mut e) => {
249                e.actual = actual;
250                e.expect.sort();
251                e.expect.dedup();
252                e.msg = msg.to_string();
253                Err(e)
254            }
255        }
256    }
257    fn verify_nexts(&mut self, actual: &[Record]) -> Result<(), CallMismatchError> {
258        for index in 0..=actual.len() {
259            self.verify_next(index, actual.get(index))?;
260        }
261        Ok(())
262    }
263    fn verify_next(&mut self, index: usize, a: Option<&Record>) -> Result<(), CallMismatchError> {
264        if let Err(e) = self.next(a) {
265            if a.is_none() && e.is_empty() {
266                return Ok(());
267            }
268            Err(CallMismatchError::new(e, index))
269        } else {
270            Ok(())
271        }
272    }
273
274    fn next(&mut self, p: Option<&Record>) -> Result<(), Vec<String>> {
275        match self {
276            Call::Id(id) => {
277                if Some(id.as_str()) == p.as_ref().map(|x| x.id.as_str()) {
278                    *self = Call::Seq(VecDeque::new());
279                    Ok(())
280                } else {
281                    Err(vec![id.to_string()])
282                }
283            }
284            Call::Seq(list) => {
285                while !list.is_empty() {
286                    match list[0].next(p) {
287                        Err(e) if e.is_empty() => list.pop_front(),
288                        ret => return ret,
289                    };
290                }
291                Err(Vec::new())
292            }
293            Call::Par(s) => {
294                let mut es = Vec::new();
295                for i in s.iter_mut() {
296                    match i.next(p) {
297                        Ok(_) => return Ok(()),
298                        Err(mut e) => es.append(&mut e),
299                    }
300                }
301                Err(es)
302            }
303            Call::Any(s) => {
304                let mut is_end = false;
305                let mut is_ok = false;
306                let mut es = Vec::new();
307                s.retain_mut(|s| match s.next(p) {
308                    Ok(_) => {
309                        is_ok = true;
310                        true
311                    }
312                    Err(e) => {
313                        is_end |= e.is_empty();
314                        es.extend(e);
315                        false
316                    }
317                });
318                if is_ok {
319                    Ok(())
320                } else if is_end {
321                    Err(Vec::new())
322                } else {
323                    Err(es)
324                }
325            }
326        }
327    }
328}
329
330/// Types convertible to [`Call`].
331pub trait ToCall {
332    fn to_call(&self) -> Call;
333}
334
335impl<T: ?Sized + ToCall> ToCall for &T {
336    fn to_call(&self) -> Call {
337        T::to_call(self)
338    }
339}
340
341impl ToCall for Call {
342    fn to_call(&self) -> Call {
343        self.clone()
344    }
345}
346
347/// Equivalent to [`Call::id`].
348impl ToCall for str {
349    fn to_call(&self) -> Call {
350        Call::id(self)
351    }
352}
353
354/// Equivalent to [`Call::id`].
355impl ToCall for String {
356    fn to_call(&self) -> Call {
357        Call::id(self)
358    }
359}
360
361/// Equivalent to [`Call::id`].
362impl ToCall for usize {
363    fn to_call(&self) -> Call {
364        Call::id(self)
365    }
366}
367
368/// Equivalent to [`Call::seq`].
369impl<T: ToCall> ToCall for [T] {
370    fn to_call(&self) -> Call {
371        Call::seq(self)
372    }
373}
374
375/// Equivalent to [`Call::seq`].
376impl<T: ToCall, const N: usize> ToCall for [T; N] {
377    fn to_call(&self) -> Call {
378        Call::seq(self)
379    }
380}
381
382/// Equivalent to [`Call::seq`].
383impl<T: ToCall> ToCall for Vec<T> {
384    fn to_call(&self) -> Call {
385        Call::seq(self)
386    }
387}
388
389/// Equivalent to [`Call::empty`].
390impl ToCall for () {
391    fn to_call(&self) -> Call {
392        Call::empty()
393    }
394}
395
396/// The error type representing that the call to [`call`] is different from what was expected.
397#[derive(Debug)]
398struct CallMismatchError {
399    msg: String,
400    actual: Records,
401    expect: Vec<String>,
402    mismatch_index: usize,
403    thread_id: ThreadId,
404}
405impl CallMismatchError {
406    fn new(expect: Vec<String>, mismatch_index: usize) -> Self {
407        Self {
408            msg: String::new(),
409            actual: Records::empty(),
410            expect,
411            mismatch_index,
412            thread_id: thread::current().id(),
413        }
414    }
415
416    fn actual_id(&self, index: usize) -> &str {
417        if let Some(a) = self.actual.0.get(index) {
418            &a.id
419        } else {
420            "(end)"
421        }
422    }
423    #[cfg(test)]
424    fn set_dummy_file_line(&mut self) {
425        for a in &mut self.actual.0 {
426            a.set_dummy_file_line();
427        }
428    }
429
430    pub fn display(&self, backtrace: bool, color: bool) -> impl Display + '_ {
431        struct CallMismatchErrorDisplay<'a> {
432            this: &'a CallMismatchError,
433            backtrace: bool,
434            color: bool,
435        }
436        impl std::fmt::Display for CallMismatchErrorDisplay<'_> {
437            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
438                self.this.fmt_with(f, self.backtrace, self.color)
439            }
440        }
441        CallMismatchErrorDisplay {
442            this: self,
443            backtrace,
444            color,
445        }
446    }
447
448    fn fmt_with(
449        &self,
450        f: &mut std::fmt::Formatter<'_>,
451        backtrace: bool,
452        color: bool,
453    ) -> std::fmt::Result {
454        let around = 5;
455        if backtrace && self.actual.has_bakctrace() {
456            writeln!(f, "actual calls with backtrace :")?;
457            self.actual.fmt_backtrace(f, self.mismatch_index, around)?;
458            writeln!(f)?;
459        }
460
461        writeln!(f, "actual calls :")?;
462        self.actual
463            .fmt_summary(f, self.mismatch_index, around, color)?;
464
465        writeln!(f)?;
466        writeln!(f, "{}", self.msg)?;
467        if let Some(a) = self.actual.0.get(self.mismatch_index) {
468            writeln!(f, "{}:{}", a.file, a.line)?;
469        }
470        if backtrace {
471            writeln!(f, "thread : {:?}", self.thread_id)?;
472        }
473        writeln!(f, "actual : {}", self.actual_id(self.mismatch_index))?;
474        writeln!(f, "expect : {}", self.expect.join(", "))?;
475        Ok(())
476    }
477}
478impl Display for CallMismatchError {
479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480        self.fmt_with(f, false, false)
481    }
482}
483impl Error for CallMismatchError {}
484
485/// Record of one [`call`] call.
486#[derive(Debug)]
487struct Record {
488    id: String,
489    file: &'static str,
490    line: u32,
491    backtrace: Backtrace,
492    thread_id: ThreadId,
493}
494impl Record {
495    #[cfg(test)]
496    fn set_dummy_file_line(&mut self) {
497        self.file = r"tests\test.rs";
498        self.line = 10;
499    }
500}
501
502impl Display for Record {
503    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504        writeln!(f, "# {}", self.id)?;
505        writeln!(f, "{}:{}", self.file, self.line)?;
506        if self.backtrace.status() == BacktraceStatus::Captured {
507            writeln!(f)?;
508            writeln!(f, "{}", self.backtrace)?;
509        }
510        Ok(())
511    }
512}