1use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
15use converge_provider::{
16 Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
17};
18
19const REQUEST_PREFIX: &str = "provider-request:";
22const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
23const MALFORMED_PREFIX: &str = "provider-request-error:";
24
25pub struct ProviderSelectionSuggestor {
35 backends: Vec<Arc<dyn Backend>>,
36}
37
38impl ProviderSelectionSuggestor {
39 pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
40 Self { backends }
41 }
42}
43
44#[async_trait]
45impl Suggestor for ProviderSelectionSuggestor {
46 fn name(&self) -> &str {
47 "ProviderSelectionSuggestor"
48 }
49
50 fn dependencies(&self) -> &[ContextKey] {
51 &[ContextKey::Seeds]
52 }
53
54 fn accepts(&self, ctx: &dyn Context) -> bool {
55 ctx.get(ContextKey::Seeds).iter().any(|f| {
56 f.id().as_str().starts_with(REQUEST_PREFIX)
57 && match serde_json::from_str::<ProviderRequest>(f.content()) {
58 Ok(_) => !assignment_exists(ctx, request_id(f.id().as_str())),
59 Err(_) => !malformed_diagnostic_exists(ctx, f.id().as_str()),
60 }
61 })
62 }
63
64 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
65 let mut proposals = Vec::new();
66
67 for fact in ctx
68 .get(ContextKey::Seeds)
69 .iter()
70 .filter(|f| f.id().as_str().starts_with(REQUEST_PREFIX))
71 {
72 match serde_json::from_str::<ProviderRequest>(fact.content()) {
73 Ok(req) => {
74 if assignment_exists(ctx, request_id(fact.id().as_str())) {
75 continue;
76 }
77
78 let assignment = route(&req, &self.backends);
79 proposals.push(
80 ProposedFact::new(
81 ContextKey::Strategies,
82 format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
83 serde_json::to_string(&assignment).unwrap_or_default(),
84 self.name(),
85 )
86 .with_confidence(assignment.coverage_ratio),
87 );
88 }
89 Err(error) => {
90 if malformed_diagnostic_exists(ctx, fact.id().as_str()) {
91 continue;
92 }
93
94 let diagnostic = serde_json::json!({
95 "request_fact_id": fact.id(),
96 "message": "malformed provider request ignored",
97 "error": error.to_string(),
98 });
99 proposals.push(
100 ProposedFact::new(
101 ContextKey::Diagnostic,
102 malformed_diagnostic_id(fact.id().as_str()),
103 diagnostic.to_string(),
104 self.name(),
105 )
106 .with_confidence(1.0),
107 );
108 }
109 }
110 }
111
112 if proposals.is_empty() {
113 AgentEffect::empty()
114 } else {
115 AgentEffect::with_proposals(proposals)
116 }
117 }
118}
119
120fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
123 if let Some(requirements) = &req.backend_requirements {
124 return route_backend_requirements(req, requirements, backends);
125 }
126
127 let edges: Vec<(usize, usize)> = req
131 .required_capabilities
132 .iter()
133 .enumerate()
134 .flat_map(|(i, cap)| {
135 let cap = cap.clone();
136 backends
137 .iter()
138 .enumerate()
139 .filter(move |(_, b)| b.has_capability(cap.clone()))
140 .map(move |(j, _)| (i, j))
141 })
142 .collect();
143
144 let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
145 .unwrap_or_default();
146
147 let mut covered = vec![false; req.required_capabilities.len()];
148 let mut assignments = Vec::with_capacity(matching.size);
149
150 for (cap_idx, backend_idx) in &matching.pairs {
151 assignments.push(CapabilityAssignment {
152 capability: req.required_capabilities[*cap_idx].clone(),
153 backend_name: backends[*backend_idx].name().to_string(),
154 });
155 covered[*cap_idx] = true;
156 }
157
158 let unmatched = req
159 .required_capabilities
160 .iter()
161 .enumerate()
162 .filter(|(i, _)| !covered[*i])
163 .map(|(_, c)| c.clone())
164 .collect::<Vec<_>>();
165
166 let coverage_ratio = if req.required_capabilities.is_empty() {
167 1.0
168 } else {
169 matching.size as f64 / req.required_capabilities.len() as f64
170 };
171
172 ProviderAssignment {
173 request_id: req.id.clone(),
174 assignments,
175 unmatched,
176 coverage_ratio,
177 }
178}
179
180fn route_backend_requirements(
181 req: &ProviderRequest,
182 requirements: &BackendRequirements,
183 backends: &[Arc<dyn Backend>],
184) -> ProviderAssignment {
185 let required_capabilities = if requirements.required_capabilities.is_empty() {
186 req.required_capabilities.clone()
187 } else {
188 requirements.required_capabilities.clone()
189 };
190
191 let matched_backend = backends.iter().find(|backend| {
192 backend.kind() == requirements.kind
193 && required_capabilities
194 .iter()
195 .all(|capability| backend.has_capability(capability.clone()))
196 && (!requirements.requires_replay || backend.supports_replay())
197 && (!requirements.requires_offline || !backend.requires_network())
198 });
199
200 if let Some(backend) = matched_backend {
201 let assignments = required_capabilities
202 .iter()
203 .cloned()
204 .map(|capability| CapabilityAssignment {
205 capability,
206 backend_name: backend.name().to_string(),
207 })
208 .collect::<Vec<_>>();
209 return ProviderAssignment {
210 request_id: req.id.clone(),
211 assignments,
212 unmatched: Vec::new(),
213 coverage_ratio: 1.0,
214 };
215 }
216
217 let coverage_ratio = if required_capabilities.is_empty() {
218 1.0
219 } else {
220 0.0
221 };
222 ProviderAssignment {
223 request_id: req.id.clone(),
224 assignments: Vec::new(),
225 unmatched: required_capabilities,
226 coverage_ratio,
227 }
228}
229
230fn request_id(fact_id: &str) -> &str {
233 fact_id.trim_start_matches(REQUEST_PREFIX)
234}
235
236fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
237 let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
238 ctx.get(ContextKey::Strategies)
239 .iter()
240 .any(|f| f.id().as_str() == assignment_id)
241}
242
243fn malformed_diagnostic_id(fact_id: &str) -> String {
244 format!("{MALFORMED_PREFIX}{fact_id}")
245}
246
247fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
248 let diagnostic_id = malformed_diagnostic_id(fact_id);
249 ctx.get(ContextKey::Diagnostic)
250 .iter()
251 .any(|fact| fact.id().as_str() == diagnostic_id)
252}
253
254#[cfg(test)]
257mod tests {
258 use super::*;
259 use converge_core::{ContextState, Engine};
260 use converge_provider::{BackendKind, Capability};
261
262 struct MockBackend {
263 name: &'static str,
264 kind: BackendKind,
265 capabilities: Vec<Capability>,
266 supports_replay: bool,
267 requires_network: bool,
268 }
269
270 impl Backend for MockBackend {
271 fn name(&self) -> &str {
272 self.name
273 }
274 fn kind(&self) -> BackendKind {
275 self.kind.clone()
276 }
277 fn capabilities(&self) -> Vec<Capability> {
278 self.capabilities.clone()
279 }
280 fn supports_replay(&self) -> bool {
281 self.supports_replay
282 }
283 fn requires_network(&self) -> bool {
284 self.requires_network
285 }
286 }
287
288 fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
289 backend_with(name, BackendKind::Llm, caps, false, true)
290 }
291
292 fn backend_with(
293 name: &'static str,
294 kind: BackendKind,
295 caps: &[Capability],
296 supports_replay: bool,
297 requires_network: bool,
298 ) -> Arc<dyn Backend> {
299 Arc::new(MockBackend {
300 name,
301 kind,
302 capabilities: caps.to_vec(),
303 supports_replay,
304 requires_network,
305 })
306 }
307
308 fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
309 ProviderRequest {
310 id: id.to_string(),
311 required_capabilities: caps.to_vec(),
312 backend_requirements: None,
313 }
314 }
315
316 #[test]
317 fn full_coverage_when_all_capabilities_available() {
318 let pool = vec![
319 backend("anthropic", &[Capability::Reasoning]),
320 backend("kong", &[Capability::AccessControl]),
321 backend("elastic", &[Capability::FullTextSearch]),
322 ];
323 let req = request(
324 "req-1",
325 &[
326 Capability::Reasoning,
327 Capability::AccessControl,
328 Capability::FullTextSearch,
329 ],
330 );
331
332 let assignment = route(&req, &pool);
333
334 assert_eq!(assignment.assignments.len(), 3);
335 assert!(assignment.unmatched.is_empty());
336 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
337 }
338
339 #[test]
340 fn partial_coverage_when_capability_missing() {
341 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
342 let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
343
344 let assignment = route(&req, &pool);
345
346 assert_eq!(assignment.assignments.len(), 1);
347 assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
348 assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
349 }
350
351 #[test]
352 fn no_double_booking_with_two_same_capability_slots() {
353 let pool = vec![
354 backend("anthropic", &[Capability::Reasoning]),
355 backend("openai", &[Capability::Reasoning]),
356 ];
357 let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
358
359 let assignment = route(&req, &pool);
360
361 assert_eq!(assignment.assignments.len(), 2);
362 let names: Vec<_> = assignment
363 .assignments
364 .iter()
365 .map(|a| &a.backend_name)
366 .collect();
367 let unique: std::collections::HashSet<_> = names.iter().collect();
368 assert_eq!(unique.len(), 2);
369 }
370
371 #[test]
372 fn multi_capability_backend_can_only_fill_one_slot() {
373 let pool = vec![backend(
375 "all-in-one",
376 &[Capability::Reasoning, Capability::AccessControl],
377 )];
378 let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
379
380 let assignment = route(&req, &pool);
381
382 assert_eq!(assignment.assignments.len(), 1);
384 assert_eq!(assignment.unmatched.len(), 1);
385 }
386
387 #[test]
388 fn empty_pool_yields_zero_coverage() {
389 let req = request("req-5", &[Capability::Reasoning]);
390 let assignment = route(&req, &[]);
391 assert_eq!(assignment.coverage_ratio, 0.0);
392 assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
393 }
394
395 #[test]
396 fn empty_request_yields_full_coverage() {
397 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
398 let req = request("req-6", &[]);
399 let assignment = route(&req, &pool);
400 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
401 assert!(assignment.assignments.is_empty());
402 }
403
404 #[test]
405 fn backend_requirements_select_one_backend_satisfying_role_constraints() {
406 let pool = vec![
407 backend("remote-llm", &[Capability::AccessControl]),
408 backend_with(
409 "local-policy",
410 BackendKind::Policy,
411 &[Capability::AccessControl],
412 true,
413 false,
414 ),
415 ];
416 let req = ProviderRequest {
417 id: "policy-role".to_string(),
418 required_capabilities: vec![],
419 backend_requirements: Some(
420 BackendRequirements::access_policy()
421 .with_replay()
422 .with_offline(),
423 ),
424 };
425
426 let assignment = route(&req, &pool);
427
428 assert_eq!(assignment.assignments.len(), 1);
429 assert_eq!(assignment.assignments[0].backend_name, "local-policy");
430 assert!(assignment.unmatched.is_empty());
431 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
432 }
433
434 #[test]
435 fn repeated_routing_is_deterministic_for_equal_candidates() {
436 let pool = vec![
437 backend("reasoner-a", &[Capability::Reasoning]),
438 backend("reasoner-b", &[Capability::Reasoning]),
439 backend("policy-a", &[Capability::AccessControl]),
440 ];
441 let req = request(
442 "req-7",
443 &[
444 Capability::Reasoning,
445 Capability::Reasoning,
446 Capability::AccessControl,
447 ],
448 );
449
450 let first = route(&req, &pool);
451 let second = route(&req, &pool);
452
453 assert_eq!(first.assignments, second.assignments);
454 assert_eq!(first.unmatched, second.unmatched);
455 assert_eq!(first.coverage_ratio, second.coverage_ratio);
456 }
457
458 #[tokio::test]
459 async fn malformed_request_emits_diagnostic_once() {
460 let mut engine = Engine::new();
461 engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
462 "anthropic",
463 &[Capability::Reasoning],
464 )]));
465
466 let mut ctx = ContextState::new();
467 ctx.add_input(ContextKey::Seeds, "provider-request:broken", "{")
468 .expect("seed should stage");
469
470 let first = engine.run(ctx).await.expect("run should converge");
471 let diagnostics = first.context.get(ContextKey::Diagnostic);
472 assert_eq!(diagnostics.len(), 1);
473 assert_eq!(
474 diagnostics[0].id(),
475 "provider-request-error:provider-request:broken"
476 );
477 assert!(!first.context.has(ContextKey::Strategies));
478
479 let mut rerun_engine = Engine::new();
480 rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
481 "anthropic",
482 &[Capability::Reasoning],
483 )]));
484 let second = rerun_engine
485 .run(first.context.clone())
486 .await
487 .expect("rerun should converge");
488 assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
489 }
490}