Skip to main content

converge_optimization/suggestors/
assignment.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Optimal assignment via the Hungarian algorithm (O(n³)).
5//!
6//! Reads an [`AssignmentRequest`] from context, solves the linear-sum
7//! assignment problem, and proposes an [`AssignmentPlan`] to
8//! [`ContextKey::Strategies`].
9//!
10//! # Formation role
11//!
12//! Seed a request once; every downstream suggestor that needs to know who
13//! does what reads the plan from `ContextKey::Strategies`. If cost estimates
14//! change (e.g. a capacity suggestor updates constraints), re-seed with a new
15//! request id — the suggestor reacts and the formation re-converges.
16
17use async_trait::async_trait;
18use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
19use serde::{Deserialize, Serialize};
20
21use crate::assignment::{AssignmentProblem, hungarian};
22
23// ── Request ───────────────────────────────────────────────────────────────────
24
25/// Seed this under [`ContextKey::Seeds`] with id prefix `"assignment-request:"`.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct AssignmentRequest {
28    /// Stable identifier for idempotency.
29    pub id: String,
30    /// Labels for the agents (rows). Length must equal `costs.len()`.
31    pub agents: Vec<String>,
32    /// Labels for the tasks (columns). Length must equal `costs[i].len()`.
33    pub tasks: Vec<String>,
34    /// Cost matrix: `costs[agent][task]`. Must be square (n×n).
35    pub costs: Vec<Vec<i64>>,
36}
37
38// ── Plan (output) ─────────────────────────────────────────────────────────────
39
40/// The optimal assignment produced by the suggestor.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct AssignmentPlan {
43    pub request_id: String,
44    /// `(agent_label, task_label)` pairs, one per matched agent.
45    pub assignments: Vec<(String, String)>,
46    pub total_cost: i64,
47    /// `assignments.len() / agents.len()` — 1.0 means fully matched.
48    pub utilization: f64,
49}
50
51// ── Suggestor ─────────────────────────────────────────────────────────────────
52
53const REQUEST_PREFIX: &str = "assignment-request:";
54const PLAN_PREFIX: &str = "assignment-plan:";
55const ERROR_PREFIX: &str = "assignment-request-error:";
56
57/// Solves a linear-sum assignment problem using the Hungarian algorithm.
58///
59/// Registers as a zero-configuration unit — no injected state required.
60pub struct AssignmentSuggestor;
61
62#[async_trait]
63impl Suggestor for AssignmentSuggestor {
64    fn name(&self) -> &str {
65        "AssignmentSuggestor"
66    }
67
68    fn dependencies(&self) -> &[ContextKey] {
69        &[ContextKey::Seeds]
70    }
71
72    fn complexity_hint(&self) -> Option<&'static str> {
73        Some("O(n³) Hungarian algorithm — n = agents = tasks; practical for n ≤ 500")
74    }
75
76    fn accepts(&self, ctx: &dyn Context) -> bool {
77        ctx.get(ContextKey::Seeds).iter().any(|f| {
78            f.id().as_str().starts_with(REQUEST_PREFIX)
79                && match serde_json::from_str::<AssignmentRequest>(f.content()) {
80                    Ok(_) => !plan_exists(ctx, req_id(f.id().as_str())),
81                    Err(_) => !error_exists(ctx, f.id().as_str()),
82                }
83        })
84    }
85
86    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
87        let mut proposals = Vec::new();
88
89        for fact in ctx
90            .get(ContextKey::Seeds)
91            .iter()
92            .filter(|f| f.id().as_str().starts_with(REQUEST_PREFIX))
93        {
94            match serde_json::from_str::<AssignmentRequest>(fact.content()) {
95                Ok(req) => {
96                    if plan_exists(ctx, req_id(fact.id().as_str())) {
97                        continue;
98                    }
99                    let plan = solve(&req);
100                    proposals.push(
101                        ProposedFact::new(
102                            ContextKey::Strategies,
103                            format!("{}{}", PLAN_PREFIX, plan.request_id),
104                            serde_json::to_string(&plan).unwrap_or_default(),
105                            self.name(),
106                        )
107                        .with_confidence(plan.utilization),
108                    );
109                }
110                Err(e) => {
111                    if error_exists(ctx, fact.id().as_str()) {
112                        continue;
113                    }
114                    let diag = serde_json::json!({
115                        "request_fact_id": fact.id(),
116                        "message": "malformed assignment request",
117                        "error": e.to_string(),
118                    });
119                    proposals.push(
120                        ProposedFact::new(
121                            ContextKey::Diagnostic,
122                            format!("{}{}", ERROR_PREFIX, fact.id()),
123                            diag.to_string(),
124                            self.name(),
125                        )
126                        .with_confidence(1.0),
127                    );
128                }
129            }
130        }
131
132        if proposals.is_empty() {
133            AgentEffect::empty()
134        } else {
135            AgentEffect::with_proposals(proposals)
136        }
137    }
138}
139
140// ── Core logic ────────────────────────────────────────────────────────────────
141
142fn solve(req: &AssignmentRequest) -> AssignmentPlan {
143    if req.agents.is_empty() {
144        return AssignmentPlan {
145            request_id: req.id.clone(),
146            assignments: vec![],
147            total_cost: 0,
148            utilization: 1.0,
149        };
150    }
151
152    let problem = AssignmentProblem::from_costs(req.costs.clone());
153    if problem.validate().is_err() {
154        return AssignmentPlan {
155            request_id: req.id.clone(),
156            assignments: vec![],
157            total_cost: 0,
158            utilization: 0.0,
159        };
160    }
161
162    match hungarian::solve(&problem) {
163        Ok(sol) => {
164            let assignments = sol
165                .assignments
166                .iter()
167                .enumerate()
168                .map(|(agent_idx, &task_idx)| {
169                    (
170                        req.agents.get(agent_idx).cloned().unwrap_or_default(),
171                        req.tasks.get(task_idx).cloned().unwrap_or_default(),
172                    )
173                })
174                .collect::<Vec<_>>();
175            let n = assignments.len();
176            AssignmentPlan {
177                request_id: req.id.clone(),
178                assignments,
179                total_cost: sol.total_cost,
180                utilization: n as f64 / req.agents.len() as f64,
181            }
182        }
183        Err(_) => AssignmentPlan {
184            request_id: req.id.clone(),
185            assignments: vec![],
186            total_cost: 0,
187            utilization: 0.0,
188        },
189    }
190}
191
192// ── Helpers ───────────────────────────────────────────────────────────────────
193
194fn req_id(fact_id: &str) -> &str {
195    fact_id.trim_start_matches(REQUEST_PREFIX)
196}
197
198fn plan_exists(ctx: &dyn Context, request_id: &str) -> bool {
199    let id = format!("{}{}", PLAN_PREFIX, request_id);
200    ctx.get(ContextKey::Strategies)
201        .iter()
202        .any(|f| f.id().as_str() == id)
203}
204
205fn error_exists(ctx: &dyn Context, fact_id: &str) -> bool {
206    let id = format!("{}{}", ERROR_PREFIX, fact_id);
207    ctx.get(ContextKey::Diagnostic)
208        .iter()
209        .any(|f| f.id().as_str() == id)
210}
211
212// ── Tests ─────────────────────────────────────────────────────────────────────
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use converge_core::{ContextState, Engine};
218
219    fn req_json(id: &str, costs: Vec<Vec<i64>>) -> String {
220        let n = costs.len();
221        serde_json::to_string(&AssignmentRequest {
222            id: id.to_string(),
223            agents: (0..n).map(|i| format!("agent-{i}")).collect(),
224            tasks: (0..n).map(|i| format!("task-{i}")).collect(),
225            costs,
226        })
227        .unwrap()
228    }
229
230    #[tokio::test]
231    async fn textbook_3x3_finds_optimal_cost() {
232        // Taha 3×3: optimal = 9
233        let mut engine = Engine::new();
234        engine.register_suggestor(AssignmentSuggestor);
235
236        let mut ctx = ContextState::new();
237        ctx.add_input(
238            ContextKey::Seeds,
239            "assignment-request:r1",
240            req_json("r1", vec![vec![9, 2, 7], vec![6, 4, 3], vec![5, 8, 1]]),
241        )
242        .unwrap();
243
244        let result = engine.run(ctx).await.unwrap();
245        let plans = result.context.get(ContextKey::Strategies);
246        assert_eq!(plans.len(), 1);
247        let plan: AssignmentPlan = serde_json::from_str(plans[0].content()).unwrap();
248        assert_eq!(plan.total_cost, 9, "optimal cost = 9");
249        assert_eq!(plan.assignments.len(), 3);
250        assert!((plan.utilization - 1.0).abs() < f64::EPSILON);
251    }
252
253    #[tokio::test]
254    async fn result_is_idempotent() {
255        let mut engine = Engine::new();
256        engine.register_suggestor(AssignmentSuggestor);
257
258        let mut ctx = ContextState::new();
259        ctx.add_input(
260            ContextKey::Seeds,
261            "assignment-request:r1",
262            req_json("r1", vec![vec![9, 2, 7], vec![6, 4, 3], vec![5, 8, 1]]),
263        )
264        .unwrap();
265
266        let first = engine.run(ctx).await.unwrap();
267        let mut engine2 = Engine::new();
268        engine2.register_suggestor(AssignmentSuggestor);
269        let second = engine2.run(first.context.clone()).await.unwrap();
270        assert_eq!(
271            second.context.get(ContextKey::Strategies).len(),
272            first.context.get(ContextKey::Strategies).len(),
273        );
274    }
275
276    #[tokio::test]
277    async fn malformed_request_emits_diagnostic() {
278        let mut engine = Engine::new();
279        engine.register_suggestor(AssignmentSuggestor);
280
281        let mut ctx = ContextState::new();
282        ctx.add_input(ContextKey::Seeds, "assignment-request:bad", "{")
283            .unwrap();
284
285        let result = engine.run(ctx).await.unwrap();
286        assert_eq!(result.context.get(ContextKey::Diagnostic).len(), 1);
287        assert!(!result.context.has(ContextKey::Strategies));
288    }
289}