1use crate::MapletResult;
7use std::collections::HashSet;
8use std::hash::Hash;
9
10pub trait MergeOperator<V>: Clone + Send + Sync {
12 fn merge(&self, left: V, right: V) -> MapletResult<V>;
18
19 fn identity(&self) -> V;
21
22 fn is_associative(&self) -> bool {
24 true }
26
27 fn is_commutative(&self) -> bool {
29 true }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub struct CounterOperator;
36
37impl MergeOperator<u64> for CounterOperator {
38 fn merge(&self, left: u64, right: u64) -> MapletResult<u64> {
39 Ok(left.saturating_add(right))
40 }
41
42 fn identity(&self) -> u64 {
43 0
44 }
45}
46
47impl MergeOperator<u32> for CounterOperator {
48 fn merge(&self, left: u32, right: u32) -> MapletResult<u32> {
49 Ok(left.saturating_add(right))
50 }
51
52 fn identity(&self) -> u32 {
53 0
54 }
55}
56
57impl MergeOperator<i64> for CounterOperator {
58 fn merge(&self, left: i64, right: i64) -> MapletResult<i64> {
59 Ok(left.saturating_add(right))
60 }
61
62 fn identity(&self) -> i64 {
63 0
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
69pub struct SetOperator;
70
71impl<T: Clone + Hash + Eq> MergeOperator<HashSet<T>> for SetOperator {
72 fn merge(&self, mut left: HashSet<T>, right: HashSet<T>) -> MapletResult<HashSet<T>> {
73 left.extend(right);
74 Ok(left)
75 }
76
77 fn identity(&self) -> HashSet<T> {
78 HashSet::new()
79 }
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
84pub struct StringOperator;
85
86impl MergeOperator<String> for StringOperator {
87 fn merge(&self, _left: String, right: String) -> MapletResult<String> {
88 Ok(right)
90 }
91
92 fn identity(&self) -> String {
93 String::new()
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
99pub struct MaxOperator;
100
101impl MergeOperator<u64> for MaxOperator {
102 fn merge(&self, left: u64, right: u64) -> MapletResult<u64> {
103 Ok(left.max(right))
104 }
105
106 fn identity(&self) -> u64 {
107 0
108 }
109}
110
111impl MergeOperator<f64> for MaxOperator {
112 fn merge(&self, left: f64, right: f64) -> MapletResult<f64> {
113 Ok(left.max(right))
114 }
115
116 fn identity(&self) -> f64 {
117 f64::NEG_INFINITY
118 }
119}
120
121#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
123pub struct MinOperator;
124
125impl MergeOperator<u64> for MinOperator {
126 fn merge(&self, left: u64, right: u64) -> MapletResult<u64> {
127 Ok(left.min(right))
128 }
129
130 fn identity(&self) -> u64 {
131 u64::MAX
132 }
133}
134
135impl MergeOperator<f64> for MinOperator {
136 fn merge(&self, left: f64, right: f64) -> MapletResult<f64> {
137 Ok(left.min(right))
138 }
139
140 fn identity(&self) -> f64 {
141 f64::INFINITY
142 }
143}
144
145#[derive(Clone)]
147pub struct CustomOperator<F> {
148 #[allow(dead_code)]
149 merge_fn: F,
150}
151
152impl<F> CustomOperator<F> {
153 pub const fn new(merge_fn: F) -> Self {
155 Self { merge_fn }
156 }
157}
158
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161pub struct StringConcatOperator;
162
163impl MergeOperator<String> for StringConcatOperator {
164 fn merge(&self, left: String, right: String) -> MapletResult<String> {
165 Ok(format!("{left}{right}"))
166 }
167
168 fn identity(&self) -> String {
169 String::new()
170 }
171}
172
173#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct VectorConcatOperator;
176
177impl<T: Clone> MergeOperator<Vec<T>> for VectorConcatOperator {
178 fn merge(&self, mut left: Vec<T>, right: Vec<T>) -> MapletResult<Vec<T>> {
179 left.extend(right);
180 Ok(left)
181 }
182
183 fn identity(&self) -> Vec<T> {
184 Vec::new()
185 }
186}
187
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
190pub struct VectorOperator;
191
192impl MergeOperator<Vec<f64>> for VectorOperator {
193 fn merge(&self, left: Vec<f64>, right: Vec<f64>) -> MapletResult<Vec<f64>> {
194 if left.len() != right.len() {
195 return Err(crate::MapletError::Internal(format!(
196 "Vector length mismatch: {} != {}",
197 left.len(),
198 right.len()
199 )));
200 }
201 Ok(left.into_iter().zip(right).map(|(l, r)| l + r).collect())
202 }
203
204 fn identity(&self) -> Vec<f64> {
205 Vec::new()
206 }
207}
208
209impl MergeOperator<Vec<f32>> for VectorOperator {
210 fn merge(&self, left: Vec<f32>, right: Vec<f32>) -> MapletResult<Vec<f32>> {
211 if left.len() != right.len() {
212 return Err(crate::MapletError::Internal(format!(
213 "Vector length mismatch: {} != {}",
214 left.len(),
215 right.len()
216 )));
217 }
218 Ok(left.into_iter().zip(right).map(|(l, r)| l + r).collect())
219 }
220
221 fn identity(&self) -> Vec<f32> {
222 Vec::new()
223 }
224}
225
226#[derive(Debug, Clone, Copy, PartialEq, Eq)]
228pub struct BoolOrOperator;
229
230impl MergeOperator<bool> for BoolOrOperator {
231 fn merge(&self, left: bool, right: bool) -> MapletResult<bool> {
232 Ok(left || right)
233 }
234
235 fn identity(&self) -> bool {
236 false
237 }
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242pub struct BoolAndOperator;
243
244impl MergeOperator<bool> for BoolAndOperator {
245 fn merge(&self, left: bool, right: bool) -> MapletResult<bool> {
246 Ok(left && right)
247 }
248
249 fn identity(&self) -> bool {
250 true
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use std::collections::HashSet;
258
259 #[test]
260 fn test_counter_operator() {
261 let op: CounterOperator = CounterOperator;
262
263 assert_eq!(op.merge(5u64, 3u64).unwrap(), 8);
264 assert_eq!(op.merge(0u64, 10u64).unwrap(), 10);
265
266 assert_eq!(op.merge(u64::MAX, 1).unwrap(), u64::MAX);
268 }
269
270 #[test]
271 fn test_set_operator() {
272 let op = SetOperator;
273
274 let mut set1 = HashSet::new();
275 set1.insert("a".to_string());
276 set1.insert("b".to_string());
277
278 let mut set2 = HashSet::new();
279 set2.insert("b".to_string());
280 set2.insert("c".to_string());
281
282 let result = op.merge(set1, set2).unwrap();
283 assert_eq!(result.len(), 3);
284 assert!(result.contains("a"));
285 assert!(result.contains("b"));
286 assert!(result.contains("c"));
287 }
288
289 #[test]
290 fn test_max_operator() {
291 let op: MaxOperator = MaxOperator;
292
293 assert_eq!(op.merge(5u64, 3u64).unwrap(), 5);
294 assert_eq!(op.merge(3u64, 5u64).unwrap(), 5);
295 assert_eq!(op.merge(5.0, 3.0).unwrap(), 5.0);
296 }
297
298 #[test]
299 fn test_min_operator() {
300 let op: MinOperator = MinOperator;
301
302 assert_eq!(op.merge(5u64, 3u64).unwrap(), 3);
303 assert_eq!(op.merge(3u64, 5u64).unwrap(), 3);
304 assert_eq!(op.merge(5.0, 3.0).unwrap(), 3.0);
305 }
306
307 #[test]
308 fn test_string_concat_operator() {
309 let op = StringConcatOperator;
310
311 assert_eq!(
312 op.merge("hello".to_string(), "world".to_string()).unwrap(),
313 "helloworld"
314 );
315 assert_eq!(op.identity(), "");
316 }
317
318 #[test]
319 fn test_vector_concat_operator() {
320 let op = VectorConcatOperator;
321
322 let vec1 = vec![1, 2, 3];
323 let vec2 = vec![4, 5, 6];
324 let result = op.merge(vec1, vec2).unwrap();
325 assert_eq!(result, vec![1, 2, 3, 4, 5, 6]);
326 }
327
328 #[test]
329 fn test_bool_operators() {
330 let or_op = BoolOrOperator;
331 let and_op = BoolAndOperator;
332
333 assert_eq!(or_op.merge(false, true).unwrap(), true);
334 assert_eq!(or_op.merge(false, false).unwrap(), false);
335 assert_eq!(or_op.identity(), false);
336
337 assert_eq!(and_op.merge(true, false).unwrap(), false);
338 assert_eq!(and_op.merge(true, true).unwrap(), true);
339 assert_eq!(and_op.identity(), true);
340 }
341}