1use ferray_core::dimension::{Dimension, Ix1, Ix2, IxDyn};
8use ferray_core::error::{FerrayError, FerrayResult};
9
10#[derive(Debug, Clone)]
17pub struct StringArray<D: Dimension> {
18 data: Vec<String>,
20 dim: D,
22}
23
24pub type StringArray1 = StringArray<Ix1>;
26
27pub type StringArray2 = StringArray<Ix2>;
29
30impl<D: Dimension> StringArray<D> {
31 pub fn from_vec(dim: D, data: Vec<String>) -> FerrayResult<Self> {
37 let expected = dim.size();
38 if data.len() != expected {
39 return Err(FerrayError::shape_mismatch(format!(
40 "data length {} does not match shape {:?} (expected {})",
41 data.len(),
42 dim.as_slice(),
43 expected,
44 )));
45 }
46 Ok(Self { data, dim })
47 }
48
49 pub fn empty(dim: D) -> FerrayResult<Self> {
55 let size = dim.size();
56 let data = vec![String::new(); size];
57 Ok(Self { data, dim })
58 }
59
60 #[inline]
62 pub fn shape(&self) -> &[usize] {
63 self.dim.as_slice()
64 }
65
66 #[inline]
68 pub fn ndim(&self) -> usize {
69 self.dim.ndim()
70 }
71
72 #[inline]
74 pub fn len(&self) -> usize {
75 self.data.len()
76 }
77
78 #[inline]
80 pub fn is_empty(&self) -> bool {
81 self.data.is_empty()
82 }
83
84 #[inline]
86 pub const fn dim(&self) -> &D {
87 &self.dim
88 }
89
90 #[inline]
92 pub fn as_slice(&self) -> &[String] {
93 &self.data
94 }
95
96 #[inline]
98 pub fn as_slice_mut(&mut self) -> &mut [String] {
99 &mut self.data
100 }
101
102 #[inline]
104 pub fn into_vec(self) -> Vec<String> {
105 self.data
106 }
107
108 pub fn map<F>(&self, f: F) -> FerrayResult<Self>
110 where
111 F: Fn(&str) -> String,
112 {
113 let data: Vec<String> = self.data.iter().map(|s| f(s)).collect();
114 Self::from_vec(self.dim.clone(), data)
115 }
116
117 pub fn map_to_vec<T, F>(&self, f: F) -> Vec<T>
122 where
123 F: Fn(&str) -> T,
124 {
125 self.data.iter().map(|s| f(s)).collect()
126 }
127
128 pub fn iter(&self) -> std::slice::Iter<'_, String> {
130 self.data.iter()
131 }
132
133 pub fn reshape<D2: Dimension>(self, new_dim: D2) -> FerrayResult<StringArray<D2>> {
153 StringArray::<D2>::from_vec(new_dim, self.data)
154 }
155
156 pub fn flatten(self) -> StringArray1 {
162 let n = self.data.len();
163 StringArray::<Ix1>::from_vec(Ix1::new([n]), self.data)
164 .expect("flatten: length check is trivially satisfied")
165 }
166
167 pub fn into_dyn(self) -> StringArray<IxDyn> {
171 let shape = self.dim.as_slice().to_vec();
172 StringArray::<IxDyn>::from_vec(IxDyn::new(&shape), self.data)
173 .expect("into_dyn: shape length check is trivially satisfied")
174 }
175
176 pub fn get(&self, idx: &[usize]) -> Option<&String> {
182 let shape = self.dim.as_slice();
183 if idx.len() != shape.len() {
184 return None;
185 }
186 let mut flat = 0usize;
187 let mut stride = 1usize;
188 for (i, (&dim, &k)) in shape.iter().zip(idx.iter()).enumerate().rev() {
190 if k >= dim {
191 return None;
192 }
193 if i == shape.len() - 1 {
194 flat += k;
195 } else {
196 flat += k * stride;
197 }
198 stride *= dim;
199 }
200 self.data.get(flat)
201 }
202
203 #[must_use]
209 pub fn at(&self, idx: usize) -> Option<&String> {
210 self.data.get(idx)
211 }
212
213 pub fn slice_axis(
225 &self,
226 axis: usize,
227 range: std::ops::Range<usize>,
228 ) -> FerrayResult<StringArray<IxDyn>> {
229 let shape = self.dim.as_slice().to_vec();
230 let ndim = shape.len();
231 if axis >= ndim {
232 return Err(ferray_core::error::FerrayError::axis_out_of_bounds(
233 axis, ndim,
234 ));
235 }
236 let axis_len = shape[axis];
237 if range.end > axis_len || range.start > range.end {
238 return Err(ferray_core::error::FerrayError::invalid_value(format!(
239 "slice_axis: range {:?} out of bounds for axis {axis} with size {axis_len}",
240 range
241 )));
242 }
243 let new_axis_len = range.end - range.start;
244 let inner_stride: usize = shape[axis + 1..].iter().product();
245 let block = axis_len * inner_stride;
246 let outer_size: usize = shape[..axis].iter().product();
247 let mut new_shape = shape.clone();
248 new_shape[axis] = new_axis_len;
249 let total: usize = new_shape.iter().product();
250 let mut out: Vec<String> = Vec::with_capacity(total);
251 for o in 0..outer_size {
252 let base = o * block;
253 for i in range.clone() {
254 let row_start = base + i * inner_stride;
255 out.extend_from_slice(&self.data[row_start..row_start + inner_stride]);
256 }
257 }
258 StringArray::<IxDyn>::from_vec(IxDyn::new(&new_shape), out)
259 }
260
261 pub fn get_row(&self, idx: usize) -> FerrayResult<crate::string_array::StringArray1> {
268 let shape = self.dim.as_slice();
269 if shape.len() != 2 {
270 return Err(ferray_core::error::FerrayError::shape_mismatch(format!(
271 "get_row: expected a 2-D StringArray, got {}-D",
272 shape.len()
273 )));
274 }
275 let nrows = shape[0];
276 let ncols = shape[1];
277 if idx >= nrows {
278 return Err(ferray_core::error::FerrayError::index_out_of_bounds(
279 idx as isize,
280 0,
281 nrows,
282 ));
283 }
284 let row: Vec<String> = self.data[idx * ncols..(idx + 1) * ncols].to_vec();
285 crate::string_array::StringArray1::from_vec(ferray_core::dimension::Ix1::new([ncols]), row)
286 }
287}
288
289impl<D: Dimension> PartialEq for StringArray<D> {
290 fn eq(&self, other: &Self) -> bool {
291 self.dim == other.dim && self.data == other.data
292 }
293}
294
295impl<D: Dimension> Eq for StringArray<D> {}
296
297impl<D: Dimension> std::fmt::Display for StringArray<D> {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 write!(f, "array([")?;
301 for (i, s) in self.data.iter().enumerate() {
302 if i > 0 {
303 write!(f, ", ")?;
304 }
305 write!(f, "{s:?}")?;
306 }
307 write!(f, "])")
308 }
309}
310
311impl<'a, D: Dimension> IntoIterator for &'a StringArray<D> {
313 type Item = &'a String;
314 type IntoIter = std::slice::Iter<'a, String>;
315
316 fn into_iter(self) -> Self::IntoIter {
317 self.data.iter()
318 }
319}
320
321impl<D: Dimension> IntoIterator for StringArray<D> {
323 type Item = String;
324 type IntoIter = std::vec::IntoIter<String>;
325
326 fn into_iter(self) -> Self::IntoIter {
327 self.data.into_iter()
328 }
329}
330
331impl StringArray<Ix1> {
336 pub fn from_slice(items: &[&str]) -> FerrayResult<Self> {
343 let data: Vec<String> = items.iter().map(|s| (*s).to_string()).collect();
344 let dim = Ix1::new([data.len()]);
345 Self::from_vec(dim, data)
346 }
347}
348
349impl StringArray<Ix2> {
350 pub fn transpose(&self) -> FerrayResult<Self> {
356 let shape = self.shape();
357 let (nrows, ncols) = (shape[0], shape[1]);
358 let mut data = Vec::with_capacity(nrows * ncols);
359 for c in 0..ncols {
360 for r in 0..nrows {
361 data.push(self.data[r * ncols + c].clone());
362 }
363 }
364 Self::from_vec(Ix2::new([ncols, nrows]), data)
365 }
366
367 pub fn from_rows(rows: &[&[&str]]) -> FerrayResult<Self> {
372 if rows.is_empty() {
373 return Self::from_vec(Ix2::new([0, 0]), Vec::new());
374 }
375 let ncols = rows[0].len();
376 for (i, row) in rows.iter().enumerate() {
377 if row.len() != ncols {
378 return Err(FerrayError::shape_mismatch(format!(
379 "row {} has length {} but row 0 has length {}",
380 i,
381 row.len(),
382 ncols
383 )));
384 }
385 }
386 let nrows = rows.len();
387 let data: Vec<String> = rows
388 .iter()
389 .flat_map(|row| row.iter().map(|s| (*s).to_string()))
390 .collect();
391 Self::from_vec(Ix2::new([nrows, ncols]), data)
392 }
393}
394
395impl StringArray<IxDyn> {
396 pub fn from_vec_dyn(shape: &[usize], data: Vec<String>) -> FerrayResult<Self> {
398 Self::from_vec(IxDyn::new(shape), data)
399 }
400}
401
402pub fn array(items: &[&str]) -> FerrayResult<StringArray1> {
409 StringArray1::from_slice(items)
410}
411
412use ferray_core::dimension::broadcast::broadcast_shapes;
417
418pub(crate) struct BroadcastIter {
424 out_shape: Vec<usize>,
425 shape_a: Vec<usize>,
426 shape_b: Vec<usize>,
427 strides_a: Vec<usize>,
428 strides_b: Vec<usize>,
429 out_size: usize,
430 linear: usize,
431}
432
433impl Iterator for BroadcastIter {
434 type Item = (usize, usize);
435
436 fn next(&mut self) -> Option<Self::Item> {
437 if self.linear >= self.out_size {
438 return None;
439 }
440 let multi = linear_to_multi(self.linear, &self.out_shape);
441 let idx_a = multi_to_broadcast_linear(&multi, &self.shape_a, &self.strides_a);
442 let idx_b = multi_to_broadcast_linear(&multi, &self.shape_b, &self.strides_b);
443 self.linear += 1;
444 Some((idx_a, idx_b))
445 }
446
447 fn size_hint(&self) -> (usize, Option<usize>) {
448 let remaining = self.out_size - self.linear;
449 (remaining, Some(remaining))
450 }
451}
452
453impl ExactSizeIterator for BroadcastIter {}
454
455pub(crate) fn broadcast_binary<Da: Dimension, Db: Dimension>(
464 a: &StringArray<Da>,
465 b: &StringArray<Db>,
466) -> FerrayResult<(Vec<usize>, BroadcastIter)> {
467 let shape_a = a.shape().to_vec();
468 let shape_b = b.shape().to_vec();
469 let out_shape = broadcast_shapes(&shape_a, &shape_b)?;
470 let out_size: usize = out_shape.iter().product();
471
472 let strides_a = compute_strides(&shape_a);
473 let strides_b = compute_strides(&shape_b);
474
475 let iter = BroadcastIter {
476 out_shape: out_shape.clone(),
477 shape_a,
478 shape_b,
479 strides_a,
480 strides_b,
481 out_size,
482 linear: 0,
483 };
484 Ok((out_shape, iter))
485}
486
487fn compute_strides(shape: &[usize]) -> Vec<usize> {
489 let ndim = shape.len();
490 if ndim == 0 {
491 return vec![];
492 }
493 let mut strides = vec![1usize; ndim];
494 for i in (0..ndim - 1).rev() {
495 strides[i] = strides[i + 1] * shape[i + 1];
496 }
497 strides
498}
499
500fn linear_to_multi(mut linear: usize, shape: &[usize]) -> Vec<usize> {
502 let ndim = shape.len();
503 let mut indices = vec![0usize; ndim];
504 for i in (0..ndim).rev() {
505 if shape[i] > 0 {
506 indices[i] = linear % shape[i];
507 linear /= shape[i];
508 }
509 }
510 indices
511}
512
513fn multi_to_broadcast_linear(multi: &[usize], src_shape: &[usize], src_strides: &[usize]) -> usize {
516 let out_ndim = multi.len();
517 let src_ndim = src_shape.len();
518 let pad = out_ndim.saturating_sub(src_ndim);
519
520 let mut linear = 0usize;
521 for i in 0..src_ndim {
522 let idx = multi[i + pad];
523 let effective = if src_shape[i] == 1 { 0 } else { idx };
525 linear += effective * src_strides[i];
526 }
527 linear
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
537 fn at_returns_flat_index() {
538 let a = array(&["a", "b", "c"]).unwrap();
539 assert_eq!(a.at(0).unwrap(), "a");
540 assert_eq!(a.at(2).unwrap(), "c");
541 assert!(a.at(3).is_none());
542 }
543
544 #[test]
545 fn get_row_returns_1d() {
546 let a = StringArray::<Ix2>::from_vec(
548 Ix2::new([2, 3]),
549 vec![
550 "a0".into(),
551 "a1".into(),
552 "a2".into(),
553 "b0".into(),
554 "b1".into(),
555 "b2".into(),
556 ],
557 )
558 .unwrap();
559 let row1 = a.get_row(1).unwrap();
560 assert_eq!(row1.shape(), &[3]);
561 assert_eq!(row1.as_slice(), &["b0", "b1", "b2"]);
562 }
563
564 #[test]
565 fn get_row_rejects_non_2d() {
566 let a = array(&["a", "b", "c"]).unwrap();
567 assert!(a.get_row(0).is_err());
568 }
569
570 #[test]
571 fn get_row_index_out_of_bounds_errors() {
572 let a = StringArray::<Ix2>::from_vec(Ix2::new([2, 2]), vec!["x".into(); 4]).unwrap();
573 assert!(a.get_row(5).is_err());
574 }
575
576 #[test]
577 fn slice_axis_rows_2d() {
578 let a = StringArray::<Ix2>::from_vec(
579 Ix2::new([4, 2]),
580 vec![
581 "0,0".into(),
582 "0,1".into(),
583 "1,0".into(),
584 "1,1".into(),
585 "2,0".into(),
586 "2,1".into(),
587 "3,0".into(),
588 "3,1".into(),
589 ],
590 )
591 .unwrap();
592 let r = a.slice_axis(0, 1..3).unwrap();
593 assert_eq!(r.shape(), &[2, 2]);
594 assert_eq!(r.as_slice(), &["1,0", "1,1", "2,0", "2,1"]);
595 }
596
597 #[test]
598 fn slice_axis_columns_2d() {
599 let a = StringArray::<Ix2>::from_vec(
600 Ix2::new([2, 4]),
601 vec![
602 "0,0".into(),
603 "0,1".into(),
604 "0,2".into(),
605 "0,3".into(),
606 "1,0".into(),
607 "1,1".into(),
608 "1,2".into(),
609 "1,3".into(),
610 ],
611 )
612 .unwrap();
613 let r = a.slice_axis(1, 1..3).unwrap();
614 assert_eq!(r.shape(), &[2, 2]);
615 assert_eq!(r.as_slice(), &["0,1", "0,2", "1,1", "1,2"]);
616 }
617
618 #[test]
619 fn slice_axis_axis_out_of_bounds() {
620 let a = array(&["a", "b"]).unwrap();
621 assert!(a.slice_axis(5, 0..1).is_err());
622 }
623
624 #[test]
625 fn slice_axis_range_too_large_errors() {
626 let a = array(&["a", "b", "c"]).unwrap();
627 assert!(a.slice_axis(0, 0..10).is_err());
628 }
629
630 #[test]
631 fn create_from_slice() {
632 let a = array(&["hello", "world"]).unwrap();
633 assert_eq!(a.shape(), &[2]);
634 assert_eq!(a.len(), 2);
635 assert_eq!(a.as_slice()[0], "hello");
636 assert_eq!(a.as_slice()[1], "world");
637 }
638
639 #[test]
640 fn create_from_vec() {
641 let a = StringArray1::from_vec(Ix1::new([3]), vec!["a".into(), "b".into(), "c".into()])
642 .unwrap();
643 assert_eq!(a.shape(), &[3]);
644 }
645
646 #[test]
647 fn shape_mismatch_error() {
648 let res = StringArray1::from_vec(Ix1::new([5]), vec!["a".into(), "b".into()]);
649 assert!(res.is_err());
650 }
651
652 #[test]
653 fn empty_array() {
654 let a = StringArray1::empty(Ix1::new([4])).unwrap();
655 assert_eq!(a.len(), 4);
656 assert!(a.as_slice().iter().all(std::string::String::is_empty));
657 }
658
659 #[test]
660 fn map_strings() {
661 let a = array(&["hello", "world"]).unwrap();
662 let b = a.map(str::to_uppercase).unwrap();
663 assert_eq!(b.as_slice()[0], "HELLO");
664 assert_eq!(b.as_slice()[1], "WORLD");
665 }
666
667 #[test]
668 fn from_rows_2d() {
669 let a = StringArray2::from_rows(&[&["a", "b"], &["c", "d"]]).unwrap();
670 assert_eq!(a.shape(), &[2, 2]);
671 assert_eq!(a.as_slice(), &["a", "b", "c", "d"]);
672 }
673
674 #[test]
675 fn from_rows_ragged_error() {
676 let res = StringArray2::from_rows(&[&["a", "b"], &["c"]]);
677 assert!(res.is_err());
678 }
679
680 #[test]
681 fn equality() {
682 let a = array(&["x", "y"]).unwrap();
683 let b = array(&["x", "y"]).unwrap();
684 let c = array(&["x", "z"]).unwrap();
685 assert_eq!(a, b);
686 assert_ne!(a, c);
687 }
688
689 #[test]
690 fn broadcast_binary_scalar() {
691 let a = array(&["hello", "world"]).unwrap();
692 let b = array(&["!"]).unwrap();
693 let (shape, pairs) = broadcast_binary(&a, &b).unwrap();
694 assert_eq!(shape, vec![2]);
695 let collected: Vec<(usize, usize)> = pairs.collect();
696 assert_eq!(collected, vec![(0, 0), (1, 0)]);
697 }
698
699 #[test]
700 fn broadcast_binary_same_shape() {
701 let a = array(&["a", "b", "c"]).unwrap();
702 let b = array(&["x", "y", "z"]).unwrap();
703 let (shape, pairs) = broadcast_binary(&a, &b).unwrap();
704 assert_eq!(shape, vec![3]);
705 let collected: Vec<(usize, usize)> = pairs.collect();
706 assert_eq!(collected, vec![(0, 0), (1, 1), (2, 2)]);
707 }
708
709 #[test]
710 fn broadcast_binary_iter_size_hint() {
711 let a = array(&["hello", "world"]).unwrap();
713 let b = array(&["!"]).unwrap();
714 let (_shape, pairs) = broadcast_binary(&a, &b).unwrap();
715 assert_eq!(pairs.size_hint(), (2, Some(2)));
716 assert_eq!(pairs.len(), 2);
717 }
718
719 #[test]
720 fn into_vec() {
721 let a = array(&["a", "b"]).unwrap();
722 let v = a.into_vec();
723 assert_eq!(v, vec!["a".to_string(), "b".to_string()]);
724 }
725
726 #[test]
729 fn reshape_1d_to_2d() {
730 let a = array(&["a", "b", "c", "d", "e", "f"]).unwrap();
731 let b = a.reshape(Ix2::new([2, 3])).unwrap();
732 assert_eq!(b.shape(), &[2, 3]);
733 assert_eq!(b.as_slice(), &["a", "b", "c", "d", "e", "f"]);
734 }
735
736 #[test]
737 fn reshape_wrong_size_errors() {
738 let a = array(&["a", "b", "c"]).unwrap();
739 assert!(a.reshape(Ix2::new([2, 2])).is_err());
740 }
741
742 #[test]
743 fn flatten_2d_to_1d() {
744 let a = StringArray2::from_rows(&[&["a", "b"], &["c", "d"]]).unwrap();
745 let f = a.flatten();
746 assert_eq!(f.shape(), &[4]);
747 assert_eq!(f.as_slice(), &["a", "b", "c", "d"]);
748 }
749
750 #[test]
751 fn into_dyn_preserves_shape() {
752 let a = StringArray2::from_rows(&[&["x", "y"], &["z", "w"]]).unwrap();
753 let d = a.into_dyn();
754 assert_eq!(d.shape(), &[2, 2]);
755 assert_eq!(d.as_slice(), &["x", "y", "z", "w"]);
756 }
757
758 #[test]
759 fn transpose_2x3() {
760 let a = StringArray2::from_rows(&[&["a", "b", "c"], &["d", "e", "f"]]).unwrap();
762 let t = a.transpose().unwrap();
763 assert_eq!(t.shape(), &[3, 2]);
764 assert_eq!(t.as_slice(), &["a", "d", "b", "e", "c", "f"]);
765 }
766
767 #[test]
768 fn transpose_square_is_involution() {
769 let a = StringArray2::from_rows(&[&["1", "2"], &["3", "4"]]).unwrap();
770 let t = a.transpose().unwrap();
771 let tt = t.transpose().unwrap();
772 assert_eq!(tt.as_slice(), a.as_slice());
773 }
774
775 #[test]
776 fn get_1d() {
777 let a = array(&["zero", "one", "two"]).unwrap();
778 assert_eq!(a.get(&[0]).unwrap(), "zero");
779 assert_eq!(a.get(&[1]).unwrap(), "one");
780 assert_eq!(a.get(&[2]).unwrap(), "two");
781 assert_eq!(a.get(&[3]), None); assert_eq!(a.get(&[0, 0]), None); }
784
785 #[test]
786 fn get_2d() {
787 let a = StringArray2::from_rows(&[&["a", "b", "c"], &["d", "e", "f"]]).unwrap();
788 assert_eq!(a.get(&[0, 0]).unwrap(), "a");
789 assert_eq!(a.get(&[0, 2]).unwrap(), "c");
790 assert_eq!(a.get(&[1, 1]).unwrap(), "e");
791 assert_eq!(a.get(&[2, 0]), None); assert_eq!(a.get(&[0, 3]), None); }
794}