agcodex_ollama/
pull.rs

1use std::collections::HashMap;
2use std::io;
3use std::io::Write;
4
5/// Events emitted while pulling a model from Ollama.
6#[derive(Debug, Clone)]
7pub enum PullEvent {
8    /// A human-readable status message (e.g., "verifying", "writing").
9    Status(String),
10    /// Byte-level progress update for a specific layer digest.
11    ChunkProgress {
12        digest: String,
13        total: Option<u64>,
14        completed: Option<u64>,
15    },
16    /// The pull finished successfully.
17    Success,
18
19    /// Error event with a message.
20    Error(String),
21}
22
23/// A simple observer for pull progress events. Implementations decide how to
24/// render progress (CLI, TUI, logs, ...).
25pub trait PullProgressReporter {
26    fn on_event(&mut self, event: &PullEvent) -> io::Result<()>;
27}
28
29/// A minimal CLI reporter that writes inline progress to stderr.
30pub struct CliProgressReporter {
31    printed_header: bool,
32    last_line_len: usize,
33    last_completed_sum: u64,
34    last_instant: std::time::Instant,
35    totals_by_digest: HashMap<String, (u64, u64)>,
36}
37
38impl Default for CliProgressReporter {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl CliProgressReporter {
45    pub fn new() -> Self {
46        Self {
47            printed_header: false,
48            last_line_len: 0,
49            last_completed_sum: 0,
50            last_instant: std::time::Instant::now(),
51            totals_by_digest: HashMap::new(),
52        }
53    }
54}
55
56impl PullProgressReporter for CliProgressReporter {
57    fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
58        let mut out = std::io::stderr();
59        match event {
60            PullEvent::Status(status) => {
61                // Avoid noisy manifest messages; otherwise show status inline.
62                if status.eq_ignore_ascii_case("pulling manifest") {
63                    return Ok(());
64                }
65                let pad = self.last_line_len.saturating_sub(status.len());
66                let line = format!("\r{status}{}", " ".repeat(pad));
67                self.last_line_len = status.len();
68                out.write_all(line.as_bytes())?;
69                out.flush()
70            }
71            PullEvent::ChunkProgress {
72                digest,
73                total,
74                completed,
75            } => {
76                if let Some(t) = *total {
77                    self.totals_by_digest
78                        .entry(digest.clone())
79                        .or_insert((0, 0))
80                        .0 = t;
81                }
82                if let Some(c) = *completed {
83                    self.totals_by_digest
84                        .entry(digest.clone())
85                        .or_insert((0, 0))
86                        .1 = c;
87                }
88
89                let (sum_total, sum_completed) = self
90                    .totals_by_digest
91                    .values()
92                    .fold((0u64, 0u64), |acc, (t, c)| (acc.0 + *t, acc.1 + *c));
93                if sum_total > 0 {
94                    if !self.printed_header {
95                        let gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
96                        let header = format!("Downloading model: total {gb:.2} GB\n");
97                        out.write_all(b"\r\x1b[2K")?;
98                        out.write_all(header.as_bytes())?;
99                        self.printed_header = true;
100                    }
101                    let now = std::time::Instant::now();
102                    let dt = now
103                        .duration_since(self.last_instant)
104                        .as_secs_f64()
105                        .max(0.001);
106                    let dbytes = sum_completed.saturating_sub(self.last_completed_sum) as f64;
107                    let speed_mb_s = dbytes / (1024.0 * 1024.0) / dt;
108                    self.last_completed_sum = sum_completed;
109                    self.last_instant = now;
110
111                    let done_gb = (sum_completed as f64) / (1024.0 * 1024.0 * 1024.0);
112                    let total_gb = (sum_total as f64) / (1024.0 * 1024.0 * 1024.0);
113                    let pct = (sum_completed as f64) * 100.0 / (sum_total as f64);
114                    let text =
115                        format!("{done_gb:.2}/{total_gb:.2} GB ({pct:.1}%) {speed_mb_s:.1} MB/s");
116                    let pad = self.last_line_len.saturating_sub(text.len());
117                    let line = format!("\r{text}{}", " ".repeat(pad));
118                    self.last_line_len = text.len();
119                    out.write_all(line.as_bytes())?;
120                    out.flush()
121                } else {
122                    Ok(())
123                }
124            }
125            PullEvent::Error(_) => {
126                // This will be handled by the caller, so we don't do anything
127                // here or the error will be printed twice.
128                Ok(())
129            }
130            PullEvent::Success => {
131                out.write_all(b"\n")?;
132                out.flush()
133            }
134        }
135    }
136}
137
138/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
139/// CLI behavior aligned until a dedicated TUI integration is implemented.
140#[derive(Default)]
141pub struct TuiProgressReporter(CliProgressReporter);
142
143impl PullProgressReporter for TuiProgressReporter {
144    fn on_event(&mut self, event: &PullEvent) -> io::Result<()> {
145        self.0.on_event(event)
146    }
147}