1use ferray_core::dimension::{Dimension, IxDyn};
6use ferray_core::error::FerrayResult;
7
8use crate::string_array::{StringArray, broadcast_binary};
9
10pub fn add<Da: Dimension, Db: Dimension>(
21 a: &StringArray<Da>,
22 b: &StringArray<Db>,
23) -> FerrayResult<StringArray<IxDyn>> {
24 let (out_shape, pairs) = broadcast_binary(a, b)?;
25 let a_data = a.as_slice();
26 let b_data = b.as_slice();
27
28 let data: Vec<String> = pairs
29 .iter()
30 .map(|&(ia, ib)| format!("{}{}", a_data[ia], b_data[ib]))
31 .collect();
32
33 StringArray::from_vec(IxDyn::new(&out_shape), data)
34}
35
36pub fn multiply<D: Dimension>(a: &StringArray<D>, n: usize) -> FerrayResult<StringArray<D>> {
41 a.map(|s| s.repeat(n))
42}
43
44#[cfg(test)]
45mod tests {
46 use super::*;
47 use crate::string_array::array;
48
49 #[test]
50 fn test_add_same_shape() {
51 let a = array(&["hello", "foo"]).unwrap();
52 let b = array(&[" world", " bar"]).unwrap();
53 let c = add(&a, &b).unwrap();
54 assert_eq!(c.as_slice(), &["hello world", "foo bar"]);
55 }
56
57 #[test]
58 fn test_add_broadcast_scalar() {
59 let a = array(&["hello", "world"]).unwrap();
61 let b = array(&["!"]).unwrap();
62 let c = add(&a, &b).unwrap();
63 assert_eq!(c.as_slice(), &["hello!", "world!"]);
64 }
65
66 #[test]
67 fn test_add_broadcast_scalar_left() {
68 let a = array(&[">> "]).unwrap();
69 let b = array(&["hello", "world"]).unwrap();
70 let c = add(&a, &b).unwrap();
71 assert_eq!(c.as_slice(), &[">> hello", ">> world"]);
72 }
73
74 #[test]
75 fn test_add_incompatible_shapes() {
76 let a = array(&["a", "b", "c"]).unwrap();
77 let b = array(&["x", "y"]).unwrap();
78 assert!(add(&a, &b).is_err());
79 }
80
81 #[test]
82 fn test_multiply() {
83 let a = array(&["ab", "cd"]).unwrap();
84 let b = multiply(&a, 3).unwrap();
85 assert_eq!(b.as_slice(), &["ababab", "cdcdcd"]);
86 }
87
88 #[test]
89 fn test_multiply_zero() {
90 let a = array(&["hello"]).unwrap();
91 let b = multiply(&a, 0).unwrap();
92 assert_eq!(b.as_slice(), &[""]);
93 }
94
95 #[test]
96 fn test_multiply_one() {
97 let a = array(&["hello"]).unwrap();
98 let b = multiply(&a, 1).unwrap();
99 assert_eq!(b.as_slice(), &["hello"]);
100 }
101}