use std::sync::Arc;
use synwire_core::{BoxFuture, SamplingError, SamplingProvider, SamplingRequest, SamplingResponse};
pub type SamplingFn = Arc<
dyn Fn(SamplingRequest) -> BoxFuture<'static, Result<SamplingResponse, SamplingError>>
+ Send
+ Sync,
>;
pub struct DirectModelSampling {
invoke: Option<SamplingFn>,
}
impl DirectModelSampling {
#[must_use]
pub fn new(
invoke: impl Fn(SamplingRequest) -> BoxFuture<'static, Result<SamplingResponse, SamplingError>>
+ Send
+ Sync
+ 'static,
) -> Self {
Self {
invoke: Some(Arc::new(invoke)),
}
}
#[must_use]
pub const fn unavailable() -> Self {
Self { invoke: None }
}
}
impl SamplingProvider for DirectModelSampling {
fn is_available(&self) -> bool {
self.invoke.is_some()
}
fn sample(
&self,
request: SamplingRequest,
) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>> {
match &self.invoke {
Some(invoke) => invoke(request),
None => Box::pin(async { Err(SamplingError::NotAvailable) }),
}
}
}