1use std::collections::HashMap;
2use std::io;
3use std::io::Write;
4
5#[derive(Debug, Clone)]
7pub enum PullEvent {
8 Status(String),
10 ChunkProgress {
12 digest: String,
13 total: Option<u64>,
14 completed: Option<u64>,
15 },
16 Success,
18
19 Error(String),
21}
22
23pub trait PullProgressReporter {
26 fn on_event(&mut self, event: &PullEvent) -> io::Result<()>;
27}
28
29pub struct CliProgressReporter {
31 printed_header: bool,
32 last_line_len: usize,
33 last_completed_sum: u64,
34 last_instant: std::time::Instant,
35 totals_by_digest: HashMap<String, (u64, u64)>,
36}
37
38impl Default for CliProgressReporter {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl CliProgressReporter {
45 pub fn new() -> Self {
46 Self {
47 printed_header: false,
48 last_line_len: 0,
49 last_completed_sum: 0,
50 last_instant: std::time::Instant::now(),
51 totals_by_digest: HashMap::new(),
52 }
53 }
54}
55
56impl PullProgressReporter for CliProgressReporter {
57 fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
58 let mut out = std::io::stderr();
59 match event {
60 PullEvent::Status(status) => {
61 if status.eq_ignore_ascii_case("pulling manifest") {
63 return Ok(());
64 }
65 let pad = self.last_line_len.saturating_sub(status.len());
66 let line = format!("\r{status}{}", " ".repeat(pad));
67 self.last_line_len = status.len();
68 out.write_all(line.as_bytes())?;
69 out.flush()
70 }
71 PullEvent::ChunkProgress {
72 digest,
73 total,
74 completed,
75 } => {
76 if let Some(t) = *total {
77 self.totals_by_digest
78 .entry(digest.clone())
79 .or_insert((0, 0))
80 .0 = t;
81 }
82 if let Some(c) = *completed {
83 self.totals_by_digest
84 .entry(digest.clone())
85 .or_insert((0, 0))
86 .1 = c;
87 }
88
89 let (sum_total, sum_completed) = self
90 .totals_by_digest
91 .values()
92 .fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c));
93 if sum_total > 0 {
94 if !self.printed_header {
95 let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
96 let header = format!("Downloading model: total {gb:.2} GB\n");
97 out.write_all(b"\r\x1b[2K")?;
98 out.write_all(header.as_bytes())?;
99 self.printed_header = true;
100 }
101 let now = std::time::Instant::now();
102 let dt = now
103 .duration_since(self.last_instant)
104 .as_secs_f64()
105 .max(0.001);
106 let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
107 let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
108 self.last_completed_sum = sum_completed;
109 self.last_instant = now;
110
111 let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
112 let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
113 let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
114 let text =
115 format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
116 let pad = self.last_line_len.saturating_sub(text.len());
117 let line = format!("\r{text}{}", " ".repeat(pad));
118 self.last_line_len = text.len();
119 out.write_all(line.as_bytes())?;
120 out.flush()
121 } else {
122 Ok(())
123 }
124 }
125 PullEvent::Error(_) => {
126 Ok(())
129 }
130 PullEvent::Success => {
131 out.write_all(b"\n")?;
132 out.flush()
133 }
134 }
135 }
136}
137
138#[derive(Default)]
141pub struct TuiProgressReporter(CliProgressReporter);
142
143impl PullProgressReporter for TuiProgressReporter {
144 fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
145 self.0.on_event(event)
146 }
147}