Skip to main content

demand/
spinner.rs

1use std::{
2    io::{self, Write},
3    marker::PhantomData,
4    sync::{
5        LazyLock,
6        mpsc::{self, Sender, TryRecvError},
7    },
8    thread::sleep,
9    time::Duration,
10};
11
12use console::Term;
13use termcolor::{Buffer, WriteColor};
14
15use crate::{Theme, ctrlc, theme};
16
17/// tell a prompt to do something while running
18/// currently its only useful for spinner
19/// but that could change
20pub enum SpinnerAction {
21    /// change the theme
22    Theme(&'static Theme),
23    /// change the style
24    Style(&'static SpinnerStyle),
25    /// change the title
26    Title(String),
27}
28
29// SAFETY: ensure that 'spinner lives longer than any use of style or theme by spinner
30pub struct SpinnerActionRunner<'spinner> {
31    sender: Sender<SpinnerAction>,
32    r: PhantomData<&'spinner ()>, // need to use 'spinner to have it on the struct
33}
34
35impl<'spinner> SpinnerActionRunner<'spinner> {
36    fn new(sender: Sender<SpinnerAction>) -> Self {
37        Self {
38            sender,
39            r: PhantomData,
40        }
41    }
42
43    /// set the spinner theme
44    /// will not compile if ref to theme doesn't outlast spinner
45    pub fn theme(
46        &mut self, // with just this the compiler assumes that theme might be stored in self so it wont let u mutate it after this fn call
47        theme: &'spinner Theme,
48    ) -> Result<(), std::sync::mpsc::SendError<SpinnerAction>> {
49        let theme = unsafe { std::mem::transmute::<&Theme, &Theme>(theme) };
50        self.sender.send(SpinnerAction::Theme(theme))
51    }
52
53    /// set the spinner style
54    /// will not compile if ref to style doesn't outlast spinner
55    pub fn style(
56        &mut self, // with just this the compiler assumes that theme might be stored in self so it wont let u mutate it after this fn call
57        style: &'spinner SpinnerStyle,
58    ) -> Result<(), std::sync::mpsc::SendError<SpinnerAction>> {
59        let style = unsafe { std::mem::transmute::<&SpinnerStyle, &SpinnerStyle>(style) };
60        self.sender.send(SpinnerAction::Style(style))
61    }
62
63    /// set the spinner title
64    pub fn title<S: Into<String>>(
65        &self,
66        title: S,
67    ) -> Result<(), std::sync::mpsc::SendError<SpinnerAction>> {
68        self.sender.send(SpinnerAction::Title(title.into()))
69    }
70}
71
72/// Show a spinner
73///
74/// # Example
75/// ```rust
76/// use demand::{Spinner,SpinnerStyle};
77/// use std::time::Duration;
78/// use std::thread::sleep;
79///
80/// let spinner = Spinner::new("Loading data...")
81///   .style(&SpinnerStyle::line())
82///   .run(|_| {
83///        sleep(Duration::from_secs(2));
84///    })
85///   .expect("error running spinner");
86/// ```
87pub struct Spinner<'a> {
88    // The title of the spinner
89    pub title: String,
90    // The style of the spinner
91    pub style: &'a SpinnerStyle,
92    /// The colors/style of the spinner
93    pub theme: &'a Theme,
94
95    term: Term,
96    frame: usize,
97    height: usize,
98}
99
100impl<'a> Spinner<'a> {
101    /// Create a new spinner with the given title
102    pub fn new<S: Into<String>>(title: S) -> Self {
103        Self {
104            title: title.into(),
105            style: &DEFAULT,
106            theme: &theme::DEFAULT,
107            term: Term::stderr(),
108            frame: 0,
109            height: 0,
110        }
111    }
112
113    /// Set the style of the spinner
114    pub fn style(mut self, style: &'a SpinnerStyle) -> Self {
115        self.style = style;
116        self
117    }
118
119    /// Set the theme of the dialog
120    pub fn theme(mut self, theme: &'a Theme) -> Self {
121        self.theme = theme;
122        self
123    }
124
125    /// Displays the dialog to the user and returns their response
126    // SAFETY: 'spinner must out live 'scope
127    // this ensures that as long as the spinner doesnt try to access the theme
128    // or style outside of the scope closure the theme and style will still be valid
129    pub fn run<'scope, 'spinner: 'scope, F, T>(mut self, func: F) -> io::Result<T>
130    where
131        F: FnOnce(&mut SpinnerActionRunner<'spinner>) -> T + Send + 'scope,
132        T: Send + 'scope,
133    {
134        let t = self.term.clone();
135        let _ctrlc_handle = ctrlc::set_ctrlc_handler(move || {
136            t.show_cursor().unwrap();
137            std::process::exit(130);
138        })?;
139
140        std::thread::scope(|s| {
141            let (sender, receiver) = mpsc::channel();
142            let handle = s.spawn(move || {
143                // so you can just |s| instead of |mut s|
144                let mut sender = SpinnerActionRunner::new(sender);
145                func(&mut sender)
146            });
147            self.term.hide_cursor()?;
148            loop {
149                match receiver.try_recv() {
150                    Ok(a) => match a {
151                        SpinnerAction::Title(title) => self.title = title,
152                        SpinnerAction::Style(s) => self.style = s,
153                        SpinnerAction::Theme(theme) => self.theme = theme,
154                    },
155                    Err(TryRecvError::Empty) => (),
156                    Err(TryRecvError::Disconnected) => {
157                        self.clear()?;
158                        self.term.show_cursor()?;
159                        break;
160                    }
161                }
162                self.clear()?;
163                let output = self.render()?;
164                self.height = output.lines().count() - 1;
165                self.term.write_all(output.as_bytes())?;
166                sleep(self.style.fps);
167                if handle.is_finished() {
168                    self.clear()?;
169                    self.term.show_cursor()?;
170                    break;
171                }
172            }
173            handle
174                .join()
175                .map_err(|e| io::Error::other(format!("thread panicked: {e:?}")))
176        })
177    }
178
179    /// Render the spinner and return the output
180    fn render(&mut self) -> io::Result<String> {
181        let mut out = Buffer::ansi();
182
183        if self.frame > self.style.frames.len() - 1 {
184            self.frame = 0
185        }
186
187        out.set_color(&self.theme.input_prompt)?;
188        write!(out, "{} ", self.style.frames[self.frame])?;
189        out.reset()?;
190
191        write!(out, "{}", self.title)?;
192
193        self.frame += 1;
194
195        Ok(std::str::from_utf8(out.as_slice()).unwrap().to_string())
196    }
197
198    fn clear(&mut self) -> io::Result<()> {
199        if self.height == 0 {
200            self.term.clear_line()?;
201        } else {
202            self.term.clear_last_lines(self.height)?;
203        }
204        self.height = 0;
205        Ok(())
206    }
207}
208
209pub(crate) static DEFAULT: LazyLock<SpinnerStyle> = LazyLock::new(SpinnerStyle::line);
210
211/// The style of the spinner
212///
213/// # Example
214/// ```rust
215/// use demand::SpinnerStyle;
216/// use std::time::Duration;
217///
218/// let dots_style = SpinnerStyle::dots();
219/// let custom_style = SpinnerStyle {
220///   frames: vec!["  ", ". ", "..", "..."],
221///   fps: Duration::from_millis(1000 / 10),
222/// };
223/// ```
224pub struct SpinnerStyle {
225    /// The characters to use as frames for the spinner
226    pub frames: Vec<&'static str>,
227    /// The frames per second of the spinner
228    /// Usually represented as a fraction of a second in milliseconds for example `Duration::from_millis(1000/10)`
229    /// which would be 10 frames per second
230    pub fps: Duration,
231}
232
233impl SpinnerStyle {
234    // Create a new spinner type of dots
235    pub fn dots() -> Self {
236        Self {
237            frames: vec!["⣾", "⣽", "⣻", "⢿", "⡿", "⣟", "⣯", "⣷"],
238            fps: Duration::from_millis(1000 / 10),
239        }
240    }
241    // Create a new spinner type of jump
242    pub fn jump() -> Self {
243        Self {
244            frames: vec!["⢄", "⢂", "⢁", "⡁", "⡈", "⡐", "⡠"],
245            fps: Duration::from_millis(1000 / 10),
246        }
247    }
248    // Create a new spinner type of line
249    pub fn line() -> Self {
250        Self {
251            frames: vec!["-", "\\", "|", "/"],
252            fps: Duration::from_millis(1000 / 10),
253        }
254    }
255    // Create a new spinner type of points
256    pub fn points() -> Self {
257        Self {
258            frames: vec!["∙∙∙", "●∙∙", "∙●∙", "∙∙●"],
259            fps: Duration::from_millis(1000 / 7),
260        }
261    }
262    // Create a new spinner type of meter
263    pub fn meter() -> Self {
264        Self {
265            frames: vec!["▱▱▱", "▰▱▱", "▰▰▱", "▰▰▰", "▰▰▱", "▰▱▱", "▱▱▱"],
266            fps: Duration::from_millis(1000 / 7),
267        }
268    }
269    // Create a new spinner type of mini dots
270    pub fn minidots() -> Self {
271        Self {
272            frames: vec!["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
273            fps: Duration::from_millis(1000 / 12),
274        }
275    }
276    // Create a new spinner type of ellipsis
277    pub fn ellipsis() -> Self {
278        Self {
279            frames: vec!["   ", ".  ", ".. ", "..."],
280            fps: Duration::from_millis(1000 / 3),
281        }
282    }
283}
284
285#[cfg(test)]
286mod test {
287    use crate::test::without_ansi;
288
289    use super::*;
290
291    #[test]
292    fn test_render() {
293        for t in vec![
294            SpinnerStyle::dots(),
295            SpinnerStyle::jump(),
296            SpinnerStyle::line(),
297            SpinnerStyle::points(),
298            SpinnerStyle::meter(),
299            SpinnerStyle::minidots(),
300            SpinnerStyle::ellipsis(),
301        ] {
302            let mut spinner = Spinner::new("Loading data...").style(&t);
303            for f in spinner.style.frames.clone().iter() {
304                assert_eq!(
305                    format!("{} Loading data...", f),
306                    without_ansi(spinner.render().unwrap().as_str())
307                );
308            }
309        }
310    }
311
312    #[test]
313    fn scope_test() {
314        let spinner = Spinner::new("Scoped");
315        let mut a = [1, 2, 3];
316        let mut i = 0;
317        let out = spinner
318            .run(|_| {
319                for n in &mut a {
320                    if i == 1 {
321                        *n = 5;
322                    }
323                    i += 1;
324                    std::thread::sleep(Duration::from_millis(*n));
325                }
326                i * 5
327            })
328            .unwrap();
329        assert_eq!(a, [1, 5, 3]);
330        assert_eq!(out, 15);
331    }
332}