Skip to main content

warnings/
warnings.rs

1use std::{fmt::Debug, future::Future};
2
3use pin_project::pin_project;
4
5pub trait Warning: 'static {
6    const ID: WarningId;
7
8    fn enabled() -> bool {
9        Self::ID.enabled()
10    }
11
12    fn if_enabled(item: impl FnOnce()) {
13        Self::ID.if_enabled(item)
14    }
15
16    fn allow<O>(item: impl FnOnce() -> O) -> O {
17        allow::<Self, _>(item)
18    }
19
20    fn allow_async<F: Future>(future: F) -> AllowFuture<F> {
21        AllowFuture::new(future, Self::ID)
22    }
23}
24
25#[derive(Clone, Copy)]
26pub struct WarningId {
27    #[cfg(debug_assertions)]
28    type_id: fn() -> std::any::TypeId,
29    #[cfg(debug_assertions)]
30    name: fn() -> &'static str,
31}
32
33impl Debug for WarningId {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        let mut dbg = f.debug_struct("WarningId");
36        #[cfg(debug_assertions)]
37        {
38            dbg.field("type_id", &self.type_id)
39                .field("name", &(self.name)());
40        }
41        dbg.finish()?;
42        Ok(())
43    }
44}
45
46impl WarningId {
47    /// Get the ID of a warning
48    #[allow(unused)]
49    pub const fn of<W: Warning + ?Sized>() -> Self {
50        Self {
51            #[cfg(debug_assertions)]
52            type_id: std::any::TypeId::of::<W>,
53            #[cfg(debug_assertions)]
54            name: std::any::type_name::<W>,
55        }
56    }
57
58    #[allow(unreachable_code)]
59    pub fn enabled(&self) -> bool {
60        #[cfg(debug_assertions)]
61        return !ALLOW_STACK.with(|stack| {
62            let stack = stack.borrow();
63            tracing::trace!("Checking if warning {self:?} is enabled, stack: {stack:?}");
64            stack.iter().any(|w| (w.type_id)() == (self.type_id)())
65        });
66        false
67    }
68
69    pub fn if_enabled(&self, f: impl FnOnce()) {
70        if self.enabled() {
71            f();
72        }
73    }
74}
75
76#[cfg(debug_assertions)]
77thread_local! {
78    static ALLOW_STACK: std::cell::RefCell<Vec<WarningId>> = const { std::cell::RefCell::new(Vec::new()) };
79}
80
81pub struct Allow {
82    _private: (),
83}
84
85impl Allow {
86    #[allow(unused)]
87    pub fn new(warning: WarningId) -> Self {
88        #[cfg(debug_assertions)]
89        ALLOW_STACK.with(|stack| {
90            stack.borrow_mut().push(warning);
91        });
92        Self { _private: () }
93    }
94}
95
96impl Drop for Allow {
97    fn drop(&mut self) {
98        #[cfg(debug_assertions)]
99        ALLOW_STACK.with(|stack| {
100            stack.borrow_mut().pop();
101        });
102    }
103}
104
105#[test]
106fn warning_guard() {
107    struct Lint {}
108
109    impl Warning for Lint {
110        const ID: WarningId = WarningId::of::<Lint>();
111    }
112
113    let warning = WarningId::of::<Lint>();
114    {
115        let _guard = Allow::new(warning);
116        assert!(!warning.enabled());
117    }
118    assert!(warning.enabled());
119}
120
121#[pin_project]
122pub struct AllowFuture<F> {
123    #[pin]
124    future: F,
125    #[allow(unused)]
126    warning: WarningId,
127}
128
129impl<F> AllowFuture<F> {
130    pub fn new(future: F, warning: WarningId) -> Self {
131        Self { future, warning }
132    }
133}
134
135impl<F: Future> Future for AllowFuture<F> {
136    type Output = F::Output;
137
138    fn poll(
139        self: std::pin::Pin<&mut Self>,
140        cx: &mut std::task::Context<'_>,
141    ) -> std::task::Poll<Self::Output> {
142        #[cfg(debug_assertions)]
143        let _guard = Allow::new(self.warning);
144        let this = self.project();
145        this.future.poll(cx)
146    }
147}
148
149pub fn allow<W: Warning + ?Sized, O>(item: impl FnOnce() -> O) -> O {
150    #[cfg(debug_assertions)]
151    let _gaurd = Allow::new(W::ID);
152    item()
153}
154
155pub trait AllowFutureExt: Future {
156    /// Allow a lint while a future is running
157    fn allow<W: Warning + ?Sized>(self) -> AllowFuture<Self>
158    where
159        Self: Sized,
160    {
161        AllowFuture::new(self, W::ID)
162    }
163}
164
165impl<F: Future> AllowFutureExt for F {}
166
167#[cfg(test)]
168#[tokio::test]
169async fn allow_future() {
170    struct Lint {}
171
172    impl Warning for Lint {
173        const ID: WarningId = WarningId::of::<Lint>();
174    }
175
176    let warning = WarningId::of::<Lint>();
177    let assert_future_enabled = async {
178        let mut poll_count = 0;
179        std::future::poll_fn(|cx| {
180            assert!(!warning.enabled());
181            match poll_count {
182                ..=5 => {
183                    poll_count += 1;
184                    cx.waker().wake_by_ref();
185                    std::task::Poll::Pending
186                }
187                6.. => std::task::Poll::Ready(()),
188            }
189        })
190        .await;
191    };
192    let future = AllowFuture::new(assert_future_enabled, warning);
193    future.await;
194    assert!(warning.enabled());
195}