Skip to main content

ferray_core/array/
display.rs

1// ferray-core: Display/Debug formatting with NumPy-style output (REQ-39)
2
3use std::fmt;
4use std::sync::atomic::{AtomicUsize, Ordering};
5
6use crate::dimension::Dimension;
7use crate::dtype::Element;
8
9use super::arc::ArcArray;
10use super::cow::CowArray;
11use super::owned::Array;
12use super::view::ArrayView;
13
14// ---------------------------------------------------------------------------
15// Global print options
16// ---------------------------------------------------------------------------
17
18static PRINT_PRECISION: AtomicUsize = AtomicUsize::new(8);
19static PRINT_THRESHOLD: AtomicUsize = AtomicUsize::new(1000);
20static PRINT_LINEWIDTH: AtomicUsize = AtomicUsize::new(75);
21static PRINT_EDGEITEMS: AtomicUsize = AtomicUsize::new(3);
22
23/// Configure how arrays are printed.
24///
25/// Matches `NumPy`'s `set_printoptions`:
26/// - `precision`: number of decimal places for floats (default 8)
27/// - `threshold`: total element count above which truncation kicks in (default 1000)
28/// - `linewidth`: max characters per line (default 75)
29/// - `edgeitems`: number of items shown at each edge when truncated (default 3)
30pub fn set_print_options(precision: usize, threshold: usize, linewidth: usize, edgeitems: usize) {
31    PRINT_PRECISION.store(precision, Ordering::SeqCst);
32    PRINT_THRESHOLD.store(threshold, Ordering::SeqCst);
33    PRINT_LINEWIDTH.store(linewidth, Ordering::SeqCst);
34    PRINT_EDGEITEMS.store(edgeitems, Ordering::SeqCst);
35}
36
37/// Get current print options as `(precision, threshold, linewidth, edgeitems)`.
38pub fn get_print_options() -> (usize, usize, usize, usize) {
39    (
40        PRINT_PRECISION.load(Ordering::SeqCst),
41        PRINT_THRESHOLD.load(Ordering::SeqCst),
42        PRINT_LINEWIDTH.load(Ordering::SeqCst),
43        PRINT_EDGEITEMS.load(Ordering::SeqCst),
44    )
45}
46
47// ---------------------------------------------------------------------------
48// Core formatting logic
49// ---------------------------------------------------------------------------
50
51/// Format an array's data for display, handling truncation for large arrays.
52fn format_array_data<T: Element, D: Dimension>(
53    inner: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, D::NdarrayDim>,
54    f: &mut fmt::Formatter<'_>,
55) -> fmt::Result {
56    let shape = inner.shape();
57    let ndim = shape.len();
58    let size: usize = shape.iter().product();
59    let (precision, threshold, _linewidth, edgeitems) = get_print_options();
60
61    if ndim == 0 {
62        // Scalar
63        // 0-d arrays always have exactly one element
64        let val = inner.iter().next().unwrap_or_else(|| unreachable!());
65        write!(f, "{val}")?;
66        return Ok(());
67    }
68
69    let truncate = size > threshold;
70
71    write!(f, "array(")?;
72    format_recursive(inner, shape, &[], truncate, edgeitems, precision, f)?;
73    write!(f, ")")?;
74    Ok(())
75}
76
77/// Recursively format nested brackets.
78fn format_recursive<T: fmt::Display>(
79    data: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, impl ndarray::Dimension>,
80    shape: &[usize],
81    indices: &[usize],
82    truncate: bool,
83    edgeitems: usize,
84    precision: usize,
85    f: &mut fmt::Formatter<'_>,
86) -> fmt::Result {
87    let depth = indices.len();
88    let ndim = shape.len();
89    let n = shape[depth];
90    let show_all = !truncate || n <= 2 * edgeitems;
91
92    write!(f, "[")?;
93    if depth == ndim - 1 {
94        // Innermost dimension: print elements
95        if show_all {
96            for i in 0..n {
97                if i > 0 {
98                    write!(f, ", ")?;
99                }
100                let mut idx = indices.to_vec();
101                idx.push(i);
102                write_element_at(data, &idx, precision, f)?;
103            }
104        } else {
105            for i in 0..edgeitems {
106                if i > 0 {
107                    write!(f, ", ")?;
108                }
109                let mut idx = indices.to_vec();
110                idx.push(i);
111                write_element_at(data, &idx, precision, f)?;
112            }
113            write!(f, ", ..., ")?;
114            for i in (n - edgeitems)..n {
115                if i > n - edgeitems {
116                    write!(f, ", ")?;
117                }
118                let mut idx = indices.to_vec();
119                idx.push(i);
120                write_element_at(data, &idx, precision, f)?;
121            }
122        }
123    } else {
124        // Outer dimension: recurse
125        let indent = " ".repeat(depth + 7); // "array(" = 6 chars + 1 for [
126
127        if show_all {
128            for i in 0..n {
129                if i > 0 {
130                    write!(f, ",\n{indent}")?;
131                }
132                let mut idx = indices.to_vec();
133                idx.push(i);
134                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
135            }
136        } else {
137            for i in 0..edgeitems {
138                if i > 0 {
139                    write!(f, ",\n{indent}")?;
140                }
141                let mut idx = indices.to_vec();
142                idx.push(i);
143                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
144            }
145            write!(f, ",\n{indent}...")?;
146            for i in (n - edgeitems)..n {
147                write!(f, ",\n{indent}")?;
148                let mut idx = indices.to_vec();
149                idx.push(i);
150                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
151            }
152        }
153    }
154    write!(f, "]")?;
155    Ok(())
156}
157
158/// Write a single element given multi-dimensional indices.
159fn write_element_at<T: fmt::Display>(
160    data: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, impl ndarray::Dimension>,
161    indices: &[usize],
162    _precision: usize,
163    f: &mut fmt::Formatter<'_>,
164) -> fmt::Result {
165    // Convert indices to ndarray's indexing — use dynamic indexing
166    let nd_idx = ndarray::IxDyn(indices);
167    let dyn_view = data.view().into_dyn();
168    let val = &dyn_view[nd_idx];
169    write!(f, "{val}")
170}
171
172// ---------------------------------------------------------------------------
173// Display / Debug for Array<T, D>
174// ---------------------------------------------------------------------------
175
176impl<T: Element, D: Dimension> fmt::Display for Array<T, D> {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        format_array_data::<T, D>(&self.inner, f)
179    }
180}
181
182impl<T: Element, D: Dimension> fmt::Debug for Array<T, D> {
183    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
184        write!(f, "Array(dtype={}, shape={:?}, ", T::dtype(), self.shape())?;
185        format_array_data::<T, D>(&self.inner, f)?;
186        write!(f, ")")
187    }
188}
189
190// ---------------------------------------------------------------------------
191// Display / Debug for ArrayView
192// ---------------------------------------------------------------------------
193
194impl<T: Element, D: Dimension> fmt::Display for ArrayView<'_, T, D> {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        format_array_data::<T, D>(&self.inner, f)
197    }
198}
199
200impl<T: Element, D: Dimension> fmt::Debug for ArrayView<'_, T, D> {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        write!(
203            f,
204            "ArrayView(dtype={}, shape={:?}, ",
205            T::dtype(),
206            self.shape()
207        )?;
208        format_array_data::<T, D>(&self.inner, f)?;
209        write!(f, ")")
210    }
211}
212
213// ---------------------------------------------------------------------------
214// Display / Debug for ArcArray
215// ---------------------------------------------------------------------------
216
217impl<T: Element, D: Dimension> fmt::Display for ArcArray<T, D> {
218    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219        // Build a temporary ndarray view for formatting
220        let nd_dim = self.dim().to_ndarray_dim();
221        let slice = self.as_slice();
222        let view =
223            ndarray::ArrayView::from_shape(nd_dim, slice).expect("ArcArray shape consistent");
224        format_array_data::<T, D>(&view, f)
225    }
226}
227
228impl<T: Element, D: Dimension> fmt::Debug for ArcArray<T, D> {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        write!(
231            f,
232            "ArcArray(dtype={}, shape={:?}, refs={}, ",
233            T::dtype(),
234            self.shape(),
235            self.ref_count()
236        )?;
237        fmt::Display::fmt(self, f)?;
238        write!(f, ")")
239    }
240}
241
242// ---------------------------------------------------------------------------
243// Display / Debug for CowArray
244// ---------------------------------------------------------------------------
245
246impl<T: Element, D: Dimension> fmt::Display for CowArray<'_, T, D> {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        match self {
249            CowArray::Borrowed(v) => fmt::Display::fmt(v, f),
250            CowArray::Owned(a) => fmt::Display::fmt(a, f),
251        }
252    }
253}
254
255impl<T: Element, D: Dimension> fmt::Debug for CowArray<'_, T, D> {
256    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257        match self {
258            CowArray::Borrowed(v) => {
259                write!(f, "CowArray::Borrowed(")?;
260                fmt::Debug::fmt(v, f)?;
261                write!(f, ")")
262            }
263            CowArray::Owned(a) => {
264                write!(f, "CowArray::Owned(")?;
265                fmt::Debug::fmt(a, f)?;
266                write!(f, ")")
267            }
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::dimension::{Ix1, Ix2};
276
277    #[test]
278    fn display_1d() {
279        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
280        let s = format!("{arr}");
281        assert!(s.contains("[1, 2, 3, 4]"));
282        assert!(s.starts_with("array("));
283    }
284
285    #[test]
286    fn display_2d() {
287        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
288        let s = format!("{arr}");
289        assert!(s.contains("[1, 2, 3]"));
290        assert!(s.contains("[4, 5, 6]"));
291    }
292
293    #[test]
294    fn debug_format() {
295        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![1.0, 2.0]).unwrap();
296        let s = format!("{arr:?}");
297        assert!(s.contains("dtype=float64"));
298        assert!(s.contains("shape=[2]"));
299    }
300
301    #[test]
302    fn truncated_display() {
303        // Set low threshold to force truncation
304        set_print_options(8, 5, 75, 2);
305
306        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
307        let s = format!("{arr}");
308        assert!(s.contains("..."));
309
310        // Reset to defaults
311        set_print_options(8, 1000, 75, 3);
312    }
313
314    #[test]
315    fn arc_display() {
316        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
317        let arc = ArcArray::from_owned(arr);
318        let s = format!("{arc}");
319        assert!(s.contains("[10, 20, 30]"));
320    }
321
322    #[test]
323    fn cow_display() {
324        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![7, 8]).unwrap();
325        let cow = CowArray::from_owned(arr);
326        let s = format!("{cow}");
327        assert!(s.contains("[7, 8]"));
328    }
329}