Skip to main content

converge_kernel/
provider_selection.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Provider selection via bipartite matching.
5//!
6//! Reads a [`ProviderRequest`] from context, matches required capabilities
7//! against the registered backend pool using Hopcroft-Karp, and proposes a
8//! [`ProviderAssignment`] to [`ContextKey::Strategies`].
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::ProvenanceSource;
15use converge_pack::{
16    AgentEffect, Context, ContextKey, DiagnosticPayload, FactPayload, ProposedFact, Suggestor,
17};
18use converge_provider::{
19    Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
20};
21use serde::{Deserialize, Serialize};
22
23/// Typed fact payload for provider selection requests.
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25#[serde(deny_unknown_fields)]
26pub struct ProviderRequestPayload {
27    request: ProviderRequest,
28}
29
30impl ProviderRequestPayload {
31    #[must_use]
32    pub fn new(request: ProviderRequest) -> Self {
33        Self { request }
34    }
35
36    #[must_use]
37    pub fn request(&self) -> &ProviderRequest {
38        &self.request
39    }
40}
41
42impl From<ProviderRequest> for ProviderRequestPayload {
43    fn from(request: ProviderRequest) -> Self {
44        Self::new(request)
45    }
46}
47
48impl FactPayload for ProviderRequestPayload {
49    const FAMILY: &'static str = "converge.kernel.provider.request";
50    const VERSION: u16 = 1;
51}
52
53/// Typed fact payload for provider selection assignments.
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
55#[serde(deny_unknown_fields)]
56pub struct ProviderAssignmentPayload {
57    assignment: ProviderAssignment,
58}
59
60impl ProviderAssignmentPayload {
61    #[must_use]
62    pub fn new(assignment: ProviderAssignment) -> Self {
63        Self { assignment }
64    }
65
66    #[must_use]
67    pub fn assignment(&self) -> &ProviderAssignment {
68        &self.assignment
69    }
70}
71
72impl From<ProviderAssignment> for ProviderAssignmentPayload {
73    fn from(assignment: ProviderAssignment) -> Self {
74        Self::new(assignment)
75    }
76}
77
78impl FactPayload for ProviderAssignmentPayload {
79    const FAMILY: &'static str = "converge.kernel.provider.assignment";
80    const VERSION: u16 = 1;
81}
82
83// ── Suggestor ─────────────────────────────────────────────────────────────────
84
85const REQUEST_PREFIX: &str = "provider-request:";
86const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
87const MALFORMED_PREFIX: &str = "provider-request-error:";
88
89/// Routes required capabilities to available backends via bipartite matching.
90///
91/// # Construction
92///
93/// ```rust,ignore
94/// let backends: Vec<Arc<dyn Backend>> = product_extension_backends();
95///
96/// engine.register_suggestor(ProviderSelectionSuggestor::new(backends));
97/// ```
98pub struct ProviderSelectionSuggestor {
99    backends: Vec<Arc<dyn Backend>>,
100}
101
102/// Canonical provenance marker for `converge-kernel` fact-emitting
103/// suggestors. Currently used by [`ProviderSelectionSuggestor`].
104#[derive(Copy, Clone, Debug)]
105pub struct ConvergeKernel;
106
107impl ProvenanceSource for ConvergeKernel {
108    fn as_str(&self) -> &'static str {
109        "converge-kernel"
110    }
111}
112
113/// Canonical provenance const for [`ConvergeKernel`].
114pub const CONVERGE_KERNEL_PROVENANCE: ConvergeKernel = ConvergeKernel;
115
116impl ProviderSelectionSuggestor {
117    pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
118        Self { backends }
119    }
120}
121
122#[async_trait]
123impl Suggestor for ProviderSelectionSuggestor {
124    fn name(&self) -> &str {
125        "ProviderSelectionSuggestor"
126    }
127
128    fn provenance(&self) -> &'static str {
129        CONVERGE_KERNEL_PROVENANCE.as_str()
130    }
131
132    fn dependencies(&self) -> &[ContextKey] {
133        &[ContextKey::Seeds]
134    }
135
136    fn accepts(&self, ctx: &dyn Context) -> bool {
137        ctx.get(ContextKey::Seeds).iter().any(|f| {
138            f.id().as_str().starts_with(REQUEST_PREFIX)
139                && match f.payload::<ProviderRequestPayload>() {
140                    Some(_) => !assignment_exists(ctx, request_id(f.id().as_str())),
141                    None => !malformed_diagnostic_exists(ctx, f.id().as_str()),
142                }
143        })
144    }
145
146    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
147        let mut proposals = Vec::new();
148
149        for fact in ctx
150            .get(ContextKey::Seeds)
151            .iter()
152            .filter(|f| f.id().as_str().starts_with(REQUEST_PREFIX))
153        {
154            match fact.payload::<ProviderRequestPayload>() {
155                Some(payload) => {
156                    if assignment_exists(ctx, request_id(fact.id().as_str())) {
157                        continue;
158                    }
159
160                    let assignment = route(payload.request(), &self.backends);
161                    proposals.push(
162                        ProposedFact::new(
163                            ContextKey::Strategies,
164                            format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
165                            ProviderAssignmentPayload::new(assignment.clone()),
166                            self.name().to_string(),
167                        )
168                        .with_confidence(assignment.coverage_ratio),
169                    );
170                }
171                None => {
172                    if malformed_diagnostic_exists(ctx, fact.id().as_str()) {
173                        continue;
174                    }
175
176                    proposals.push(
177                        ProposedFact::new(
178                            ContextKey::Diagnostic,
179                            malformed_diagnostic_id(fact.id().as_str()),
180                            DiagnosticPayload::new(
181                                self.name(),
182                                format!(
183                                    "malformed provider request '{}': expected {} v{} payload",
184                                    fact.id(),
185                                    ProviderRequestPayload::FAMILY,
186                                    ProviderRequestPayload::VERSION
187                                ),
188                            ),
189                            self.name().to_string(),
190                        )
191                        .with_confidence(1.0),
192                    );
193                }
194            }
195        }
196
197        if proposals.is_empty() {
198            AgentEffect::empty()
199        } else {
200            AgentEffect::with_proposals(proposals)
201        }
202    }
203}
204
205// ── Matching logic ────────────────────────────────────────────────────────────
206
207fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
208    if let Some(requirements) = &req.backend_requirements {
209        return route_backend_requirements(req, requirements, backends);
210    }
211
212    // Left = required capability slots (index = position in req.required_capabilities).
213    // Right = backends (index = position in `backends`).
214    // Edge: backends[j].has_capability(req.required_capabilities[i]).
215    let edges: Vec<(usize, usize)> = req
216        .required_capabilities
217        .iter()
218        .enumerate()
219        .flat_map(|(i, cap)| {
220            let cap = cap.clone();
221            backends
222                .iter()
223                .enumerate()
224                .filter(move |(_, b)| b.has_capability(cap.clone()))
225                .map(move |(j, _)| (i, j))
226        })
227        .collect();
228
229    let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
230        .unwrap_or_default();
231
232    let mut covered = vec![false; req.required_capabilities.len()];
233    let mut assignments = Vec::with_capacity(matching.size);
234
235    for (cap_idx, backend_idx) in &matching.pairs {
236        assignments.push(CapabilityAssignment {
237            capability: req.required_capabilities[*cap_idx].clone(),
238            backend_name: backends[*backend_idx].name().to_string(),
239        });
240        covered[*cap_idx] = true;
241    }
242
243    let unmatched = req
244        .required_capabilities
245        .iter()
246        .enumerate()
247        .filter(|(i, _)| !covered[*i])
248        .map(|(_, c)| c.clone())
249        .collect::<Vec<_>>();
250
251    let coverage_ratio = if req.required_capabilities.is_empty() {
252        1.0
253    } else {
254        matching.size as f64 / req.required_capabilities.len() as f64
255    };
256
257    ProviderAssignment {
258        request_id: req.id.clone(),
259        assignments,
260        unmatched,
261        coverage_ratio,
262    }
263}
264
265fn route_backend_requirements(
266    req: &ProviderRequest,
267    requirements: &BackendRequirements,
268    backends: &[Arc<dyn Backend>],
269) -> ProviderAssignment {
270    let required_capabilities = if requirements.required_capabilities.is_empty() {
271        req.required_capabilities.clone()
272    } else {
273        requirements.required_capabilities.clone()
274    };
275
276    let matched_backend = backends.iter().find(|backend| {
277        backend.kind() == requirements.kind
278            && required_capabilities
279                .iter()
280                .all(|capability| backend.has_capability(capability.clone()))
281            && (!requirements.requires_replay || backend.supports_replay())
282            && (!requirements.requires_offline || !backend.requires_network())
283    });
284
285    if let Some(backend) = matched_backend {
286        let assignments = required_capabilities
287            .iter()
288            .cloned()
289            .map(|capability| CapabilityAssignment {
290                capability,
291                backend_name: backend.name().to_string(),
292            })
293            .collect::<Vec<_>>();
294        return ProviderAssignment {
295            request_id: req.id.clone(),
296            assignments,
297            unmatched: Vec::new(),
298            coverage_ratio: 1.0,
299        };
300    }
301
302    let coverage_ratio = if required_capabilities.is_empty() {
303        1.0
304    } else {
305        0.0
306    };
307    ProviderAssignment {
308        request_id: req.id.clone(),
309        assignments: Vec::new(),
310        unmatched: required_capabilities,
311        coverage_ratio,
312    }
313}
314
315// ── Helpers ───────────────────────────────────────────────────────────────────
316
317fn request_id(fact_id: &str) -> &str {
318    fact_id.trim_start_matches(REQUEST_PREFIX)
319}
320
321fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
322    let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
323    ctx.get(ContextKey::Strategies)
324        .iter()
325        .any(|f| f.id().as_str() == assignment_id)
326}
327
328fn malformed_diagnostic_id(fact_id: &str) -> String {
329    format!("{MALFORMED_PREFIX}{fact_id}")
330}
331
332fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
333    let diagnostic_id = malformed_diagnostic_id(fact_id);
334    ctx.get(ContextKey::Diagnostic)
335        .iter()
336        .any(|fact| fact.id().as_str() == diagnostic_id)
337}
338
339// ── Tests ─────────────────────────────────────────────────────────────────────
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use converge_core::{ContextState, Engine};
345    use converge_pack::TextPayload;
346    use converge_provider::{BackendKind, Capability};
347
348    struct MockBackend {
349        name: &'static str,
350        kind: BackendKind,
351        capabilities: Vec<Capability>,
352        supports_replay: bool,
353        requires_network: bool,
354    }
355
356    impl Backend for MockBackend {
357        fn name(&self) -> &str {
358            self.name
359        }
360        fn kind(&self) -> BackendKind {
361            self.kind.clone()
362        }
363        fn capabilities(&self) -> Vec<Capability> {
364            self.capabilities.clone()
365        }
366        fn supports_replay(&self) -> bool {
367            self.supports_replay
368        }
369        fn requires_network(&self) -> bool {
370            self.requires_network
371        }
372    }
373
374    fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
375        backend_with(name, BackendKind::Llm, caps, false, true)
376    }
377
378    fn backend_with(
379        name: &'static str,
380        kind: BackendKind,
381        caps: &[Capability],
382        supports_replay: bool,
383        requires_network: bool,
384    ) -> Arc<dyn Backend> {
385        Arc::new(MockBackend {
386            name,
387            kind,
388            capabilities: caps.to_vec(),
389            supports_replay,
390            requires_network,
391        })
392    }
393
394    fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
395        ProviderRequest {
396            id: id.to_string(),
397            required_capabilities: caps.to_vec(),
398            backend_requirements: None,
399        }
400    }
401
402    #[test]
403    fn full_coverage_when_all_capabilities_available() {
404        let pool = vec![
405            backend("anthropic", &[Capability::Reasoning]),
406            backend("kong", &[Capability::AccessControl]),
407            backend("elastic", &[Capability::FullTextSearch]),
408        ];
409        let req = request(
410            "req-1",
411            &[
412                Capability::Reasoning,
413                Capability::AccessControl,
414                Capability::FullTextSearch,
415            ],
416        );
417
418        let assignment = route(&req, &pool);
419
420        assert_eq!(assignment.assignments.len(), 3);
421        assert!(assignment.unmatched.is_empty());
422        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
423    }
424
425    #[test]
426    fn partial_coverage_when_capability_missing() {
427        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
428        let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
429
430        let assignment = route(&req, &pool);
431
432        assert_eq!(assignment.assignments.len(), 1);
433        assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
434        assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
435    }
436
437    #[test]
438    fn no_double_booking_with_two_same_capability_slots() {
439        let pool = vec![
440            backend("anthropic", &[Capability::Reasoning]),
441            backend("openai", &[Capability::Reasoning]),
442        ];
443        let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
444
445        let assignment = route(&req, &pool);
446
447        assert_eq!(assignment.assignments.len(), 2);
448        let names: Vec<_> = assignment
449            .assignments
450            .iter()
451            .map(|a| &a.backend_name)
452            .collect();
453        let unique: std::collections::HashSet<_> = names.iter().collect();
454        assert_eq!(unique.len(), 2);
455    }
456
457    #[test]
458    fn multi_capability_backend_can_only_fill_one_slot() {
459        // One backend that has both capabilities but should only fill one slot.
460        let pool = vec![backend(
461            "all-in-one",
462            &[Capability::Reasoning, Capability::AccessControl],
463        )];
464        let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
465
466        let assignment = route(&req, &pool);
467
468        // Only one slot filled — backend can't be double-booked.
469        assert_eq!(assignment.assignments.len(), 1);
470        assert_eq!(assignment.unmatched.len(), 1);
471    }
472
473    #[test]
474    fn empty_pool_yields_zero_coverage() {
475        let req = request("req-5", &[Capability::Reasoning]);
476        let assignment = route(&req, &[]);
477        assert_eq!(assignment.coverage_ratio, 0.0);
478        assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
479    }
480
481    #[test]
482    fn empty_request_yields_full_coverage() {
483        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
484        let req = request("req-6", &[]);
485        let assignment = route(&req, &pool);
486        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
487        assert!(assignment.assignments.is_empty());
488    }
489
490    #[test]
491    fn backend_requirements_select_one_backend_satisfying_role_constraints() {
492        let pool = vec![
493            backend("remote-llm", &[Capability::AccessControl]),
494            backend_with(
495                "local-policy",
496                BackendKind::Policy,
497                &[Capability::AccessControl],
498                true,
499                false,
500            ),
501        ];
502        let req = ProviderRequest {
503            id: "policy-role".to_string(),
504            required_capabilities: vec![],
505            backend_requirements: Some(
506                BackendRequirements::access_policy()
507                    .with_replay()
508                    .with_offline(),
509            ),
510        };
511
512        let assignment = route(&req, &pool);
513
514        assert_eq!(assignment.assignments.len(), 1);
515        assert_eq!(assignment.assignments[0].backend_name, "local-policy");
516        assert!(assignment.unmatched.is_empty());
517        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
518    }
519
520    #[test]
521    fn repeated_routing_is_deterministic_for_equal_candidates() {
522        let pool = vec![
523            backend("reasoner-a", &[Capability::Reasoning]),
524            backend("reasoner-b", &[Capability::Reasoning]),
525            backend("policy-a", &[Capability::AccessControl]),
526        ];
527        let req = request(
528            "req-7",
529            &[
530                Capability::Reasoning,
531                Capability::Reasoning,
532                Capability::AccessControl,
533            ],
534        );
535
536        let first = route(&req, &pool);
537        let second = route(&req, &pool);
538
539        assert_eq!(first.assignments, second.assignments);
540        assert_eq!(first.unmatched, second.unmatched);
541        assert_eq!(first.coverage_ratio, second.coverage_ratio);
542    }
543
544    #[tokio::test]
545    async fn malformed_request_emits_diagnostic_once() {
546        let mut engine = Engine::new();
547        engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
548            "anthropic",
549            &[Capability::Reasoning],
550        )]));
551
552        let mut ctx = ContextState::new();
553        ctx.add_proposal(ProposedFact::new(
554            ContextKey::Seeds,
555            "provider-request:broken",
556            TextPayload::new("not a provider request"),
557            "test",
558        ))
559        .expect("seed should stage");
560
561        let first = engine.run(ctx).await.expect("run should converge");
562        let diagnostics = first.context.get(ContextKey::Diagnostic);
563        assert_eq!(diagnostics.len(), 1);
564        assert_eq!(
565            diagnostics[0].id(),
566            "provider-request-error:provider-request:broken"
567        );
568        assert!(!first.context.has(ContextKey::Strategies));
569
570        let mut rerun_engine = Engine::new();
571        rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
572            "anthropic",
573            &[Capability::Reasoning],
574        )]));
575        let second = rerun_engine
576            .run(first.context.clone())
577            .await
578            .expect("rerun should converge");
579        assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
580    }
581}