Skip to main content

covy_core/
shard.rs

1pub const TASK_SCHEMA_VERSION: u16 = 1;
2pub const SHARD_PLAN_SCHEMA_VERSION: u16 = 1;
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
5pub struct Task {
6    pub id: String,
7    pub selector: String,
8    pub est_ms: u64,
9    #[serde(default)]
10    pub tags: Vec<String>,
11    #[serde(default)]
12    pub module: Option<String>,
13    #[serde(default)]
14    pub splittable: bool,
15}
16
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
18pub struct TaskSet {
19    #[serde(default = "default_task_schema_version")]
20    pub schema_version: u16,
21    #[serde(default)]
22    pub tasks: Vec<Task>,
23}
24
25impl Default for TaskSet {
26    fn default() -> Self {
27        Self {
28            schema_version: TASK_SCHEMA_VERSION,
29            tasks: Vec::new(),
30        }
31    }
32}
33
34#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
35pub struct PlannedTask {
36    pub id: String,
37    pub selector: String,
38    pub est_ms: u64,
39}
40
41#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
42pub struct PlannedShard {
43    pub id: usize,
44    #[serde(default)]
45    pub tasks: Vec<PlannedTask>,
46    pub predicted_duration_ms: u64,
47}
48
49#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
50pub struct UniversalShardPlan {
51    #[serde(default = "default_shard_plan_schema_version")]
52    pub schema_version: u16,
53    pub algorithm: String,
54    #[serde(default)]
55    pub shards: Vec<PlannedShard>,
56}
57
58impl Default for UniversalShardPlan {
59    fn default() -> Self {
60        Self {
61            schema_version: SHARD_PLAN_SCHEMA_VERSION,
62            algorithm: "lpt".to_string(),
63            shards: Vec::new(),
64        }
65    }
66}
67
68fn default_task_schema_version() -> u16 {
69    TASK_SCHEMA_VERSION
70}
71
72fn default_shard_plan_schema_version() -> u16 {
73    SHARD_PLAN_SCHEMA_VERSION
74}
75
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
77pub struct Shard {
78    pub id: usize,
79    pub tests: Vec<String>,
80    pub predicted_duration_ms: u64,
81}
82
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
84pub struct ShardPlan {
85    pub shards: Vec<Shard>,
86    pub total_predicted_duration_ms: u64,
87    pub makespan_ms: u64,
88    pub imbalance_ratio: f64,
89    pub parallel_efficiency: f64,
90    pub whale_count: usize,
91    pub top_10_share: f64,
92}
93
94pub fn build_timed_jobs(
95    test_ids: &[String],
96    timings: &crate::testmap::TestTimingHistory,
97    unknown_test_duration_ms: u64,
98) -> Vec<(String, u64)> {
99    test_ids
100        .iter()
101        .map(|test_id| {
102            let duration = timings
103                .duration_ms
104                .get(test_id)
105                .copied()
106                .unwrap_or(unknown_test_duration_ms);
107            (test_id.clone(), duration)
108        })
109        .collect()
110}
111
112/// Longest-processing-time-first bin-packing planner.
113pub fn plan_shards_lpt(input: &[(String, u64)], shard_count: usize) -> ShardPlan {
114    if shard_count == 0 {
115        return ShardPlan::default();
116    }
117
118    let mut jobs = input.to_vec();
119    jobs.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
120
121    let mut shards: Vec<Shard> = (0..shard_count)
122        .map(|id| Shard {
123            id,
124            tests: Vec::new(),
125            predicted_duration_ms: 0,
126        })
127        .collect();
128
129    for (test_id, duration_ms) in jobs {
130        let target = shards
131            .iter()
132            .min_by(|a, b| {
133                a.predicted_duration_ms
134                    .cmp(&b.predicted_duration_ms)
135                    .then_with(|| a.id.cmp(&b.id))
136            })
137            .map(|s| s.id)
138            .unwrap_or(0);
139
140        if let Some(shard) = shards.get_mut(target) {
141            shard.tests.push(test_id);
142            shard.predicted_duration_ms = shard.predicted_duration_ms.saturating_add(duration_ms);
143        }
144    }
145
146    let total = shards.iter().map(|s| s.predicted_duration_ms).sum();
147    let makespan = shards
148        .iter()
149        .map(|s| s.predicted_duration_ms)
150        .max()
151        .unwrap_or(0);
152
153    let (imbalance_ratio, parallel_efficiency) = compute_load_metrics(&shards, total, makespan);
154    let whale_threshold = compute_whale_threshold_ms(input);
155    ShardPlan {
156        shards,
157        total_predicted_duration_ms: total,
158        makespan_ms: makespan,
159        imbalance_ratio,
160        parallel_efficiency,
161        whale_count: count_whales(input, whale_threshold),
162        top_10_share: compute_top_10_share(input, total),
163    }
164}
165
166pub fn compute_whale_threshold_ms(input: &[(String, u64)]) -> u64 {
167    if input.is_empty() {
168        return 30_000;
169    }
170
171    let mut durations: Vec<u64> = input.iter().map(|(_, d)| *d).collect();
172    durations.sort_unstable();
173    let idx = ((durations.len() as f64 * 0.95).ceil() as usize).saturating_sub(1);
174    let p95 = durations[idx.min(durations.len() - 1)];
175    std::cmp::max(30_000, p95.saturating_mul(2))
176}
177
178pub fn plan_shards_whale_lpt(input: &[(String, u64)], shard_count: usize) -> ShardPlan {
179    if shard_count == 0 {
180        return ShardPlan::default();
181    }
182
183    let mut jobs = input.to_vec();
184    jobs.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
185
186    let whale_threshold = compute_whale_threshold_ms(&jobs);
187    let mut whales = Vec::new();
188    let mut rest = Vec::new();
189    for job in jobs {
190        if job.1 > whale_threshold {
191            whales.push(job);
192        } else {
193            rest.push(job);
194        }
195    }
196
197    let mut shards: Vec<Shard> = (0..shard_count)
198        .map(|id| Shard {
199            id,
200            tests: Vec::new(),
201            predicted_duration_ms: 0,
202        })
203        .collect();
204
205    for (test_id, duration_ms) in whales.into_iter().chain(rest.into_iter()) {
206        let target = shards
207            .iter()
208            .min_by(|a, b| {
209                a.predicted_duration_ms
210                    .cmp(&b.predicted_duration_ms)
211                    .then_with(|| a.id.cmp(&b.id))
212            })
213            .map(|s| s.id)
214            .unwrap_or(0);
215
216        if let Some(shard) = shards.get_mut(target) {
217            shard.tests.push(test_id);
218            shard.predicted_duration_ms = shard.predicted_duration_ms.saturating_add(duration_ms);
219        }
220    }
221
222    let total = shards.iter().map(|s| s.predicted_duration_ms).sum();
223    let makespan = shards
224        .iter()
225        .map(|s| s.predicted_duration_ms)
226        .max()
227        .unwrap_or(0);
228
229    let (imbalance_ratio, parallel_efficiency) = compute_load_metrics(&shards, total, makespan);
230    ShardPlan {
231        shards,
232        total_predicted_duration_ms: total,
233        makespan_ms: makespan,
234        imbalance_ratio,
235        parallel_efficiency,
236        whale_count: count_whales(input, whale_threshold),
237        top_10_share: compute_top_10_share(input, total),
238    }
239}
240
241fn count_whales(input: &[(String, u64)], whale_threshold: u64) -> usize {
242    input
243        .iter()
244        .filter(|(_, duration)| *duration > whale_threshold)
245        .count()
246}
247
248fn compute_top_10_share(input: &[(String, u64)], total: u64) -> f64 {
249    if total == 0 || input.is_empty() {
250        return 0.0;
251    }
252    let mut durations: Vec<u64> = input.iter().map(|(_, d)| *d).collect();
253    durations.sort_unstable_by(|a, b| b.cmp(a));
254    let top_sum: u64 = durations.into_iter().take(10).sum();
255    top_sum as f64 / total as f64
256}
257
258fn compute_load_metrics(shards: &[Shard], total: u64, makespan: u64) -> (f64, f64) {
259    if shards.is_empty() {
260        return (0.0, 0.0);
261    }
262    let mut loads: Vec<u64> = shards.iter().map(|s| s.predicted_duration_ms).collect();
263    loads.sort_unstable();
264    let median = if loads.len() % 2 == 1 {
265        loads[loads.len() / 2] as f64
266    } else {
267        let hi = loads.len() / 2;
268        let lo = hi - 1;
269        (loads[lo] as f64 + loads[hi] as f64) / 2.0
270    };
271    let imbalance_ratio = if median > 0.0 {
272        makespan as f64 / median
273    } else {
274        0.0
275    };
276    let parallel_efficiency = if makespan > 0 {
277        total as f64 / ((makespan as f64) * (shards.len() as f64))
278    } else {
279        1.0
280    };
281    (imbalance_ratio, parallel_efficiency)
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[test]
289    fn test_plan_shards_lpt_balances_work() {
290        let input = vec![
291            ("t1".to_string(), 100),
292            ("t2".to_string(), 90),
293            ("t3".to_string(), 80),
294            ("t4".to_string(), 70),
295        ];
296        let plan = plan_shards_lpt(&input, 2);
297        assert_eq!(plan.shards.len(), 2);
298        assert_eq!(plan.total_predicted_duration_ms, 340);
299        assert!(plan.makespan_ms <= 170);
300        assert!(plan.parallel_efficiency > 0.0);
301    }
302
303    #[test]
304    fn test_plan_shards_lpt_is_deterministic_on_ties() {
305        let input = vec![
306            ("b".to_string(), 10),
307            ("a".to_string(), 10),
308            ("d".to_string(), 10),
309            ("c".to_string(), 10),
310        ];
311        let p1 = plan_shards_lpt(&input, 2);
312        let p2 = plan_shards_lpt(&input, 2);
313        assert_eq!(p1.shards[0].tests, p2.shards[0].tests);
314        assert_eq!(p1.shards[1].tests, p2.shards[1].tests);
315    }
316
317    #[test]
318    fn test_build_timed_jobs_uses_fallback_for_unknown_tests() {
319        let mut timings = crate::testmap::TestTimingHistory::default();
320        timings.duration_ms.insert("known".to_string(), 50);
321        let jobs = build_timed_jobs(
322            &["known".to_string(), "unknown".to_string()],
323            &timings,
324            8000,
325        );
326        assert_eq!(jobs[0], ("known".to_string(), 50));
327        assert_eq!(jobs[1], ("unknown".to_string(), 8000));
328    }
329
330    #[test]
331    fn test_taskset_defaults_schema_version() {
332        let taskset: TaskSet = serde_json::from_str(r#"{"tasks":[]}"#).unwrap();
333        assert_eq!(taskset.schema_version, TASK_SCHEMA_VERSION);
334    }
335
336    #[test]
337    fn test_task_defaults_optional_fields() {
338        let task: Task = serde_json::from_str(
339            r#"{
340                "id":"tests/test_mod.py::test_one",
341                "selector":"tests/test_mod.py::test_one",
342                "est_ms":1200
343            }"#,
344        )
345        .unwrap();
346        assert!(task.tags.is_empty());
347        assert!(task.module.is_none());
348        assert!(!task.splittable);
349    }
350
351    #[test]
352    fn test_universal_shard_plan_defaults_schema_version() {
353        let plan: UniversalShardPlan =
354            serde_json::from_str(r#"{"algorithm":"lpt","shards":[]}"#).unwrap();
355        assert_eq!(plan.schema_version, SHARD_PLAN_SCHEMA_VERSION);
356    }
357
358    #[test]
359    fn test_compute_whale_threshold_uses_p95_rule() {
360        let jobs = vec![
361            ("a".to_string(), 1000),
362            ("b".to_string(), 2000),
363            ("c".to_string(), 3000),
364            ("d".to_string(), 4000),
365            ("e".to_string(), 5000),
366        ];
367        assert_eq!(compute_whale_threshold_ms(&jobs), 30_000);
368    }
369
370    #[test]
371    fn test_plan_shards_whale_lpt_assigns_large_outlier_first() {
372        let input = vec![
373            ("whale".to_string(), 90_000),
374            ("a".to_string(), 10_000),
375            ("b".to_string(), 9_000),
376            ("c".to_string(), 8_000),
377            ("d".to_string(), 7_000),
378        ];
379        let plan = plan_shards_whale_lpt(&input, 2);
380        assert_eq!(plan.shards.len(), 2);
381        assert!(plan.shards.iter().any(|s| s.tests.iter().any(|t| t == "whale")));
382    }
383
384    #[test]
385    fn test_plan_shards_whale_lpt_is_deterministic_on_ties() {
386        let input = vec![
387            ("b".to_string(), 10),
388            ("a".to_string(), 10),
389            ("d".to_string(), 10),
390            ("c".to_string(), 10),
391        ];
392        let p1 = plan_shards_whale_lpt(&input, 2);
393        let p2 = plan_shards_whale_lpt(&input, 2);
394        assert_eq!(p1.shards[0].tests, p2.shards[0].tests);
395        assert_eq!(p1.shards[1].tests, p2.shards[1].tests);
396    }
397
398    #[test]
399    fn test_plan_metrics_are_computed() {
400        let input = vec![
401            ("a".to_string(), 50),
402            ("b".to_string(), 40),
403            ("c".to_string(), 10),
404        ];
405        let plan = plan_shards_lpt(&input, 2);
406        assert!(plan.imbalance_ratio >= 1.0);
407        assert!(plan.parallel_efficiency > 0.0);
408        assert!(plan.top_10_share > 0.0);
409    }
410}