use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::{ModelDescriptor, ModelMetaProbe, ProbeError};
pub type ProbeFuture<'a> =
Pin<Box<dyn Future<Output = Result<Option<ModelDescriptor>, ProbeError>> + Send + 'a>>;
pub trait ModelMetaProbeDyn: Send + Sync {
fn describe_boxed<'a>(&'a self, model: &'a str) -> ProbeFuture<'a>;
}
impl<P> ModelMetaProbeDyn for P
where
P: ModelMetaProbe + ?Sized,
{
fn describe_boxed<'a>(&'a self, model: &'a str) -> ProbeFuture<'a> {
Box::pin(self.describe(model))
}
}
#[derive(Clone)]
pub struct DynProbe {
inner: Arc<dyn ModelMetaProbeDyn>,
}
impl DynProbe {
pub fn new(inner: Arc<dyn ModelMetaProbeDyn>) -> Self {
Self { inner }
}
}
impl std::fmt::Debug for DynProbe {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DynProbe").finish_non_exhaustive()
}
}
impl ModelMetaProbe for DynProbe {
async fn describe(&self, model: &str) -> Result<Option<ModelDescriptor>, ProbeError> {
self.inner.describe_boxed(model).await
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::{ChainedProbe, ModelDescriptor, StubProbe};
fn make_probe() -> StubProbe {
StubProbe::new([(
"gpt-4o",
ModelDescriptor::builder("openai", "gpt-4o")
.context_window(128_000)
.build(),
)])
}
#[tokio::test]
async fn blanket_impl_satisfies_dyn() {
let probes: Vec<Box<dyn ModelMetaProbeDyn>> = vec![Box::new(make_probe())];
let desc = probes[0].describe_boxed("gpt-4o").await.unwrap().unwrap();
assert_eq!(desc.context_window, Some(128_000));
}
#[tokio::test]
async fn dyn_probe_round_trip_through_static_trait() {
let erased: Arc<dyn ModelMetaProbeDyn> = Arc::new(make_probe());
let lifted = DynProbe::new(erased);
let desc = lifted.describe("gpt-4o").await.unwrap().unwrap();
assert_eq!(desc.context_window, Some(128_000));
}
#[tokio::test]
async fn dyn_probe_composes_with_chained() {
let primary: Arc<dyn ModelMetaProbeDyn> = Arc::new(StubProbe::default());
let fallback: Arc<dyn ModelMetaProbeDyn> = Arc::new(make_probe());
let chained = ChainedProbe::new(DynProbe::new(primary), DynProbe::new(fallback));
let desc = chained.describe("gpt-4o").await.unwrap().unwrap();
assert_eq!(desc.context_window, Some(128_000));
}
}