Skip to main content

converge_kernel/
provider_selection.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Provider selection via bipartite matching.
5//!
6//! Reads a [`ProviderRequest`] from context, matches required capabilities
7//! against the registered backend pool using Hopcroft-Karp, and proposes a
8//! [`ProviderAssignment`] to [`ContextKey::Strategies`].
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
15use converge_provider::{
16    Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
17};
18
19// ── Suggestor ─────────────────────────────────────────────────────────────────
20
21const REQUEST_PREFIX: &str = "provider-request:";
22const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
23const MALFORMED_PREFIX: &str = "provider-request-error:";
24
25/// Routes required capabilities to available backends via bipartite matching.
26///
27/// # Construction
28///
29/// ```rust,ignore
30/// let backends: Vec<Arc<dyn Backend>> = product_extension_backends();
31///
32/// engine.register_suggestor(ProviderSelectionSuggestor::new(backends));
33/// ```
34pub struct ProviderSelectionSuggestor {
35    backends: Vec<Arc<dyn Backend>>,
36}
37
38impl ProviderSelectionSuggestor {
39    pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
40        Self { backends }
41    }
42}
43
44#[async_trait]
45impl Suggestor for ProviderSelectionSuggestor {
46    fn name(&self) -> &str {
47        "ProviderSelectionSuggestor"
48    }
49
50    fn dependencies(&self) -> &[ContextKey] {
51        &[ContextKey::Seeds]
52    }
53
54    fn accepts(&self, ctx: &dyn Context) -> bool {
55        ctx.get(ContextKey::Seeds).iter().any(|f| {
56            f.id().as_str().starts_with(REQUEST_PREFIX)
57                && match serde_json::from_str::<ProviderRequest>(f.content()) {
58                    Ok(_) => !assignment_exists(ctx, request_id(f.id().as_str())),
59                    Err(_) => !malformed_diagnostic_exists(ctx, f.id().as_str()),
60                }
61        })
62    }
63
64    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
65        let mut proposals = Vec::new();
66
67        for fact in ctx
68            .get(ContextKey::Seeds)
69            .iter()
70            .filter(|f| f.id().as_str().starts_with(REQUEST_PREFIX))
71        {
72            match serde_json::from_str::<ProviderRequest>(fact.content()) {
73                Ok(req) => {
74                    if assignment_exists(ctx, request_id(fact.id().as_str())) {
75                        continue;
76                    }
77
78                    let assignment = route(&req, &self.backends);
79                    proposals.push(
80                        ProposedFact::new(
81                            ContextKey::Strategies,
82                            format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
83                            serde_json::to_string(&assignment).unwrap_or_default(),
84                            self.name(),
85                        )
86                        .with_confidence(assignment.coverage_ratio),
87                    );
88                }
89                Err(error) => {
90                    if malformed_diagnostic_exists(ctx, fact.id().as_str()) {
91                        continue;
92                    }
93
94                    let diagnostic = serde_json::json!({
95                        "request_fact_id": fact.id(),
96                        "message": "malformed provider request ignored",
97                        "error": error.to_string(),
98                    });
99                    proposals.push(
100                        ProposedFact::new(
101                            ContextKey::Diagnostic,
102                            malformed_diagnostic_id(fact.id().as_str()),
103                            diagnostic.to_string(),
104                            self.name(),
105                        )
106                        .with_confidence(1.0),
107                    );
108                }
109            }
110        }
111
112        if proposals.is_empty() {
113            AgentEffect::empty()
114        } else {
115            AgentEffect::with_proposals(proposals)
116        }
117    }
118}
119
120// ── Matching logic ────────────────────────────────────────────────────────────
121
122fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
123    if let Some(requirements) = &req.backend_requirements {
124        return route_backend_requirements(req, requirements, backends);
125    }
126
127    // Left = required capability slots (index = position in req.required_capabilities).
128    // Right = backends (index = position in `backends`).
129    // Edge: backends[j].has_capability(req.required_capabilities[i]).
130    let edges: Vec<(usize, usize)> = req
131        .required_capabilities
132        .iter()
133        .enumerate()
134        .flat_map(|(i, cap)| {
135            let cap = cap.clone();
136            backends
137                .iter()
138                .enumerate()
139                .filter(move |(_, b)| b.has_capability(cap.clone()))
140                .map(move |(j, _)| (i, j))
141        })
142        .collect();
143
144    let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
145        .unwrap_or_default();
146
147    let mut covered = vec![false; req.required_capabilities.len()];
148    let mut assignments = Vec::with_capacity(matching.size);
149
150    for (cap_idx, backend_idx) in &matching.pairs {
151        assignments.push(CapabilityAssignment {
152            capability: req.required_capabilities[*cap_idx].clone(),
153            backend_name: backends[*backend_idx].name().to_string(),
154        });
155        covered[*cap_idx] = true;
156    }
157
158    let unmatched = req
159        .required_capabilities
160        .iter()
161        .enumerate()
162        .filter(|(i, _)| !covered[*i])
163        .map(|(_, c)| c.clone())
164        .collect::<Vec<_>>();
165
166    let coverage_ratio = if req.required_capabilities.is_empty() {
167        1.0
168    } else {
169        matching.size as f64 / req.required_capabilities.len() as f64
170    };
171
172    ProviderAssignment {
173        request_id: req.id.clone(),
174        assignments,
175        unmatched,
176        coverage_ratio,
177    }
178}
179
180fn route_backend_requirements(
181    req: &ProviderRequest,
182    requirements: &BackendRequirements,
183    backends: &[Arc<dyn Backend>],
184) -> ProviderAssignment {
185    let required_capabilities = if requirements.required_capabilities.is_empty() {
186        req.required_capabilities.clone()
187    } else {
188        requirements.required_capabilities.clone()
189    };
190
191    let matched_backend = backends.iter().find(|backend| {
192        backend.kind() == requirements.kind
193            && required_capabilities
194                .iter()
195                .all(|capability| backend.has_capability(capability.clone()))
196            && (!requirements.requires_replay || backend.supports_replay())
197            && (!requirements.requires_offline || !backend.requires_network())
198    });
199
200    if let Some(backend) = matched_backend {
201        let assignments = required_capabilities
202            .iter()
203            .cloned()
204            .map(|capability| CapabilityAssignment {
205                capability,
206                backend_name: backend.name().to_string(),
207            })
208            .collect::<Vec<_>>();
209        return ProviderAssignment {
210            request_id: req.id.clone(),
211            assignments,
212            unmatched: Vec::new(),
213            coverage_ratio: 1.0,
214        };
215    }
216
217    let coverage_ratio = if required_capabilities.is_empty() {
218        1.0
219    } else {
220        0.0
221    };
222    ProviderAssignment {
223        request_id: req.id.clone(),
224        assignments: Vec::new(),
225        unmatched: required_capabilities,
226        coverage_ratio,
227    }
228}
229
230// ── Helpers ───────────────────────────────────────────────────────────────────
231
232fn request_id(fact_id: &str) -> &str {
233    fact_id.trim_start_matches(REQUEST_PREFIX)
234}
235
236fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
237    let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
238    ctx.get(ContextKey::Strategies)
239        .iter()
240        .any(|f| f.id().as_str() == assignment_id)
241}
242
243fn malformed_diagnostic_id(fact_id: &str) -> String {
244    format!("{MALFORMED_PREFIX}{fact_id}")
245}
246
247fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
248    let diagnostic_id = malformed_diagnostic_id(fact_id);
249    ctx.get(ContextKey::Diagnostic)
250        .iter()
251        .any(|fact| fact.id().as_str() == diagnostic_id)
252}
253
254// ── Tests ─────────────────────────────────────────────────────────────────────
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use converge_core::{ContextState, Engine};
260    use converge_provider::{BackendKind, Capability};
261
262    struct MockBackend {
263        name: &'static str,
264        kind: BackendKind,
265        capabilities: Vec<Capability>,
266        supports_replay: bool,
267        requires_network: bool,
268    }
269
270    impl Backend for MockBackend {
271        fn name(&self) -> &str {
272            self.name
273        }
274        fn kind(&self) -> BackendKind {
275            self.kind.clone()
276        }
277        fn capabilities(&self) -> Vec<Capability> {
278            self.capabilities.clone()
279        }
280        fn supports_replay(&self) -> bool {
281            self.supports_replay
282        }
283        fn requires_network(&self) -> bool {
284            self.requires_network
285        }
286    }
287
288    fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
289        backend_with(name, BackendKind::Llm, caps, false, true)
290    }
291
292    fn backend_with(
293        name: &'static str,
294        kind: BackendKind,
295        caps: &[Capability],
296        supports_replay: bool,
297        requires_network: bool,
298    ) -> Arc<dyn Backend> {
299        Arc::new(MockBackend {
300            name,
301            kind,
302            capabilities: caps.to_vec(),
303            supports_replay,
304            requires_network,
305        })
306    }
307
308    fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
309        ProviderRequest {
310            id: id.to_string(),
311            required_capabilities: caps.to_vec(),
312            backend_requirements: None,
313        }
314    }
315
316    #[test]
317    fn full_coverage_when_all_capabilities_available() {
318        let pool = vec![
319            backend("anthropic", &[Capability::Reasoning]),
320            backend("kong", &[Capability::AccessControl]),
321            backend("elastic", &[Capability::FullTextSearch]),
322        ];
323        let req = request(
324            "req-1",
325            &[
326                Capability::Reasoning,
327                Capability::AccessControl,
328                Capability::FullTextSearch,
329            ],
330        );
331
332        let assignment = route(&req, &pool);
333
334        assert_eq!(assignment.assignments.len(), 3);
335        assert!(assignment.unmatched.is_empty());
336        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
337    }
338
339    #[test]
340    fn partial_coverage_when_capability_missing() {
341        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
342        let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
343
344        let assignment = route(&req, &pool);
345
346        assert_eq!(assignment.assignments.len(), 1);
347        assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
348        assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
349    }
350
351    #[test]
352    fn no_double_booking_with_two_same_capability_slots() {
353        let pool = vec![
354            backend("anthropic", &[Capability::Reasoning]),
355            backend("openai", &[Capability::Reasoning]),
356        ];
357        let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
358
359        let assignment = route(&req, &pool);
360
361        assert_eq!(assignment.assignments.len(), 2);
362        let names: Vec<_> = assignment
363            .assignments
364            .iter()
365            .map(|a| &a.backend_name)
366            .collect();
367        let unique: std::collections::HashSet<_> = names.iter().collect();
368        assert_eq!(unique.len(), 2);
369    }
370
371    #[test]
372    fn multi_capability_backend_can_only_fill_one_slot() {
373        // One backend that has both capabilities but should only fill one slot.
374        let pool = vec![backend(
375            "all-in-one",
376            &[Capability::Reasoning, Capability::AccessControl],
377        )];
378        let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
379
380        let assignment = route(&req, &pool);
381
382        // Only one slot filled — backend can't be double-booked.
383        assert_eq!(assignment.assignments.len(), 1);
384        assert_eq!(assignment.unmatched.len(), 1);
385    }
386
387    #[test]
388    fn empty_pool_yields_zero_coverage() {
389        let req = request("req-5", &[Capability::Reasoning]);
390        let assignment = route(&req, &[]);
391        assert_eq!(assignment.coverage_ratio, 0.0);
392        assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
393    }
394
395    #[test]
396    fn empty_request_yields_full_coverage() {
397        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
398        let req = request("req-6", &[]);
399        let assignment = route(&req, &pool);
400        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
401        assert!(assignment.assignments.is_empty());
402    }
403
404    #[test]
405    fn backend_requirements_select_one_backend_satisfying_role_constraints() {
406        let pool = vec![
407            backend("remote-llm", &[Capability::AccessControl]),
408            backend_with(
409                "local-policy",
410                BackendKind::Policy,
411                &[Capability::AccessControl],
412                true,
413                false,
414            ),
415        ];
416        let req = ProviderRequest {
417            id: "policy-role".to_string(),
418            required_capabilities: vec![],
419            backend_requirements: Some(
420                BackendRequirements::access_policy()
421                    .with_replay()
422                    .with_offline(),
423            ),
424        };
425
426        let assignment = route(&req, &pool);
427
428        assert_eq!(assignment.assignments.len(), 1);
429        assert_eq!(assignment.assignments[0].backend_name, "local-policy");
430        assert!(assignment.unmatched.is_empty());
431        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
432    }
433
434    #[test]
435    fn repeated_routing_is_deterministic_for_equal_candidates() {
436        let pool = vec![
437            backend("reasoner-a", &[Capability::Reasoning]),
438            backend("reasoner-b", &[Capability::Reasoning]),
439            backend("policy-a", &[Capability::AccessControl]),
440        ];
441        let req = request(
442            "req-7",
443            &[
444                Capability::Reasoning,
445                Capability::Reasoning,
446                Capability::AccessControl,
447            ],
448        );
449
450        let first = route(&req, &pool);
451        let second = route(&req, &pool);
452
453        assert_eq!(first.assignments, second.assignments);
454        assert_eq!(first.unmatched, second.unmatched);
455        assert_eq!(first.coverage_ratio, second.coverage_ratio);
456    }
457
458    #[tokio::test]
459    async fn malformed_request_emits_diagnostic_once() {
460        let mut engine = Engine::new();
461        engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
462            "anthropic",
463            &[Capability::Reasoning],
464        )]));
465
466        let mut ctx = ContextState::new();
467        ctx.add_input(ContextKey::Seeds, "provider-request:broken", "{")
468            .expect("seed should stage");
469
470        let first = engine.run(ctx).await.expect("run should converge");
471        let diagnostics = first.context.get(ContextKey::Diagnostic);
472        assert_eq!(diagnostics.len(), 1);
473        assert_eq!(
474            diagnostics[0].id(),
475            "provider-request-error:provider-request:broken"
476        );
477        assert!(!first.context.has(ContextKey::Strategies));
478
479        let mut rerun_engine = Engine::new();
480        rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
481            "anthropic",
482            &[Capability::Reasoning],
483        )]));
484        let second = rerun_engine
485            .run(first.context.clone())
486            .await
487            .expect("rerun should converge");
488        assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
489    }
490}