Skip to main content

llama_cpp_bindings/
timing.rs

1//! Safe wrapper around `llama_timings`.
2use std::fmt::{Debug, Display, Formatter};
3
4/// A wrapper around `llama_timings`.
5#[derive(Clone, Copy, Debug)]
6pub struct LlamaTimings {
7    /// The underlying `llama_perf_context_data` from the C API.
8    pub timings: llama_cpp_bindings_sys::llama_perf_context_data,
9}
10
11impl LlamaTimings {
12    /// Create a new `LlamaTimings`.
13    /// ```
14    /// # use llama_cpp_bindings::timing::LlamaTimings;
15    /// let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6, 1);
16    /// let timings_str = "load time = 2.00 ms
17    /// prompt eval time = 3.00 ms / 5 tokens (0.60 ms per token, 1666.67 tokens per second)
18    /// eval time = 4.00 ms / 6 runs (0.67 ms per token, 1500.00 tokens per second)\n";
19    /// assert_eq!(timings_str, format!("{}", timings));
20    /// ```
21    #[must_use]
22    pub const fn new(
23        t_start_ms: f64,
24        t_load_ms: f64,
25        t_p_eval_ms: f64,
26        t_eval_ms: f64,
27        n_p_eval: i32,
28        n_eval: i32,
29        n_reused: i32,
30    ) -> Self {
31        Self {
32            timings: llama_cpp_bindings_sys::llama_perf_context_data {
33                t_start_ms,
34                t_load_ms,
35                t_p_eval_ms,
36                t_eval_ms,
37                n_p_eval,
38                n_eval,
39                n_reused,
40            },
41        }
42    }
43
44    /// Get the start time in milliseconds.
45    #[must_use]
46    pub const fn t_start_ms(&self) -> f64 {
47        self.timings.t_start_ms
48    }
49
50    /// Get the load time in milliseconds.
51    #[must_use]
52    pub const fn t_load_ms(&self) -> f64 {
53        self.timings.t_load_ms
54    }
55
56    /// Get the prompt evaluation time in milliseconds.
57    #[must_use]
58    pub const fn t_p_eval_ms(&self) -> f64 {
59        self.timings.t_p_eval_ms
60    }
61
62    /// Get the evaluation time in milliseconds.
63    #[must_use]
64    pub const fn t_eval_ms(&self) -> f64 {
65        self.timings.t_eval_ms
66    }
67
68    /// Get the number of prompt evaluations.
69    #[must_use]
70    pub const fn n_p_eval(&self) -> i32 {
71        self.timings.n_p_eval
72    }
73
74    /// Get the number of evaluations.
75    #[must_use]
76    pub const fn n_eval(&self) -> i32 {
77        self.timings.n_eval
78    }
79
80    /// Set the start time in milliseconds.
81    pub const fn set_t_start_ms(&mut self, t_start_ms: f64) {
82        self.timings.t_start_ms = t_start_ms;
83    }
84
85    /// Set the load time in milliseconds.
86    pub const fn set_t_load_ms(&mut self, t_load_ms: f64) {
87        self.timings.t_load_ms = t_load_ms;
88    }
89
90    /// Set the prompt evaluation time in milliseconds.
91    pub const fn set_t_p_eval_ms(&mut self, t_p_eval_ms: f64) {
92        self.timings.t_p_eval_ms = t_p_eval_ms;
93    }
94
95    /// Set the evaluation time in milliseconds.
96    pub const fn set_t_eval_ms(&mut self, t_eval_ms: f64) {
97        self.timings.t_eval_ms = t_eval_ms;
98    }
99
100    /// Set the number of prompt evaluations.
101    pub const fn set_n_p_eval(&mut self, n_p_eval: i32) {
102        self.timings.n_p_eval = n_p_eval;
103    }
104
105    /// Set the number of evaluations.
106    pub const fn set_n_eval(&mut self, n_eval: i32) {
107        self.timings.n_eval = n_eval;
108    }
109}
110
111fn write_timings(timings: &LlamaTimings, writer: &mut dyn std::fmt::Write) -> std::fmt::Result {
112    writeln!(writer, "load time = {:.2} ms", timings.t_load_ms())?;
113
114    if timings.n_p_eval() > 0 {
115        writeln!(
116            writer,
117            "prompt eval time = {:.2} ms / {} tokens ({:.2} ms per token, {:.2} tokens per second)",
118            timings.t_p_eval_ms(),
119            timings.n_p_eval(),
120            timings.t_p_eval_ms() / f64::from(timings.n_p_eval()),
121            1e3 / timings.t_p_eval_ms() * f64::from(timings.n_p_eval())
122        )?;
123    } else {
124        writeln!(
125            writer,
126            "prompt eval time = {:.2} ms / 0 tokens",
127            timings.t_p_eval_ms(),
128        )?;
129    }
130
131    if timings.n_eval() > 0 {
132        writeln!(
133            writer,
134            "eval time = {:.2} ms / {} runs ({:.2} ms per token, {:.2} tokens per second)",
135            timings.t_eval_ms(),
136            timings.n_eval(),
137            timings.t_eval_ms() / f64::from(timings.n_eval()),
138            1e3 / timings.t_eval_ms() * f64::from(timings.n_eval())
139        )?;
140    } else {
141        writeln!(writer, "eval time = {:.2} ms / 0 runs", timings.t_eval_ms(),)?;
142    }
143
144    Ok(())
145}
146
147impl Display for LlamaTimings {
148    fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
149        write_timings(self, formatter)
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::LlamaTimings;
156
157    #[test]
158    fn display_format_with_valid_counts() {
159        let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6, 1);
160        let output = format!("{timings}");
161
162        assert!(output.contains("load time = 2.00 ms"));
163        assert!(output.contains("prompt eval time = 3.00 ms / 5 tokens"));
164        assert!(output.contains("eval time = 4.00 ms / 6 runs"));
165    }
166
167    #[test]
168    fn display_format_handles_zero_eval_counts() {
169        let timings = LlamaTimings::new(0.0, 1.0, 2.0, 3.0, 0, 0, 0);
170        let output = format!("{timings}");
171
172        assert!(output.contains("load time = 1.00 ms"));
173        assert!(output.contains("prompt eval time = 2.00 ms / 0 tokens"));
174        assert!(output.contains("eval time = 3.00 ms / 0 runs"));
175        assert!(!output.contains("NaN"));
176        assert!(!output.contains("inf"));
177    }
178
179    #[test]
180    fn set_t_start_ms() {
181        let mut timings = LlamaTimings::new(0.0, 0.0, 0.0, 0.0, 0, 0, 0);
182
183        timings.set_t_start_ms(42.0);
184
185        assert!((timings.t_start_ms() - 42.0).abs() < f64::EPSILON);
186    }
187
188    #[test]
189    fn set_t_load_ms() {
190        let mut timings = LlamaTimings::new(0.0, 0.0, 0.0, 0.0, 0, 0, 0);
191
192        timings.set_t_load_ms(10.5);
193
194        assert!((timings.t_load_ms() - 10.5).abs() < f64::EPSILON);
195    }
196
197    #[test]
198    fn set_t_p_eval_ms() {
199        let mut timings = LlamaTimings::new(0.0, 0.0, 0.0, 0.0, 0, 0, 0);
200
201        timings.set_t_p_eval_ms(7.7);
202
203        assert!((timings.t_p_eval_ms() - 7.7).abs() < f64::EPSILON);
204    }
205
206    #[test]
207    fn set_t_eval_ms() {
208        let mut timings = LlamaTimings::new(0.0, 0.0, 0.0, 0.0, 0, 0, 0);
209
210        timings.set_t_eval_ms(3.3);
211
212        assert!((timings.t_eval_ms() - 3.3).abs() < f64::EPSILON);
213    }
214
215    #[test]
216    fn set_n_p_eval() {
217        let mut timings = LlamaTimings::new(0.0, 0.0, 0.0, 0.0, 0, 0, 0);
218
219        timings.set_n_p_eval(100);
220
221        assert_eq!(timings.n_p_eval(), 100);
222    }
223
224    #[test]
225    fn set_n_eval() {
226        let mut timings = LlamaTimings::new(0.0, 0.0, 0.0, 0.0, 0, 0, 0);
227
228        timings.set_n_eval(200);
229
230        assert_eq!(timings.n_eval(), 200);
231    }
232
233    #[test]
234    fn write_timings_propagates_writer_errors() {
235        struct FailingWriter;
236
237        impl std::fmt::Write for FailingWriter {
238            fn write_str(&mut self, _text: &str) -> std::fmt::Result {
239                Err(std::fmt::Error)
240            }
241        }
242
243        let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6, 1);
244        let result = super::write_timings(&timings, &mut FailingWriter);
245
246        assert!(result.is_err());
247    }
248
249    #[test]
250    fn write_timings_zero_p_eval_with_failing_writer() {
251        struct FailAfterNWrites {
252            remaining: usize,
253        }
254
255        impl std::fmt::Write for FailAfterNWrites {
256            fn write_str(&mut self, _text: &str) -> std::fmt::Result {
257                if self.remaining == 0 {
258                    return Err(std::fmt::Error);
259                }
260                self.remaining -= 1;
261
262                Ok(())
263            }
264        }
265
266        let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 0, 6, 1);
267        let result = super::write_timings(&timings, &mut FailAfterNWrites { remaining: 1 });
268
269        assert!(result.is_err());
270    }
271
272    #[test]
273    fn write_timings_fails_at_each_writeln_boundary() {
274        struct FailAfterNWrites {
275            remaining: usize,
276        }
277
278        impl std::fmt::Write for FailAfterNWrites {
279            fn write_str(&mut self, _text: &str) -> std::fmt::Result {
280                if self.remaining == 0 {
281                    return Err(std::fmt::Error);
282                }
283                self.remaining -= 1;
284
285                Ok(())
286            }
287        }
288
289        let with_counts = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6, 1);
290        let zero_counts = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 0, 0, 1);
291
292        for writes_allowed in 0..20 {
293            let _ = super::write_timings(
294                &with_counts,
295                &mut FailAfterNWrites {
296                    remaining: writes_allowed,
297                },
298            );
299            let _ = super::write_timings(
300                &zero_counts,
301                &mut FailAfterNWrites {
302                    remaining: writes_allowed,
303                },
304            );
305        }
306    }
307}