1use ferray_core::dimension::{Dimension, IxDyn};
6use ferray_core::error::FerrayResult;
7
8use crate::string_array::{StringArray, broadcast_binary};
9
10pub fn add<Da: Dimension, Db: Dimension>(
21 a: &StringArray<Da>,
22 b: &StringArray<Db>,
23) -> FerrayResult<StringArray<IxDyn>> {
24 let (out_shape, pairs) = broadcast_binary(a, b)?;
25 let a_data = a.as_slice();
26 let b_data = b.as_slice();
27
28 let data: Vec<String> = pairs
29 .map(|(ia, ib)| format!("{}{}", a_data[ia], b_data[ib]))
30 .collect();
31
32 StringArray::from_vec(IxDyn::new(&out_shape), data)
33}
34
35pub fn add_same<D: Dimension>(
46 a: &StringArray<D>,
47 b: &StringArray<D>,
48) -> FerrayResult<StringArray<D>> {
49 if a.shape() != b.shape() {
50 return Err(ferray_core::error::FerrayError::shape_mismatch(format!(
51 "add_same: shapes {:?} and {:?} must be identical",
52 a.shape(),
53 b.shape()
54 )));
55 }
56 let data: Vec<String> = a
57 .iter()
58 .zip(b.iter())
59 .map(|(x, y)| format!("{x}{y}"))
60 .collect();
61 StringArray::from_vec(a.dim().clone(), data)
62}
63
64pub fn multiply<D: Dimension>(a: &StringArray<D>, n: usize) -> FerrayResult<StringArray<D>> {
69 a.map(|s| s.repeat(n))
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use crate::string_array::array;
76
77 #[test]
78 fn test_add_same_shape() {
79 let a = array(&["hello", "foo"]).unwrap();
80 let b = array(&[" world", " bar"]).unwrap();
81 let c = add(&a, &b).unwrap();
82 assert_eq!(c.as_slice(), &["hello world", "foo bar"]);
83 }
84
85 #[test]
86 fn test_add_broadcast_scalar() {
87 let a = array(&["hello", "world"]).unwrap();
89 let b = array(&["!"]).unwrap();
90 let c = add(&a, &b).unwrap();
91 assert_eq!(c.as_slice(), &["hello!", "world!"]);
92 }
93
94 #[test]
95 fn test_add_broadcast_scalar_left() {
96 let a = array(&[">> "]).unwrap();
97 let b = array(&["hello", "world"]).unwrap();
98 let c = add(&a, &b).unwrap();
99 assert_eq!(c.as_slice(), &[">> hello", ">> world"]);
100 }
101
102 #[test]
103 fn test_add_incompatible_shapes() {
104 let a = array(&["a", "b", "c"]).unwrap();
105 let b = array(&["x", "y"]).unwrap();
106 assert!(add(&a, &b).is_err());
107 }
108
109 #[test]
110 fn test_multiply() {
111 let a = array(&["ab", "cd"]).unwrap();
112 let b = multiply(&a, 3).unwrap();
113 assert_eq!(b.as_slice(), &["ababab", "cdcdcd"]);
114 }
115
116 #[test]
117 fn test_multiply_zero() {
118 let a = array(&["hello"]).unwrap();
119 let b = multiply(&a, 0).unwrap();
120 assert_eq!(b.as_slice(), &[""]);
121 }
122
123 #[test]
124 fn test_multiply_one() {
125 let a = array(&["hello"]).unwrap();
126 let b = multiply(&a, 1).unwrap();
127 assert_eq!(b.as_slice(), &["hello"]);
128 }
129}