use crate::router::{
error::{Error, Result},
forest::{self, RandomForest, ForestParams},
strategies::RoutingStrategy,
types::{ModelInfo, RoutingDecision},
};
pub struct RoRFStrategy {
forest: RandomForest,
threshold: f32,
model_a: String,
provider_a: String,
model_b: String,
provider_b: String,
}
impl RoRFStrategy {
pub fn new(
forest: RandomForest,
threshold: f32,
model_a: impl Into<String>,
provider_a: impl Into<String>,
model_b: impl Into<String>,
provider_b: impl Into<String>,
) -> Self {
assert!(
(0.0..=1.0).contains(&threshold),
"threshold must be in [0, 1]"
);
Self {
forest,
threshold,
model_a: model_a.into(),
provider_a: provider_a.into(),
model_b: model_b.into(),
provider_b: provider_b.into(),
}
}
pub fn train(
features: &[Vec<f32>],
labels: &[u8],
threshold: f32,
model_a: impl Into<String>,
provider_a: impl Into<String>,
model_b: impl Into<String>,
provider_b: impl Into<String>,
params: ForestParams,
) -> Result<Self> {
let trained = forest::train(features, labels, ¶ms)?;
Ok(Self::new(
trained,
threshold,
model_a,
provider_a,
model_b,
provider_b,
))
}
fn confidence(&self, prob_a: f32) -> f32 {
((prob_a - self.threshold).abs() * 2.0).clamp(0.0, 1.0)
}
pub fn forest(&self) -> &RandomForest {
&self.forest
}
pub fn threshold(&self) -> f32 {
self.threshold
}
}
impl RoutingStrategy for RoRFStrategy {
fn name(&self) -> &'static str {
"rorf"
}
fn route(
&self,
_content: &str,
embedding: Option<&[f32]>,
_models: &[ModelInfo],
) -> Result<RoutingDecision> {
let emb = embedding.ok_or_else(|| {
Error::config("RoRF strategy requires a query embedding")
})?;
let prob_a = self.forest.predict_proba(emb)?;
let prob_b = 1.0 - prob_a;
let (model, provider) = if prob_a >= self.threshold {
(&self.model_a, &self.provider_a)
} else {
(&self.model_b, &self.provider_b)
};
Ok(RoutingDecision::new(model, provider)
.with_reasoning(format!(
"RoRF routing (P(model_a)={:.3}, threshold={:.3})",
prob_a, self.threshold
))
.with_confidence(self.confidence(prob_a))
.with_meta("prob_model_a", prob_a)
.with_meta("prob_model_b", prob_b)
.with_meta("threshold", self.threshold))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::router::forest::ForestParams;
use crate::router::forest::tree::TreeParams;
fn simple_data() -> (Vec<Vec<f32>>, Vec<u8>) {
let f: Vec<Vec<f32>> = (0..20)
.map(|i| vec![i as f32 / 20.0])
.collect();
let l: Vec<u8> = (0..20).map(|i| if i < 10 { 0 } else { 1 }).collect();
(f, l)
}
fn trained_strategy(threshold: f32) -> RoRFStrategy {
let (f, l) = simple_data();
let params = ForestParams {
n_estimators: 10,
tree: TreeParams { max_depth: Some(5), ..Default::default() },
..Default::default()
};
RoRFStrategy::train(
&f, &l,
threshold,
"strong-model", "provider-a",
"weak-model", "provider-b",
params,
).unwrap()
}
#[test]
fn routes_simple_region_to_model_a() {
let s = trained_strategy(0.5);
let d = s.route("q", Some(&[0.1]), &[]).unwrap();
assert_eq!(d.model, "strong-model");
}
#[test]
fn routes_complex_region_to_model_b() {
let s = trained_strategy(0.5);
let d = s.route("q", Some(&[0.9]), &[]).unwrap();
assert_eq!(d.model, "weak-model");
}
#[test]
fn no_embedding_returns_error() {
let s = trained_strategy(0.5);
let err = s.route("q", None, &[]).unwrap_err();
assert!(matches!(err, Error::Configuration(_)));
}
#[test]
fn wrong_dimension_returns_error() {
let s = trained_strategy(0.5);
let err = s.route("q", Some(&[0.5, 0.5]), &[]).unwrap_err();
assert!(matches!(err, Error::Forest(_)));
}
#[test]
fn confidence_is_set() {
let s = trained_strategy(0.5);
let d = s.route("q", Some(&[0.1]), &[]).unwrap();
assert!(d.confidence.is_some());
}
#[test]
fn metadata_contains_probabilities() {
let s = trained_strategy(0.5);
let d = s.route("q", Some(&[0.1]), &[]).unwrap();
assert!(d.metadata.contains_key("prob_model_a"));
assert!(d.metadata.contains_key("prob_model_b"));
}
#[test]
fn high_threshold_routes_to_model_b_more() {
let s = trained_strategy(0.99);
let d = s.route("q", Some(&[0.1]), &[]).unwrap();
assert_eq!(d.model, "weak-model");
}
#[test]
fn forest_accessor() {
let s = trained_strategy(0.5);
assert_eq!(s.forest().n_features, 1);
}
#[test]
fn name_is_rorf() {
let s = trained_strategy(0.5);
assert_eq!(s.name(), "rorf");
}
}