Skip to main content

ferray_core/array/
display.rs

1// ferray-core: Display/Debug formatting with NumPy-style output (REQ-39)
2
3use std::fmt;
4use std::sync::atomic::{AtomicUsize, Ordering};
5
6use crate::dimension::Dimension;
7use crate::dtype::Element;
8
9use super::arc::ArcArray;
10use super::cow::CowArray;
11use super::owned::Array;
12use super::view::ArrayView;
13
14// ---------------------------------------------------------------------------
15// Global print options
16// ---------------------------------------------------------------------------
17
18static PRINT_PRECISION: AtomicUsize = AtomicUsize::new(8);
19static PRINT_THRESHOLD: AtomicUsize = AtomicUsize::new(1000);
20static PRINT_LINEWIDTH: AtomicUsize = AtomicUsize::new(75);
21static PRINT_EDGEITEMS: AtomicUsize = AtomicUsize::new(3);
22
23/// Configure how arrays are printed.
24///
25/// Matches NumPy's `set_printoptions`:
26/// - `precision`: number of decimal places for floats (default 8)
27/// - `threshold`: total element count above which truncation kicks in (default 1000)
28/// - `linewidth`: max characters per line (default 75)
29/// - `edgeitems`: number of items shown at each edge when truncated (default 3)
30pub fn set_print_options(precision: usize, threshold: usize, linewidth: usize, edgeitems: usize) {
31    PRINT_PRECISION.store(precision, Ordering::Relaxed);
32    PRINT_THRESHOLD.store(threshold, Ordering::Relaxed);
33    PRINT_LINEWIDTH.store(linewidth, Ordering::Relaxed);
34    PRINT_EDGEITEMS.store(edgeitems, Ordering::Relaxed);
35}
36
37/// Get current print options as `(precision, threshold, linewidth, edgeitems)`.
38pub fn get_print_options() -> (usize, usize, usize, usize) {
39    (
40        PRINT_PRECISION.load(Ordering::Relaxed),
41        PRINT_THRESHOLD.load(Ordering::Relaxed),
42        PRINT_LINEWIDTH.load(Ordering::Relaxed),
43        PRINT_EDGEITEMS.load(Ordering::Relaxed),
44    )
45}
46
47// ---------------------------------------------------------------------------
48// Core formatting logic
49// ---------------------------------------------------------------------------
50
51/// Format an array's data for display, handling truncation for large arrays.
52fn format_array_data<T: Element, D: Dimension>(
53    inner: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, D::NdarrayDim>,
54    f: &mut fmt::Formatter<'_>,
55) -> fmt::Result {
56    let shape = inner.shape();
57    let ndim = shape.len();
58    let size: usize = shape.iter().product();
59    let (precision, threshold, _linewidth, edgeitems) = get_print_options();
60
61    if ndim == 0 {
62        // Scalar
63        let val = inner.iter().next().unwrap();
64        write!(f, "{val}")?;
65        return Ok(());
66    }
67
68    let truncate = size > threshold;
69
70    write!(f, "array(")?;
71    format_recursive(inner, shape, &[], truncate, edgeitems, precision, f)?;
72    write!(f, ")")?;
73    Ok(())
74}
75
76/// Recursively format nested brackets.
77fn format_recursive<T: fmt::Display>(
78    data: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, impl ndarray::Dimension>,
79    shape: &[usize],
80    indices: &[usize],
81    truncate: bool,
82    edgeitems: usize,
83    precision: usize,
84    f: &mut fmt::Formatter<'_>,
85) -> fmt::Result {
86    let depth = indices.len();
87    let ndim = shape.len();
88
89    if depth == ndim - 1 {
90        // Innermost dimension: print elements
91        write!(f, "[")?;
92        let n = shape[depth];
93        let show_all = !truncate || n <= 2 * edgeitems;
94
95        if show_all {
96            for i in 0..n {
97                if i > 0 {
98                    write!(f, ", ")?;
99                }
100                let mut idx = indices.to_vec();
101                idx.push(i);
102                write_element_at(data, &idx, precision, f)?;
103            }
104        } else {
105            for i in 0..edgeitems {
106                if i > 0 {
107                    write!(f, ", ")?;
108                }
109                let mut idx = indices.to_vec();
110                idx.push(i);
111                write_element_at(data, &idx, precision, f)?;
112            }
113            write!(f, ", ..., ")?;
114            for i in (n - edgeitems)..n {
115                if i > n - edgeitems {
116                    write!(f, ", ")?;
117                }
118                let mut idx = indices.to_vec();
119                idx.push(i);
120                write_element_at(data, &idx, precision, f)?;
121            }
122        }
123        write!(f, "]")?;
124    } else {
125        // Outer dimension: recurse
126        write!(f, "[")?;
127        let n = shape[depth];
128        let show_all = !truncate || n <= 2 * edgeitems;
129        let indent = " ".repeat(depth + 7); // "array(" = 6 chars + 1 for [
130
131        if show_all {
132            for i in 0..n {
133                if i > 0 {
134                    write!(f, ",\n{indent}")?;
135                }
136                let mut idx = indices.to_vec();
137                idx.push(i);
138                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
139            }
140        } else {
141            for i in 0..edgeitems {
142                if i > 0 {
143                    write!(f, ",\n{indent}")?;
144                }
145                let mut idx = indices.to_vec();
146                idx.push(i);
147                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
148            }
149            write!(f, ",\n{indent}...")?;
150            for i in (n - edgeitems)..n {
151                write!(f, ",\n{indent}")?;
152                let mut idx = indices.to_vec();
153                idx.push(i);
154                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
155            }
156        }
157        write!(f, "]")?;
158    }
159    Ok(())
160}
161
162/// Write a single element given multi-dimensional indices.
163fn write_element_at<T: fmt::Display>(
164    data: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, impl ndarray::Dimension>,
165    indices: &[usize],
166    _precision: usize,
167    f: &mut fmt::Formatter<'_>,
168) -> fmt::Result {
169    // Convert indices to ndarray's indexing — use dynamic indexing
170    let nd_idx = ndarray::IxDyn(indices);
171    let dyn_view = data.view().into_dyn();
172    let val = &dyn_view[nd_idx];
173    write!(f, "{val}")
174}
175
176// ---------------------------------------------------------------------------
177// Display / Debug for Array<T, D>
178// ---------------------------------------------------------------------------
179
180impl<T: Element, D: Dimension> fmt::Display for Array<T, D> {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        format_array_data::<T, D>(&self.inner, f)
183    }
184}
185
186impl<T: Element, D: Dimension> fmt::Debug for Array<T, D> {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        write!(f, "Array(dtype={}, shape={:?}, ", T::dtype(), self.shape())?;
189        format_array_data::<T, D>(&self.inner, f)?;
190        write!(f, ")")
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Display / Debug for ArrayView
196// ---------------------------------------------------------------------------
197
198impl<T: Element, D: Dimension> fmt::Display for ArrayView<'_, T, D> {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        format_array_data::<T, D>(&self.inner, f)
201    }
202}
203
204impl<T: Element, D: Dimension> fmt::Debug for ArrayView<'_, T, D> {
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        write!(
207            f,
208            "ArrayView(dtype={}, shape={:?}, ",
209            T::dtype(),
210            self.shape()
211        )?;
212        format_array_data::<T, D>(&self.inner, f)?;
213        write!(f, ")")
214    }
215}
216
217// ---------------------------------------------------------------------------
218// Display / Debug for ArcArray
219// ---------------------------------------------------------------------------
220
221impl<T: Element, D: Dimension> fmt::Display for ArcArray<T, D> {
222    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
223        // Build a temporary ndarray view for formatting
224        let nd_dim = self.dim().to_ndarray_dim();
225        let slice = self.as_slice();
226        let view =
227            ndarray::ArrayView::from_shape(nd_dim, slice).expect("ArcArray shape consistent");
228        format_array_data::<T, D>(&view, f)
229    }
230}
231
232impl<T: Element, D: Dimension> fmt::Debug for ArcArray<T, D> {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        write!(
235            f,
236            "ArcArray(dtype={}, shape={:?}, refs={}, ",
237            T::dtype(),
238            self.shape(),
239            self.ref_count()
240        )?;
241        fmt::Display::fmt(self, f)?;
242        write!(f, ")")
243    }
244}
245
246// ---------------------------------------------------------------------------
247// Display / Debug for CowArray
248// ---------------------------------------------------------------------------
249
250impl<T: Element, D: Dimension> fmt::Display for CowArray<'_, T, D> {
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        match self {
253            CowArray::Borrowed(v) => fmt::Display::fmt(v, f),
254            CowArray::Owned(a) => fmt::Display::fmt(a, f),
255        }
256    }
257}
258
259impl<T: Element, D: Dimension> fmt::Debug for CowArray<'_, T, D> {
260    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261        match self {
262            CowArray::Borrowed(v) => {
263                write!(f, "CowArray::Borrowed(")?;
264                fmt::Debug::fmt(v, f)?;
265                write!(f, ")")
266            }
267            CowArray::Owned(a) => {
268                write!(f, "CowArray::Owned(")?;
269                fmt::Debug::fmt(a, f)?;
270                write!(f, ")")
271            }
272        }
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::dimension::{Ix1, Ix2};
280
281    #[test]
282    fn display_1d() {
283        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
284        let s = format!("{arr}");
285        assert!(s.contains("[1, 2, 3, 4]"));
286        assert!(s.starts_with("array("));
287    }
288
289    #[test]
290    fn display_2d() {
291        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
292        let s = format!("{arr}");
293        assert!(s.contains("[1, 2, 3]"));
294        assert!(s.contains("[4, 5, 6]"));
295    }
296
297    #[test]
298    fn debug_format() {
299        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![1.0, 2.0]).unwrap();
300        let s = format!("{arr:?}");
301        assert!(s.contains("dtype=float64"));
302        assert!(s.contains("shape=[2]"));
303    }
304
305    #[test]
306    fn truncated_display() {
307        // Set low threshold to force truncation
308        set_print_options(8, 5, 75, 2);
309
310        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
311        let s = format!("{arr}");
312        assert!(s.contains("..."));
313
314        // Reset to defaults
315        set_print_options(8, 1000, 75, 3);
316    }
317
318    #[test]
319    fn arc_display() {
320        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
321        let arc = ArcArray::from_owned(arr);
322        let s = format!("{arc}");
323        assert!(s.contains("[10, 20, 30]"));
324    }
325
326    #[test]
327    fn cow_display() {
328        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![7, 8]).unwrap();
329        let cow = CowArray::from_owned(arr);
330        let s = format!("{cow}");
331        assert!(s.contains("[7, 8]"));
332    }
333}