use crate::{ModelDescriptor, ModelMetaProbe, ProbeError};
#[derive(Debug, Clone)]
pub struct ChainedProbe<P, F> {
primary: P,
fallback: F,
tolerant: bool,
}
impl<P, F> ChainedProbe<P, F> {
pub fn new(primary: P, fallback: F) -> Self {
Self {
primary,
fallback,
tolerant: false,
}
}
pub fn tolerant(primary: P, fallback: F) -> Self {
Self {
primary,
fallback,
tolerant: true,
}
}
}
impl<P, F> ModelMetaProbe for ChainedProbe<P, F>
where
P: ModelMetaProbe,
F: ModelMetaProbe,
{
async fn describe(&self, model: &str) -> Result<Option<ModelDescriptor>, ProbeError> {
match self.primary.describe(model).await {
Ok(Some(desc)) => Ok(Some(desc)),
Ok(None) => self.fallback.describe(model).await,
Err(err) if self.tolerant => {
tracing::warn!(
target: "rig_model_meta::chained",
error = %err,
model = model,
"primary probe failed; consulting fallback"
);
self.fallback.describe(model).await
}
Err(err) => Err(err),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::{ModelDescriptor, StubProbe};
fn primary() -> StubProbe {
StubProbe::new([(
"local",
ModelDescriptor::builder("ollama", "local")
.context_window(8192)
.build(),
)])
}
fn fallback() -> StubProbe {
StubProbe::new([(
"cloud",
ModelDescriptor::builder("openai", "cloud")
.context_window(128_000)
.build(),
)])
}
#[tokio::test]
async fn primary_match_wins() {
let chain = ChainedProbe::new(primary(), fallback());
let desc = chain.describe("local").await.unwrap().unwrap();
assert_eq!(desc.context_window, Some(8192));
}
#[tokio::test]
async fn falls_back_when_primary_returns_none() {
let chain = ChainedProbe::new(primary(), fallback());
let desc = chain.describe("cloud").await.unwrap().unwrap();
assert_eq!(desc.context_window, Some(128_000));
}
#[tokio::test]
async fn both_miss_returns_none() {
let chain = ChainedProbe::new(primary(), fallback());
assert!(chain.describe("nope").await.unwrap().is_none());
}
#[derive(Clone)]
struct AlwaysFails;
impl ModelMetaProbe for AlwaysFails {
async fn describe(&self, _model: &str) -> Result<Option<ModelDescriptor>, ProbeError> {
Err(ProbeError::Transport("nope".to_string()))
}
}
#[tokio::test]
async fn strict_chain_propagates_primary_error() {
let chain = ChainedProbe::new(AlwaysFails, fallback());
let err = chain.describe("cloud").await.unwrap_err();
assert!(matches!(err, ProbeError::Transport(_)));
}
#[tokio::test]
async fn tolerant_chain_consults_fallback_after_primary_error() {
let chain = ChainedProbe::tolerant(AlwaysFails, fallback());
let desc = chain.describe("cloud").await.unwrap().unwrap();
assert_eq!(desc.context_window, Some(128_000));
}
}