Skip to main content

bones_triage/schedule/
fallback.rs

1//! Constrained optimisation fallback scheduler for multi-agent assignment.
2//!
3//! Used when the Whittle indexability gate fires (e.g., dependency-graph
4//! cycles). Provides greedy min-cost-style assignment with fairness and
5//! anti-duplication guarantees.
6//!
7//! # Algorithm
8//!
9//! 1. Sort items by composite score descending.
10//! 2. Assign each item to the **least-loaded** agent that has not previously
11//!    skipped that item (history-aware).  If all agents have skipped the item,
12//!    fall back to the globally least-loaded agent (anti-starvation).
13//! 3. **Fairness**: when `items.len() >= agent_count`, every agent receives at
14//!    least one item.  A configurable `max_load_skew` cap (default 1) limits
15//!    how many more items than the average any single agent can carry.
16//! 4. **Anti-duplicate**: each item appears in at most one assignment.
17//!
18//! # Regime Reporting
19//!
20//! The [`ScheduleRegime`] enum is used by `bn plan --explain` to surface which
21//! scheduler (Whittle or fallback) was active and why.
22
23use std::collections::{HashMap, HashSet};
24
25// ---------------------------------------------------------------------------
26// Public types
27// ---------------------------------------------------------------------------
28
29/// A single agent–item assignment produced by the scheduler.
30#[derive(Debug, Clone, PartialEq, Eq)]
31pub struct Assignment {
32    /// Zero-based index into the agent roster.
33    pub agent_idx: usize,
34    /// The work-item ID that was assigned.
35    pub item_id: String,
36}
37
38/// Which scheduling regime was used, for `bn plan --explain`.
39#[derive(Debug, Clone, PartialEq)]
40pub enum ScheduleRegime {
41    /// Whittle Index was used (normal path).
42    Whittle {
43        /// Aggregate indexability score (1.0 = fully indexable, 0.0 = not).
44        indexability_score: f64,
45    },
46    /// Fallback constrained-optimisation scheduler was used.
47    Fallback {
48        /// Human-readable reason why Whittle was bypassed.
49        reason: String,
50    },
51}
52
53impl ScheduleRegime {
54    /// Returns `true` if the Whittle regime is active.
55    #[must_use]
56    pub const fn is_whittle(&self) -> bool {
57        matches!(self, Self::Whittle { .. })
58    }
59
60    /// Returns `true` if the fallback regime is active.
61    #[must_use]
62    pub const fn is_fallback(&self) -> bool {
63        matches!(self, Self::Fallback { .. })
64    }
65
66    /// Short one-line description suitable for CLI output.
67    #[must_use]
68    pub fn explain(&self) -> String {
69        match self {
70            Self::Whittle { indexability_score } => {
71                format!("Whittle Index (indexability score: {indexability_score:.3})")
72            }
73            Self::Fallback { reason } => {
74                format!("Fallback scheduler — {reason}")
75            }
76        }
77    }
78}
79
80/// Configuration for the fallback scheduler.
81#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct FallbackConfig {
83    /// Maximum number of extra items any agent may receive above the per-agent
84    /// average (floor).  Prevents one agent from hoarding all the work.
85    /// Default: `1`.
86    pub max_load_skew: usize,
87}
88
89impl Default for FallbackConfig {
90    fn default() -> Self {
91        Self { max_load_skew: 1 }
92    }
93}
94
95// ---------------------------------------------------------------------------
96// Core assignment function
97// ---------------------------------------------------------------------------
98
99/// Assign work items to agents using a greedy constrained-optimisation
100/// approach.
101///
102/// # Arguments
103///
104/// * `items` — Item IDs to assign. Duplicates are silently deduplicated.
105/// * `agent_count` — Number of agents available. Must be ≥ 1.
106/// * `scores` — Composite scores keyed by item ID. Missing items get `0.0`.
107/// * `history` — Previously attempted assignments (`agent_idx`, `item_id`) that
108///   were **not completed** (i.e., skipped). The scheduler avoids re-pairing
109///   the same (agent, item) when possible.
110///
111/// # Returns
112///
113/// A `Vec<Assignment>` in score-descending order (highest-priority items
114/// listed first). Every item appears **at most once**. When
115/// `items.len() >= agent_count`, every agent receives at least one item.
116///
117/// # Panics
118///
119/// Panics if `agent_count == 0`.
120#[must_use]
121#[allow(clippy::implicit_hasher)]
122pub fn assign_fallback(
123    items: &[String],
124    agent_count: usize,
125    scores: &HashMap<String, f64>,
126    history: &[Assignment],
127) -> Vec<Assignment> {
128    assign_fallback_with_config(
129        items,
130        agent_count,
131        scores,
132        history,
133        &FallbackConfig::default(),
134    )
135}
136
137/// Like [`assign_fallback`] but accepts explicit [`FallbackConfig`].
138///
139/// # Panics
140///
141/// Panics if `agent_count == 0`.
142#[must_use]
143#[allow(clippy::implicit_hasher)]
144pub fn assign_fallback_with_config(
145    items: &[String],
146    agent_count: usize,
147    scores: &HashMap<String, f64>,
148    history: &[Assignment],
149    config: &FallbackConfig,
150) -> Vec<Assignment> {
151    assert!(agent_count >= 1, "agent_count must be at least 1");
152
153    // Deduplicate items while preserving the first occurrence order.
154    let unique_items: Vec<String> = {
155        let mut seen: HashSet<&str> = HashSet::new();
156        items
157            .iter()
158            .filter(|id| seen.insert(id.as_str()))
159            .cloned()
160            .collect()
161    };
162
163    if unique_items.is_empty() {
164        return Vec::new();
165    }
166
167    // Sort items by score descending; ties broken by item ID for determinism.
168    let mut sorted: Vec<&str> = unique_items.iter().map(String::as_str).collect();
169    sorted.sort_by(|&a, &b| {
170        let sa = scores.get(a).copied().unwrap_or(0.0);
171        let sb = scores.get(b).copied().unwrap_or(0.0);
172        sb.partial_cmp(&sa)
173            .unwrap_or(std::cmp::Ordering::Equal)
174            .then_with(|| a.cmp(b))
175    });
176
177    // Build a skip-set from history: (agent_idx, item_id) pairs to avoid.
178    let skip_set: HashSet<(usize, &str)> = history
179        .iter()
180        .filter(|a| a.agent_idx < agent_count)
181        .map(|a| (a.agent_idx, a.item_id.as_str()))
182        .collect();
183
184    // Per-agent load counters.
185    let mut load: Vec<usize> = vec![0; agent_count];
186
187    // Compute per-agent max load cap: floor(items / agents) + max_load_skew.
188    // This prevents any single agent from accumulating too much more than
189    // their fair share.  We recompute after each assignment.
190    let total_items = sorted.len();
191
192    let mut assignments: Vec<Assignment> = Vec::with_capacity(total_items);
193
194    for &item_id in &sorted {
195        // Preferred agent: least-loaded among those who haven't skipped this item.
196        let preferred = pick_agent(&load, agent_count, config, total_items, |ag_idx| {
197            !skip_set.contains(&(ag_idx, item_id))
198        });
199
200        // If no preferred agent found (all have skipped or all are at cap),
201        // fall back to absolute least-loaded without the skip filter.
202        let agent_idx = preferred.unwrap_or_else(|| {
203            pick_agent(&load, agent_count, config, total_items, |_| true)
204                .unwrap_or_else(|| least_loaded_agent(&load))
205        });
206
207        load[agent_idx] += 1;
208        assignments.push(Assignment {
209            agent_idx,
210            item_id: item_id.to_string(),
211        });
212    }
213
214    // Fairness pass: ensure every agent has at least one item when
215    // items >= agent_count. Steal the last (lowest-priority) item from the
216    // most-loaded agent and re-assign it to any starved agent.
217    if total_items >= agent_count {
218        enforce_fairness(&mut assignments, &mut load, agent_count, scores);
219    }
220
221    assignments
222}
223
224// ---------------------------------------------------------------------------
225// Internal helpers
226// ---------------------------------------------------------------------------
227
228/// Pick the best agent index subject to a predicate and load cap.
229///
230/// Returns `None` if no agent satisfies the predicate within the cap.
231fn pick_agent(
232    load: &[usize],
233    agent_count: usize,
234    config: &FallbackConfig,
235    total_items: usize,
236    predicate: impl Fn(usize) -> bool,
237) -> Option<usize> {
238    // Fair-share cap: base per-agent allocation + skew allowance.
239    let base = total_items / agent_count;
240    let cap = base + config.max_load_skew;
241
242    (0..agent_count)
243        .filter(|&ag| predicate(ag) && load[ag] < cap)
244        .min_by_key(|&ag| load[ag])
245}
246
247/// Return the index of the least-loaded agent (ties broken by lowest index).
248fn least_loaded_agent(load: &[usize]) -> usize {
249    load.iter()
250        .enumerate()
251        .min_by_key(|&(_, &l)| l)
252        .map_or(0, |(idx, _)| idx)
253}
254
255/// Enforce fairness: steal the lowest-priority item from over-loaded agents
256/// and give it to any agent with zero items, when items >= `agent_count`.
257fn enforce_fairness(
258    assignments: &mut [Assignment],
259    load: &mut [usize],
260    agent_count: usize,
261    scores: &HashMap<String, f64>,
262) {
263    for starved_agent in 0..agent_count {
264        if load[starved_agent] > 0 {
265            continue;
266        }
267
268        // Find the most-loaded agent with more than 1 item (has items to spare).
269        let donor = (0..agent_count)
270            .filter(|&ag| load[ag] > 1)
271            .max_by_key(|&ag| load[ag]);
272
273        let Some(donor_idx) = donor else {
274            break; // Cannot fix starvation — not enough items to redistribute.
275        };
276
277        // Find the lowest-scoring assignment belonging to the donor.
278        let steal_pos = assignments
279            .iter()
280            .enumerate()
281            .filter(|(_, a)| a.agent_idx == donor_idx)
282            .min_by(|(_, a1), (_, a2)| {
283                let s1 = scores.get(a1.item_id.as_str()).copied().unwrap_or(0.0);
284                let s2 = scores.get(a2.item_id.as_str()).copied().unwrap_or(0.0);
285                s1.partial_cmp(&s2)
286                    .unwrap_or(std::cmp::Ordering::Equal)
287                    .then_with(|| a2.item_id.cmp(&a1.item_id))
288            })
289            .map(|(pos, _)| pos);
290
291        if let Some(pos) = steal_pos {
292            load[donor_idx] -= 1;
293            load[starved_agent] += 1;
294            assignments[pos].agent_idx = starved_agent;
295        }
296    }
297}
298
299// ---------------------------------------------------------------------------
300// Tests
301// ---------------------------------------------------------------------------
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    fn scores(pairs: &[(&str, f64)]) -> HashMap<String, f64> {
308        pairs.iter().map(|(k, v)| (k.to_string(), *v)).collect()
309    }
310
311    fn items(ids: &[&str]) -> Vec<String> {
312        ids.iter().map(|s| s.to_string()).collect()
313    }
314
315    fn history(pairs: &[(usize, &str)]) -> Vec<Assignment> {
316        pairs
317            .iter()
318            .map(|(ag, id)| Assignment {
319                agent_idx: *ag,
320                item_id: id.to_string(),
321            })
322            .collect()
323    }
324
325    // -----------------------------------------------------------------------
326    // Basic assignment
327    // -----------------------------------------------------------------------
328
329    #[test]
330    fn assigns_single_item_to_single_agent() {
331        let s = scores(&[("bn-a", 5.0)]);
332        let result = assign_fallback(&items(&["bn-a"]), 1, &s, &[]);
333
334        assert_eq!(result.len(), 1);
335        assert_eq!(result[0].agent_idx, 0);
336        assert_eq!(result[0].item_id, "bn-a");
337    }
338
339    #[test]
340    fn assigns_multiple_items_to_multiple_agents() {
341        let s = scores(&[("bn-a", 3.0), ("bn-b", 5.0), ("bn-c", 1.0)]);
342        let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
343
344        assert_eq!(result.len(), 3);
345        // All items assigned.
346        let assigned: HashSet<&str> = result.iter().map(|a| a.item_id.as_str()).collect();
347        assert!(assigned.contains("bn-a"));
348        assert!(assigned.contains("bn-b"));
349        assert!(assigned.contains("bn-c"));
350    }
351
352    #[test]
353    fn highest_score_assigned_first() {
354        // bn-b has highest score; it should be the first assignment.
355        let s = scores(&[("bn-a", 3.0), ("bn-b", 9.0), ("bn-c", 1.0)]);
356        let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
357
358        assert_eq!(result[0].item_id, "bn-b", "highest score first");
359    }
360
361    #[test]
362    fn empty_items_returns_empty() {
363        let s = scores(&[]);
364        let result = assign_fallback(&[], 3, &s, &[]);
365        assert!(result.is_empty());
366    }
367
368    #[test]
369    fn single_agent_gets_all_items() {
370        let s = scores(&[("bn-a", 2.0), ("bn-b", 5.0), ("bn-c", 1.0)]);
371        let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 1, &s, &[]);
372
373        assert_eq!(result.len(), 3);
374        assert!(result.iter().all(|a| a.agent_idx == 0));
375    }
376
377    // -----------------------------------------------------------------------
378    // Anti-duplicate
379    // -----------------------------------------------------------------------
380
381    #[test]
382    fn no_item_assigned_twice() {
383        let s = scores(&[("bn-a", 1.0), ("bn-b", 2.0), ("bn-c", 3.0)]);
384        let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
385
386        let ids: Vec<&str> = result.iter().map(|a| a.item_id.as_str()).collect();
387        let unique: HashSet<&str> = ids.iter().copied().collect();
388        assert_eq!(ids.len(), unique.len(), "no item appears twice");
389    }
390
391    #[test]
392    fn duplicate_input_items_deduplicated() {
393        let s = scores(&[("bn-a", 5.0)]);
394        let result = assign_fallback(&items(&["bn-a", "bn-a", "bn-a"]), 2, &s, &[]);
395        // Only one unique item, so only one assignment.
396        assert_eq!(result.len(), 1);
397        assert_eq!(result[0].item_id, "bn-a");
398    }
399
400    // -----------------------------------------------------------------------
401    // Fairness constraint
402    // -----------------------------------------------------------------------
403
404    #[test]
405    fn fairness_every_agent_gets_one_item_when_items_gte_agents() {
406        // 3 items, 3 agents → each agent gets exactly 1 item.
407        let s = scores(&[("bn-a", 3.0), ("bn-b", 5.0), ("bn-c", 1.0)]);
408        let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 3, &s, &[]);
409
410        assert_eq!(result.len(), 3);
411        let mut per_agent = vec![0usize; 3];
412        for a in &result {
413            per_agent[a.agent_idx] += 1;
414        }
415        for (ag, &count) in per_agent.iter().enumerate() {
416            assert_eq!(count, 1, "agent {ag} should have exactly 1 item");
417        }
418    }
419
420    #[test]
421    fn fairness_no_agent_starved_with_four_items_three_agents() {
422        // 4 items, 3 agents → all agents get at least 1.
423        let s = scores(&[("bn-a", 4.0), ("bn-b", 3.0), ("bn-c", 2.0), ("bn-d", 1.0)]);
424        let result = assign_fallback(&items(&["bn-a", "bn-b", "bn-c", "bn-d"]), 3, &s, &[]);
425
426        let mut per_agent = vec![0usize; 3];
427        for a in &result {
428            per_agent[a.agent_idx] += 1;
429        }
430        for (ag, &count) in per_agent.iter().enumerate() {
431            assert!(
432                count >= 1,
433                "agent {ag} should have at least 1 item (got {count})"
434            );
435        }
436    }
437
438    #[test]
439    fn fairness_ok_when_items_less_than_agents() {
440        // 2 items, 3 agents → 2 assignments (not all agents get work, that's OK).
441        let s = scores(&[("bn-a", 5.0), ("bn-b", 3.0)]);
442        let result = assign_fallback(&items(&["bn-a", "bn-b"]), 3, &s, &[]);
443
444        assert_eq!(result.len(), 2);
445    }
446
447    // -----------------------------------------------------------------------
448    // History / anti-starvation
449    // -----------------------------------------------------------------------
450
451    #[test]
452    fn history_avoids_previous_skip_assignment() {
453        // Agent 0 previously skipped bn-a. bn-a should go to agent 1.
454        let s = scores(&[("bn-a", 5.0), ("bn-b", 3.0)]);
455        let h = history(&[(0, "bn-a")]);
456
457        let result = assign_fallback(&items(&["bn-a", "bn-b"]), 2, &s, &h);
458
459        let bn_a = result.iter().find(|a| a.item_id == "bn-a").unwrap();
460        assert_eq!(
461            bn_a.agent_idx, 1,
462            "bn-a should not go to agent 0 (who skipped it)"
463        );
464    }
465
466    #[test]
467    fn history_falls_back_when_all_agents_skipped() {
468        // Both agents skipped bn-a — scheduler must still assign it (no panic).
469        let s = scores(&[("bn-a", 5.0)]);
470        let h = history(&[(0, "bn-a"), (1, "bn-a")]);
471
472        let result = assign_fallback(&items(&["bn-a"]), 2, &s, &h);
473
474        assert_eq!(result.len(), 1);
475        assert_eq!(result[0].item_id, "bn-a");
476    }
477
478    #[test]
479    fn history_with_unknown_agent_idx_is_ignored() {
480        // agent_idx = 99 is out of range for a 2-agent run — should not panic.
481        let s = scores(&[("bn-a", 5.0)]);
482        let h = history(&[(99, "bn-a")]);
483
484        let result = assign_fallback(&items(&["bn-a"]), 2, &s, &h);
485        assert_eq!(result.len(), 1);
486    }
487
488    // -----------------------------------------------------------------------
489    // ScheduleRegime
490    // -----------------------------------------------------------------------
491
492    #[test]
493    fn regime_whittle_explain() {
494        let r = ScheduleRegime::Whittle {
495            indexability_score: 0.95,
496        };
497        assert!(r.is_whittle());
498        assert!(!r.is_fallback());
499        let s = r.explain();
500        assert!(s.contains("Whittle"), "explain: {s}");
501        assert!(s.contains("0.950"), "explain: {s}");
502    }
503
504    #[test]
505    fn regime_fallback_explain() {
506        let r = ScheduleRegime::Fallback {
507            reason: "dependency cycle detected".to_string(),
508        };
509        assert!(r.is_fallback());
510        assert!(!r.is_whittle());
511        let s = r.explain();
512        assert!(s.contains("Fallback"), "explain: {s}");
513        assert!(s.contains("dependency cycle"), "explain: {s}");
514    }
515
516    // -----------------------------------------------------------------------
517    // Determinism
518    // -----------------------------------------------------------------------
519
520    #[test]
521    fn assignment_is_deterministic() {
522        let s = scores(&[("bn-a", 5.0), ("bn-b", 5.0), ("bn-c", 5.0)]);
523        let result1 = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
524        let result2 = assign_fallback(&items(&["bn-a", "bn-b", "bn-c"]), 2, &s, &[]);
525
526        let r1: Vec<(&str, usize)> = result1
527            .iter()
528            .map(|a| (a.item_id.as_str(), a.agent_idx))
529            .collect();
530        let r2: Vec<(&str, usize)> = result2
531            .iter()
532            .map(|a| (a.item_id.as_str(), a.agent_idx))
533            .collect();
534        assert_eq!(r1, r2, "assignment must be deterministic");
535    }
536
537    // -----------------------------------------------------------------------
538    // Missing score defaults
539    // -----------------------------------------------------------------------
540
541    #[test]
542    fn missing_score_defaults_to_zero() {
543        let s = scores(&[]);
544        let result = assign_fallback(&items(&["bn-a", "bn-b"]), 2, &s, &[]);
545        assert_eq!(result.len(), 2);
546    }
547}