1use std::collections::{HashMap, HashSet};
24
25#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Assignment {
32 pub agent_idx: usize,
34 pub item_id: String,
36}
37
38#[derive(Debug, Clone, PartialEq)]
40pub enum ScheduleRegime {
41 Whittle {
43 indexability_score: f64,
45 },
46 Fallback {
48 reason: String,
50 },
51}
52
53impl ScheduleRegime {
54 #[must_use]
56 pub const fn is_whittle(&self) -> bool {
57 matches!(self, Self::Whittle { .. })
58 }
59
60 #[must_use]
62 pub const fn is_fallback(&self) -> bool {
63 matches!(self, Self::Fallback { .. })
64 }
65
66 #[must_use]
68 pub fn explain(&self) -> String {
69 match self {
70 Self::Whittle { indexability_score } => {
71 format!("Whittle Index (indexability score: {indexability_score:.3})")
72 }
73 Self::Fallback { reason } => {
74 format!("Fallback scheduler — {reason}")
75 }
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct FallbackConfig {
83 pub max_load_skew: usize,
87}
88
89impl Default for FallbackConfig {
90 fn default() -> Self {
91 Self { max_load_skew: 1 }
92 }
93}
94
95#[must_use]
121#[allow(clippy::implicit_hasher)]
122pub fn assign_fallback(
123 items: &[String],
124 agent_count: usize,
125 scores: &HashMap<String, f64>,
126 history: &[Assignment],
127) -> Vec<Assignment> {
128 assign_fallback_with_config(
129 items,
130 agent_count,
131 scores,
132 history,
133 &FallbackConfig::default(),
134 )
135}
136
137#[must_use]
143#[allow(clippy::implicit_hasher)]
144pub fn assign_fallback_with_config(
145 items: &[String],
146 agent_count: usize,
147 scores: &HashMap<String, f64>,
148 history: &[Assignment],
149 config: &FallbackConfig,
150) -> Vec<Assignment> {
151 assert!(agent_count >= 1, "agent_count must be at least 1");
152
153 let unique_items: Vec<String> = {
155 let mut seen: HashSet<&str> = HashSet::new();
156 items
157 .iter()
158 .filter(|id| seen.insert(id.as_str()))
159 .cloned()
160 .collect()
161 };
162
163 if unique_items.is_empty() {
164 return Vec::new();
165 }
166
167 let mut sorted: Vec<&str> = unique_items.iter().map(String::as_str).collect();
169 sorted.sort_by(|&a, &b| {
170 let sa = scores.get(a).copied().unwrap_or(0.0);
171 let sb = scores.get(b).copied().unwrap_or(0.0);
172 sb.partial_cmp(&sa)
173 .unwrap_or(std::cmp::Ordering::Equal)
174 .then_with(|| a.cmp(b))
175 });
176
177 let skip_set: HashSet<(usize, &str)> = history
179 .iter()
180 .filter(|a| a.agent_idx < agent_count)
181 .map(|a| (a.agent_idx, a.item_id.as_str()))
182 .collect();
183
184 let mut load: Vec<usize> = vec![0; agent_count];
186
187 let total_items = sorted.len();
191
192 let mut assignments: Vec<Assignment> = Vec::with_capacity(total_items);
193
194 for &item_id in &sorted {
195 let preferred = pick_agent(&load, agent_count, config, total_items, |ag_idx| {
197 !skip_set.contains(&(ag_idx, item_id))
198 });
199
200 let agent_idx = preferred.unwrap_or_else(|| {
203 pick_agent(&load, agent_count, config, total_items, |_| true)
204 .unwrap_or_else(|| least_loaded_agent(&load))
205 });
206
207 load[agent_idx] += 1;
208 assignments.push(Assignment {
209 agent_idx,
210 item_id: item_id.to_string(),
211 });
212 }
213
214 if total_items >= agent_count {
218 enforce_fairness(&mut assignments, &mut load, agent_count, scores);
219 }
220
221 assignments
222}
223
224fn pick_agent(
232 load: &[usize],
233 agent_count: usize,
234 config: &FallbackConfig,
235 total_items: usize,
236 predicate: impl Fn(usize) -> bool,
237) -> Option<usize> {
238 let base = total_items / agent_count;
240 let cap = base + config.max_load_skew;
241
242 (0..agent_count)
243 .filter(|&ag| predicate(ag) && load[ag] < cap)
244 .min_by_key(|&ag| load[ag])
245}
246
247fn least_loaded_agent(load: &[usize]) -> usize {
249 load.iter()
250 .enumerate()
251 .min_by_key(|&(_, &l)| l)
252 .map_or(0, |(idx, _)| idx)
253}
254
255fn enforce_fairness(
258 assignments: &mut [Assignment],
259 load: &mut [usize],
260 agent_count: usize,
261 scores: &HashMap<String, f64>,
262) {
263 for starved_agent in 0..agent_count {
264 if load[starved_agent] > 0 {
265 continue;
266 }
267
268 let donor = (0..agent_count)
270 .filter(|&ag| load[ag] > 1)
271 .max_by_key(|&ag| load[ag]);
272
273 let Some(donor_idx) = donor else {
274 break; };
276
277 let steal_pos = assignments
279 .iter()
280 .enumerate()
281 .filter(|(_, a)| a.agent_idx == donor_idx)
282 .min_by(|(_, a1), (_, a2)| {
283 let s1 = scores.get(a1.item_id.as_str()).copied().unwrap_or(0.0);
284 let s2 = scores.get(a2.item_id.as_str()).copied().unwrap_or(0.0);
285 s1.partial_cmp(&s2)
286 .unwrap_or(std::cmp::Ordering::Equal)
287 .then_with(|| a2.item_id.cmp(&a1.item_id))
288 })
289 .map(|(pos, _)| pos);
290
291 if let Some(pos) = steal_pos {
292 load[donor_idx] -= 1;
293 load[starved_agent] += 1;
294 assignments[pos].agent_idx = starved_agent;
295 }
296 }
297}
298
299#[cfg(test)]
304mod tests {
305 use super::*;
306
307 fn scores(pairs: &[(&str, f64)]) -> HashMap<String, f64> {
308 pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
309 }
310
311 fn items(ids: &[&str]) -> Vec<String> {
312 ids.iter().map(|s| s.to_string()).collect()
313 }
314
315 fn history(pairs: &[(usize, &str)]) -> Vec<Assignment> {
316 pairs
317 .iter()
318 .map(|(ag, id)| Assignment {
319 agent_idx: *ag,
320 item_id: id.to_string(),
321 })
322 .collect()
323 }
324
325 #[test]
330 fn assigns_single_item_to_single_agent() {
331 let s = scores(&[("bn-a", 5.0)]);
332 let result = assign_fallback(&items(&["bn-a"]), 1, &s, &[]);
333
334 assert_eq!(result.len(), 1);
335 assert_eq!(result[0].agent_idx, 0);
336 assert_eq!(result[0].item_id, "bn-a");
337 }
338
339 #[test]
340 fn assigns_multiple_items_to_multiple_agents() {
341 let s = scores(&[("bn-a", 3.0), ("bn-b", 5.0), ("bn-c", 1.0)]);
342 let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
343
344 assert_eq!(result.len(), 3);
345 let assigned: HashSet<&str> = result.iter().map(|a| a.item_id.as_str()).collect();
347 assert!(assigned.contains("bn-a"));
348 assert!(assigned.contains("bn-b"));
349 assert!(assigned.contains("bn-c"));
350 }
351
352 #[test]
353 fn highest_score_assigned_first() {
354 let s = scores(&[("bn-a", 3.0), ("bn-b", 9.0), ("bn-c", 1.0)]);
356 let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
357
358 assert_eq!(result[0].item_id, "bn-b", "highest score first");
359 }
360
361 #[test]
362 fn empty_items_returns_empty() {
363 let s = scores(&[]);
364 let result = assign_fallback(&[], 3, &s, &[]);
365 assert!(result.is_empty());
366 }
367
368 #[test]
369 fn single_agent_gets_all_items() {
370 let s = scores(&[("bn-a", 2.0), ("bn-b", 5.0), ("bn-c", 1.0)]);
371 let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 1, &s, &[]);
372
373 assert_eq!(result.len(), 3);
374 assert!(result.iter().all(|a| a.agent_idx == 0));
375 }
376
377 #[test]
382 fn no_item_assigned_twice() {
383 let s = scores(&[("bn-a", 1.0), ("bn-b", 2.0), ("bn-c", 3.0)]);
384 let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
385
386 let ids: Vec<&str> = result.iter().map(|a| a.item_id.as_str()).collect();
387 let unique: HashSet<&str> = ids.iter().copied().collect();
388 assert_eq!(ids.len(), unique.len(), "no item appears twice");
389 }
390
391 #[test]
392 fn duplicate_input_items_deduplicated() {
393 let s = scores(&[("bn-a", 5.0)]);
394 let result = assign_fallback(&items(&["bn-a", "bn-a", "bn-a"]), 2, &s, &[]);
395 assert_eq!(result.len(), 1);
397 assert_eq!(result[0].item_id, "bn-a");
398 }
399
400 #[test]
405 fn fairness_every_agent_gets_one_item_when_items_gte_agents() {
406 let s = scores(&[("bn-a", 3.0), ("bn-b", 5.0), ("bn-c", 1.0)]);
408 let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 3, &s, &[]);
409
410 assert_eq!(result.len(), 3);
411 let mut per_agent = vec![0usize; 3];
412 for a in &result {
413 per_agent[a.agent_idx] += 1;
414 }
415 for (ag, &count) in per_agent.iter().enumerate() {
416 assert_eq!(count, 1, "agent {ag} should have exactly 1 item");
417 }
418 }
419
420 #[test]
421 fn fairness_no_agent_starved_with_four_items_three_agents() {
422 let s = scores(&[("bn-a", 4.0), ("bn-b", 3.0), ("bn-c", 2.0), ("bn-d", 1.0)]);
424 let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c", "bn-d"]), 3, &s, &[]);
425
426 let mut per_agent = vec![0usize; 3];
427 for a in &result {
428 per_agent[a.agent_idx] += 1;
429 }
430 for (ag, &count) in per_agent.iter().enumerate() {
431 assert!(
432 count >= 1,
433 "agent {ag} should have at least 1 item (got {count})"
434 );
435 }
436 }
437
438 #[test]
439 fn fairness_ok_when_items_less_than_agents() {
440 let s = scores(&[("bn-a", 5.0), ("bn-b", 3.0)]);
442 let result = assign_fallback(&items(&["bn-a", "bn-b"]), 3, &s, &[]);
443
444 assert_eq!(result.len(), 2);
445 }
446
447 #[test]
452 fn history_avoids_previous_skip_assignment() {
453 let s = scores(&[("bn-a", 5.0), ("bn-b", 3.0)]);
455 let h = history(&[(0, "bn-a")]);
456
457 let result = assign_fallback(&items(&["bn-a", "bn-b"]), 2, &s, &h);
458
459 let bn_a = result.iter().find(|a| a.item_id == "bn-a").unwrap();
460 assert_eq!(
461 bn_a.agent_idx, 1,
462 "bn-a should not go to agent 0 (who skipped it)"
463 );
464 }
465
466 #[test]
467 fn history_falls_back_when_all_agents_skipped() {
468 let s = scores(&[("bn-a", 5.0)]);
470 let h = history(&[(0, "bn-a"), (1, "bn-a")]);
471
472 let result = assign_fallback(&items(&["bn-a"]), 2, &s, &h);
473
474 assert_eq!(result.len(), 1);
475 assert_eq!(result[0].item_id, "bn-a");
476 }
477
478 #[test]
479 fn history_with_unknown_agent_idx_is_ignored() {
480 let s = scores(&[("bn-a", 5.0)]);
482 let h = history(&[(99, "bn-a")]);
483
484 let result = assign_fallback(&items(&["bn-a"]), 2, &s, &h);
485 assert_eq!(result.len(), 1);
486 }
487
488 #[test]
493 fn regime_whittle_explain() {
494 let r = ScheduleRegime::Whittle {
495 indexability_score: 0.95,
496 };
497 assert!(r.is_whittle());
498 assert!(!r.is_fallback());
499 let s = r.explain();
500 assert!(s.contains("Whittle"), "explain: {s}");
501 assert!(s.contains("0.950"), "explain: {s}");
502 }
503
504 #[test]
505 fn regime_fallback_explain() {
506 let r = ScheduleRegime::Fallback {
507 reason: "dependency cycle detected".to_string(),
508 };
509 assert!(r.is_fallback());
510 assert!(!r.is_whittle());
511 let s = r.explain();
512 assert!(s.contains("Fallback"), "explain: {s}");
513 assert!(s.contains("dependency cycle"), "explain: {s}");
514 }
515
516 #[test]
521 fn assignment_is_deterministic() {
522 let s = scores(&[("bn-a", 5.0), ("bn-b", 5.0), ("bn-c", 5.0)]);
523 let result1 = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
524 let result2 = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
525
526 let r1: Vec<(&str, usize)> = result1
527 .iter()
528 .map(|a| (a.item_id.as_str(), a.agent_idx))
529 .collect();
530 let r2: Vec<(&str, usize)> = result2
531 .iter()
532 .map(|a| (a.item_id.as_str(), a.agent_idx))
533 .collect();
534 assert_eq!(r1, r2, "assignment must be deterministic");
535 }
536
537 #[test]
542 fn missing_score_defaults_to_zero() {
543 let s = scores(&[]);
544 let result = assign_fallback(&items(&["bn-a", "bn-b"]), 2, &s, &[]);
545 assert_eq!(result.len(), 2);
546 }
547}