Skip to main content

converge_optimization/suggestors/
assignment.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Optimal assignment via the Hungarian algorithm (O(n³)).
5//!
6//! Reads an [`AssignmentRequest`] from context, solves the linear-sum
7//! assignment problem, and proposes an [`AssignmentPlan`] to
8//! [`ContextKey::Strategies`].
9//!
10//! # Formation role
11//!
12//! Seed a request once; every downstream suggestor that needs to know who
13//! does what reads the plan from `ContextKey::Strategies`. If cost estimates
14//! change (e.g. a capacity suggestor updates constraints), re-seed with a new
15//! request id — the suggestor reacts and the formation re-converges.
16
17use async_trait::async_trait;
18use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
19use serde::{Deserialize, Serialize};
20
21use crate::assignment::{AssignmentProblem, hungarian};
22
23// ── Request ───────────────────────────────────────────────────────────────────
24
25/// Seed this under [`ContextKey::Seeds`] with id prefix `"assignment-request:"`.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct AssignmentRequest {
28    /// Stable identifier for idempotency.
29    pub id: String,
30    /// Labels for the agents (rows). Length must equal `costs.len()`.
31    pub agents: Vec<String>,
32    /// Labels for the tasks (columns). Length must equal `costs[i].len()`.
33    pub tasks: Vec<String>,
34    /// Cost matrix: `costs[agent][task]`. Must be square (n×n).
35    pub costs: Vec<Vec<i64>>,
36}
37
38// ── Plan (output) ─────────────────────────────────────────────────────────────
39
40/// The optimal assignment produced by the suggestor.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct AssignmentPlan {
43    pub request_id: String,
44    /// `(agent_label, task_label)` pairs, one per matched agent.
45    pub assignments: Vec<(String, String)>,
46    pub total_cost: i64,
47    /// `assignments.len() / agents.len()` — 1.0 means fully matched.
48    pub utilization: f64,
49}
50
51// ── Suggestor ─────────────────────────────────────────────────────────────────
52
53const REQUEST_PREFIX: &str = "assignment-request:";
54const PLAN_PREFIX: &str = "assignment-plan:";
55const ERROR_PREFIX: &str = "assignment-request-error:";
56
57/// Solves a linear-sum assignment problem using the Hungarian algorithm.
58///
59/// Registers as a zero-configuration unit — no injected state required.
60pub struct AssignmentSuggestor;
61
62#[async_trait]
63impl Suggestor for AssignmentSuggestor {
64    fn name(&self) -> &str {
65        "AssignmentSuggestor"
66    }
67
68    fn dependencies(&self) -> &[ContextKey] {
69        &[ContextKey::Seeds]
70    }
71
72    fn complexity_hint(&self) -> Option<&'static str> {
73        Some("O(n³) Hungarian algorithm — n = agents = tasks; practical for n ≤ 500")
74    }
75
76    fn accepts(&self, ctx: &dyn Context) -> bool {
77        ctx.get(ContextKey::Seeds).iter().any(|f| {
78            f.id.starts_with(REQUEST_PREFIX)
79                && match serde_json::from_str::<AssignmentRequest>(&f.content) {
80                    Ok(_) => !plan_exists(ctx, req_id(&f.id)),
81                    Err(_) => !error_exists(ctx, &f.id),
82                }
83        })
84    }
85
86    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
87        let mut proposals = Vec::new();
88
89        for fact in ctx
90            .get(ContextKey::Seeds)
91            .iter()
92            .filter(|f| f.id.starts_with(REQUEST_PREFIX))
93        {
94            match serde_json::from_str::<AssignmentRequest>(&fact.content) {
95                Ok(req) => {
96                    if plan_exists(ctx, req_id(&fact.id)) {
97                        continue;
98                    }
99                    let plan = solve(&req);
100                    proposals.push(
101                        ProposedFact::new(
102                            ContextKey::Strategies,
103                            format!("{}{}", PLAN_PREFIX, plan.request_id),
104                            serde_json::to_string(&plan).unwrap_or_default(),
105                            self.name(),
106                        )
107                        .with_confidence(plan.utilization),
108                    );
109                }
110                Err(e) => {
111                    if error_exists(ctx, &fact.id) {
112                        continue;
113                    }
114                    let diag = serde_json::json!({
115                        "request_fact_id": fact.id,
116                        "message": "malformed assignment request",
117                        "error": e.to_string(),
118                    });
119                    proposals.push(
120                        ProposedFact::new(
121                            ContextKey::Diagnostic,
122                            format!("{}{}", ERROR_PREFIX, fact.id),
123                            diag.to_string(),
124                            self.name(),
125                        )
126                        .with_confidence(1.0),
127                    );
128                }
129            }
130        }
131
132        if proposals.is_empty() {
133            AgentEffect::empty()
134        } else {
135            AgentEffect::with_proposals(proposals)
136        }
137    }
138}
139
140// ── Core logic ────────────────────────────────────────────────────────────────
141
142fn solve(req: &AssignmentRequest) -> AssignmentPlan {
143    if req.agents.is_empty() {
144        return AssignmentPlan {
145            request_id: req.id.clone(),
146            assignments: vec![],
147            total_cost: 0,
148            utilization: 1.0,
149        };
150    }
151
152    let problem = AssignmentProblem::from_costs(req.costs.clone());
153    if problem.validate().is_err() {
154        return AssignmentPlan {
155            request_id: req.id.clone(),
156            assignments: vec![],
157            total_cost: 0,
158            utilization: 0.0,
159        };
160    }
161
162    match hungarian::solve(&problem) {
163        Ok(sol) => {
164            let assignments = sol
165                .assignments
166                .iter()
167                .enumerate()
168                .map(|(agent_idx, &task_idx)| {
169                    (
170                        req.agents.get(agent_idx).cloned().unwrap_or_default(),
171                        req.tasks.get(task_idx).cloned().unwrap_or_default(),
172                    )
173                })
174                .collect::<Vec<_>>();
175            let n = assignments.len();
176            AssignmentPlan {
177                request_id: req.id.clone(),
178                assignments,
179                total_cost: sol.total_cost,
180                utilization: n as f64 / req.agents.len() as f64,
181            }
182        }
183        Err(_) => AssignmentPlan {
184            request_id: req.id.clone(),
185            assignments: vec![],
186            total_cost: 0,
187            utilization: 0.0,
188        },
189    }
190}
191
192// ── Helpers ───────────────────────────────────────────────────────────────────
193
194fn req_id(fact_id: &str) -> &str {
195    fact_id.trim_start_matches(REQUEST_PREFIX)
196}
197
198fn plan_exists(ctx: &dyn Context, request_id: &str) -> bool {
199    let id = format!("{}{}", PLAN_PREFIX, request_id);
200    ctx.get(ContextKey::Strategies).iter().any(|f| f.id == id)
201}
202
203fn error_exists(ctx: &dyn Context, fact_id: &str) -> bool {
204    let id = format!("{}{}", ERROR_PREFIX, fact_id);
205    ctx.get(ContextKey::Diagnostic).iter().any(|f| f.id == id)
206}
207
208// ── Tests ─────────────────────────────────────────────────────────────────────
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use converge_core::{ContextState, Engine};
214
215    fn req_json(id: &str, costs: Vec<Vec<i64>>) -> String {
216        let n = costs.len();
217        serde_json::to_string(&AssignmentRequest {
218            id: id.to_string(),
219            agents: (0..n).map(|i| format!("agent-{i}")).collect(),
220            tasks: (0..n).map(|i| format!("task-{i}")).collect(),
221            costs,
222        })
223        .unwrap()
224    }
225
226    #[tokio::test]
227    async fn textbook_3x3_finds_optimal_cost() {
228        // Taha 3×3: optimal = 9
229        let mut engine = Engine::new();
230        engine.register_suggestor(AssignmentSuggestor);
231
232        let mut ctx = ContextState::new();
233        ctx.add_input(
234            ContextKey::Seeds,
235            "assignment-request:r1",
236            req_json("r1", vec![vec![9, 2, 7], vec![6, 4, 3], vec![5, 8, 1]]),
237        )
238        .unwrap();
239
240        let result = engine.run(ctx).await.unwrap();
241        let plans = result.context.get(ContextKey::Strategies);
242        assert_eq!(plans.len(), 1);
243        let plan: AssignmentPlan = serde_json::from_str(&plans[0].content).unwrap();
244        assert_eq!(plan.total_cost, 9, "optimal cost = 9");
245        assert_eq!(plan.assignments.len(), 3);
246        assert!((plan.utilization - 1.0).abs() < f64::EPSILON);
247    }
248
249    #[tokio::test]
250    async fn result_is_idempotent() {
251        let mut engine = Engine::new();
252        engine.register_suggestor(AssignmentSuggestor);
253
254        let mut ctx = ContextState::new();
255        ctx.add_input(
256            ContextKey::Seeds,
257            "assignment-request:r1",
258            req_json("r1", vec![vec![9, 2, 7], vec![6, 4, 3], vec![5, 8, 1]]),
259        )
260        .unwrap();
261
262        let first = engine.run(ctx).await.unwrap();
263        let mut engine2 = Engine::new();
264        engine2.register_suggestor(AssignmentSuggestor);
265        let second = engine2.run(first.context.clone()).await.unwrap();
266        assert_eq!(
267            second.context.get(ContextKey::Strategies).len(),
268            first.context.get(ContextKey::Strategies).len(),
269        );
270    }
271
272    #[tokio::test]
273    async fn malformed_request_emits_diagnostic() {
274        let mut engine = Engine::new();
275        engine.register_suggestor(AssignmentSuggestor);
276
277        let mut ctx = ContextState::new();
278        ctx.add_input(ContextKey::Seeds, "assignment-request:bad", "{")
279            .unwrap();
280
281        let result = engine.run(ctx).await.unwrap();
282        assert_eq!(result.context.get(ContextKey::Diagnostic).len(), 1);
283        assert!(!result.context.has(ContextKey::Strategies));
284    }
285}