oxigdal_streaming/transformations/
partition.rs1use crate::core::stream::StreamElement;
4use crate::error::{Result, StreamingError};
5use ahash::AHasher;
6use serde::{Deserialize, Serialize};
7use std::hash::{Hash, Hasher};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicUsize, Ordering};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum PartitionStrategy {
14 Hash,
16
17 Range,
19
20 RoundRobin,
22
23 Random,
25
26 Broadcast,
28}
29
30pub trait KeySelector: Send + Sync {
32 fn select_key(&self, element: &StreamElement) -> Vec<u8>;
34}
35
36pub struct ElementKeySelector;
38
39impl KeySelector for ElementKeySelector {
40 fn select_key(&self, element: &StreamElement) -> Vec<u8> {
41 element.key.clone().unwrap_or_default()
42 }
43}
44
45pub trait Partitioner: Send + Sync {
47 fn partition(&self, element: &StreamElement, num_partitions: usize) -> Result<usize>;
49
50 fn strategy(&self) -> PartitionStrategy;
52}
53
54pub struct HashPartitioner<K>
56where
57 K: KeySelector,
58{
59 key_selector: Arc<K>,
60}
61
62impl<K> HashPartitioner<K>
63where
64 K: KeySelector,
65{
66 pub fn new(key_selector: K) -> Self {
68 Self {
69 key_selector: Arc::new(key_selector),
70 }
71 }
72}
73
74impl<K> Partitioner for HashPartitioner<K>
75where
76 K: KeySelector,
77{
78 fn partition(&self, element: &StreamElement, num_partitions: usize) -> Result<usize> {
79 if num_partitions == 0 {
80 return Err(StreamingError::PartitionError(
81 "Number of partitions must be greater than 0".to_string(),
82 ));
83 }
84
85 let key = self.key_selector.select_key(element);
86 let mut hasher = AHasher::default();
87 key.hash(&mut hasher);
88 let hash = hasher.finish();
89
90 Ok((hash as usize) % num_partitions)
91 }
92
93 fn strategy(&self) -> PartitionStrategy {
94 PartitionStrategy::Hash
95 }
96}
97
98pub struct RangePartitioner<K>
100where
101 K: KeySelector,
102{
103 key_selector: Arc<K>,
104 boundaries: Vec<Vec<u8>>,
105}
106
107impl<K> RangePartitioner<K>
108where
109 K: KeySelector,
110{
111 pub fn new(key_selector: K, boundaries: Vec<Vec<u8>>) -> Self {
113 Self {
114 key_selector: Arc::new(key_selector),
115 boundaries,
116 }
117 }
118}
119
120impl<K> Partitioner for RangePartitioner<K>
121where
122 K: KeySelector,
123{
124 fn partition(&self, element: &StreamElement, num_partitions: usize) -> Result<usize> {
125 if num_partitions == 0 {
126 return Err(StreamingError::PartitionError(
127 "Number of partitions must be greater than 0".to_string(),
128 ));
129 }
130
131 let key = self.key_selector.select_key(element);
132
133 for (i, boundary) in self.boundaries.iter().enumerate() {
134 if &key < boundary {
135 return Ok(i.min(num_partitions - 1));
136 }
137 }
138
139 Ok(num_partitions - 1)
140 }
141
142 fn strategy(&self) -> PartitionStrategy {
143 PartitionStrategy::Range
144 }
145}
146
147pub struct RoundRobinPartitioner {
149 counter: Arc<AtomicUsize>,
150}
151
152impl RoundRobinPartitioner {
153 pub fn new() -> Self {
155 Self {
156 counter: Arc::new(AtomicUsize::new(0)),
157 }
158 }
159}
160
161impl Default for RoundRobinPartitioner {
162 fn default() -> Self {
163 Self::new()
164 }
165}
166
167impl Partitioner for RoundRobinPartitioner {
168 fn partition(&self, _element: &StreamElement, num_partitions: usize) -> Result<usize> {
169 if num_partitions == 0 {
170 return Err(StreamingError::PartitionError(
171 "Number of partitions must be greater than 0".to_string(),
172 ));
173 }
174
175 let partition = self.counter.fetch_add(1, Ordering::Relaxed) % num_partitions;
176 Ok(partition)
177 }
178
179 fn strategy(&self) -> PartitionStrategy {
180 PartitionStrategy::RoundRobin
181 }
182}
183
184pub struct BroadcastPartitioner;
186
187impl BroadcastPartitioner {
188 pub fn new() -> Self {
190 Self
191 }
192}
193
194impl Default for BroadcastPartitioner {
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl Partitioner for BroadcastPartitioner {
201 fn partition(&self, _element: &StreamElement, num_partitions: usize) -> Result<usize> {
202 if num_partitions == 0 {
203 return Err(StreamingError::PartitionError(
204 "Number of partitions must be greater than 0".to_string(),
205 ));
206 }
207
208 Ok(0)
209 }
210
211 fn strategy(&self) -> PartitionStrategy {
212 PartitionStrategy::Broadcast
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use chrono::Utc;
220
221 #[test]
222 fn test_hash_partitioner() {
223 let partitioner = HashPartitioner::new(ElementKeySelector);
224
225 let elem = StreamElement::new(vec![1, 2, 3], Utc::now()).with_key(vec![1]);
226 let partition = partitioner
227 .partition(&elem, 4)
228 .expect("Failed to partition element with hash partitioner");
229
230 assert!(partition < 4);
231 }
232
233 #[test]
234 fn test_hash_partitioner_consistency() {
235 let partitioner = HashPartitioner::new(ElementKeySelector);
236
237 let elem = StreamElement::new(vec![1, 2, 3], Utc::now()).with_key(vec![1]);
238 let p1 = partitioner
239 .partition(&elem, 4)
240 .expect("Failed to partition element for consistency test (first call)");
241 let p2 = partitioner
242 .partition(&elem, 4)
243 .expect("Failed to partition element for consistency test (second call)");
244
245 assert_eq!(p1, p2);
246 }
247
248 #[test]
249 fn test_range_partitioner() {
250 let boundaries = vec![vec![5], vec![10], vec![15]];
251 let partitioner = RangePartitioner::new(ElementKeySelector, boundaries);
252
253 let elem1 = StreamElement::new(vec![1, 2, 3], Utc::now()).with_key(vec![3]);
254 let elem2 = StreamElement::new(vec![1, 2, 3], Utc::now()).with_key(vec![7]);
255 let elem3 = StreamElement::new(vec![1, 2, 3], Utc::now()).with_key(vec![12]);
256
257 assert_eq!(
258 partitioner
259 .partition(&elem1, 4)
260 .expect("Failed to partition element 1 with range partitioner"),
261 0
262 );
263 assert_eq!(
264 partitioner
265 .partition(&elem2, 4)
266 .expect("Failed to partition element 2 with range partitioner"),
267 1
268 );
269 assert_eq!(
270 partitioner
271 .partition(&elem3, 4)
272 .expect("Failed to partition element 3 with range partitioner"),
273 2
274 );
275 }
276
277 #[test]
278 fn test_round_robin_partitioner() {
279 let partitioner = RoundRobinPartitioner::new();
280
281 let elem = StreamElement::new(vec![1, 2, 3], Utc::now());
282
283 let mut partitions = Vec::new();
284 for _ in 0..8 {
285 partitions.push(
286 partitioner
287 .partition(&elem, 4)
288 .expect("Failed to partition element with round-robin partitioner"),
289 );
290 }
291
292 assert_eq!(partitions, vec![0, 1, 2, 3, 0, 1, 2, 3]);
293 }
294
295 #[test]
296 fn test_broadcast_partitioner() {
297 let partitioner = BroadcastPartitioner::new();
298
299 let elem = StreamElement::new(vec![1, 2, 3], Utc::now());
300 let partition = partitioner
301 .partition(&elem, 4)
302 .expect("Failed to partition element with broadcast partitioner");
303
304 assert_eq!(partition, 0);
305 assert_eq!(partitioner.strategy(), PartitionStrategy::Broadcast);
306 }
307}