1use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
15use converge_provider_api::{
16 Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
17};
18
19const REQUEST_PREFIX: &str = "provider-request:";
22const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
23const MALFORMED_PREFIX: &str = "provider-request-error:";
24
25pub struct ProviderSelectionSuggestor {
38 backends: Vec<Arc<dyn Backend>>,
39}
40
41impl ProviderSelectionSuggestor {
42 pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
43 Self { backends }
44 }
45}
46
47#[async_trait]
48impl Suggestor for ProviderSelectionSuggestor {
49 fn name(&self) -> &str {
50 "ProviderSelectionSuggestor"
51 }
52
53 fn dependencies(&self) -> &[ContextKey] {
54 &[ContextKey::Seeds]
55 }
56
57 fn accepts(&self, ctx: &dyn Context) -> bool {
58 ctx.get(ContextKey::Seeds).iter().any(|f| {
59 f.id.starts_with(REQUEST_PREFIX)
60 && match serde_json::from_str::<ProviderRequest>(&f.content) {
61 Ok(_) => !assignment_exists(ctx, request_id(&f.id)),
62 Err(_) => !malformed_diagnostic_exists(ctx, &f.id),
63 }
64 })
65 }
66
67 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
68 let mut proposals = Vec::new();
69
70 for fact in ctx
71 .get(ContextKey::Seeds)
72 .iter()
73 .filter(|f| f.id.starts_with(REQUEST_PREFIX))
74 {
75 match serde_json::from_str::<ProviderRequest>(&fact.content) {
76 Ok(req) => {
77 if assignment_exists(ctx, request_id(&fact.id)) {
78 continue;
79 }
80
81 let assignment = route(&req, &self.backends);
82 proposals.push(
83 ProposedFact::new(
84 ContextKey::Strategies,
85 format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
86 serde_json::to_string(&assignment).unwrap_or_default(),
87 self.name(),
88 )
89 .with_confidence(assignment.coverage_ratio),
90 );
91 }
92 Err(error) => {
93 if malformed_diagnostic_exists(ctx, &fact.id) {
94 continue;
95 }
96
97 let diagnostic = serde_json::json!({
98 "request_fact_id": fact.id,
99 "message": "malformed provider request ignored",
100 "error": error.to_string(),
101 });
102 proposals.push(
103 ProposedFact::new(
104 ContextKey::Diagnostic,
105 malformed_diagnostic_id(&fact.id),
106 diagnostic.to_string(),
107 self.name(),
108 )
109 .with_confidence(1.0),
110 );
111 }
112 }
113 }
114
115 if proposals.is_empty() {
116 AgentEffect::empty()
117 } else {
118 AgentEffect::with_proposals(proposals)
119 }
120 }
121}
122
123fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
126 if let Some(requirements) = &req.backend_requirements {
127 return route_backend_requirements(req, requirements, backends);
128 }
129
130 let edges: Vec<(usize, usize)> = req
134 .required_capabilities
135 .iter()
136 .enumerate()
137 .flat_map(|(i, cap)| {
138 let cap = cap.clone();
139 backends
140 .iter()
141 .enumerate()
142 .filter(move |(_, b)| b.has_capability(cap.clone()))
143 .map(move |(j, _)| (i, j))
144 })
145 .collect();
146
147 let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
148 .unwrap_or_default();
149
150 let mut covered = vec![false; req.required_capabilities.len()];
151 let mut assignments = Vec::with_capacity(matching.size);
152
153 for (cap_idx, backend_idx) in &matching.pairs {
154 assignments.push(CapabilityAssignment {
155 capability: req.required_capabilities[*cap_idx].clone(),
156 backend_name: backends[*backend_idx].name().to_string(),
157 });
158 covered[*cap_idx] = true;
159 }
160
161 let unmatched = req
162 .required_capabilities
163 .iter()
164 .enumerate()
165 .filter(|(i, _)| !covered[*i])
166 .map(|(_, c)| c.clone())
167 .collect::<Vec<_>>();
168
169 let coverage_ratio = if req.required_capabilities.is_empty() {
170 1.0
171 } else {
172 matching.size as f64 / req.required_capabilities.len() as f64
173 };
174
175 ProviderAssignment {
176 request_id: req.id.clone(),
177 assignments,
178 unmatched,
179 coverage_ratio,
180 }
181}
182
183fn route_backend_requirements(
184 req: &ProviderRequest,
185 requirements: &BackendRequirements,
186 backends: &[Arc<dyn Backend>],
187) -> ProviderAssignment {
188 let required_capabilities = if requirements.required_capabilities.is_empty() {
189 req.required_capabilities.clone()
190 } else {
191 requirements.required_capabilities.clone()
192 };
193
194 let matched_backend = backends.iter().find(|backend| {
195 backend.kind() == requirements.kind
196 && required_capabilities
197 .iter()
198 .all(|capability| backend.has_capability(capability.clone()))
199 && (!requirements.requires_replay || backend.supports_replay())
200 && (!requirements.requires_offline || !backend.requires_network())
201 });
202
203 if let Some(backend) = matched_backend {
204 let assignments = required_capabilities
205 .iter()
206 .cloned()
207 .map(|capability| CapabilityAssignment {
208 capability,
209 backend_name: backend.name().to_string(),
210 })
211 .collect::<Vec<_>>();
212 return ProviderAssignment {
213 request_id: req.id.clone(),
214 assignments,
215 unmatched: Vec::new(),
216 coverage_ratio: 1.0,
217 };
218 }
219
220 let coverage_ratio = if required_capabilities.is_empty() {
221 1.0
222 } else {
223 0.0
224 };
225 ProviderAssignment {
226 request_id: req.id.clone(),
227 assignments: Vec::new(),
228 unmatched: required_capabilities,
229 coverage_ratio,
230 }
231}
232
233fn request_id(fact_id: &str) -> &str {
236 fact_id.trim_start_matches(REQUEST_PREFIX)
237}
238
239fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
240 let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
241 ctx.get(ContextKey::Strategies)
242 .iter()
243 .any(|f| f.id == assignment_id)
244}
245
246fn malformed_diagnostic_id(fact_id: &str) -> String {
247 format!("{MALFORMED_PREFIX}{fact_id}")
248}
249
250fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
251 let diagnostic_id = malformed_diagnostic_id(fact_id);
252 ctx.get(ContextKey::Diagnostic)
253 .iter()
254 .any(|fact| fact.id == diagnostic_id)
255}
256
257#[cfg(test)]
260mod tests {
261 use super::*;
262 use converge_core::{ContextState, Engine};
263 use converge_provider_api::{BackendKind, Capability};
264
265 struct MockBackend {
266 name: &'static str,
267 kind: BackendKind,
268 capabilities: Vec<Capability>,
269 supports_replay: bool,
270 requires_network: bool,
271 }
272
273 impl Backend for MockBackend {
274 fn name(&self) -> &str {
275 self.name
276 }
277 fn kind(&self) -> BackendKind {
278 self.kind.clone()
279 }
280 fn capabilities(&self) -> Vec<Capability> {
281 self.capabilities.clone()
282 }
283 fn supports_replay(&self) -> bool {
284 self.supports_replay
285 }
286 fn requires_network(&self) -> bool {
287 self.requires_network
288 }
289 }
290
291 fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
292 backend_with(name, BackendKind::Llm, caps, false, true)
293 }
294
295 fn backend_with(
296 name: &'static str,
297 kind: BackendKind,
298 caps: &[Capability],
299 supports_replay: bool,
300 requires_network: bool,
301 ) -> Arc<dyn Backend> {
302 Arc::new(MockBackend {
303 name,
304 kind,
305 capabilities: caps.to_vec(),
306 supports_replay,
307 requires_network,
308 })
309 }
310
311 fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
312 ProviderRequest {
313 id: id.to_string(),
314 required_capabilities: caps.to_vec(),
315 backend_requirements: None,
316 }
317 }
318
319 #[test]
320 fn full_coverage_when_all_capabilities_available() {
321 let pool = vec![
322 backend("anthropic", &[Capability::Reasoning]),
323 backend("kong", &[Capability::AccessControl]),
324 backend("elastic", &[Capability::FullTextSearch]),
325 ];
326 let req = request(
327 "req-1",
328 &[
329 Capability::Reasoning,
330 Capability::AccessControl,
331 Capability::FullTextSearch,
332 ],
333 );
334
335 let assignment = route(&req, &pool);
336
337 assert_eq!(assignment.assignments.len(), 3);
338 assert!(assignment.unmatched.is_empty());
339 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
340 }
341
342 #[test]
343 fn partial_coverage_when_capability_missing() {
344 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
345 let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
346
347 let assignment = route(&req, &pool);
348
349 assert_eq!(assignment.assignments.len(), 1);
350 assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
351 assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
352 }
353
354 #[test]
355 fn no_double_booking_with_two_same_capability_slots() {
356 let pool = vec![
357 backend("anthropic", &[Capability::Reasoning]),
358 backend("openai", &[Capability::Reasoning]),
359 ];
360 let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
361
362 let assignment = route(&req, &pool);
363
364 assert_eq!(assignment.assignments.len(), 2);
365 let names: Vec<_> = assignment
366 .assignments
367 .iter()
368 .map(|a| &a.backend_name)
369 .collect();
370 let unique: std::collections::HashSet<_> = names.iter().collect();
371 assert_eq!(unique.len(), 2);
372 }
373
374 #[test]
375 fn multi_capability_backend_can_only_fill_one_slot() {
376 let pool = vec![backend(
378 "all-in-one",
379 &[Capability::Reasoning, Capability::AccessControl],
380 )];
381 let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
382
383 let assignment = route(&req, &pool);
384
385 assert_eq!(assignment.assignments.len(), 1);
387 assert_eq!(assignment.unmatched.len(), 1);
388 }
389
390 #[test]
391 fn empty_pool_yields_zero_coverage() {
392 let req = request("req-5", &[Capability::Reasoning]);
393 let assignment = route(&req, &[]);
394 assert_eq!(assignment.coverage_ratio, 0.0);
395 assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
396 }
397
398 #[test]
399 fn empty_request_yields_full_coverage() {
400 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
401 let req = request("req-6", &[]);
402 let assignment = route(&req, &pool);
403 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
404 assert!(assignment.assignments.is_empty());
405 }
406
407 #[test]
408 fn backend_requirements_select_one_backend_satisfying_role_constraints() {
409 let pool = vec![
410 backend("remote-llm", &[Capability::AccessControl]),
411 backend_with(
412 "local-policy",
413 BackendKind::Policy,
414 &[Capability::AccessControl],
415 true,
416 false,
417 ),
418 ];
419 let req = ProviderRequest {
420 id: "policy-role".to_string(),
421 required_capabilities: vec![],
422 backend_requirements: Some(
423 BackendRequirements::access_policy()
424 .with_replay()
425 .with_offline(),
426 ),
427 };
428
429 let assignment = route(&req, &pool);
430
431 assert_eq!(assignment.assignments.len(), 1);
432 assert_eq!(assignment.assignments[0].backend_name, "local-policy");
433 assert!(assignment.unmatched.is_empty());
434 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
435 }
436
437 #[test]
438 fn repeated_routing_is_deterministic_for_equal_candidates() {
439 let pool = vec![
440 backend("reasoner-a", &[Capability::Reasoning]),
441 backend("reasoner-b", &[Capability::Reasoning]),
442 backend("policy-a", &[Capability::AccessControl]),
443 ];
444 let req = request(
445 "req-7",
446 &[
447 Capability::Reasoning,
448 Capability::Reasoning,
449 Capability::AccessControl,
450 ],
451 );
452
453 let first = route(&req, &pool);
454 let second = route(&req, &pool);
455
456 assert_eq!(first.assignments, second.assignments);
457 assert_eq!(first.unmatched, second.unmatched);
458 assert_eq!(first.coverage_ratio, second.coverage_ratio);
459 }
460
461 #[tokio::test]
462 async fn malformed_request_emits_diagnostic_once() {
463 let mut engine = Engine::new();
464 engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
465 "anthropic",
466 &[Capability::Reasoning],
467 )]));
468
469 let mut ctx = ContextState::new();
470 ctx.add_input(ContextKey::Seeds, "provider-request:broken", "{")
471 .expect("seed should stage");
472
473 let first = engine.run(ctx).await.expect("run should converge");
474 let diagnostics = first.context.get(ContextKey::Diagnostic);
475 assert_eq!(diagnostics.len(), 1);
476 assert_eq!(
477 diagnostics[0].id,
478 "provider-request-error:provider-request:broken"
479 );
480 assert!(!first.context.has(ContextKey::Strategies));
481
482 let mut rerun_engine = Engine::new();
483 rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
484 "anthropic",
485 &[Capability::Reasoning],
486 )]));
487 let second = rerun_engine
488 .run(first.context.clone())
489 .await
490 .expect("rerun should converge");
491 assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
492 }
493}