task_partitioner/
lib.rs

1// Copyright (c) The nextest Contributors
2// Copyright (c) Kaspar Schleiser <kaspar@schleiser.de>
3// SPDX-License-Identifier: MIT OR Apache-2.0
4
5//! Support for partitioning task runs across several machines.
6//!
7//! At the moment this only supports simple hash-based and count-based sharding. In the future it
8//! could potentially be made smarter: e.g. using data to pick different sets of binaries and tests
9//! to run, with an aim to minimize total build and test times.
10
11use std::{
12    borrow::Cow,
13    fmt,
14    hash::{Hash, Hasher},
15    str::FromStr,
16};
17
18use thiserror::Error;
19use twox_hash::XxHash64;
20
21/// A builder for creating `Partitioner` instances.
22///
23/// The relationship between `PartitionerBuilder` and `Partitioner` is similar to that between
24/// `std`'s `BuildHasher` and `Hasher`.
25#[derive(Clone, Debug, Eq, PartialEq)]
26#[non_exhaustive]
27pub enum PartitionerBuilder {
28    /// Partition based on counting test numbers.
29    Count {
30        /// The shard this is in, counting up from 1.
31        shard: u64,
32
33        /// The total number of shards.
34        total_shards: u64,
35    },
36
37    /// Partition based on hashing. Individual partitions are stateless.
38    Hash {
39        /// The shard this is in, counting up from 1.
40        shard: u64,
41
42        /// The total number of shards.
43        total_shards: u64,
44    },
45}
46
47/// Represents an individual partitioner, typically scoped to a test binary.
48pub trait Partitioner: fmt::Debug {
49    /// Returns true if the given task name matches the partition.
50    fn task_matches(&mut self, task_name: &str) -> bool;
51}
52
53impl PartitionerBuilder {
54    /// Creates a new `Partitioner` from this `PartitionerBuilder`.
55    pub fn build(&self) -> Box<dyn Partitioner> {
56        // Note we don't use test_binary at the moment but might in the future.
57        match self {
58            PartitionerBuilder::Count {
59                shard,
60                total_shards,
61            } => Box::new(CountPartitioner::new(*shard, *total_shards)),
62            PartitionerBuilder::Hash {
63                shard,
64                total_shards,
65            } => Box::new(HashPartitioner::new(*shard, *total_shards)),
66        }
67    }
68}
69
70impl FromStr for PartitionerBuilder {
71    type Err = PartitionerBuilderParseError;
72
73    fn from_str(s: &str) -> Result<Self, Self::Err> {
74        // Parse the string: it looks like "hash:<shard>/<total_shards>".
75        if let Some(input) = s.strip_prefix("hash:") {
76            let (shard, total_shards) = parse_shards(input, "hash:M/N")?;
77
78            Ok(PartitionerBuilder::Hash {
79                shard,
80                total_shards,
81            })
82        } else if let Some(input) = s.strip_prefix("count:") {
83            let (shard, total_shards) = parse_shards(input, "count:M/N")?;
84
85            Ok(PartitionerBuilder::Count {
86                shard,
87                total_shards,
88            })
89        } else {
90            Err(PartitionerBuilderParseError::new(
91                None,
92                format!(
93                    "partition input '{}' must begin with \"hash:\" or \"count:\"",
94                    s
95                ),
96            ))
97        }
98    }
99}
100
101fn parse_shards(
102    input: &str,
103    expected_format: &'static str,
104) -> Result<(u64, u64), PartitionerBuilderParseError> {
105    let mut split = input.splitn(2, '/');
106    // First "next" always returns a value.
107    let shard_str = split.next().expect("split should have at least 1 element");
108    // Second "next" may or may not return a value.
109    let total_shards_str = split.next().ok_or_else(|| {
110        PartitionerBuilderParseError::new(
111            Some(expected_format),
112            format!("expected input '{}' to be in the format M/N", input),
113        )
114    })?;
115
116    let shard: u64 = shard_str.parse().map_err(|err| {
117        PartitionerBuilderParseError::new(
118            Some(expected_format),
119            format!("failed to parse shard '{}' as u64: {}", shard_str, err),
120        )
121    })?;
122
123    let total_shards: u64 = total_shards_str.parse().map_err(|err| {
124        PartitionerBuilderParseError::new(
125            Some(expected_format),
126            format!(
127                "failed to parse total_shards '{}' as u64: {}",
128                total_shards_str, err
129            ),
130        )
131    })?;
132
133    // Check that shard > 0 and <= total_shards.
134    if !(1..=total_shards).contains(&shard) {
135        return Err(PartitionerBuilderParseError::new(
136            Some(expected_format),
137            format!(
138                "shard {} must be a number between 1 and total shards {}, inclusive",
139                shard, total_shards
140            ),
141        ));
142    }
143
144    Ok((shard, total_shards))
145}
146
147#[derive(Clone, Debug)]
148struct CountPartitioner {
149    shard_minus_one: u64,
150    total_shards: u64,
151    curr: u64,
152}
153
154impl CountPartitioner {
155    fn new(shard: u64, total_shards: u64) -> Self {
156        let shard_minus_one = shard - 1;
157        Self {
158            shard_minus_one,
159            total_shards,
160            curr: 0,
161        }
162    }
163}
164
165impl Partitioner for CountPartitioner {
166    fn task_matches(&mut self, _task_name: &str) -> bool {
167        let matches = self.curr == self.shard_minus_one;
168        self.curr = (self.curr + 1) % self.total_shards;
169        matches
170    }
171}
172
173#[derive(Clone, Debug)]
174struct HashPartitioner {
175    shard_minus_one: u64,
176    total_shards: u64,
177}
178
179impl HashPartitioner {
180    fn new(shard: u64, total_shards: u64) -> Self {
181        let shard_minus_one = shard - 1;
182        Self {
183            shard_minus_one,
184            total_shards,
185        }
186    }
187}
188
189impl Partitioner for HashPartitioner {
190    fn task_matches(&mut self, task_name: &str) -> bool {
191        let mut hasher = XxHash64::default();
192        task_name.hash(&mut hasher);
193        hasher.finish() % self.total_shards == self.shard_minus_one
194    }
195}
196
197/// An error that occurs while parsing a
198/// [`PartitionerBuilder`](crate::PartitionerBuilder) input.
199#[derive(Clone, Debug, Error)]
200pub struct PartitionerBuilderParseError {
201    expected_format: Option<&'static str>,
202    message: Cow<'static, str>,
203}
204
205impl PartitionerBuilderParseError {
206    pub(crate) fn new(
207        expected_format: Option<&'static str>,
208        message: impl Into<Cow<'static, str>>,
209    ) -> Self {
210        Self {
211            expected_format,
212            message: message.into(),
213        }
214    }
215}
216
217impl fmt::Display for PartitionerBuilderParseError {
218    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219        match self.expected_format {
220            Some(format) => {
221                write!(
222                    f,
223                    "partition must be in the format \"{}\":\n{}",
224                    format, self.message
225                )
226            }
227            None => write!(f, "{}", self.message),
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn partitioner_builder_from_str() {
238        let successes = vec![
239            (
240                "hash:1/2",
241                PartitionerBuilder::Hash {
242                    shard: 1,
243                    total_shards: 2,
244                },
245            ),
246            (
247                "hash:1/1",
248                PartitionerBuilder::Hash {
249                    shard: 1,
250                    total_shards: 1,
251                },
252            ),
253            (
254                "hash:99/200",
255                PartitionerBuilder::Hash {
256                    shard: 99,
257                    total_shards: 200,
258                },
259            ),
260        ];
261
262        let failures = vec![
263            "foo",
264            "hash",
265            "hash:",
266            "hash:1",
267            "hash:1/",
268            "hash:0/2",
269            "hash:3/2",
270            "hash:m/2",
271            "hash:1/n",
272            "hash:1/2/3",
273        ];
274
275        for (input, output) in successes {
276            assert_eq!(
277                PartitionerBuilder::from_str(input).unwrap_or_else(|err| panic!(
278                    "expected input '{}' to succeed, failed with: {}",
279                    input, err
280                )),
281                output,
282                "success case '{}' matches",
283                input,
284            );
285        }
286
287        for input in failures {
288            PartitionerBuilder::from_str(input)
289                .expect_err(&format!("expected input '{}' to fail", input));
290        }
291    }
292}