Skip to main content

ferray_core/array/
display.rs

1// ferray-core: Display/Debug formatting with NumPy-style output (REQ-39)
2
3use std::fmt;
4use std::sync::atomic::{AtomicUsize, Ordering};
5
6use crate::dimension::Dimension;
7use crate::dtype::Element;
8
9use super::arc::ArcArray;
10use super::cow::CowArray;
11use super::owned::Array;
12use super::view::ArrayView;
13
14// ---------------------------------------------------------------------------
15// Global print options
16// ---------------------------------------------------------------------------
17
18static PRINT_PRECISION: AtomicUsize = AtomicUsize::new(8);
19static PRINT_THRESHOLD: AtomicUsize = AtomicUsize::new(1000);
20static PRINT_LINEWIDTH: AtomicUsize = AtomicUsize::new(75);
21static PRINT_EDGEITEMS: AtomicUsize = AtomicUsize::new(3);
22
23/// Configure how arrays are printed.
24///
25/// Matches NumPy's `set_printoptions`:
26/// - `precision`: number of decimal places for floats (default 8)
27/// - `threshold`: total element count above which truncation kicks in (default 1000)
28/// - `linewidth`: max characters per line (default 75)
29/// - `edgeitems`: number of items shown at each edge when truncated (default 3)
30pub fn set_print_options(precision: usize, threshold: usize, linewidth: usize, edgeitems: usize) {
31    PRINT_PRECISION.store(precision, Ordering::SeqCst);
32    PRINT_THRESHOLD.store(threshold, Ordering::SeqCst);
33    PRINT_LINEWIDTH.store(linewidth, Ordering::SeqCst);
34    PRINT_EDGEITEMS.store(edgeitems, Ordering::SeqCst);
35}
36
37/// Get current print options as `(precision, threshold, linewidth, edgeitems)`.
38pub fn get_print_options() -> (usize, usize, usize, usize) {
39    (
40        PRINT_PRECISION.load(Ordering::SeqCst),
41        PRINT_THRESHOLD.load(Ordering::SeqCst),
42        PRINT_LINEWIDTH.load(Ordering::SeqCst),
43        PRINT_EDGEITEMS.load(Ordering::SeqCst),
44    )
45}
46
47// ---------------------------------------------------------------------------
48// Core formatting logic
49// ---------------------------------------------------------------------------
50
51/// Format an array's data for display, handling truncation for large arrays.
52fn format_array_data<T: Element, D: Dimension>(
53    inner: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, D::NdarrayDim>,
54    f: &mut fmt::Formatter<'_>,
55) -> fmt::Result {
56    let shape = inner.shape();
57    let ndim = shape.len();
58    let size: usize = shape.iter().product();
59    let (precision, threshold, _linewidth, edgeitems) = get_print_options();
60
61    if ndim == 0 {
62        // Scalar
63        // 0-d arrays always have exactly one element
64        let val = inner.iter().next().unwrap_or_else(|| unreachable!());
65        write!(f, "{val}")?;
66        return Ok(());
67    }
68
69    let truncate = size > threshold;
70
71    write!(f, "array(")?;
72    format_recursive(inner, shape, &[], truncate, edgeitems, precision, f)?;
73    write!(f, ")")?;
74    Ok(())
75}
76
77/// Recursively format nested brackets.
78fn format_recursive<T: fmt::Display>(
79    data: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, impl ndarray::Dimension>,
80    shape: &[usize],
81    indices: &[usize],
82    truncate: bool,
83    edgeitems: usize,
84    precision: usize,
85    f: &mut fmt::Formatter<'_>,
86) -> fmt::Result {
87    let depth = indices.len();
88    let ndim = shape.len();
89
90    if depth == ndim - 1 {
91        // Innermost dimension: print elements
92        write!(f, "[")?;
93        let n = shape[depth];
94        let show_all = !truncate || n <= 2 * edgeitems;
95
96        if show_all {
97            for i in 0..n {
98                if i > 0 {
99                    write!(f, ", ")?;
100                }
101                let mut idx = indices.to_vec();
102                idx.push(i);
103                write_element_at(data, &idx, precision, f)?;
104            }
105        } else {
106            for i in 0..edgeitems {
107                if i > 0 {
108                    write!(f, ", ")?;
109                }
110                let mut idx = indices.to_vec();
111                idx.push(i);
112                write_element_at(data, &idx, precision, f)?;
113            }
114            write!(f, ", ..., ")?;
115            for i in (n - edgeitems)..n {
116                if i > n - edgeitems {
117                    write!(f, ", ")?;
118                }
119                let mut idx = indices.to_vec();
120                idx.push(i);
121                write_element_at(data, &idx, precision, f)?;
122            }
123        }
124        write!(f, "]")?;
125    } else {
126        // Outer dimension: recurse
127        write!(f, "[")?;
128        let n = shape[depth];
129        let show_all = !truncate || n <= 2 * edgeitems;
130        let indent = " ".repeat(depth + 7); // "array(" = 6 chars + 1 for [
131
132        if show_all {
133            for i in 0..n {
134                if i > 0 {
135                    write!(f, ",\n{indent}")?;
136                }
137                let mut idx = indices.to_vec();
138                idx.push(i);
139                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
140            }
141        } else {
142            for i in 0..edgeitems {
143                if i > 0 {
144                    write!(f, ",\n{indent}")?;
145                }
146                let mut idx = indices.to_vec();
147                idx.push(i);
148                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
149            }
150            write!(f, ",\n{indent}...")?;
151            for i in (n - edgeitems)..n {
152                write!(f, ",\n{indent}")?;
153                let mut idx = indices.to_vec();
154                idx.push(i);
155                format_recursive(data, shape, &idx, truncate, edgeitems, precision, f)?;
156            }
157        }
158        write!(f, "]")?;
159    }
160    Ok(())
161}
162
163/// Write a single element given multi-dimensional indices.
164fn write_element_at<T: fmt::Display>(
165    data: &ndarray::ArrayBase<impl ndarray::Data<Elem = T>, impl ndarray::Dimension>,
166    indices: &[usize],
167    _precision: usize,
168    f: &mut fmt::Formatter<'_>,
169) -> fmt::Result {
170    // Convert indices to ndarray's indexing — use dynamic indexing
171    let nd_idx = ndarray::IxDyn(indices);
172    let dyn_view = data.view().into_dyn();
173    let val = &dyn_view[nd_idx];
174    write!(f, "{val}")
175}
176
177// ---------------------------------------------------------------------------
178// Display / Debug for Array<T, D>
179// ---------------------------------------------------------------------------
180
181impl<T: Element, D: Dimension> fmt::Display for Array<T, D> {
182    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183        format_array_data::<T, D>(&self.inner, f)
184    }
185}
186
187impl<T: Element, D: Dimension> fmt::Debug for Array<T, D> {
188    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189        write!(f, "Array(dtype={}, shape={:?}, ", T::dtype(), self.shape())?;
190        format_array_data::<T, D>(&self.inner, f)?;
191        write!(f, ")")
192    }
193}
194
195// ---------------------------------------------------------------------------
196// Display / Debug for ArrayView
197// ---------------------------------------------------------------------------
198
199impl<T: Element, D: Dimension> fmt::Display for ArrayView<'_, T, D> {
200    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201        format_array_data::<T, D>(&self.inner, f)
202    }
203}
204
205impl<T: Element, D: Dimension> fmt::Debug for ArrayView<'_, T, D> {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        write!(
208            f,
209            "ArrayView(dtype={}, shape={:?}, ",
210            T::dtype(),
211            self.shape()
212        )?;
213        format_array_data::<T, D>(&self.inner, f)?;
214        write!(f, ")")
215    }
216}
217
218// ---------------------------------------------------------------------------
219// Display / Debug for ArcArray
220// ---------------------------------------------------------------------------
221
222impl<T: Element, D: Dimension> fmt::Display for ArcArray<T, D> {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        // Build a temporary ndarray view for formatting
225        let nd_dim = self.dim().to_ndarray_dim();
226        let slice = self.as_slice();
227        let view =
228            ndarray::ArrayView::from_shape(nd_dim, slice).expect("ArcArray shape consistent");
229        format_array_data::<T, D>(&view, f)
230    }
231}
232
233impl<T: Element, D: Dimension> fmt::Debug for ArcArray<T, D> {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        write!(
236            f,
237            "ArcArray(dtype={}, shape={:?}, refs={}, ",
238            T::dtype(),
239            self.shape(),
240            self.ref_count()
241        )?;
242        fmt::Display::fmt(self, f)?;
243        write!(f, ")")
244    }
245}
246
247// ---------------------------------------------------------------------------
248// Display / Debug for CowArray
249// ---------------------------------------------------------------------------
250
251impl<T: Element, D: Dimension> fmt::Display for CowArray<'_, T, D> {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        match self {
254            CowArray::Borrowed(v) => fmt::Display::fmt(v, f),
255            CowArray::Owned(a) => fmt::Display::fmt(a, f),
256        }
257    }
258}
259
260impl<T: Element, D: Dimension> fmt::Debug for CowArray<'_, T, D> {
261    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262        match self {
263            CowArray::Borrowed(v) => {
264                write!(f, "CowArray::Borrowed(")?;
265                fmt::Debug::fmt(v, f)?;
266                write!(f, ")")
267            }
268            CowArray::Owned(a) => {
269                write!(f, "CowArray::Owned(")?;
270                fmt::Debug::fmt(a, f)?;
271                write!(f, ")")
272            }
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::dimension::{Ix1, Ix2};
281
282    #[test]
283    fn display_1d() {
284        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
285        let s = format!("{arr}");
286        assert!(s.contains("[1, 2, 3, 4]"));
287        assert!(s.starts_with("array("));
288    }
289
290    #[test]
291    fn display_2d() {
292        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
293        let s = format!("{arr}");
294        assert!(s.contains("[1, 2, 3]"));
295        assert!(s.contains("[4, 5, 6]"));
296    }
297
298    #[test]
299    fn debug_format() {
300        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![1.0, 2.0]).unwrap();
301        let s = format!("{arr:?}");
302        assert!(s.contains("dtype=float64"));
303        assert!(s.contains("shape=[2]"));
304    }
305
306    #[test]
307    fn truncated_display() {
308        // Set low threshold to force truncation
309        set_print_options(8, 5, 75, 2);
310
311        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([10]), (0..10).collect()).unwrap();
312        let s = format!("{arr}");
313        assert!(s.contains("..."));
314
315        // Reset to defaults
316        set_print_options(8, 1000, 75, 3);
317    }
318
319    #[test]
320    fn arc_display() {
321        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
322        let arc = ArcArray::from_owned(arr);
323        let s = format!("{arc}");
324        assert!(s.contains("[10, 20, 30]"));
325    }
326
327    #[test]
328    fn cow_display() {
329        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![7, 8]).unwrap();
330        let cow = CowArray::from_owned(arr);
331        let s = format!("{cow}");
332        assert!(s.contains("[7, 8]"));
333    }
334}