dprint_cli_core/logging/
progress_bars.rs

1use crossterm::style::Stylize;
2use crossterm::tty::IsTty;
3use parking_lot::RwLock;
4use std::sync::Arc;
5use std::time::Duration;
6use std::time::SystemTime;
7
8use crate::logging::Logger;
9use crate::logging::LoggerRefreshItemKind;
10use crate::logging::LoggerTextItem;
11
12// Inspired by Indicatif, but this custom implementation allows for more control over
13// what's going on under the hood and it works better with the multi-threading model
14// going on in dprint.
15
16#[derive(Clone, Copy, PartialEq)]
17pub enum ProgressBarStyle {
18  Download,
19  Action,
20}
21
22#[derive(Clone)]
23pub struct ProgressBar {
24  id: usize,
25  start_time: SystemTime,
26  progress_bars: ProgressBars,
27  message: String,
28  size: usize,
29  style: ProgressBarStyle,
30  pos: Arc<RwLock<usize>>,
31}
32
33impl ProgressBar {
34  pub fn set_position(&self, new_pos: usize) {
35    let mut pos = self.pos.write();
36    *pos = new_pos;
37  }
38
39  pub fn finish(&self) {
40    self.progress_bars.finish_progress(self.id);
41  }
42}
43
44#[derive(Clone)]
45pub struct ProgressBars {
46  logger: Logger,
47  state: Arc<RwLock<InternalState>>,
48}
49
50struct InternalState {
51  // this ensures only one draw thread is running
52  drawer_id: usize,
53  progress_bar_counter: usize,
54  progress_bars: Vec<ProgressBar>,
55}
56
57impl ProgressBars {
58  /// Checks if progress bars are supported
59  pub fn are_supported() -> bool {
60    std::io::stderr().is_tty() && crate::terminal::get_terminal_width().is_some()
61  }
62
63  /// Creates a new ProgressBars or returns None when not supported.
64  pub fn new(logger: &Logger) -> Option<Self> {
65    if ProgressBars::are_supported() {
66      Some(ProgressBars {
67        logger: logger.clone(),
68        state: Arc::new(RwLock::new(InternalState {
69          drawer_id: 0,
70          progress_bar_counter: 0,
71          progress_bars: Vec::new(),
72        })),
73      })
74    } else {
75      None
76    }
77  }
78
79  pub fn add_progress(&self, message: String, style: ProgressBarStyle, total_size: usize) -> ProgressBar {
80    let mut internal_state = self.state.write();
81    let id = internal_state.progress_bar_counter;
82    let pb = ProgressBar {
83      id,
84      progress_bars: self.clone(),
85      start_time: SystemTime::now(),
86      message,
87      size: total_size,
88      style,
89      pos: Arc::new(RwLock::new(0)),
90    };
91    internal_state.progress_bars.push(pb.clone());
92    internal_state.progress_bar_counter += 1;
93
94    if internal_state.progress_bars.len() == 1 {
95      self.start_draw_thread(&mut internal_state);
96    }
97
98    pb
99  }
100
101  fn finish_progress(&self, progress_bar_id: usize) {
102    let mut internal_state = self.state.write();
103
104    if let Some(index) = internal_state.progress_bars.iter().position(|p| p.id == progress_bar_id) {
105      internal_state.progress_bars.remove(index);
106    }
107
108    if internal_state.progress_bars.is_empty() {
109      self.logger.remove_refresh_item(LoggerRefreshItemKind::ProgressBars)
110    }
111  }
112
113  fn start_draw_thread(&self, internal_state: &mut InternalState) {
114    internal_state.drawer_id += 1;
115    let drawer_id = internal_state.drawer_id;
116    let internal_state = self.state.clone();
117    let logger = self.logger.clone();
118    std::thread::spawn(move || {
119      loop {
120        {
121          let internal_state = internal_state.read();
122          // exit if not the current draw thread or there are no more progress bars
123          if internal_state.drawer_id != drawer_id || internal_state.progress_bars.is_empty() {
124            break;
125          }
126
127          let terminal_width = crate::terminal::get_terminal_width().unwrap();
128          let mut text = String::new();
129          for (i, progress_bar) in internal_state.progress_bars.iter().enumerate() {
130            if i > 0 {
131              text.push('\n');
132            }
133            text.push_str(&progress_bar.message);
134            text.push('\n');
135            text.push_str(&get_progress_bar_text(
136              terminal_width,
137              *progress_bar.pos.read(),
138              progress_bar.size,
139              progress_bar.style,
140              progress_bar.start_time.elapsed().unwrap(),
141            ));
142          }
143
144          logger.set_refresh_item(LoggerRefreshItemKind::ProgressBars, vec![LoggerTextItem::Text(text)]);
145        }
146
147        std::thread::sleep(Duration::from_millis(100));
148      }
149    });
150  }
151}
152
153fn get_progress_bar_text(terminal_width: u16, pos: usize, total: usize, pb_style: ProgressBarStyle, duration: Duration) -> String {
154  let total = std::cmp::max(pos, total); // increase the total when pos > total
155  let bytes_text = if pb_style == ProgressBarStyle::Download {
156    format!(" {}/{}", get_bytes_text(pos, total), get_bytes_text(total, total))
157  } else {
158    String::new()
159  };
160
161  let elapsed_text = get_elapsed_text(duration);
162  let mut text = String::new();
163  text.push_str(&elapsed_text);
164  // get progress bar
165  let percent = pos as f32 / total as f32;
166  // don't include the bytes text in this because a string going from X.XXMB to XX.XXMB should not adjust the progress bar
167  let total_bars = (std::cmp::min(50, terminal_width - 15) as usize) - elapsed_text.len() - 1 - 2;
168  let completed_bars = (total_bars as f32 * percent).floor() as usize;
169  text.push_str(" [");
170  if completed_bars != total_bars {
171    if completed_bars > 0 {
172      text.push_str(&format!("{}", format!("{}{}", "#".repeat(completed_bars - 1), ">").cyan()))
173    }
174    text.push_str(&format!("{}", "-".repeat(total_bars - completed_bars).blue()))
175  } else {
176    text.push_str(&format!("{}", "#".repeat(completed_bars).cyan()))
177  }
178  text.push(']');
179
180  // bytes text
181  text.push_str(&bytes_text);
182
183  text
184}
185
186fn get_bytes_text(byte_count: usize, total_bytes: usize) -> String {
187  let bytes_to_kb = 1_000;
188  let bytes_to_mb = 1_000_000;
189  return if total_bytes < bytes_to_mb {
190    get_in_format(byte_count, bytes_to_kb, "KB")
191  } else {
192    get_in_format(byte_count, bytes_to_mb, "MB")
193  };
194
195  fn get_in_format(byte_count: usize, conversion: usize, suffix: &str) -> String {
196    let converted_value = byte_count / conversion;
197    let decimal = (byte_count % conversion) * 100 / conversion;
198    format!("{}.{:0>2}{}", converted_value, decimal, suffix)
199  }
200}
201
202fn get_elapsed_text(elapsed: Duration) -> String {
203  let elapsed_secs = elapsed.as_secs();
204  let seconds = elapsed_secs % 60;
205  let minutes = (elapsed_secs / 60) % 60;
206  let hours = (elapsed_secs / 60) / 60;
207  format!("[{:0>2}:{:0>2}:{:0>2}]", hours, minutes, seconds)
208}
209
210#[cfg(test)]
211mod test {
212  use super::*;
213  use std::time::Duration;
214
215  #[test]
216  fn should_get_bytes_text() {
217    assert_eq!(get_bytes_text(9, 999), "0.00KB");
218    assert_eq!(get_bytes_text(10, 999), "0.01KB");
219    assert_eq!(get_bytes_text(100, 999), "0.10KB");
220    assert_eq!(get_bytes_text(200, 999), "0.20KB");
221    assert_eq!(get_bytes_text(520, 999), "0.52KB");
222    assert_eq!(get_bytes_text(1000, 10_000), "1.00KB");
223    assert_eq!(get_bytes_text(10_000, 10_000), "10.00KB");
224    assert_eq!(get_bytes_text(999_999, 990_999), "999.99KB");
225    assert_eq!(get_bytes_text(1_000_000, 1_000_000), "1.00MB");
226    assert_eq!(get_bytes_text(9_524_102, 10_000_000), "9.52MB");
227  }
228
229  #[test]
230  fn should_get_elapsed_text() {
231    assert_eq!(get_elapsed_text(Duration::from_secs(1)), "[00:00:01]");
232    assert_eq!(get_elapsed_text(Duration::from_secs(20)), "[00:00:20]");
233    assert_eq!(get_elapsed_text(Duration::from_secs(59)), "[00:00:59]");
234    assert_eq!(get_elapsed_text(Duration::from_secs(60)), "[00:01:00]");
235    assert_eq!(get_elapsed_text(Duration::from_secs(60 * 5 + 23)), "[00:05:23]");
236    assert_eq!(get_elapsed_text(Duration::from_secs(60 * 59 + 59)), "[00:59:59]");
237    assert_eq!(get_elapsed_text(Duration::from_secs(60 * 60)), "[01:00:00]");
238    assert_eq!(get_elapsed_text(Duration::from_secs(60 * 60 * 3 + 20 * 60 + 2)), "[03:20:02]");
239    assert_eq!(get_elapsed_text(Duration::from_secs(60 * 60 * 99)), "[99:00:00]");
240    assert_eq!(get_elapsed_text(Duration::from_secs(60 * 60 * 120)), "[120:00:00]");
241  }
242}