use serde::{Deserialize, Serialize};
use crate::backend::BackendKind;
use crate::capability::Capability;
use crate::error::BackendError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackendRequirements {
pub kind: BackendKind,
pub required_capabilities: Vec<Capability>,
pub max_cost_class: CostClass,
pub max_latency_ms: u32,
pub data_sovereignty: DataSovereignty,
pub compliance: ComplianceLevel,
pub requires_replay: bool,
pub requires_offline: bool,
}
impl BackendRequirements {
#[must_use]
pub fn new(kind: BackendKind) -> Self {
Self {
kind,
required_capabilities: Vec::new(),
max_cost_class: CostClass::VeryHigh,
max_latency_ms: 0,
data_sovereignty: DataSovereignty::Any,
compliance: ComplianceLevel::None,
requires_replay: false,
requires_offline: false,
}
}
#[must_use]
pub fn with_capability(mut self, capability: Capability) -> Self {
self.required_capabilities.push(capability);
self
}
#[must_use]
pub fn with_max_cost(mut self, cost: CostClass) -> Self {
self.max_cost_class = cost;
self
}
#[must_use]
pub fn with_max_latency_ms(mut self, ms: u32) -> Self {
self.max_latency_ms = ms;
self
}
#[must_use]
pub fn with_data_sovereignty(mut self, sovereignty: DataSovereignty) -> Self {
self.data_sovereignty = sovereignty;
self
}
#[must_use]
pub fn with_compliance(mut self, compliance: ComplianceLevel) -> Self {
self.compliance = compliance;
self
}
#[must_use]
pub fn with_replay(mut self) -> Self {
self.requires_replay = true;
self
}
#[must_use]
pub fn with_offline(mut self) -> Self {
self.requires_offline = true;
self
}
#[must_use]
pub fn fast_llm() -> Self {
Self::new(BackendKind::Llm)
.with_capability(Capability::TextGeneration)
.with_max_cost(CostClass::Low)
.with_max_latency_ms(2000)
}
#[must_use]
pub fn reasoning_llm() -> Self {
Self::new(BackendKind::Llm)
.with_capability(Capability::TextGeneration)
.with_capability(Capability::Reasoning)
.with_max_cost(CostClass::High)
.with_max_latency_ms(30_000)
}
#[must_use]
pub fn access_policy() -> Self {
Self::new(BackendKind::Policy)
.with_capability(Capability::AccessControl)
.with_max_latency_ms(100)
}
#[must_use]
pub fn constraint_solver() -> Self {
Self::new(BackendKind::Optimization).with_capability(Capability::ConstraintSolving)
}
#[must_use]
pub fn embedding_pipeline() -> Self {
Self::new(BackendKind::Analytics).with_capability(Capability::Embedding)
}
#[must_use]
pub fn vector_search() -> Self {
Self::new(BackendKind::Search).with_capability(Capability::VectorSearch)
}
}
pub trait BackendSelector: Send + Sync {
fn select(&self, requirements: &BackendRequirements) -> Result<String, BackendError>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub enum CostClass {
Free,
VeryLow,
Low,
Medium,
High,
VeryHigh,
}
impl CostClass {
#[must_use]
pub fn allowed_classes(self) -> Vec<CostClass> {
let all = [
CostClass::Free,
CostClass::VeryLow,
CostClass::Low,
CostClass::Medium,
CostClass::High,
CostClass::VeryHigh,
];
all.iter().copied().filter(|&c| c <= self).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum DataSovereignty {
Any,
EU,
US,
Switzerland,
China,
OnPremises,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ComplianceLevel {
None,
GDPR,
HIPAA,
SOC2,
HighExplainability,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cost_class_ordering() {
assert!(CostClass::Free < CostClass::VeryLow);
assert!(CostClass::VeryLow < CostClass::Low);
assert!(CostClass::Low < CostClass::Medium);
assert!(CostClass::Medium < CostClass::High);
assert!(CostClass::High < CostClass::VeryHigh);
}
#[test]
fn allowed_classes_correct() {
assert_eq!(CostClass::Free.allowed_classes(), vec![CostClass::Free]);
assert_eq!(
CostClass::Low.allowed_classes(),
vec![CostClass::Free, CostClass::VeryLow, CostClass::Low]
);
assert_eq!(CostClass::VeryHigh.allowed_classes().len(), 6);
}
#[test]
fn requirements_builder() {
let reqs = BackendRequirements::new(BackendKind::Llm)
.with_capability(Capability::TextGeneration)
.with_capability(Capability::Reasoning)
.with_max_cost(CostClass::Medium)
.with_max_latency_ms(5000);
assert_eq!(reqs.kind, BackendKind::Llm);
assert_eq!(reqs.required_capabilities.len(), 2);
assert_eq!(reqs.max_cost_class, CostClass::Medium);
assert_eq!(reqs.max_latency_ms, 5000);
}
#[test]
fn preset_constructors() {
let fast = BackendRequirements::fast_llm();
assert_eq!(fast.kind, BackendKind::Llm);
assert_eq!(fast.max_cost_class, CostClass::Low);
let policy = BackendRequirements::access_policy();
assert_eq!(policy.kind, BackendKind::Policy);
assert!(
policy
.required_capabilities
.contains(&Capability::AccessControl)
);
let solver = BackendRequirements::constraint_solver();
assert_eq!(solver.kind, BackendKind::Optimization);
}
}