use ferray_core::dimension::{Dimension, Ix1, Ix2, IxDyn};
use ferray_core::error::{FerrayError, FerrayResult};
#[derive(Debug, Clone)]
pub struct StringArray<D: Dimension> {
data: Vec<String>,
dim: D,
}
pub type StringArray1 = StringArray<Ix1>;
pub type StringArray2 = StringArray<Ix2>;
impl<D: Dimension> StringArray<D> {
pub fn from_vec(dim: D, data: Vec<String>) -> FerrayResult<Self> {
let expected = dim.size();
if data.len() != expected {
return Err(FerrayError::shape_mismatch(format!(
"data length {} does not match shape {:?} (expected {})",
data.len(),
dim.as_slice(),
expected,
)));
}
Ok(Self { data, dim })
}
pub fn empty(dim: D) -> FerrayResult<Self> {
let size = dim.size();
let data = vec![String::new(); size];
Ok(Self { data, dim })
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.dim.as_slice()
}
#[inline]
pub fn ndim(&self) -> usize {
self.dim.ndim()
}
#[inline]
pub fn len(&self) -> usize {
self.data.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
#[inline]
pub const fn dim(&self) -> &D {
&self.dim
}
#[inline]
pub fn as_slice(&self) -> &[String] {
&self.data
}
#[inline]
pub fn as_slice_mut(&mut self) -> &mut [String] {
&mut self.data
}
#[inline]
pub fn into_vec(self) -> Vec<String> {
self.data
}
pub fn map<F>(&self, f: F) -> FerrayResult<Self>
where
F: Fn(&str) -> String,
{
let data: Vec<String> = self.data.iter().map(|s| f(s)).collect();
Self::from_vec(self.dim.clone(), data)
}
pub fn map_to_vec<T, F>(&self, f: F) -> Vec<T>
where
F: Fn(&str) -> T,
{
self.data.iter().map(|s| f(s)).collect()
}
pub fn iter(&self) -> std::slice::Iter<'_, String> {
self.data.iter()
}
pub fn reshape<D2: Dimension>(self, new_dim: D2) -> FerrayResult<StringArray<D2>> {
StringArray::<D2>::from_vec(new_dim, self.data)
}
pub fn flatten(self) -> StringArray1 {
let n = self.data.len();
StringArray::<Ix1>::from_vec(Ix1::new([n]), self.data)
.expect("flatten: length check is trivially satisfied")
}
pub fn into_dyn(self) -> StringArray<IxDyn> {
let shape = self.dim.as_slice().to_vec();
StringArray::<IxDyn>::from_vec(IxDyn::new(&shape), self.data)
.expect("into_dyn: shape length check is trivially satisfied")
}
pub fn get(&self, idx: &[usize]) -> Option<&String> {
let shape = self.dim.as_slice();
if idx.len() != shape.len() {
return None;
}
let mut flat = 0usize;
let mut stride = 1usize;
for (i, (&dim, &k)) in shape.iter().zip(idx.iter()).enumerate().rev() {
if k >= dim {
return None;
}
if i == shape.len() - 1 {
flat += k;
} else {
flat += k * stride;
}
stride *= dim;
}
self.data.get(flat)
}
#[must_use]
pub fn at(&self, idx: usize) -> Option<&String> {
self.data.get(idx)
}
pub fn slice_axis(
&self,
axis: usize,
range: std::ops::Range<usize>,
) -> FerrayResult<StringArray<IxDyn>> {
let shape = self.dim.as_slice().to_vec();
let ndim = shape.len();
if axis >= ndim {
return Err(ferray_core::error::FerrayError::axis_out_of_bounds(
axis, ndim,
));
}
let axis_len = shape[axis];
if range.end > axis_len || range.start > range.end {
return Err(ferray_core::error::FerrayError::invalid_value(format!(
"slice_axis: range {:?} out of bounds for axis {axis} with size {axis_len}",
range
)));
}
let new_axis_len = range.end - range.start;
let inner_stride: usize = shape[axis + 1..].iter().product();
let block = axis_len * inner_stride;
let outer_size: usize = shape[..axis].iter().product();
let mut new_shape = shape.clone();
new_shape[axis] = new_axis_len;
let total: usize = new_shape.iter().product();
let mut out: Vec<String> = Vec::with_capacity(total);
for o in 0..outer_size {
let base = o * block;
for i in range.clone() {
let row_start = base + i * inner_stride;
out.extend_from_slice(&self.data[row_start..row_start + inner_stride]);
}
}
StringArray::<IxDyn>::from_vec(IxDyn::new(&new_shape), out)
}
pub fn get_row(&self, idx: usize) -> FerrayResult<crate::string_array::StringArray1> {
let shape = self.dim.as_slice();
if shape.len() != 2 {
return Err(ferray_core::error::FerrayError::shape_mismatch(format!(
"get_row: expected a 2-D StringArray, got {}-D",
shape.len()
)));
}
let nrows = shape[0];
let ncols = shape[1];
if idx >= nrows {
return Err(ferray_core::error::FerrayError::index_out_of_bounds(
idx as isize,
0,
nrows,
));
}
let row: Vec<String> = self.data[idx * ncols..(idx + 1) * ncols].to_vec();
crate::string_array::StringArray1::from_vec(ferray_core::dimension::Ix1::new([ncols]), row)
}
}
impl<D: Dimension> PartialEq for StringArray<D> {
fn eq(&self, other: &Self) -> bool {
self.dim == other.dim && self.data == other.data
}
}
impl<D: Dimension> Eq for StringArray<D> {}
impl<D: Dimension> std::fmt::Display for StringArray<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "array([")?;
for (i, s) in self.data.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{s:?}")?;
}
write!(f, "])")
}
}
impl<'a, D: Dimension> IntoIterator for &'a StringArray<D> {
type Item = &'a String;
type IntoIter = std::slice::Iter<'a, String>;
fn into_iter(self) -> Self::IntoIter {
self.data.iter()
}
}
impl<D: Dimension> IntoIterator for StringArray<D> {
type Item = String;
type IntoIter = std::vec::IntoIter<String>;
fn into_iter(self) -> Self::IntoIter {
self.data.into_iter()
}
}
impl StringArray<Ix1> {
pub fn from_slice(items: &[&str]) -> FerrayResult<Self> {
let data: Vec<String> = items.iter().map(|s| (*s).to_string()).collect();
let dim = Ix1::new([data.len()]);
Self::from_vec(dim, data)
}
}
impl StringArray<Ix2> {
pub fn transpose(&self) -> FerrayResult<Self> {
let shape = self.shape();
let (nrows, ncols) = (shape[0], shape[1]);
let mut data = Vec::with_capacity(nrows * ncols);
for c in 0..ncols {
for r in 0..nrows {
data.push(self.data[r * ncols + c].clone());
}
}
Self::from_vec(Ix2::new([ncols, nrows]), data)
}
pub fn from_rows(rows: &[&[&str]]) -> FerrayResult<Self> {
if rows.is_empty() {
return Self::from_vec(Ix2::new([0, 0]), Vec::new());
}
let ncols = rows[0].len();
for (i, row) in rows.iter().enumerate() {
if row.len() != ncols {
return Err(FerrayError::shape_mismatch(format!(
"row {} has length {} but row 0 has length {}",
i,
row.len(),
ncols
)));
}
}
let nrows = rows.len();
let data: Vec<String> = rows
.iter()
.flat_map(|row| row.iter().map(|s| (*s).to_string()))
.collect();
Self::from_vec(Ix2::new([nrows, ncols]), data)
}
}
impl StringArray<IxDyn> {
pub fn from_vec_dyn(shape: &[usize], data: Vec<String>) -> FerrayResult<Self> {
Self::from_vec(IxDyn::new(shape), data)
}
}
pub fn array(items: &[&str]) -> FerrayResult<StringArray1> {
StringArray1::from_slice(items)
}
use ferray_core::dimension::broadcast::broadcast_shapes;
pub(crate) struct BroadcastIter {
out_shape: Vec<usize>,
shape_a: Vec<usize>,
shape_b: Vec<usize>,
strides_a: Vec<usize>,
strides_b: Vec<usize>,
out_size: usize,
linear: usize,
}
impl Iterator for BroadcastIter {
type Item = (usize, usize);
fn next(&mut self) -> Option<Self::Item> {
if self.linear >= self.out_size {
return None;
}
let multi = linear_to_multi(self.linear, &self.out_shape);
let idx_a = multi_to_broadcast_linear(&multi, &self.shape_a, &self.strides_a);
let idx_b = multi_to_broadcast_linear(&multi, &self.shape_b, &self.strides_b);
self.linear += 1;
Some((idx_a, idx_b))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.out_size - self.linear;
(remaining, Some(remaining))
}
}
impl ExactSizeIterator for BroadcastIter {}
pub(crate) fn broadcast_binary<Da: Dimension, Db: Dimension>(
a: &StringArray<Da>,
b: &StringArray<Db>,
) -> FerrayResult<(Vec<usize>, BroadcastIter)> {
let shape_a = a.shape().to_vec();
let shape_b = b.shape().to_vec();
let out_shape = broadcast_shapes(&shape_a, &shape_b)?;
let out_size: usize = out_shape.iter().product();
let strides_a = compute_strides(&shape_a);
let strides_b = compute_strides(&shape_b);
let iter = BroadcastIter {
out_shape: out_shape.clone(),
shape_a,
shape_b,
strides_a,
strides_b,
out_size,
linear: 0,
};
Ok((out_shape, iter))
}
fn compute_strides(shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
fn linear_to_multi(mut linear: usize, shape: &[usize]) -> Vec<usize> {
let ndim = shape.len();
let mut indices = vec![0usize; ndim];
for i in (0..ndim).rev() {
if shape[i] > 0 {
indices[i] = linear % shape[i];
linear /= shape[i];
}
}
indices
}
fn multi_to_broadcast_linear(multi: &[usize], src_shape: &[usize], src_strides: &[usize]) -> usize {
let out_ndim = multi.len();
let src_ndim = src_shape.len();
let pad = out_ndim.saturating_sub(src_ndim);
let mut linear = 0usize;
for i in 0..src_ndim {
let idx = multi[i + pad];
let effective = if src_shape[i] == 1 { 0 } else { idx };
linear += effective * src_strides[i];
}
linear
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn at_returns_flat_index() {
let a = array(&["a", "b", "c"]).unwrap();
assert_eq!(a.at(0).unwrap(), "a");
assert_eq!(a.at(2).unwrap(), "c");
assert!(a.at(3).is_none());
}
#[test]
fn get_row_returns_1d() {
let a = StringArray::<Ix2>::from_vec(
Ix2::new([2, 3]),
vec![
"a0".into(),
"a1".into(),
"a2".into(),
"b0".into(),
"b1".into(),
"b2".into(),
],
)
.unwrap();
let row1 = a.get_row(1).unwrap();
assert_eq!(row1.shape(), &[3]);
assert_eq!(row1.as_slice(), &["b0", "b1", "b2"]);
}
#[test]
fn get_row_rejects_non_2d() {
let a = array(&["a", "b", "c"]).unwrap();
assert!(a.get_row(0).is_err());
}
#[test]
fn get_row_index_out_of_bounds_errors() {
let a = StringArray::<Ix2>::from_vec(Ix2::new([2, 2]), vec!["x".into(); 4]).unwrap();
assert!(a.get_row(5).is_err());
}
#[test]
fn slice_axis_rows_2d() {
let a = StringArray::<Ix2>::from_vec(
Ix2::new([4, 2]),
vec![
"0,0".into(),
"0,1".into(),
"1,0".into(),
"1,1".into(),
"2,0".into(),
"2,1".into(),
"3,0".into(),
"3,1".into(),
],
)
.unwrap();
let r = a.slice_axis(0, 1..3).unwrap();
assert_eq!(r.shape(), &[2, 2]);
assert_eq!(r.as_slice(), &["1,0", "1,1", "2,0", "2,1"]);
}
#[test]
fn slice_axis_columns_2d() {
let a = StringArray::<Ix2>::from_vec(
Ix2::new([2, 4]),
vec![
"0,0".into(),
"0,1".into(),
"0,2".into(),
"0,3".into(),
"1,0".into(),
"1,1".into(),
"1,2".into(),
"1,3".into(),
],
)
.unwrap();
let r = a.slice_axis(1, 1..3).unwrap();
assert_eq!(r.shape(), &[2, 2]);
assert_eq!(r.as_slice(), &["0,1", "0,2", "1,1", "1,2"]);
}
#[test]
fn slice_axis_axis_out_of_bounds() {
let a = array(&["a", "b"]).unwrap();
assert!(a.slice_axis(5, 0..1).is_err());
}
#[test]
fn slice_axis_range_too_large_errors() {
let a = array(&["a", "b", "c"]).unwrap();
assert!(a.slice_axis(0, 0..10).is_err());
}
#[test]
fn create_from_slice() {
let a = array(&["hello", "world"]).unwrap();
assert_eq!(a.shape(), &[2]);
assert_eq!(a.len(), 2);
assert_eq!(a.as_slice()[0], "hello");
assert_eq!(a.as_slice()[1], "world");
}
#[test]
fn create_from_vec() {
let a = StringArray1::from_vec(Ix1::new([3]), vec!["a".into(), "b".into(), "c".into()])
.unwrap();
assert_eq!(a.shape(), &[3]);
}
#[test]
fn shape_mismatch_error() {
let res = StringArray1::from_vec(Ix1::new([5]), vec!["a".into(), "b".into()]);
assert!(res.is_err());
}
#[test]
fn empty_array() {
let a = StringArray1::empty(Ix1::new([4])).unwrap();
assert_eq!(a.len(), 4);
assert!(a.as_slice().iter().all(std::string::String::is_empty));
}
#[test]
fn map_strings() {
let a = array(&["hello", "world"]).unwrap();
let b = a.map(str::to_uppercase).unwrap();
assert_eq!(b.as_slice()[0], "HELLO");
assert_eq!(b.as_slice()[1], "WORLD");
}
#[test]
fn from_rows_2d() {
let a = StringArray2::from_rows(&[&["a", "b"], &["c", "d"]]).unwrap();
assert_eq!(a.shape(), &[2, 2]);
assert_eq!(a.as_slice(), &["a", "b", "c", "d"]);
}
#[test]
fn from_rows_ragged_error() {
let res = StringArray2::from_rows(&[&["a", "b"], &["c"]]);
assert!(res.is_err());
}
#[test]
fn equality() {
let a = array(&["x", "y"]).unwrap();
let b = array(&["x", "y"]).unwrap();
let c = array(&["x", "z"]).unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn broadcast_binary_scalar() {
let a = array(&["hello", "world"]).unwrap();
let b = array(&["!"]).unwrap();
let (shape, pairs) = broadcast_binary(&a, &b).unwrap();
assert_eq!(shape, vec![2]);
let collected: Vec<(usize, usize)> = pairs.collect();
assert_eq!(collected, vec![(0, 0), (1, 0)]);
}
#[test]
fn broadcast_binary_same_shape() {
let a = array(&["a", "b", "c"]).unwrap();
let b = array(&["x", "y", "z"]).unwrap();
let (shape, pairs) = broadcast_binary(&a, &b).unwrap();
assert_eq!(shape, vec![3]);
let collected: Vec<(usize, usize)> = pairs.collect();
assert_eq!(collected, vec![(0, 0), (1, 1), (2, 2)]);
}
#[test]
fn broadcast_binary_iter_size_hint() {
let a = array(&["hello", "world"]).unwrap();
let b = array(&["!"]).unwrap();
let (_shape, pairs) = broadcast_binary(&a, &b).unwrap();
assert_eq!(pairs.size_hint(), (2, Some(2)));
assert_eq!(pairs.len(), 2);
}
#[test]
fn into_vec() {
let a = array(&["a", "b"]).unwrap();
let v = a.into_vec();
assert_eq!(v, vec!["a".to_string(), "b".to_string()]);
}
#[test]
fn reshape_1d_to_2d() {
let a = array(&["a", "b", "c", "d", "e", "f"]).unwrap();
let b = a.reshape(Ix2::new([2, 3])).unwrap();
assert_eq!(b.shape(), &[2, 3]);
assert_eq!(b.as_slice(), &["a", "b", "c", "d", "e", "f"]);
}
#[test]
fn reshape_wrong_size_errors() {
let a = array(&["a", "b", "c"]).unwrap();
assert!(a.reshape(Ix2::new([2, 2])).is_err());
}
#[test]
fn flatten_2d_to_1d() {
let a = StringArray2::from_rows(&[&["a", "b"], &["c", "d"]]).unwrap();
let f = a.flatten();
assert_eq!(f.shape(), &[4]);
assert_eq!(f.as_slice(), &["a", "b", "c", "d"]);
}
#[test]
fn into_dyn_preserves_shape() {
let a = StringArray2::from_rows(&[&["x", "y"], &["z", "w"]]).unwrap();
let d = a.into_dyn();
assert_eq!(d.shape(), &[2, 2]);
assert_eq!(d.as_slice(), &["x", "y", "z", "w"]);
}
#[test]
fn transpose_2x3() {
let a = StringArray2::from_rows(&[&["a", "b", "c"], &["d", "e", "f"]]).unwrap();
let t = a.transpose().unwrap();
assert_eq!(t.shape(), &[3, 2]);
assert_eq!(t.as_slice(), &["a", "d", "b", "e", "c", "f"]);
}
#[test]
fn transpose_square_is_involution() {
let a = StringArray2::from_rows(&[&["1", "2"], &["3", "4"]]).unwrap();
let t = a.transpose().unwrap();
let tt = t.transpose().unwrap();
assert_eq!(tt.as_slice(), a.as_slice());
}
#[test]
fn get_1d() {
let a = array(&["zero", "one", "two"]).unwrap();
assert_eq!(a.get(&[0]).unwrap(), "zero");
assert_eq!(a.get(&[1]).unwrap(), "one");
assert_eq!(a.get(&[2]).unwrap(), "two");
assert_eq!(a.get(&[3]), None); assert_eq!(a.get(&[0, 0]), None); }
#[test]
fn get_2d() {
let a = StringArray2::from_rows(&[&["a", "b", "c"], &["d", "e", "f"]]).unwrap();
assert_eq!(a.get(&[0, 0]).unwrap(), "a");
assert_eq!(a.get(&[0, 2]).unwrap(), "c");
assert_eq!(a.get(&[1, 1]).unwrap(), "e");
assert_eq!(a.get(&[2, 0]), None); assert_eq!(a.get(&[0, 3]), None); }
}