1use std::{
12 borrow::Cow,
13 fmt,
14 hash::{Hash, Hasher},
15 str::FromStr,
16};
17
18use thiserror::Error;
19use twox_hash::XxHash64;
20
21#[derive(Clone, Debug, Eq, PartialEq)]
26#[non_exhaustive]
27pub enum PartitionerBuilder {
28 Count {
30 shard: u64,
32
33 total_shards: u64,
35 },
36
37 Hash {
39 shard: u64,
41
42 total_shards: u64,
44 },
45}
46
47pub trait Partitioner: fmt::Debug {
49 fn task_matches(&mut self, task_name: &str) -> bool;
51}
52
53impl PartitionerBuilder {
54 pub fn build(&self) -> Box<dyn Partitioner> {
56 match self {
58 PartitionerBuilder::Count {
59 shard,
60 total_shards,
61 } => Box::new(CountPartitioner::new(*shard, *total_shards)),
62 PartitionerBuilder::Hash {
63 shard,
64 total_shards,
65 } => Box::new(HashPartitioner::new(*shard, *total_shards)),
66 }
67 }
68}
69
70impl FromStr for PartitionerBuilder {
71 type Err = PartitionerBuilderParseError;
72
73 fn from_str(s: &str) -> Result<Self, Self::Err> {
74 if let Some(input) = s.strip_prefix("hash:") {
76 let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
77
78 Ok(PartitionerBuilder::Hash {
79 shard,
80 total_shards,
81 })
82 } else if let Some(input) = s.strip_prefix("count:") {
83 let (shard, total_shards) = parse_shards(input, "count:M/N")?;
84
85 Ok(PartitionerBuilder::Count {
86 shard,
87 total_shards,
88 })
89 } else {
90 Err(PartitionerBuilderParseError::new(
91 None,
92 format!(
93 "partition input '{}' must begin with \"hash:\" or \"count:\"",
94 s
95 ),
96 ))
97 }
98 }
99}
100
101fn parse_shards(
102 input: &str,
103 expected_format: &'static str,
104) -> Result<(u64, u64), PartitionerBuilderParseError> {
105 let mut split = input.splitn(2, '/');
106 let shard_str = split.next().expect("split should have at least 1 element");
108 let total_shards_str = split.next().ok_or_else(|| {
110 PartitionerBuilderParseError::new(
111 Some(expected_format),
112 format!("expected input '{}' to be in the format M/N", input),
113 )
114 })?;
115
116 let shard: u64 = shard_str.parse().map_err(|err| {
117 PartitionerBuilderParseError::new(
118 Some(expected_format),
119 format!("failed to parse shard '{}' as u64: {}", shard_str, err),
120 )
121 })?;
122
123 let total_shards: u64 = total_shards_str.parse().map_err(|err| {
124 PartitionerBuilderParseError::new(
125 Some(expected_format),
126 format!(
127 "failed to parse total_shards '{}' as u64: {}",
128 total_shards_str, err
129 ),
130 )
131 })?;
132
133 if !(1..=total_shards).contains(&shard) {
135 return Err(PartitionerBuilderParseError::new(
136 Some(expected_format),
137 format!(
138 "shard {} must be a number between 1 and total shards {}, inclusive",
139 shard, total_shards
140 ),
141 ));
142 }
143
144 Ok((shard, total_shards))
145}
146
147#[derive(Clone, Debug)]
148struct CountPartitioner {
149 shard_minus_one: u64,
150 total_shards: u64,
151 curr: u64,
152}
153
154impl CountPartitioner {
155 fn new(shard: u64, total_shards: u64) -> Self {
156 let shard_minus_one = shard - 1;
157 Self {
158 shard_minus_one,
159 total_shards,
160 curr: 0,
161 }
162 }
163}
164
165impl Partitioner for CountPartitioner {
166 fn task_matches(&mut self, _task_name: &str) -> bool {
167 let matches = self.curr == self.shard_minus_one;
168 self.curr = (self.curr + 1) % self.total_shards;
169 matches
170 }
171}
172
173#[derive(Clone, Debug)]
174struct HashPartitioner {
175 shard_minus_one: u64,
176 total_shards: u64,
177}
178
179impl HashPartitioner {
180 fn new(shard: u64, total_shards: u64) -> Self {
181 let shard_minus_one = shard - 1;
182 Self {
183 shard_minus_one,
184 total_shards,
185 }
186 }
187}
188
189impl Partitioner for HashPartitioner {
190 fn task_matches(&mut self, task_name: &str) -> bool {
191 let mut hasher = XxHash64::default();
192 task_name.hash(&mut hasher);
193 hasher.finish() % self.total_shards == self.shard_minus_one
194 }
195}
196
197#[derive(Clone, Debug, Error)]
200pub struct PartitionerBuilderParseError {
201 expected_format: Option<&'static str>,
202 message: Cow<'static, str>,
203}
204
205impl PartitionerBuilderParseError {
206 pub(crate) fn new(
207 expected_format: Option<&'static str>,
208 message: impl Into<Cow<'static, str>>,
209 ) -> Self {
210 Self {
211 expected_format,
212 message: message.into(),
213 }
214 }
215}
216
217impl fmt::Display for PartitionerBuilderParseError {
218 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219 match self.expected_format {
220 Some(format) => {
221 write!(
222 f,
223 "partition must be in the format \"{}\":\n{}",
224 format, self.message
225 )
226 }
227 None => write!(f, "{}", self.message),
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn partitioner_builder_from_str() {
238 let successes = vec![
239 (
240 "hash:1/2",
241 PartitionerBuilder::Hash {
242 shard: 1,
243 total_shards: 2,
244 },
245 ),
246 (
247 "hash:1/1",
248 PartitionerBuilder::Hash {
249 shard: 1,
250 total_shards: 1,
251 },
252 ),
253 (
254 "hash:99/200",
255 PartitionerBuilder::Hash {
256 shard: 99,
257 total_shards: 200,
258 },
259 ),
260 ];
261
262 let failures = vec![
263 "foo",
264 "hash",
265 "hash:",
266 "hash:1",
267 "hash:1/",
268 "hash:0/2",
269 "hash:3/2",
270 "hash:m/2",
271 "hash:1/n",
272 "hash:1/2/3",
273 ];
274
275 for (input, output) in successes {
276 assert_eq!(
277 PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
278 "expected input '{}' to succeed, failed with: {}",
279 input, err
280 )),
281 output,
282 "success case '{}' matches",
283 input,
284 );
285 }
286
287 for input in failures {
288 PartitionerBuilder::from_str(input)
289 .expect_err(&format!("expected input '{}' to fail", input));
290 }
291 }
292}