Skip to main content

converge_provider/
selection_suggestor.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Provider selection via bipartite matching.
5//!
6//! Reads a [`ProviderRequest`] from context, matches required capabilities
7//! against the registered backend pool using Hopcroft-Karp, and proposes a
8//! [`ProviderAssignment`] to [`ContextKey::Strategies`].
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
15use converge_provider_api::{Backend, CapabilityAssignment, ProviderAssignment, ProviderRequest};
16
17// ── Suggestor ─────────────────────────────────────────────────────────────────
18
19const REQUEST_PREFIX: &str = "provider-request:";
20const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
21const MALFORMED_PREFIX: &str = "provider-request-error:";
22
23/// Routes required capabilities to available backends via bipartite matching.
24///
25/// # Construction
26///
27/// ```rust,ignore
28/// let backends: Vec<Arc<dyn Backend>> = vec![
29///     Arc::new(AnthropicBackend::from_env()),
30///     Arc::new(KongBackend::from_env()),
31/// ];
32///
33/// engine.register_suggestor(ProviderSelectionSuggestor::new(backends));
34/// ```
35pub struct ProviderSelectionSuggestor {
36    backends: Vec<Arc<dyn Backend>>,
37}
38
39impl ProviderSelectionSuggestor {
40    pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
41        Self { backends }
42    }
43}
44
45#[async_trait]
46impl Suggestor for ProviderSelectionSuggestor {
47    fn name(&self) -> &str {
48        "ProviderSelectionSuggestor"
49    }
50
51    fn dependencies(&self) -> &[ContextKey] {
52        &[ContextKey::Seeds]
53    }
54
55    fn accepts(&self, ctx: &dyn Context) -> bool {
56        ctx.get(ContextKey::Seeds).iter().any(|f| {
57            f.id.starts_with(REQUEST_PREFIX)
58                && match serde_json::from_str::<ProviderRequest>(&f.content) {
59                    Ok(_) => !assignment_exists(ctx, request_id(&f.id)),
60                    Err(_) => !malformed_diagnostic_exists(ctx, &f.id),
61                }
62        })
63    }
64
65    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
66        let mut proposals = Vec::new();
67
68        for fact in ctx
69            .get(ContextKey::Seeds)
70            .iter()
71            .filter(|f| f.id.starts_with(REQUEST_PREFIX))
72        {
73            match serde_json::from_str::<ProviderRequest>(&fact.content) {
74                Ok(req) => {
75                    if assignment_exists(ctx, request_id(&fact.id)) {
76                        continue;
77                    }
78
79                    let assignment = route(&req, &self.backends);
80                    proposals.push(
81                        ProposedFact::new(
82                            ContextKey::Strategies,
83                            format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
84                            serde_json::to_string(&assignment).unwrap_or_default(),
85                            self.name(),
86                        )
87                        .with_confidence(assignment.coverage_ratio),
88                    );
89                }
90                Err(error) => {
91                    if malformed_diagnostic_exists(ctx, &fact.id) {
92                        continue;
93                    }
94
95                    let diagnostic = serde_json::json!({
96                        "request_fact_id": fact.id,
97                        "message": "malformed provider request ignored",
98                        "error": error.to_string(),
99                    });
100                    proposals.push(
101                        ProposedFact::new(
102                            ContextKey::Diagnostic,
103                            malformed_diagnostic_id(&fact.id),
104                            diagnostic.to_string(),
105                            self.name(),
106                        )
107                        .with_confidence(1.0),
108                    );
109                }
110            }
111        }
112
113        if proposals.is_empty() {
114            AgentEffect::empty()
115        } else {
116            AgentEffect::with_proposals(proposals)
117        }
118    }
119}
120
121// ── Matching logic ────────────────────────────────────────────────────────────
122
123fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
124    // Left = required capability slots (index = position in req.required_capabilities).
125    // Right = backends (index = position in `backends`).
126    // Edge: backends[j].has_capability(req.required_capabilities[i]).
127    let edges: Vec<(usize, usize)> = req
128        .required_capabilities
129        .iter()
130        .enumerate()
131        .flat_map(|(i, cap)| {
132            let cap = cap.clone();
133            backends
134                .iter()
135                .enumerate()
136                .filter(move |(_, b)| b.has_capability(cap.clone()))
137                .map(move |(j, _)| (i, j))
138        })
139        .collect();
140
141    let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
142        .unwrap_or_default();
143
144    let mut covered = vec![false; req.required_capabilities.len()];
145    let mut assignments = Vec::with_capacity(matching.size);
146
147    for (cap_idx, backend_idx) in &matching.pairs {
148        assignments.push(CapabilityAssignment {
149            capability: req.required_capabilities[*cap_idx].clone(),
150            backend_name: backends[*backend_idx].name().to_string(),
151        });
152        covered[*cap_idx] = true;
153    }
154
155    let unmatched = req
156        .required_capabilities
157        .iter()
158        .enumerate()
159        .filter(|(i, _)| !covered[*i])
160        .map(|(_, c)| c.clone())
161        .collect::<Vec<_>>();
162
163    let coverage_ratio = if req.required_capabilities.is_empty() {
164        1.0
165    } else {
166        matching.size as f64 / req.required_capabilities.len() as f64
167    };
168
169    ProviderAssignment {
170        request_id: req.id.clone(),
171        assignments,
172        unmatched,
173        coverage_ratio,
174    }
175}
176
177// ── Helpers ───────────────────────────────────────────────────────────────────
178
179fn request_id(fact_id: &str) -> &str {
180    fact_id.trim_start_matches(REQUEST_PREFIX)
181}
182
183fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
184    let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
185    ctx.get(ContextKey::Strategies)
186        .iter()
187        .any(|f| f.id == assignment_id)
188}
189
190fn malformed_diagnostic_id(fact_id: &str) -> String {
191    format!("{MALFORMED_PREFIX}{fact_id}")
192}
193
194fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
195    let diagnostic_id = malformed_diagnostic_id(fact_id);
196    ctx.get(ContextKey::Diagnostic)
197        .iter()
198        .any(|fact| fact.id == diagnostic_id)
199}
200
201// ── Tests ─────────────────────────────────────────────────────────────────────
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use converge_core::{ContextState, Engine};
207    use converge_provider_api::{BackendKind, Capability};
208
209    struct MockBackend {
210        name: &'static str,
211        capabilities: Vec<Capability>,
212    }
213
214    impl Backend for MockBackend {
215        fn name(&self) -> &str {
216            self.name
217        }
218        fn kind(&self) -> BackendKind {
219            BackendKind::Llm
220        }
221        fn capabilities(&self) -> Vec<Capability> {
222            self.capabilities.clone()
223        }
224        fn supports_replay(&self) -> bool {
225            false
226        }
227        fn requires_network(&self) -> bool {
228            true
229        }
230    }
231
232    fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
233        Arc::new(MockBackend {
234            name,
235            capabilities: caps.to_vec(),
236        })
237    }
238
239    fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
240        ProviderRequest {
241            id: id.to_string(),
242            required_capabilities: caps.to_vec(),
243        }
244    }
245
246    #[test]
247    fn full_coverage_when_all_capabilities_available() {
248        let pool = vec![
249            backend("anthropic", &[Capability::Reasoning]),
250            backend("kong", &[Capability::AccessControl]),
251            backend("elastic", &[Capability::FullTextSearch]),
252        ];
253        let req = request(
254            "req-1",
255            &[
256                Capability::Reasoning,
257                Capability::AccessControl,
258                Capability::FullTextSearch,
259            ],
260        );
261
262        let assignment = route(&req, &pool);
263
264        assert_eq!(assignment.assignments.len(), 3);
265        assert!(assignment.unmatched.is_empty());
266        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
267    }
268
269    #[test]
270    fn partial_coverage_when_capability_missing() {
271        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
272        let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
273
274        let assignment = route(&req, &pool);
275
276        assert_eq!(assignment.assignments.len(), 1);
277        assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
278        assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
279    }
280
281    #[test]
282    fn no_double_booking_with_two_same_capability_slots() {
283        let pool = vec![
284            backend("anthropic", &[Capability::Reasoning]),
285            backend("openai", &[Capability::Reasoning]),
286        ];
287        let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
288
289        let assignment = route(&req, &pool);
290
291        assert_eq!(assignment.assignments.len(), 2);
292        let names: Vec<_> = assignment
293            .assignments
294            .iter()
295            .map(|a| &a.backend_name)
296            .collect();
297        let unique: std::collections::HashSet<_> = names.iter().collect();
298        assert_eq!(unique.len(), 2);
299    }
300
301    #[test]
302    fn multi_capability_backend_can_only_fill_one_slot() {
303        // One backend that has both capabilities but should only fill one slot.
304        let pool = vec![backend(
305            "all-in-one",
306            &[Capability::Reasoning, Capability::AccessControl],
307        )];
308        let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
309
310        let assignment = route(&req, &pool);
311
312        // Only one slot filled — backend can't be double-booked.
313        assert_eq!(assignment.assignments.len(), 1);
314        assert_eq!(assignment.unmatched.len(), 1);
315    }
316
317    #[test]
318    fn empty_pool_yields_zero_coverage() {
319        let req = request("req-5", &[Capability::Reasoning]);
320        let assignment = route(&req, &[]);
321        assert_eq!(assignment.coverage_ratio, 0.0);
322        assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
323    }
324
325    #[test]
326    fn empty_request_yields_full_coverage() {
327        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
328        let req = request("req-6", &[]);
329        let assignment = route(&req, &pool);
330        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
331        assert!(assignment.assignments.is_empty());
332    }
333
334    #[test]
335    fn repeated_routing_is_deterministic_for_equal_candidates() {
336        let pool = vec![
337            backend("reasoner-a", &[Capability::Reasoning]),
338            backend("reasoner-b", &[Capability::Reasoning]),
339            backend("policy-a", &[Capability::AccessControl]),
340        ];
341        let req = request(
342            "req-7",
343            &[
344                Capability::Reasoning,
345                Capability::Reasoning,
346                Capability::AccessControl,
347            ],
348        );
349
350        let first = route(&req, &pool);
351        let second = route(&req, &pool);
352
353        assert_eq!(first.assignments, second.assignments);
354        assert_eq!(first.unmatched, second.unmatched);
355        assert_eq!(first.coverage_ratio, second.coverage_ratio);
356    }
357
358    #[tokio::test]
359    async fn malformed_request_emits_diagnostic_once() {
360        let mut engine = Engine::new();
361        engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
362            "anthropic",
363            &[Capability::Reasoning],
364        )]));
365
366        let mut ctx = ContextState::new();
367        ctx.add_input(ContextKey::Seeds, "provider-request:broken", "{")
368            .expect("seed should stage");
369
370        let first = engine.run(ctx).await.expect("run should converge");
371        let diagnostics = first.context.get(ContextKey::Diagnostic);
372        assert_eq!(diagnostics.len(), 1);
373        assert_eq!(
374            diagnostics[0].id,
375            "provider-request-error:provider-request:broken"
376        );
377        assert!(!first.context.has(ContextKey::Strategies));
378
379        let mut rerun_engine = Engine::new();
380        rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
381            "anthropic",
382            &[Capability::Reasoning],
383        )]));
384        let second = rerun_engine
385            .run(first.context.clone())
386            .await
387            .expect("rerun should converge");
388        assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
389    }
390}