1use std::sync::Arc;
11
12use async_trait::async_trait;
13use converge_optimization::graph::matching::bipartite_matching;
14use converge_pack::Provenance;
15use converge_pack::ProvenanceSource;
16use converge_pack::{
17 AgentEffect, Context, ContextKey, DiagnosticPayload, FactPayload, ProposedFact, Suggestor,
18};
19use converge_provider::{
20 Backend, BackendRequirements, CapabilityAssignment, ProviderAssignment, ProviderRequest,
21};
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
26#[serde(deny_unknown_fields)]
27pub struct ProviderRequestPayload {
28 request: ProviderRequest,
29}
30
31impl ProviderRequestPayload {
32 #[must_use]
33 pub fn new(request: ProviderRequest) -> Self {
34 Self { request }
35 }
36
37 #[must_use]
38 pub fn request(&self) -> &ProviderRequest {
39 &self.request
40 }
41}
42
43impl From<ProviderRequest> for ProviderRequestPayload {
44 fn from(request: ProviderRequest) -> Self {
45 Self::new(request)
46 }
47}
48
49impl FactPayload for ProviderRequestPayload {
50 const FAMILY: &'static str = "converge.kernel.provider.request";
51 const VERSION: u16 = 1;
52}
53
54#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56#[serde(deny_unknown_fields)]
57pub struct ProviderAssignmentPayload {
58 assignment: ProviderAssignment,
59}
60
61impl ProviderAssignmentPayload {
62 #[must_use]
63 pub fn new(assignment: ProviderAssignment) -> Self {
64 Self { assignment }
65 }
66
67 #[must_use]
68 pub fn assignment(&self) -> &ProviderAssignment {
69 &self.assignment
70 }
71}
72
73impl From<ProviderAssignment> for ProviderAssignmentPayload {
74 fn from(assignment: ProviderAssignment) -> Self {
75 Self::new(assignment)
76 }
77}
78
79impl FactPayload for ProviderAssignmentPayload {
80 const FAMILY: &'static str = "converge.kernel.provider.assignment";
81 const VERSION: u16 = 1;
82}
83
84const REQUEST_PREFIX: &str = "provider-request:";
87const ASSIGNMENT_PREFIX: &str = "provider-assignment:";
88const MALFORMED_PREFIX: &str = "provider-request-error:";
89
90pub struct ProviderSelectionSuggestor {
100 backends: Vec<Arc<dyn Backend>>,
101}
102
103#[derive(Copy, Clone, Debug)]
106pub struct ConvergeKernel;
107
108impl ProvenanceSource for ConvergeKernel {
109 fn as_str(&self) -> &'static str {
110 "converge-kernel"
111 }
112}
113
114pub const CONVERGE_KERNEL_PROVENANCE: ConvergeKernel = ConvergeKernel;
116
117impl ProviderSelectionSuggestor {
118 pub fn new(backends: Vec<Arc<dyn Backend>>) -> Self {
119 Self { backends }
120 }
121}
122
123#[async_trait]
124impl Suggestor for ProviderSelectionSuggestor {
125 fn name(&self) -> &str {
126 "ProviderSelectionSuggestor"
127 }
128
129 fn provenance(&self) -> Provenance {
130 CONVERGE_KERNEL_PROVENANCE.provenance()
131 }
132
133 fn dependencies(&self) -> &[ContextKey] {
134 &[ContextKey::Seeds]
135 }
136
137 fn accepts(&self, ctx: &dyn Context) -> bool {
138 ctx.get(ContextKey::Seeds).iter().any(|f| {
139 f.id().as_str().starts_with(REQUEST_PREFIX)
140 && match f.payload::<ProviderRequestPayload>() {
141 Some(_) => !assignment_exists(ctx, request_id(f.id().as_str())),
142 None => !malformed_diagnostic_exists(ctx, f.id().as_str()),
143 }
144 })
145 }
146
147 async fn execute(&self, ctx: &dyn Context) -> AgentEffect {
148 let mut proposals = Vec::new();
149
150 for fact in ctx
151 .get(ContextKey::Seeds)
152 .iter()
153 .filter(|f| f.id().as_str().starts_with(REQUEST_PREFIX))
154 {
155 match fact.payload::<ProviderRequestPayload>() {
156 Some(payload) => {
157 if assignment_exists(ctx, request_id(fact.id().as_str())) {
158 continue;
159 }
160
161 let assignment = route(payload.request(), &self.backends);
162 proposals.push(
163 ProposedFact::new(
164 ContextKey::Strategies,
165 format!("{}{}", ASSIGNMENT_PREFIX, assignment.request_id),
166 ProviderAssignmentPayload::new(assignment.clone()),
167 self.provenance(),
168 )
169 .with_confidence(assignment.coverage_ratio),
170 );
171 }
172 None => {
173 if malformed_diagnostic_exists(ctx, fact.id().as_str()) {
174 continue;
175 }
176
177 proposals.push(
178 ProposedFact::new(
179 ContextKey::Diagnostic,
180 malformed_diagnostic_id(fact.id().as_str()),
181 DiagnosticPayload::new(
182 self.name(),
183 format!(
184 "malformed provider request '{}': expected {} v{} payload",
185 fact.id(),
186 ProviderRequestPayload::FAMILY,
187 ProviderRequestPayload::VERSION
188 ),
189 ),
190 self.provenance(),
191 )
192 .with_confidence(1.0),
193 );
194 }
195 }
196 }
197
198 if proposals.is_empty() {
199 AgentEffect::empty()
200 } else {
201 AgentEffect::with_proposals(proposals)
202 }
203 }
204}
205
206fn route(req: &ProviderRequest, backends: &[Arc<dyn Backend>]) -> ProviderAssignment {
209 if let Some(requirements) = &req.backend_requirements {
210 return route_backend_requirements(req, requirements, backends);
211 }
212
213 let edges: Vec<(usize, usize)> = req
217 .required_capabilities
218 .iter()
219 .enumerate()
220 .flat_map(|(i, cap)| {
221 let cap = cap.clone();
222 backends
223 .iter()
224 .enumerate()
225 .filter(move |(_, b)| b.has_capability(cap.clone()))
226 .map(move |(j, _)| (i, j))
227 })
228 .collect();
229
230 let matching = bipartite_matching(req.required_capabilities.len(), backends.len(), &edges)
231 .unwrap_or_default();
232
233 let mut covered = vec![false; req.required_capabilities.len()];
234 let mut assignments = Vec::with_capacity(matching.size);
235
236 for (cap_idx, backend_idx) in &matching.pairs {
237 assignments.push(CapabilityAssignment {
238 capability: req.required_capabilities[*cap_idx].clone(),
239 backend_name: backends[*backend_idx].name().to_string(),
240 });
241 covered[*cap_idx] = true;
242 }
243
244 let unmatched = req
245 .required_capabilities
246 .iter()
247 .enumerate()
248 .filter(|(i, _)| !covered[*i])
249 .map(|(_, c)| c.clone())
250 .collect::<Vec<_>>();
251
252 let coverage_ratio = if req.required_capabilities.is_empty() {
253 1.0
254 } else {
255 matching.size as f64 / req.required_capabilities.len() as f64
256 };
257
258 ProviderAssignment {
259 request_id: req.id.clone(),
260 assignments,
261 unmatched,
262 coverage_ratio,
263 }
264}
265
266fn route_backend_requirements(
267 req: &ProviderRequest,
268 requirements: &BackendRequirements,
269 backends: &[Arc<dyn Backend>],
270) -> ProviderAssignment {
271 let required_capabilities = if requirements.required_capabilities.is_empty() {
272 req.required_capabilities.clone()
273 } else {
274 requirements.required_capabilities.clone()
275 };
276
277 let matched_backend = backends.iter().find(|backend| {
278 backend.kind() == requirements.kind
279 && required_capabilities
280 .iter()
281 .all(|capability| backend.has_capability(capability.clone()))
282 && (!requirements.requires_replay || backend.supports_replay())
283 && (!requirements.requires_offline || !backend.requires_network())
284 });
285
286 if let Some(backend) = matched_backend {
287 let assignments = required_capabilities
288 .iter()
289 .cloned()
290 .map(|capability| CapabilityAssignment {
291 capability,
292 backend_name: backend.name().to_string(),
293 })
294 .collect::<Vec<_>>();
295 return ProviderAssignment {
296 request_id: req.id.clone(),
297 assignments,
298 unmatched: Vec::new(),
299 coverage_ratio: 1.0,
300 };
301 }
302
303 let coverage_ratio = if required_capabilities.is_empty() {
304 1.0
305 } else {
306 0.0
307 };
308 ProviderAssignment {
309 request_id: req.id.clone(),
310 assignments: Vec::new(),
311 unmatched: required_capabilities,
312 coverage_ratio,
313 }
314}
315
316fn request_id(fact_id: &str) -> &str {
319 fact_id.trim_start_matches(REQUEST_PREFIX)
320}
321
322fn assignment_exists(ctx: &dyn Context, request_id: &str) -> bool {
323 let assignment_id = format!("{}{}", ASSIGNMENT_PREFIX, request_id);
324 ctx.get(ContextKey::Strategies)
325 .iter()
326 .any(|f| f.id().as_str() == assignment_id)
327}
328
329fn malformed_diagnostic_id(fact_id: &str) -> String {
330 format!("{MALFORMED_PREFIX}{fact_id}")
331}
332
333fn malformed_diagnostic_exists(ctx: &dyn Context, fact_id: &str) -> bool {
334 let diagnostic_id = malformed_diagnostic_id(fact_id);
335 ctx.get(ContextKey::Diagnostic)
336 .iter()
337 .any(|fact| fact.id().as_str() == diagnostic_id)
338}
339
340#[cfg(test)]
343mod tests {
344 use super::*;
345 use converge_core::{ContextState, Engine};
346 use converge_pack::TextPayload;
347 use converge_provider::{BackendKind, Capability};
348
349 struct MockBackend {
350 name: &'static str,
351 kind: BackendKind,
352 capabilities: Vec<Capability>,
353 supports_replay: bool,
354 requires_network: bool,
355 }
356
357 impl Backend for MockBackend {
358 fn name(&self) -> &str {
359 self.name
360 }
361 fn kind(&self) -> BackendKind {
362 self.kind.clone()
363 }
364 fn capabilities(&self) -> Vec<Capability> {
365 self.capabilities.clone()
366 }
367 fn supports_replay(&self) -> bool {
368 self.supports_replay
369 }
370 fn requires_network(&self) -> bool {
371 self.requires_network
372 }
373 }
374
375 fn backend(name: &'static str, caps: &[Capability]) -> Arc<dyn Backend> {
376 backend_with(name, BackendKind::Llm, caps, false, true)
377 }
378
379 fn backend_with(
380 name: &'static str,
381 kind: BackendKind,
382 caps: &[Capability],
383 supports_replay: bool,
384 requires_network: bool,
385 ) -> Arc<dyn Backend> {
386 Arc::new(MockBackend {
387 name,
388 kind,
389 capabilities: caps.to_vec(),
390 supports_replay,
391 requires_network,
392 })
393 }
394
395 fn request(id: &str, caps: &[Capability]) -> ProviderRequest {
396 ProviderRequest {
397 id: id.to_string(),
398 required_capabilities: caps.to_vec(),
399 backend_requirements: None,
400 }
401 }
402
403 #[test]
404 fn full_coverage_when_all_capabilities_available() {
405 let pool = vec![
406 backend("anthropic", &[Capability::Reasoning]),
407 backend("kong", &[Capability::AccessControl]),
408 backend("elastic", &[Capability::FullTextSearch]),
409 ];
410 let req = request(
411 "req-1",
412 &[
413 Capability::Reasoning,
414 Capability::AccessControl,
415 Capability::FullTextSearch,
416 ],
417 );
418
419 let assignment = route(&req, &pool);
420
421 assert_eq!(assignment.assignments.len(), 3);
422 assert!(assignment.unmatched.is_empty());
423 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
424 }
425
426 #[test]
427 fn partial_coverage_when_capability_missing() {
428 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
429 let req = request("req-2", &[Capability::Reasoning, Capability::AccessControl]);
430
431 let assignment = route(&req, &pool);
432
433 assert_eq!(assignment.assignments.len(), 1);
434 assert_eq!(assignment.unmatched, vec![Capability::AccessControl]);
435 assert!((assignment.coverage_ratio - 0.5).abs() < f64::EPSILON);
436 }
437
438 #[test]
439 fn no_double_booking_with_two_same_capability_slots() {
440 let pool = vec![
441 backend("anthropic", &[Capability::Reasoning]),
442 backend("openai", &[Capability::Reasoning]),
443 ];
444 let req = request("req-3", &[Capability::Reasoning, Capability::Reasoning]);
445
446 let assignment = route(&req, &pool);
447
448 assert_eq!(assignment.assignments.len(), 2);
449 let names: Vec<_> = assignment
450 .assignments
451 .iter()
452 .map(|a| &a.backend_name)
453 .collect();
454 let unique: std::collections::HashSet<_> = names.iter().collect();
455 assert_eq!(unique.len(), 2);
456 }
457
458 #[test]
459 fn multi_capability_backend_can_only_fill_one_slot() {
460 let pool = vec![backend(
462 "all-in-one",
463 &[Capability::Reasoning, Capability::AccessControl],
464 )];
465 let req = request("req-4", &[Capability::Reasoning, Capability::AccessControl]);
466
467 let assignment = route(&req, &pool);
468
469 assert_eq!(assignment.assignments.len(), 1);
471 assert_eq!(assignment.unmatched.len(), 1);
472 }
473
474 #[test]
475 fn empty_pool_yields_zero_coverage() {
476 let req = request("req-5", &[Capability::Reasoning]);
477 let assignment = route(&req, &[]);
478 assert_eq!(assignment.coverage_ratio, 0.0);
479 assert_eq!(assignment.unmatched, vec![Capability::Reasoning]);
480 }
481
482 #[test]
483 fn empty_request_yields_full_coverage() {
484 let pool = vec![backend("anthropic", &[Capability::Reasoning])];
485 let req = request("req-6", &[]);
486 let assignment = route(&req, &pool);
487 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
488 assert!(assignment.assignments.is_empty());
489 }
490
491 #[test]
492 fn backend_requirements_select_one_backend_satisfying_role_constraints() {
493 let pool = vec![
494 backend("remote-llm", &[Capability::AccessControl]),
495 backend_with(
496 "local-policy",
497 BackendKind::Policy,
498 &[Capability::AccessControl],
499 true,
500 false,
501 ),
502 ];
503 let req = ProviderRequest {
504 id: "policy-role".to_string(),
505 required_capabilities: vec![],
506 backend_requirements: Some(
507 BackendRequirements::access_policy()
508 .with_replay()
509 .with_offline(),
510 ),
511 };
512
513 let assignment = route(&req, &pool);
514
515 assert_eq!(assignment.assignments.len(), 1);
516 assert_eq!(assignment.assignments[0].backend_name, "local-policy");
517 assert!(assignment.unmatched.is_empty());
518 assert!((assignment.coverage_ratio - 1.0).abs() < f64::EPSILON);
519 }
520
521 #[test]
522 fn repeated_routing_is_deterministic_for_equal_candidates() {
523 let pool = vec![
524 backend("reasoner-a", &[Capability::Reasoning]),
525 backend("reasoner-b", &[Capability::Reasoning]),
526 backend("policy-a", &[Capability::AccessControl]),
527 ];
528 let req = request(
529 "req-7",
530 &[
531 Capability::Reasoning,
532 Capability::Reasoning,
533 Capability::AccessControl,
534 ],
535 );
536
537 let first = route(&req, &pool);
538 let second = route(&req, &pool);
539
540 assert_eq!(first.assignments, second.assignments);
541 assert_eq!(first.unmatched, second.unmatched);
542 assert_eq!(first.coverage_ratio, second.coverage_ratio);
543 }
544
545 #[tokio::test]
546 async fn malformed_request_emits_diagnostic_once() {
547 let mut engine = Engine::new();
548 engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
549 "anthropic",
550 &[Capability::Reasoning],
551 )]));
552
553 let mut ctx = ContextState::new();
554 ctx.add_proposal(ProposedFact::new(
555 ContextKey::Seeds,
556 "provider-request:broken",
557 TextPayload::new("not a provider request"),
558 CONVERGE_KERNEL_PROVENANCE.provenance(),
559 ))
560 .expect("seed should stage");
561
562 let first = engine.run(ctx).await.expect("run should converge");
563 let diagnostics = first.context.get(ContextKey::Diagnostic);
564 assert_eq!(diagnostics.len(), 1);
565 assert_eq!(
566 diagnostics[0].id(),
567 "provider-request-error:provider-request:broken"
568 );
569 assert!(!first.context.has(ContextKey::Strategies));
570
571 let mut rerun_engine = Engine::new();
572 rerun_engine.register_suggestor(ProviderSelectionSuggestor::new(vec![backend(
573 "anthropic",
574 &[Capability::Reasoning],
575 )]));
576 let second = rerun_engine
577 .run(first.context.clone())
578 .await
579 .expect("rerun should converge");
580 assert_eq!(second.context.get(ContextKey::Diagnostic).len(), 1);
581 }
582}