Skip to main content

diskann_benchmark_runner/utils/
fmt.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::{
7    collections::HashMap,
8    fmt::{Display, Write},
9};
10
11/// A 2-d table for formatting properly spaced values in a table.
12pub struct Table {
13    // The number of columns is implicitly described by the number of entries in `header`.
14    header: Box<[Box<dyn Display>]>,
15    body: HashMap<(usize, usize), Box<dyn Display>>,
16    nrows: usize,
17}
18
19impl Table {
20    pub fn new<I>(header: I, nrows: usize) -> Self
21    where
22        I: IntoIterator<Item: Display + 'static>,
23    {
24        fn as_dyn_display<T: Display + 'static>(x: T) -> Box<dyn Display> {
25            Box::new(x)
26        }
27
28        let header: Box<[_]> = header.into_iter().map(as_dyn_display).collect();
29        Self {
30            header,
31            body: HashMap::new(),
32            nrows,
33        }
34    }
35
36    pub fn nrows(&self) -> usize {
37        self.nrows
38    }
39
40    pub fn ncols(&self) -> usize {
41        self.header.len()
42    }
43
44    pub fn insert<T>(&mut self, item: T, row: usize, col: usize) -> bool
45    where
46        T: Display + 'static,
47    {
48        self.check_bounds(row, col);
49        self.body.insert((row, col), Box::new(item)).is_some()
50    }
51
52    pub fn get(&self, row: usize, col: usize) -> Option<&dyn Display> {
53        self.check_bounds(row, col);
54        self.body.get(&(row, col)).map(|x| &**x)
55    }
56
57    pub fn row(&mut self, row: usize) -> Row<'_> {
58        self.check_bounds(row, 0);
59        Row::new(self, row)
60    }
61
62    fn check_bounds(&self, row: usize, col: usize) {
63        if row >= self.nrows() {
64            panic!("row {} is out of bounds (max {})", row, self.nrows());
65        }
66        if col >= self.ncols() {
67            panic!("col {} is out of bounds (max {})", col, self.ncols());
68        }
69    }
70}
71
72pub struct Row<'a> {
73    table: &'a mut Table,
74    row: usize,
75}
76
77impl<'a> Row<'a> {
78    // A **private** constructor assuming that `row` is inbounds.
79    fn new(table: &'a mut Table, row: usize) -> Self {
80        Self { table, row }
81    }
82
83    /// Insert a value into the specified column of this row.
84    pub fn insert<T>(&mut self, item: T, col: usize) -> bool
85    where
86        T: Display + 'static,
87    {
88        self.table.insert(item, self.row, col)
89    }
90}
91
92impl Display for Table {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        const SEP: &str = ",   ";
95
96        // Compute the maximum width of each column.
97        struct Count(usize);
98
99        impl Write for Count {
100            fn write_str(&mut self, s: &str) -> std::fmt::Result {
101                self.0 += s.len();
102                Ok(())
103            }
104        }
105
106        fn formatted_size<T>(x: &T) -> usize
107        where
108            T: Display + ?Sized,
109        {
110            let mut buf = Count(0);
111            match write!(&mut buf, "{}", x) {
112                // Return the number of bytes "written",
113                Ok(()) => buf.0,
114                Err(_) => 0,
115            }
116        }
117
118        let mut widths: Vec<usize> = self.header.iter().map(formatted_size).collect();
119        for row in 0..self.nrows() {
120            for (col, width) in widths.iter_mut().enumerate() {
121                if let Some(v) = self.body.get(&(row, col)) {
122                    *width = (*width).max(formatted_size(v))
123                }
124            }
125        }
126
127        let header_width: usize = widths.iter().sum::<usize>() + (widths.len() - 1) * SEP.len();
128
129        let mut buf = String::new();
130        // Print the header.
131        std::iter::zip(widths.iter(), self.header.iter())
132            .enumerate()
133            .try_for_each(|(col, (width, head))| {
134                buf.clear();
135                write!(buf, "{}", head)?;
136                write!(f, "{:>width$}", buf)?;
137                if col + 1 != self.ncols() {
138                    write!(f, "{}", SEP)?;
139                }
140                Ok(())
141            })?;
142
143        // Banner
144        write!(f, "\n{:=>header_width$}\n", "")?;
145
146        // Write out each row.
147        for row in 0..self.nrows() {
148            for (col, width) in widths.iter_mut().enumerate() {
149                match self.body.get(&(row, col)) {
150                    Some(v) => {
151                        buf.clear();
152                        write!(buf, "{}", v)?;
153                        write!(f, "{:>width$}", buf)?;
154                    }
155                    None => write!(f, "{:>width$}", "")?,
156                }
157                if col + 1 != self.ncols() {
158                    write!(f, "{}", SEP)?;
159                } else {
160                    writeln!(f)?;
161                }
162            }
163        }
164        Ok(())
165    }
166}
167
168////////////
169// Banner //
170////////////
171
172pub(crate) struct Banner<'a>(&'a str);
173
174impl<'a> Banner<'a> {
175    pub(crate) fn new(message: &'a str) -> Self {
176        Self(message)
177    }
178}
179
180impl std::fmt::Display for Banner<'_> {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        let st = format!("# {} #", self.0);
183        let len = st.len();
184        writeln!(f, "{:#>len$}", "")?;
185        writeln!(f, "{}", st)?;
186        writeln!(f, "{:#>len$}", "")?;
187        Ok(())
188    }
189}
190
191///////////
192// Tests //
193///////////
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_banner() {
201        let b = Banner::new("hello world");
202        let s = b.to_string();
203
204        let expected = "###############\n\
205                        # hello world #\n\
206                        ###############\n";
207
208        assert_eq!(s, expected);
209
210        let b = Banner::new("");
211        let s = b.to_string();
212
213        let expected = "####\n\
214                        #  #\n\
215                        ####\n";
216
217        assert_eq!(s, expected);
218
219        let b = Banner::new("foo");
220        let s = b.to_string();
221
222        let expected = "#######\n\
223                        # foo #\n\
224                        #######\n";
225
226        assert_eq!(s, expected);
227    }
228
229    #[test]
230    fn test_format() {
231        // One column
232        {
233            let headers = ["h 0"];
234            let mut table = Table::new(headers, 3);
235            table.insert("a", 0, 0);
236            table.insert("hello world", 1, 0);
237            table.insert(62, 2, 0);
238
239            let s = table.to_string();
240            let expected = r#"
241        h 0
242===========
243          a
244hello world
245         62
246"#;
247            assert_eq!(s, expected.strip_prefix('\n').unwrap());
248        }
249
250        // Two columns
251        {
252            let headers = ["a really really long header", "h1"];
253            let mut table = Table::new(headers, 3);
254            table.insert("a", 0, 0);
255            table.insert("b", 0, 1);
256
257            table.insert("hello world", 1, 0);
258            table.insert("hello world version 2", 1, 1);
259
260            table.insert(7, 2, 0);
261            table.insert("bar", 2, 1);
262
263            let s = table.to_string();
264            let expected = r#"
265a really really long header,                      h1
266====================================================
267                          a,                       b
268                hello world,   hello world version 2
269                          7,                     bar
270"#;
271            assert_eq!(s, expected.strip_prefix('\n').unwrap());
272        }
273    }
274
275    #[test]
276    fn test_row_api() {
277        let mut table = Table::new(["a", "b", "c"], 2);
278        let mut row = table.row(0);
279        row.insert(1, 0);
280        row.insert("long", 1);
281        row.insert("s", 2);
282
283        let mut row = table.row(1);
284        row.insert("string", 0);
285        row.insert(2, 1);
286        row.insert(3, 2);
287
288        let s = table.to_string();
289
290        let expected = r#"
291     a,      b,   c
292===================
293     1,   long,   s
294string,      2,   3
295"#;
296        assert_eq!(s, expected.strip_prefix('\n').unwrap());
297    }
298
299    #[test]
300    fn missing_values() {
301        let mut table = Table::new(["a", "loong", "c"], 1);
302        let mut row = table.row(0);
303        row.insert("string", 0);
304        row.insert("string", 2);
305
306        let s = table.to_string();
307        let expected = r#"
308     a,   loong,        c
309=========================
310string,        ,   string
311"#;
312        assert_eq!(s, expected.strip_prefix('\n').unwrap());
313    }
314
315    #[test]
316    #[should_panic(expected = "row 3 is out of bounds (max 2)")]
317    fn test_panic_row() {
318        let mut table = Table::new([1, 2, 3], 2);
319        let _ = table.row(3);
320    }
321
322    #[test]
323    #[should_panic(expected = "col 3 is out of bounds (max 2)")]
324    fn test_panic_col() {
325        let mut table = Table::new([1, 2], 1);
326        let mut row = table.row(0);
327        row.insert(1, 3);
328    }
329}