Skip to main content

converge_provider/
selection_suggestor.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! Provider selection via bipartite matching.
5//!
6//! Reads a [`ProviderRequest`] from context, matches required capabilities
7//! against the registered backend pool using Hopcroft-Karp, and proposes a
8//! [`ProviderAssignment`] to [`ContextKey::Strategies`].
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
15use converge_provider_api::{
16    Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
17};
18
19// ── Suggestor ─────────────────────────────────────────────────────────────────
20
21const REQUEST_PREFIX: &str = "provider-request:";
22const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
23const MALFORMED_PREFIX: &str = "provider-request-error:";
24
25/// Routes required capabilities to available backends via bipartite matching.
26///
27/// # Construction
28///
29/// ```rust,ignore
30/// let backends: Vec<Arc<dyn Backend>> = vec![
31///     Arc::new(AnthropicBackend::from_env()),
32///     Arc::new(KongBackend::from_env()),
33/// ];
34///
35/// engine.register_suggestor(ProviderSelectionSuggestor::new(backends));
36/// ```
37pub struct ProviderSelectionSuggestor {
38    backends: Vec<Arc<dyn Backend>>,
39}
40
41impl ProviderSelectionSuggestor {
42    pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
43        Self { backends }
44    }
45}
46
47#[async_trait]
48impl Suggestor for ProviderSelectionSuggestor {
49    fn name(&self) -> &str {
50        "ProviderSelectionSuggestor"
51    }
52
53    fn dependencies(&self) -> &[ContextKey] {
54        &[ContextKey::Seeds]
55    }
56
57    fn accepts(&self, ctx: &dyn Context) -> bool {
58        ctx.get(ContextKey::Seeds).iter().any(|f| {
59            f.id.starts_with(REQUEST_PREFIX)
60                && match serde_json::from_str::<ProviderRequest>(&f.content) {
61                    Ok(_) => !assignment_exists(ctx, request_id(&f.id)),
62                    Err(_) => !malformed_diagnostic_exists(ctx, &f.id),
63                }
64        })
65    }
66
67    async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
68        let mut proposals = Vec::new();
69
70        for fact in ctx
71            .get(ContextKey::Seeds)
72            .iter()
73            .filter(|f| f.id.starts_with(REQUEST_PREFIX))
74        {
75            match serde_json::from_str::<ProviderRequest>(&fact.content) {
76                Ok(req) => {
77                    if assignment_exists(ctx, request_id(&fact.id)) {
78                        continue;
79                    }
80
81                    let assignment = route(&req, &self.backends);
82                    proposals.push(
83                        ProposedFact::new(
84                            ContextKey::Strategies,
85                            format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
86                            serde_json::to_string(&assignment).unwrap_or_default(),
87                            self.name(),
88                        )
89                        .with_confidence(assignment.coverage_ratio),
90                    );
91                }
92                Err(error) => {
93                    if malformed_diagnostic_exists(ctx, &fact.id) {
94                        continue;
95                    }
96
97                    let diagnostic = serde_json::json!({
98                        "request_fact_id": fact.id,
99                        "message": "malformed provider request ignored",
100                        "error": error.to_string(),
101                    });
102                    proposals.push(
103                        ProposedFact::new(
104                            ContextKey::Diagnostic,
105                            malformed_diagnostic_id(&fact.id),
106                            diagnostic.to_string(),
107                            self.name(),
108                        )
109                        .with_confidence(1.0),
110                    );
111                }
112            }
113        }
114
115        if proposals.is_empty() {
116            AgentEffect::empty()
117        } else {
118            AgentEffect::with_proposals(proposals)
119        }
120    }
121}
122
123// ── Matching logic ────────────────────────────────────────────────────────────
124
125fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
126    if let Some(requirements) = &req.backend_requirements {
127        return route_backend_requirements(req, requirements, backends);
128    }
129
130    // Left = required capability slots (index = position in req.required_capabilities).
131    // Right = backends (index = position in `backends`).
132    // Edge: backends[j].has_capability(req.required_capabilities[i]).
133    let edges: Vec<(usize, usize)> = req
134        .required_capabilities
135        .iter()
136        .enumerate()
137        .flat_map(|(i, cap)| {
138            let cap = cap.clone();
139            backends
140                .iter()
141                .enumerate()
142                .filter(move |(_, b)| b.has_capability(cap.clone()))
143                .map(move |(j, _)| (i, j))
144        })
145        .collect();
146
147    let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
148        .unwrap_or_default();
149
150    let mut covered = vec![false; req.required_capabilities.len()];
151    let mut assignments = Vec::with_capacity(matching.size);
152
153    for (cap_idx, backend_idx) in &matching.pairs {
154        assignments.push(CapabilityAssignment {
155            capability: req.required_capabilities[*cap_idx].clone(),
156            backend_name: backends[*backend_idx].name().to_string(),
157        });
158        covered[*cap_idx] = true;
159    }
160
161    let unmatched = req
162        .required_capabilities
163        .iter()
164        .enumerate()
165        .filter(|(i, _)| !covered[*i])
166        .map(|(_, c)| c.clone())
167        .collect::<Vec<_>>();
168
169    let coverage_ratio = if req.required_capabilities.is_empty() {
170        1.0
171    } else {
172        matching.size as f64 / req.required_capabilities.len() as f64
173    };
174
175    ProviderAssignment {
176        request_id: req.id.clone(),
177        assignments,
178        unmatched,
179        coverage_ratio,
180    }
181}
182
183fn route_backend_requirements(
184    req: &ProviderRequest,
185    requirements: &BackendRequirements,
186    backends: &[Arc<dyn Backend>],
187) -> ProviderAssignment {
188    let required_capabilities = if requirements.required_capabilities.is_empty() {
189        req.required_capabilities.clone()
190    } else {
191        requirements.required_capabilities.clone()
192    };
193
194    let matched_backend = backends.iter().find(|backend| {
195        backend.kind() == requirements.kind
196            && required_capabilities
197                .iter()
198                .all(|capability| backend.has_capability(capability.clone()))
199            && (!requirements.requires_replay || backend.supports_replay())
200            && (!requirements.requires_offline || !backend.requires_network())
201    });
202
203    if let Some(backend) = matched_backend {
204        let assignments = required_capabilities
205            .iter()
206            .cloned()
207            .map(|capability| CapabilityAssignment {
208                capability,
209                backend_name: backend.name().to_string(),
210            })
211            .collect::<Vec<_>>();
212        return ProviderAssignment {
213            request_id: req.id.clone(),
214            assignments,
215            unmatched: Vec::new(),
216            coverage_ratio: 1.0,
217        };
218    }
219
220    let coverage_ratio = if required_capabilities.is_empty() {
221        1.0
222    } else {
223        0.0
224    };
225    ProviderAssignment {
226        request_id: req.id.clone(),
227        assignments: Vec::new(),
228        unmatched: required_capabilities,
229        coverage_ratio,
230    }
231}
232
233// ── Helpers ───────────────────────────────────────────────────────────────────
234
235fn request_id(fact_id: &str) -> &str {
236    fact_id.trim_start_matches(REQUEST_PREFIX)
237}
238
239fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
240    let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
241    ctx.get(ContextKey::Strategies)
242        .iter()
243        .any(|f| f.id == assignment_id)
244}
245
246fn malformed_diagnostic_id(fact_id: &str) -> String {
247    format!("{MALFORMED_PREFIX}{fact_id}")
248}
249
250fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
251    let diagnostic_id = malformed_diagnostic_id(fact_id);
252    ctx.get(ContextKey::Diagnostic)
253        .iter()
254        .any(|fact| fact.id == diagnostic_id)
255}
256
257// ── Tests ─────────────────────────────────────────────────────────────────────
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use converge_core::{ContextState, Engine};
263    use converge_provider_api::{BackendKind, Capability};
264
265    struct MockBackend {
266        name: &'static str,
267        kind: BackendKind,
268        capabilities: Vec<Capability>,
269        supports_replay: bool,
270        requires_network: bool,
271    }
272
273    impl Backend for MockBackend {
274        fn name(&self) -> &str {
275            self.name
276        }
277        fn kind(&self) -> BackendKind {
278            self.kind.clone()
279        }
280        fn capabilities(&self) -> Vec<Capability> {
281            self.capabilities.clone()
282        }
283        fn supports_replay(&self) -> bool {
284            self.supports_replay
285        }
286        fn requires_network(&self) -> bool {
287            self.requires_network
288        }
289    }
290
291    fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
292        backend_with(name, BackendKind::Llm, caps, false, true)
293    }
294
295    fn backend_with(
296        name: &'static str,
297        kind: BackendKind,
298        caps: &[Capability],
299        supports_replay: bool,
300        requires_network: bool,
301    ) -> Arc<dyn Backend> {
302        Arc::new(MockBackend {
303            name,
304            kind,
305            capabilities: caps.to_vec(),
306            supports_replay,
307            requires_network,
308        })
309    }
310
311    fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
312        ProviderRequest {
313            id: id.to_string(),
314            required_capabilities: caps.to_vec(),
315            backend_requirements: None,
316        }
317    }
318
319    #[test]
320    fn full_coverage_when_all_capabilities_available() {
321        let pool = vec![
322            backend("anthropic", &[Capability::Reasoning]),
323            backend("kong", &[Capability::AccessControl]),
324            backend("elastic", &[Capability::FullTextSearch]),
325        ];
326        let req = request(
327            "req-1",
328            &[
329                Capability::Reasoning,
330                Capability::AccessControl,
331                Capability::FullTextSearch,
332            ],
333        );
334
335        let assignment = route(&req, &pool);
336
337        assert_eq!(assignment.assignments.len(), 3);
338        assert!(assignment.unmatched.is_empty());
339        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
340    }
341
342    #[test]
343    fn partial_coverage_when_capability_missing() {
344        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
345        let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
346
347        let assignment = route(&req, &pool);
348
349        assert_eq!(assignment.assignments.len(), 1);
350        assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
351        assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
352    }
353
354    #[test]
355    fn no_double_booking_with_two_same_capability_slots() {
356        let pool = vec![
357            backend("anthropic", &[Capability::Reasoning]),
358            backend("openai", &[Capability::Reasoning]),
359        ];
360        let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
361
362        let assignment = route(&req, &pool);
363
364        assert_eq!(assignment.assignments.len(), 2);
365        let names: Vec<_> = assignment
366            .assignments
367            .iter()
368            .map(|a| &a.backend_name)
369            .collect();
370        let unique: std::collections::HashSet<_> = names.iter().collect();
371        assert_eq!(unique.len(), 2);
372    }
373
374    #[test]
375    fn multi_capability_backend_can_only_fill_one_slot() {
376        // One backend that has both capabilities but should only fill one slot.
377        let pool = vec![backend(
378            "all-in-one",
379            &[Capability::Reasoning, Capability::AccessControl],
380        )];
381        let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
382
383        let assignment = route(&req, &pool);
384
385        // Only one slot filled — backend can't be double-booked.
386        assert_eq!(assignment.assignments.len(), 1);
387        assert_eq!(assignment.unmatched.len(), 1);
388    }
389
390    #[test]
391    fn empty_pool_yields_zero_coverage() {
392        let req = request("req-5", &[Capability::Reasoning]);
393        let assignment = route(&req, &[]);
394        assert_eq!(assignment.coverage_ratio, 0.0);
395        assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
396    }
397
398    #[test]
399    fn empty_request_yields_full_coverage() {
400        let pool = vec![backend("anthropic", &[Capability::Reasoning])];
401        let req = request("req-6", &[]);
402        let assignment = route(&req, &pool);
403        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
404        assert!(assignment.assignments.is_empty());
405    }
406
407    #[test]
408    fn backend_requirements_select_one_backend_satisfying_role_constraints() {
409        let pool = vec![
410            backend("remote-llm", &[Capability::AccessControl]),
411            backend_with(
412                "local-policy",
413                BackendKind::Policy,
414                &[Capability::AccessControl],
415                true,
416                false,
417            ),
418        ];
419        let req = ProviderRequest {
420            id: "policy-role".to_string(),
421            required_capabilities: vec![],
422            backend_requirements: Some(
423                BackendRequirements::access_policy()
424                    .with_replay()
425                    .with_offline(),
426            ),
427        };
428
429        let assignment = route(&req, &pool);
430
431        assert_eq!(assignment.assignments.len(), 1);
432        assert_eq!(assignment.assignments[0].backend_name, "local-policy");
433        assert!(assignment.unmatched.is_empty());
434        assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
435    }
436
437    #[test]
438    fn repeated_routing_is_deterministic_for_equal_candidates() {
439        let pool = vec![
440            backend("reasoner-a", &[Capability::Reasoning]),
441            backend("reasoner-b", &[Capability::Reasoning]),
442            backend("policy-a", &[Capability::AccessControl]),
443        ];
444        let req = request(
445            "req-7",
446            &[
447                Capability::Reasoning,
448                Capability::Reasoning,
449                Capability::AccessControl,
450            ],
451        );
452
453        let first = route(&req, &pool);
454        let second = route(&req, &pool);
455
456        assert_eq!(first.assignments, second.assignments);
457        assert_eq!(first.unmatched, second.unmatched);
458        assert_eq!(first.coverage_ratio, second.coverage_ratio);
459    }
460
461    #[tokio::test]
462    async fn malformed_request_emits_diagnostic_once() {
463        let mut engine = Engine::new();
464        engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
465            "anthropic",
466            &[Capability::Reasoning],
467        )]));
468
469        let mut ctx = ContextState::new();
470        ctx.add_input(ContextKey::Seeds, "provider-request:broken", "{")
471            .expect("seed should stage");
472
473        let first = engine.run(ctx).await.expect("run should converge");
474        let diagnostics = first.context.get(ContextKey::Diagnostic);
475        assert_eq!(diagnostics.len(), 1);
476        assert_eq!(
477            diagnostics[0].id,
478            "provider-request-error:provider-request:broken"
479        );
480        assert!(!first.context.has(ContextKey::Strategies));
481
482        let mut rerun_engine = Engine::new();
483        rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
484            "anthropic",
485            &[Capability::Reasoning],
486        )]));
487        let second = rerun_engine
488            .run(first.context.clone())
489            .await
490            .expect("rerun should converge");
491        assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
492    }
493}