1use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::{AgentEffect, Context, ContextKey, ProposedFact, Suggestor};
15use converge_provider_api::{Backend, CapabilityAssignment, ProviderAssignment, ProviderRequest};
16
17const REQUEST_PREFIX: &str = "provider-request:";
20const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
21const MALFORMED_PREFIX: &str = "provider-request-error:";
22
23pub struct ProviderSelectionSuggestor {
36 backends: Vec<Arc<dyn Backend>>,
37}
38
39impl ProviderSelectionSuggestor {
40 pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
41 Self { backends }
42 }
43}
44
45#[async_trait]
46impl Suggestor for ProviderSelectionSuggestor {
47 fn name(&self) -> &str {
48 "ProviderSelectionSuggestor"
49 }
50
51 fn dependencies(&self) -> &[ContextKey] {
52 &[ContextKey::Seeds]
53 }
54
55 fn accepts(&self, ctx: &dyn Context) -> bool {
56 ctx.get(ContextKey::Seeds).iter().any(|f| {
57 f.id.starts_with(REQUEST_PREFIX)
58 && match serde_json::from_str::<ProviderRequest>(&f.content) {
59 Ok(_) => !assignment_exists(ctx, request_id(&f.id)),
60 Err(_) => !malformed_diagnostic_exists(ctx, &f.id),
61 }
62 })
63 }
64
65 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
66 let mut proposals = Vec::new();
67
68 for fact in ctx
69 .get(ContextKey::Seeds)
70 .iter()
71 .filter(|f| f.id.starts_with(REQUEST_PREFIX))
72 {
73 match serde_json::from_str::<ProviderRequest>(&fact.content) {
74 Ok(req) => {
75 if assignment_exists(ctx, request_id(&fact.id)) {
76 continue;
77 }
78
79 let assignment = route(&req, &self.backends);
80 proposals.push(
81 ProposedFact::new(
82 ContextKey::Strategies,
83 format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
84 serde_json::to_string(&assignment).unwrap_or_default(),
85 self.name(),
86 )
87 .with_confidence(assignment.coverage_ratio),
88 );
89 }
90 Err(error) => {
91 if malformed_diagnostic_exists(ctx, &fact.id) {
92 continue;
93 }
94
95 let diagnostic = serde_json::json!({
96 "request_fact_id": fact.id,
97 "message": "malformed provider request ignored",
98 "error": error.to_string(),
99 });
100 proposals.push(
101 ProposedFact::new(
102 ContextKey::Diagnostic,
103 malformed_diagnostic_id(&fact.id),
104 diagnostic.to_string(),
105 self.name(),
106 )
107 .with_confidence(1.0),
108 );
109 }
110 }
111 }
112
113 if proposals.is_empty() {
114 AgentEffect::empty()
115 } else {
116 AgentEffect::with_proposals(proposals)
117 }
118 }
119}
120
121fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
124 let edges: Vec<(usize, usize)> = req
128 .required_capabilities
129 .iter()
130 .enumerate()
131 .flat_map(|(i, cap)| {
132 let cap = cap.clone();
133 backends
134 .iter()
135 .enumerate()
136 .filter(move |(_, b)| b.has_capability(cap.clone()))
137 .map(move |(j, _)| (i, j))
138 })
139 .collect();
140
141 let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
142 .unwrap_or_default();
143
144 let mut covered = vec![false; req.required_capabilities.len()];
145 let mut assignments = Vec::with_capacity(matching.size);
146
147 for (cap_idx, backend_idx) in &matching.pairs {
148 assignments.push(CapabilityAssignment {
149 capability: req.required_capabilities[*cap_idx].clone(),
150 backend_name: backends[*backend_idx].name().to_string(),
151 });
152 covered[*cap_idx] = true;
153 }
154
155 let unmatched = req
156 .required_capabilities
157 .iter()
158 .enumerate()
159 .filter(|(i, _)| !covered[*i])
160 .map(|(_, c)| c.clone())
161 .collect::<Vec<_>>();
162
163 let coverage_ratio = if req.required_capabilities.is_empty() {
164 1.0
165 } else {
166 matching.size as f64 / req.required_capabilities.len() as f64
167 };
168
169 ProviderAssignment {
170 request_id: req.id.clone(),
171 assignments,
172 unmatched,
173 coverage_ratio,
174 }
175}
176
177fn request_id(fact_id: &str) -> &str {
180 fact_id.trim_start_matches(REQUEST_PREFIX)
181}
182
183fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
184 let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
185 ctx.get(ContextKey::Strategies)
186 .iter()
187 .any(|f| f.id == assignment_id)
188}
189
190fn malformed_diagnostic_id(fact_id: &str) -> String {
191 format!("{MALFORMED_PREFIX}{fact_id}")
192}
193
194fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
195 let diagnostic_id = malformed_diagnostic_id(fact_id);
196 ctx.get(ContextKey::Diagnostic)
197 .iter()
198 .any(|fact| fact.id == diagnostic_id)
199}
200
201#[cfg(test)]
204mod tests {
205 use super::*;
206 use converge_core::{ContextState, Engine};
207 use converge_provider_api::{BackendKind, Capability};
208
209 struct MockBackend {
210 name: &'static str,
211 capabilities: Vec<Capability>,
212 }
213
214 impl Backend for MockBackend {
215 fn name(&self) -> &str {
216 self.name
217 }
218 fn kind(&self) -> BackendKind {
219 BackendKind::Llm
220 }
221 fn capabilities(&self) -> Vec<Capability> {
222 self.capabilities.clone()
223 }
224 fn supports_replay(&self) -> bool {
225 false
226 }
227 fn requires_network(&self) -> bool {
228 true
229 }
230 }
231
232 fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
233 Arc::new(MockBackend {
234 name,
235 capabilities: caps.to_vec(),
236 })
237 }
238
239 fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
240 ProviderRequest {
241 id: id.to_string(),
242 required_capabilities: caps.to_vec(),
243 }
244 }
245
246 #[test]
247 fn full_coverage_when_all_capabilities_available() {
248 let pool = vec![
249 backend("anthropic", &[Capability::Reasoning]),
250 backend("kong", &[Capability::AccessControl]),
251 backend("elastic", &[Capability::FullTextSearch]),
252 ];
253 let req = request(
254 "req-1",
255 &[
256 Capability::Reasoning,
257 Capability::AccessControl,
258 Capability::FullTextSearch,
259 ],
260 );
261
262 let assignment = route(&req, &pool);
263
264 assert_eq!(assignment.assignments.len(), 3);
265 assert!(assignment.unmatched.is_empty());
266 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
267 }
268
269 #[test]
270 fn partial_coverage_when_capability_missing() {
271 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
272 let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
273
274 let assignment = route(&req, &pool);
275
276 assert_eq!(assignment.assignments.len(), 1);
277 assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
278 assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
279 }
280
281 #[test]
282 fn no_double_booking_with_two_same_capability_slots() {
283 let pool = vec![
284 backend("anthropic", &[Capability::Reasoning]),
285 backend("openai", &[Capability::Reasoning]),
286 ];
287 let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
288
289 let assignment = route(&req, &pool);
290
291 assert_eq!(assignment.assignments.len(), 2);
292 let names: Vec<_> = assignment
293 .assignments
294 .iter()
295 .map(|a| &a.backend_name)
296 .collect();
297 let unique: std::collections::HashSet<_> = names.iter().collect();
298 assert_eq!(unique.len(), 2);
299 }
300
301 #[test]
302 fn multi_capability_backend_can_only_fill_one_slot() {
303 let pool = vec![backend(
305 "all-in-one",
306 &[Capability::Reasoning, Capability::AccessControl],
307 )];
308 let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
309
310 let assignment = route(&req, &pool);
311
312 assert_eq!(assignment.assignments.len(), 1);
314 assert_eq!(assignment.unmatched.len(), 1);
315 }
316
317 #[test]
318 fn empty_pool_yields_zero_coverage() {
319 let req = request("req-5", &[Capability::Reasoning]);
320 let assignment = route(&req, &[]);
321 assert_eq!(assignment.coverage_ratio, 0.0);
322 assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
323 }
324
325 #[test]
326 fn empty_request_yields_full_coverage() {
327 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
328 let req = request("req-6", &[]);
329 let assignment = route(&req, &pool);
330 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
331 assert!(assignment.assignments.is_empty());
332 }
333
334 #[test]
335 fn repeated_routing_is_deterministic_for_equal_candidates() {
336 let pool = vec![
337 backend("reasoner-a", &[Capability::Reasoning]),
338 backend("reasoner-b", &[Capability::Reasoning]),
339 backend("policy-a", &[Capability::AccessControl]),
340 ];
341 let req = request(
342 "req-7",
343 &[
344 Capability::Reasoning,
345 Capability::Reasoning,
346 Capability::AccessControl,
347 ],
348 );
349
350 let first = route(&req, &pool);
351 let second = route(&req, &pool);
352
353 assert_eq!(first.assignments, second.assignments);
354 assert_eq!(first.unmatched, second.unmatched);
355 assert_eq!(first.coverage_ratio, second.coverage_ratio);
356 }
357
358 #[tokio::test]
359 async fn malformed_request_emits_diagnostic_once() {
360 let mut engine = Engine::new();
361 engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
362 "anthropic",
363 &[Capability::Reasoning],
364 )]));
365
366 let mut ctx = ContextState::new();
367 ctx.add_input(ContextKey::Seeds, "provider-request:broken", "{")
368 .expect("seed should stage");
369
370 let first = engine.run(ctx).await.expect("run should converge");
371 let diagnostics = first.context.get(ContextKey::Diagnostic);
372 assert_eq!(diagnostics.len(), 1);
373 assert_eq!(
374 diagnostics[0].id,
375 "provider-request-error:provider-request:broken"
376 );
377 assert!(!first.context.has(ContextKey::Strategies));
378
379 let mut rerun_engine = Engine::new();
380 rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
381 "anthropic",
382 &[Capability::Reasoning],
383 )]));
384 let second = rerun_engine
385 .run(first.context.clone())
386 .await
387 .expect("rerun should converge");
388 assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
389 }
390}