use ndarray::{ArrayBase, Data, Dimension};
use std::fmt::{Display, Formatter, Result as FmtResult};
use crate::{
DisplayArray,
display_method::{DisplayMethod, formatter::ElementFormatter},
};
#[inline]
pub fn fmt_single_array<S, D, F>(arr: &ArrayBase<S, D>, f: &mut Formatter<'_>, element_formatter: &F) -> FmtResult
where
S: Data,
S::Elem: Display,
D: Dimension,
F: ElementFormatter,
{
let shape = arr.shape();
let ndim = shape.len();
let width = arr.iter().map(|e| e.to_string().len()).max().unwrap_or(0);
let chunk_sizes: Vec<usize> = (0..ndim.saturating_sub(1)).map(|k| shape[k + 1..].iter().product()).collect();
let row_size = if ndim == 1 { shape[0] } else { chunk_sizes[0] };
let total = arr.len();
for (i, elem) in arr.iter().enumerate() {
let is_last_in_row = (i + 1) % row_size == 0;
element_formatter.format_element(f, elem, width, is_last_in_row)?;
for &chunk in &chunk_sizes {
if (i + 1) % chunk == 0 && (i + 1) < total {
writeln!(f)?;
}
}
}
Ok(())
}
#[inline]
pub fn fmt_multiple_arrays<S, D, F>(arrays: &[&ArrayBase<S, D>], f: &mut Formatter<'_>, element_formatter: &F) -> FmtResult
where
S: Data,
S::Elem: Display,
D: Dimension,
F: ElementFormatter,
{
if arrays.is_empty() {
return Ok(());
}
let first_shape = arrays[0].shape();
for (i, arr) in arrays.iter().enumerate().skip(1) {
if arr.shape() != first_shape {
return write!(
f,
"Error: Arrays have different shapes. Array 0 shape: {:?}, Array {} shape: {:?}",
first_shape,
i,
arr.shape()
);
}
}
let ndim = first_shape.len();
let mut array_widths: Vec<usize> = Vec::with_capacity(arrays.len());
for arr in arrays {
let width = arr.iter().map(|e| e.to_string().len()).max().unwrap_or(0);
array_widths.push(width);
}
if ndim == 1 {
for (arr_idx, arr) in arrays.iter().enumerate() {
for (i, elem) in arr.iter().enumerate() {
let is_last_in_row = i == arr.len() - 1;
element_formatter.format_element(f, elem, array_widths[arr_idx], is_last_in_row)?;
}
if arr_idx < arrays.len() - 1 {
element_formatter.write_array_separator(f)?;
}
}
return Ok(());
}
recursive_format_arrays(
f,
arrays,
&array_widths,
first_shape,
0,
&mut vec![0; ndim],
element_formatter,
)
}
#[inline]
fn recursive_format_arrays<S, D, F>(
f: &mut Formatter<'_>,
arrays: &[&ArrayBase<S, D>],
array_widths: &[usize],
shape: &[usize],
level: usize,
indices: &mut Vec<usize>,
element_formatter: &F,
) -> FmtResult
where
S: Data,
S::Elem: Display,
D: Dimension,
F: ElementFormatter,
{
if level == shape.len() - 2 {
for row in 0..shape[level] {
indices[level] = row;
for (arr_idx, arr) in arrays.iter().enumerate() {
for col in 0..shape[level + 1] {
indices[level + 1] = col;
let mut flat_index = 0;
let mut multiplier = 1;
for d in (0..shape.len()).rev() {
flat_index += indices[d] * multiplier;
multiplier *= shape[d];
}
let elem = arr.iter().nth(flat_index).unwrap();
let is_last_in_row = col == shape[level + 1] - 1;
element_formatter.format_element(f, elem, array_widths[arr_idx], is_last_in_row)?;
}
if arr_idx < arrays.len() - 1 {
element_formatter.write_array_separator(f)?;
}
}
writeln!(f)?;
}
return Ok(());
}
for i in 0..shape[level] {
indices[level] = i;
recursive_format_arrays(f, arrays, array_widths, shape, level + 1, indices, element_formatter)?;
if i < shape[level] - 1 {
writeln!(f)?;
}
}
Ok(())
}
#[inline]
pub fn display_impl<S, D, F>(
display_array: &DisplayArray<'_, S, D, F>,
f: &mut Formatter<'_>,
element_formatter: &F,
) -> FmtResult
where
S: Data,
S::Elem: Display,
D: Dimension,
F: DisplayMethod + ElementFormatter,
{
if display_array.arrays.len() == 1 {
return fmt_single_array(display_array.arrays[0], f, element_formatter);
}
fmt_multiple_arrays(&display_array.arrays, f, element_formatter)
}