1pub const TASK_SCHEMA_VERSION: u16 = 1;
2pub const SHARD_PLAN_SCHEMA_VERSION: u16 = 1;
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
5pub struct Task {
6 pub id: String,
7 pub selector: String,
8 pub est_ms: u64,
9 #[serde(default)]
10 pub tags: Vec<String>,
11 #[serde(default)]
12 pub module: Option<String>,
13 #[serde(default)]
14 pub splittable: bool,
15}
16
17#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
18pub struct TaskSet {
19 #[serde(default = "default_task_schema_version")]
20 pub schema_version: u16,
21 #[serde(default)]
22 pub tasks: Vec<Task>,
23}
24
25impl Default for TaskSet {
26 fn default() -> Self {
27 Self {
28 schema_version: TASK_SCHEMA_VERSION,
29 tasks: Vec::new(),
30 }
31 }
32}
33
34#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
35pub struct PlannedTask {
36 pub id: String,
37 pub selector: String,
38 pub est_ms: u64,
39}
40
41#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
42pub struct PlannedShard {
43 pub id: usize,
44 #[serde(default)]
45 pub tasks: Vec<PlannedTask>,
46 pub predicted_duration_ms: u64,
47}
48
49#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
50pub struct UniversalShardPlan {
51 #[serde(default = "default_shard_plan_schema_version")]
52 pub schema_version: u16,
53 pub algorithm: String,
54 #[serde(default)]
55 pub shards: Vec<PlannedShard>,
56}
57
58impl Default for UniversalShardPlan {
59 fn default() -> Self {
60 Self {
61 schema_version: SHARD_PLAN_SCHEMA_VERSION,
62 algorithm: "lpt".to_string(),
63 shards: Vec::new(),
64 }
65 }
66}
67
68fn default_task_schema_version() -> u16 {
69 TASK_SCHEMA_VERSION
70}
71
72fn default_shard_plan_schema_version() -> u16 {
73 SHARD_PLAN_SCHEMA_VERSION
74}
75
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
77pub struct Shard {
78 pub id: usize,
79 pub tests: Vec<String>,
80 pub predicted_duration_ms: u64,
81}
82
83#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, Default)]
84pub struct ShardPlan {
85 pub shards: Vec<Shard>,
86 pub total_predicted_duration_ms: u64,
87 pub makespan_ms: u64,
88 pub imbalance_ratio: f64,
89 pub parallel_efficiency: f64,
90 pub whale_count: usize,
91 pub top_10_share: f64,
92}
93
94pub fn build_timed_jobs(
95 test_ids: &[String],
96 timings: &crate::testmap::TestTimingHistory,
97 unknown_test_duration_ms: u64,
98) -> Vec<(String, u64)> {
99 test_ids
100 .iter()
101 .map(|test_id| {
102 let duration = timings
103 .duration_ms
104 .get(test_id)
105 .copied()
106 .unwrap_or(unknown_test_duration_ms);
107 (test_id.clone(), duration)
108 })
109 .collect()
110}
111
112pub fn plan_shards_lpt(input: &[(String, u64)], shard_count: usize) -> ShardPlan {
114 if shard_count == 0 {
115 return ShardPlan::default();
116 }
117
118 let mut jobs = input.to_vec();
119 jobs.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
120
121 let mut shards: Vec<Shard> = (0..shard_count)
122 .map(|id| Shard {
123 id,
124 tests: Vec::new(),
125 predicted_duration_ms: 0,
126 })
127 .collect();
128
129 for (test_id, duration_ms) in jobs {
130 let target = shards
131 .iter()
132 .min_by(|a, b| {
133 a.predicted_duration_ms
134 .cmp(&b.predicted_duration_ms)
135 .then_with(|| a.id.cmp(&b.id))
136 })
137 .map(|s| s.id)
138 .unwrap_or(0);
139
140 if let Some(shard) = shards.get_mut(target) {
141 shard.tests.push(test_id);
142 shard.predicted_duration_ms = shard.predicted_duration_ms.saturating_add(duration_ms);
143 }
144 }
145
146 let total = shards.iter().map(|s| s.predicted_duration_ms).sum();
147 let makespan = shards
148 .iter()
149 .map(|s| s.predicted_duration_ms)
150 .max()
151 .unwrap_or(0);
152
153 let (imbalance_ratio, parallel_efficiency) = compute_load_metrics(&shards, total, makespan);
154 let whale_threshold = compute_whale_threshold_ms(input);
155 ShardPlan {
156 shards,
157 total_predicted_duration_ms: total,
158 makespan_ms: makespan,
159 imbalance_ratio,
160 parallel_efficiency,
161 whale_count: count_whales(input, whale_threshold),
162 top_10_share: compute_top_10_share(input, total),
163 }
164}
165
166pub fn compute_whale_threshold_ms(input: &[(String, u64)]) -> u64 {
167 if input.is_empty() {
168 return 30_000;
169 }
170
171 let mut durations: Vec<u64> = input.iter().map(|(_, d)| *d).collect();
172 durations.sort_unstable();
173 let idx = ((durations.len() as f64 * 0.95).ceil() as usize).saturating_sub(1);
174 let p95 = durations[idx.min(durations.len() - 1)];
175 std::cmp::max(30_000, p95.saturating_mul(2))
176}
177
178pub fn plan_shards_whale_lpt(input: &[(String, u64)], shard_count: usize) -> ShardPlan {
179 if shard_count == 0 {
180 return ShardPlan::default();
181 }
182
183 let mut jobs = input.to_vec();
184 jobs.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
185
186 let whale_threshold = compute_whale_threshold_ms(&jobs);
187 let mut whales = Vec::new();
188 let mut rest = Vec::new();
189 for job in jobs {
190 if job.1 > whale_threshold {
191 whales.push(job);
192 } else {
193 rest.push(job);
194 }
195 }
196
197 let mut shards: Vec<Shard> = (0..shard_count)
198 .map(|id| Shard {
199 id,
200 tests: Vec::new(),
201 predicted_duration_ms: 0,
202 })
203 .collect();
204
205 for (test_id, duration_ms) in whales.into_iter().chain(rest.into_iter()) {
206 let target = shards
207 .iter()
208 .min_by(|a, b| {
209 a.predicted_duration_ms
210 .cmp(&b.predicted_duration_ms)
211 .then_with(|| a.id.cmp(&b.id))
212 })
213 .map(|s| s.id)
214 .unwrap_or(0);
215
216 if let Some(shard) = shards.get_mut(target) {
217 shard.tests.push(test_id);
218 shard.predicted_duration_ms = shard.predicted_duration_ms.saturating_add(duration_ms);
219 }
220 }
221
222 let total = shards.iter().map(|s| s.predicted_duration_ms).sum();
223 let makespan = shards
224 .iter()
225 .map(|s| s.predicted_duration_ms)
226 .max()
227 .unwrap_or(0);
228
229 let (imbalance_ratio, parallel_efficiency) = compute_load_metrics(&shards, total, makespan);
230 ShardPlan {
231 shards,
232 total_predicted_duration_ms: total,
233 makespan_ms: makespan,
234 imbalance_ratio,
235 parallel_efficiency,
236 whale_count: count_whales(input, whale_threshold),
237 top_10_share: compute_top_10_share(input, total),
238 }
239}
240
241fn count_whales(input: &[(String, u64)], whale_threshold: u64) -> usize {
242 input
243 .iter()
244 .filter(|(_, duration)| *duration > whale_threshold)
245 .count()
246}
247
248fn compute_top_10_share(input: &[(String, u64)], total: u64) -> f64 {
249 if total == 0 || input.is_empty() {
250 return 0.0;
251 }
252 let mut durations: Vec<u64> = input.iter().map(|(_, d)| *d).collect();
253 durations.sort_unstable_by(|a, b| b.cmp(a));
254 let top_sum: u64 = durations.into_iter().take(10).sum();
255 top_sum as f64 / total as f64
256}
257
258fn compute_load_metrics(shards: &[Shard], total: u64, makespan: u64) -> (f64, f64) {
259 if shards.is_empty() {
260 return (0.0, 0.0);
261 }
262 let mut loads: Vec<u64> = shards.iter().map(|s| s.predicted_duration_ms).collect();
263 loads.sort_unstable();
264 let median = if loads.len() % 2 == 1 {
265 loads[loads.len() / 2] as f64
266 } else {
267 let hi = loads.len() / 2;
268 let lo = hi - 1;
269 (loads[lo] as f64 + loads[hi] as f64) / 2.0
270 };
271 let imbalance_ratio = if median > 0.0 {
272 makespan as f64 / median
273 } else {
274 0.0
275 };
276 let parallel_efficiency = if makespan > 0 {
277 total as f64 / ((makespan as f64) * (shards.len() as f64))
278 } else {
279 1.0
280 };
281 (imbalance_ratio, parallel_efficiency)
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_plan_shards_lpt_balances_work() {
290 let input = vec![
291 ("t1".to_string(), 100),
292 ("t2".to_string(), 90),
293 ("t3".to_string(), 80),
294 ("t4".to_string(), 70),
295 ];
296 let plan = plan_shards_lpt(&input, 2);
297 assert_eq!(plan.shards.len(), 2);
298 assert_eq!(plan.total_predicted_duration_ms, 340);
299 assert!(plan.makespan_ms <= 170);
300 assert!(plan.parallel_efficiency > 0.0);
301 }
302
303 #[test]
304 fn test_plan_shards_lpt_is_deterministic_on_ties() {
305 let input = vec![
306 ("b".to_string(), 10),
307 ("a".to_string(), 10),
308 ("d".to_string(), 10),
309 ("c".to_string(), 10),
310 ];
311 let p1 = plan_shards_lpt(&input, 2);
312 let p2 = plan_shards_lpt(&input, 2);
313 assert_eq!(p1.shards[0].tests, p2.shards[0].tests);
314 assert_eq!(p1.shards[1].tests, p2.shards[1].tests);
315 }
316
317 #[test]
318 fn test_build_timed_jobs_uses_fallback_for_unknown_tests() {
319 let mut timings = crate::testmap::TestTimingHistory::default();
320 timings.duration_ms.insert("known".to_string(), 50);
321 let jobs = build_timed_jobs(
322 &["known".to_string(), "unknown".to_string()],
323 &timings,
324 8000,
325 );
326 assert_eq!(jobs[0], ("known".to_string(), 50));
327 assert_eq!(jobs[1], ("unknown".to_string(), 8000));
328 }
329
330 #[test]
331 fn test_taskset_defaults_schema_version() {
332 let taskset: TaskSet = serde_json::from_str(r#"{"tasks":[]}"#).unwrap();
333 assert_eq!(taskset.schema_version, TASK_SCHEMA_VERSION);
334 }
335
336 #[test]
337 fn test_task_defaults_optional_fields() {
338 let task: Task = serde_json::from_str(
339 r#"{
340 "id":"tests/test_mod.py::test_one",
341 "selector":"tests/test_mod.py::test_one",
342 "est_ms":1200
343 }"#,
344 )
345 .unwrap();
346 assert!(task.tags.is_empty());
347 assert!(task.module.is_none());
348 assert!(!task.splittable);
349 }
350
351 #[test]
352 fn test_universal_shard_plan_defaults_schema_version() {
353 let plan: UniversalShardPlan =
354 serde_json::from_str(r#"{"algorithm":"lpt","shards":[]}"#).unwrap();
355 assert_eq!(plan.schema_version, SHARD_PLAN_SCHEMA_VERSION);
356 }
357
358 #[test]
359 fn test_compute_whale_threshold_uses_p95_rule() {
360 let jobs = vec![
361 ("a".to_string(), 1000),
362 ("b".to_string(), 2000),
363 ("c".to_string(), 3000),
364 ("d".to_string(), 4000),
365 ("e".to_string(), 5000),
366 ];
367 assert_eq!(compute_whale_threshold_ms(&jobs), 30_000);
368 }
369
370 #[test]
371 fn test_plan_shards_whale_lpt_assigns_large_outlier_first() {
372 let input = vec![
373 ("whale".to_string(), 90_000),
374 ("a".to_string(), 10_000),
375 ("b".to_string(), 9_000),
376 ("c".to_string(), 8_000),
377 ("d".to_string(), 7_000),
378 ];
379 let plan = plan_shards_whale_lpt(&input, 2);
380 assert_eq!(plan.shards.len(), 2);
381 assert!(plan.shards.iter().any(|s| s.tests.iter().any(|t| t == "whale")));
382 }
383
384 #[test]
385 fn test_plan_shards_whale_lpt_is_deterministic_on_ties() {
386 let input = vec![
387 ("b".to_string(), 10),
388 ("a".to_string(), 10),
389 ("d".to_string(), 10),
390 ("c".to_string(), 10),
391 ];
392 let p1 = plan_shards_whale_lpt(&input, 2);
393 let p2 = plan_shards_whale_lpt(&input, 2);
394 assert_eq!(p1.shards[0].tests, p2.shards[0].tests);
395 assert_eq!(p1.shards[1].tests, p2.shards[1].tests);
396 }
397
398 #[test]
399 fn test_plan_metrics_are_computed() {
400 let input = vec![
401 ("a".to_string(), 50),
402 ("b".to_string(), 40),
403 ("c".to_string(), 10),
404 ];
405 let plan = plan_shards_lpt(&input, 2);
406 assert!(plan.imbalance_ratio >= 1.0);
407 assert!(plan.parallel_efficiency > 0.0);
408 assert!(plan.top_10_share > 0.0);
409 }
410}