Skip to main content

converge_kernel/
provider_selection.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Provider selection via bipartite matching.
5//!
6//! Reads a [`ProviderRequest`] from context, matches required capabilities
7//! against the registered backend pool using Hopcroft-Karp, and proposes a
8//! [`ProviderAssignment`] to [`ContextKey::Strategies`].
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::Provenance;
15use converge_pack::ProvenanceSource;
16use converge_pack::{
17    AgentEffect, Context, ContextKey, DiagnosticPayload, FactPayload, ProposedFact, Suggestor,
18};
19use converge_provider::{
20    Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
21};
22use serde::{Deserialize, Serialize};
23
24/// Typed fact payload for provider selection requests.
25#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26#[serde(deny_unknown_fields)]
27pub struct ProviderRequestPayload {
28    request: ProviderRequest,
29}
30
31impl ProviderRequestPayload {
32    #[must_use]
33    pub fn new(request: ProviderRequest) -> Self {
34        Self { request }
35    }
36
37    #[must_use]
38    pub fn request(&self) -> &ProviderRequest {
39        &self.request
40    }
41}
42
43impl From<ProviderRequest> for ProviderRequestPayload {
44    fn from(request: ProviderRequest) -> Self {
45        Self::new(request)
46    }
47}
48
49impl FactPayload for ProviderRequestPayload {
50    const FAMILY: &'static str = "converge.kernel.provider.request";
51    const VERSION: u16 = 1;
52}
53
54/// Typed fact payload for provider selection assignments.
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56#[serde(deny_unknown_fields)]
57pub struct ProviderAssignmentPayload {
58    assignment: ProviderAssignment,
59}
60
61impl ProviderAssignmentPayload {
62    #[must_use]
63    pub fn new(assignment: ProviderAssignment) -> Self {
64        Self { assignment }
65    }
66
67    #[must_use]
68    pub fn assignment(&self) -> &ProviderAssignment {
69        &self.assignment
70    }
71}
72
73impl From<ProviderAssignment> for ProviderAssignmentPayload {
74    fn from(assignment: ProviderAssignment) -> Self {
75        Self::new(assignment)
76    }
77}
78
79impl FactPayload for ProviderAssignmentPayload {
80    const FAMILY: &'static str = "converge.kernel.provider.assignment";
81    const VERSION: u16 = 1;
82}
83
84// ── Suggestor ─────────────────────────────────────────────────────────────────
85
86const REQUEST_PREFIX: &str = "provider-request:";
87const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
88const MALFORMED_PREFIX: &str = "provider-request-error:";
89
90/// Routes required capabilities to available backends via bipartite matching.
91///
92/// # Construction
93///
94/// ```rust,ignore
95/// let backends: Vec<Arc<dyn Backend>> = product_extension_backends();
96///
97/// engine.register_suggestor(ProviderSelectionSuggestor::new(backends));
98/// ```
99pub struct ProviderSelectionSuggestor {
100    backends: Vec<Arc<dyn Backend>>,
101}
102
103/// Canonical provenance marker for `converge-kernel` fact-emitting
104/// suggestors. Currently used by [`ProviderSelectionSuggestor`].
105#[derive(Copy, Clone, Debug)]
106pub struct ConvergeKernel;
107
108impl ProvenanceSource for ConvergeKernel {
109    fn as_str(&self) -> &'static str {
110        "converge-kernel"
111    }
112}
113
114/// Canonical provenance const for [`ConvergeKernel`].
115pub const CONVERGE_KERNEL_PROVENANCE: ConvergeKernel = ConvergeKernel;
116
117impl ProviderSelectionSuggestor {
118    pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
119        Self { backends }
120    }
121}
122
123#[async_trait]
124impl Suggestor for ProviderSelectionSuggestor {
125    fn name(&self) -> &str {
126        "ProviderSelectionSuggestor"
127    }
128
129    fn provenance(&self) -> Provenance {
130        CONVERGE_KERNEL_PROVENANCE.provenance()
131    }
132
133    fn dependencies(&self) -> &[ContextKey] {
134        &[ContextKey::Seeds]
135    }
136
137    fn accepts(&self, ctx: &dyn Context) -> bool {
138        ctx.get(ContextKey::Seeds).iter().any(|f| {
139            f.id().as_str().starts_with(REQUEST_PREFIX)
140                && match f.payload::<ProviderRequestPayload>() {
141                    Some(_) => !assignment_exists(ctx, request_id(f.id().as_str())),
142                    None => !malformed_diagnostic_exists(ctx, f.id().as_str()),
143                }
144        })
145    }
146
147    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
148        let mut proposals = Vec::new();
149
150        for fact in ctx
151            .get(ContextKey::Seeds)
152            .iter()
153            .filter(|f| f.id().as_str().starts_with(REQUEST_PREFIX))
154        {
155            match fact.payload::<ProviderRequestPayload>() {
156                Some(payload) => {
157                    if assignment_exists(ctx, request_id(fact.id().as_str())) {
158                        continue;
159                    }
160
161                    let assignment = route(payload.request(), &self.backends);
162                    proposals.push(
163                        ProposedFact::new(
164                            ContextKey::Strategies,
165                            format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
166                            ProviderAssignmentPayload::new(assignment.clone()),
167                            self.provenance(),
168                        )
169                        .with_confidence(assignment.coverage_ratio),
170                    );
171                }
172                None => {
173                    if malformed_diagnostic_exists(ctx, fact.id().as_str()) {
174                        continue;
175                    }
176
177                    proposals.push(
178                        ProposedFact::new(
179                            ContextKey::Diagnostic,
180                            malformed_diagnostic_id(fact.id().as_str()),
181                            DiagnosticPayload::new(
182                                self.name(),
183                                format!(
184                                    "malformed provider request '{}': expected {} v{} payload",
185                                    fact.id(),
186                                    ProviderRequestPayload::FAMILY,
187                                    ProviderRequestPayload::VERSION
188                                ),
189                            ),
190                            self.provenance(),
191                        )
192                        .with_confidence(1.0),
193                    );
194                }
195            }
196        }
197
198        if proposals.is_empty() {
199            AgentEffect::empty()
200        } else {
201            AgentEffect::with_proposals(proposals)
202        }
203    }
204}
205
206// ── Matching logic ────────────────────────────────────────────────────────────
207
208fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
209    if let Some(requirements) = &req.backend_requirements {
210        return route_backend_requirements(req, requirements, backends);
211    }
212
213    // Left = required capability slots (index = position in req.required_capabilities).
214    // Right = backends (index = position in `backends`).
215    // Edge: backends[j].has_capability(req.required_capabilities[i]).
216    let edges: Vec<(usize, usize)> = req
217        .required_capabilities
218        .iter()
219        .enumerate()
220        .flat_map(|(i, cap)| {
221            let cap = cap.clone();
222            backends
223                .iter()
224                .enumerate()
225                .filter(move |(_, b)| b.has_capability(cap.clone()))
226                .map(move |(j, _)| (i, j))
227        })
228        .collect();
229
230    let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
231        .unwrap_or_default();
232
233    let mut covered = vec![false; req.required_capabilities.len()];
234    let mut assignments = Vec::with_capacity(matching.size);
235
236    for (cap_idx, backend_idx) in &matching.pairs {
237        assignments.push(CapabilityAssignment {
238            capability: req.required_capabilities[*cap_idx].clone(),
239            backend_name: backends[*backend_idx].name().to_string(),
240        });
241        covered[*cap_idx] = true;
242    }
243
244    let unmatched = req
245        .required_capabilities
246        .iter()
247        .enumerate()
248        .filter(|(i, _)| !covered[*i])
249        .map(|(_, c)| c.clone())
250        .collect::<Vec<_>>();
251
252    let coverage_ratio = if req.required_capabilities.is_empty() {
253        1.0
254    } else {
255        matching.size as f64 / req.required_capabilities.len() as f64
256    };
257
258    ProviderAssignment {
259        request_id: req.id.clone(),
260        assignments,
261        unmatched,
262        coverage_ratio,
263    }
264}
265
266fn route_backend_requirements(
267    req: &ProviderRequest,
268    requirements: &BackendRequirements,
269    backends: &[Arc<dyn Backend>],
270) -> ProviderAssignment {
271    let required_capabilities = if requirements.required_capabilities.is_empty() {
272        req.required_capabilities.clone()
273    } else {
274        requirements.required_capabilities.clone()
275    };
276
277    let matched_backend = backends.iter().find(|backend| {
278        backend.kind() == requirements.kind
279            && required_capabilities
280                .iter()
281                .all(|capability| backend.has_capability(capability.clone()))
282            && (!requirements.requires_replay || backend.supports_replay())
283            && (!requirements.requires_offline || !backend.requires_network())
284    });
285
286    if let Some(backend) = matched_backend {
287        let assignments = required_capabilities
288            .iter()
289            .cloned()
290            .map(|capability| CapabilityAssignment {
291                capability,
292                backend_name: backend.name().to_string(),
293            })
294            .collect::<Vec<_>>();
295        return ProviderAssignment {
296            request_id: req.id.clone(),
297            assignments,
298            unmatched: Vec::new(),
299            coverage_ratio: 1.0,
300        };
301    }
302
303    let coverage_ratio = if required_capabilities.is_empty() {
304        1.0
305    } else {
306        0.0
307    };
308    ProviderAssignment {
309        request_id: req.id.clone(),
310        assignments: Vec::new(),
311        unmatched: required_capabilities,
312        coverage_ratio,
313    }
314}
315
316// ── Helpers ───────────────────────────────────────────────────────────────────
317
318fn request_id(fact_id: &str) -> &str {
319    fact_id.trim_start_matches(REQUEST_PREFIX)
320}
321
322fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
323    let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
324    ctx.get(ContextKey::Strategies)
325        .iter()
326        .any(|f| f.id().as_str() == assignment_id)
327}
328
329fn malformed_diagnostic_id(fact_id: &str) -> String {
330    format!("{MALFORMED_PREFIX}{fact_id}")
331}
332
333fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
334    let diagnostic_id = malformed_diagnostic_id(fact_id);
335    ctx.get(ContextKey::Diagnostic)
336        .iter()
337        .any(|fact| fact.id().as_str() == diagnostic_id)
338}
339
340// ── Tests ─────────────────────────────────────────────────────────────────────
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use converge_core::{ContextState, Engine};
346    use converge_pack::TextPayload;
347    use converge_provider::{BackendKind, Capability};
348
349    struct MockBackend {
350        name: &'static str,
351        kind: BackendKind,
352        capabilities: Vec<Capability>,
353        supports_replay: bool,
354        requires_network: bool,
355    }
356
357    impl Backend for MockBackend {
358        fn name(&self) -> &str {
359            self.name
360        }
361        fn kind(&self) -> BackendKind {
362            self.kind.clone()
363        }
364        fn capabilities(&self) -> Vec<Capability> {
365            self.capabilities.clone()
366        }
367        fn supports_replay(&self) -> bool {
368            self.supports_replay
369        }
370        fn requires_network(&self) -> bool {
371            self.requires_network
372        }
373    }
374
375    fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
376        backend_with(name, BackendKind::Llm, caps, false, true)
377    }
378
379    fn backend_with(
380        name: &'static str,
381        kind: BackendKind,
382        caps: &[Capability],
383        supports_replay: bool,
384        requires_network: bool,
385    ) -> Arc<dyn Backend> {
386        Arc::new(MockBackend {
387            name,
388            kind,
389            capabilities: caps.to_vec(),
390            supports_replay,
391            requires_network,
392        })
393    }
394
395    fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
396        ProviderRequest {
397            id: id.to_string(),
398            required_capabilities: caps.to_vec(),
399            backend_requirements: None,
400        }
401    }
402
403    #[test]
404    fn full_coverage_when_all_capabilities_available() {
405        let pool = vec![
406            backend("anthropic", &[Capability::Reasoning]),
407            backend("kong", &[Capability::AccessControl]),
408            backend("elastic", &[Capability::FullTextSearch]),
409        ];
410        let req = request(
411            "req-1",
412            &[
413                Capability::Reasoning,
414                Capability::AccessControl,
415                Capability::FullTextSearch,
416            ],
417        );
418
419        let assignment = route(&req, &pool);
420
421        assert_eq!(assignment.assignments.len(), 3);
422        assert!(assignment.unmatched.is_empty());
423        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
424    }
425
426    #[test]
427    fn partial_coverage_when_capability_missing() {
428        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
429        let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
430
431        let assignment = route(&req, &pool);
432
433        assert_eq!(assignment.assignments.len(), 1);
434        assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
435        assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
436    }
437
438    #[test]
439    fn no_double_booking_with_two_same_capability_slots() {
440        let pool = vec![
441            backend("anthropic", &[Capability::Reasoning]),
442            backend("openai", &[Capability::Reasoning]),
443        ];
444        let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
445
446        let assignment = route(&req, &pool);
447
448        assert_eq!(assignment.assignments.len(), 2);
449        let names: Vec<_> = assignment
450            .assignments
451            .iter()
452            .map(|a| &a.backend_name)
453            .collect();
454        let unique: std::collections::HashSet<_> = names.iter().collect();
455        assert_eq!(unique.len(), 2);
456    }
457
458    #[test]
459    fn multi_capability_backend_can_only_fill_one_slot() {
460        // One backend that has both capabilities but should only fill one slot.
461        let pool = vec![backend(
462            "all-in-one",
463            &[Capability::Reasoning, Capability::AccessControl],
464        )];
465        let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
466
467        let assignment = route(&req, &pool);
468
469        // Only one slot filled — backend can't be double-booked.
470        assert_eq!(assignment.assignments.len(), 1);
471        assert_eq!(assignment.unmatched.len(), 1);
472    }
473
474    #[test]
475    fn empty_pool_yields_zero_coverage() {
476        let req = request("req-5", &[Capability::Reasoning]);
477        let assignment = route(&req, &[]);
478        assert_eq!(assignment.coverage_ratio, 0.0);
479        assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
480    }
481
482    #[test]
483    fn empty_request_yields_full_coverage() {
484        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
485        let req = request("req-6", &[]);
486        let assignment = route(&req, &pool);
487        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
488        assert!(assignment.assignments.is_empty());
489    }
490
491    #[test]
492    fn backend_requirements_select_one_backend_satisfying_role_constraints() {
493        let pool = vec![
494            backend("remote-llm", &[Capability::AccessControl]),
495            backend_with(
496                "local-policy",
497                BackendKind::Policy,
498                &[Capability::AccessControl],
499                true,
500                false,
501            ),
502        ];
503        let req = ProviderRequest {
504            id: "policy-role".to_string(),
505            required_capabilities: vec![],
506            backend_requirements: Some(
507                BackendRequirements::access_policy()
508                    .with_replay()
509                    .with_offline(),
510            ),
511        };
512
513        let assignment = route(&req, &pool);
514
515        assert_eq!(assignment.assignments.len(), 1);
516        assert_eq!(assignment.assignments[0].backend_name, "local-policy");
517        assert!(assignment.unmatched.is_empty());
518        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
519    }
520
521    #[test]
522    fn repeated_routing_is_deterministic_for_equal_candidates() {
523        let pool = vec![
524            backend("reasoner-a", &[Capability::Reasoning]),
525            backend("reasoner-b", &[Capability::Reasoning]),
526            backend("policy-a", &[Capability::AccessControl]),
527        ];
528        let req = request(
529            "req-7",
530            &[
531                Capability::Reasoning,
532                Capability::Reasoning,
533                Capability::AccessControl,
534            ],
535        );
536
537        let first = route(&req, &pool);
538        let second = route(&req, &pool);
539
540        assert_eq!(first.assignments, second.assignments);
541        assert_eq!(first.unmatched, second.unmatched);
542        assert_eq!(first.coverage_ratio, second.coverage_ratio);
543    }
544
545    #[tokio::test]
546    async fn malformed_request_emits_diagnostic_once() {
547        let mut engine = Engine::new();
548        engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
549            "anthropic",
550            &[Capability::Reasoning],
551        )]));
552
553        let mut ctx = ContextState::new();
554        ctx.add_proposal(ProposedFact::new(
555            ContextKey::Seeds,
556            "provider-request:broken",
557            TextPayload::new("not a provider request"),
558            CONVERGE_KERNEL_PROVENANCE.provenance(),
559        ))
560        .expect("seed should stage");
561
562        let first = engine.run(ctx).await.expect("run should converge");
563        let diagnostics = first.context.get(ContextKey::Diagnostic);
564        assert_eq!(diagnostics.len(), 1);
565        assert_eq!(
566            diagnostics[0].id(),
567            "provider-request-error:provider-request:broken"
568        );
569        assert!(!first.context.has(ContextKey::Strategies));
570
571        let mut rerun_engine = Engine::new();
572        rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
573            "anthropic",
574            &[Capability::Reasoning],
575        )]));
576        let second = rerun_engine
577            .run(first.context.clone())
578            .await
579            .expect("rerun should converge");
580        assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
581    }
582}