use crate::BoxFuture;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SamplingRequest {
pub system: Option<String>,
pub prompt: String,
pub max_tokens: Option<u32>,
pub temperature: Option<f32>,
}
impl SamplingRequest {
pub fn new(prompt: impl Into<String>) -> Self {
Self {
system: None,
prompt: prompt.into(),
max_tokens: None,
temperature: None,
}
}
#[must_use]
pub fn with_system(mut self, system: impl Into<String>) -> Self {
self.system = Some(system.into());
self
}
#[must_use]
pub const fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
#[must_use]
pub const fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SamplingResponse {
pub text: String,
pub stop_reason: String,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum SamplingError {
#[error("no sampling provider configured")]
NotAvailable,
#[error("model refused: {0}")]
Refused(String),
#[error("sampling timed out")]
Timeout,
#[error("sampling error: {0}")]
Other(String),
}
pub trait SamplingProvider: Send + Sync {
fn is_available(&self) -> bool;
fn sample(
&self,
request: SamplingRequest,
) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>>;
}
pub struct NoopSamplingProvider;
impl SamplingProvider for NoopSamplingProvider {
fn is_available(&self) -> bool {
false
}
fn sample(
&self,
_request: SamplingRequest,
) -> BoxFuture<'_, Result<SamplingResponse, SamplingError>> {
Box::pin(async { Err(SamplingError::NotAvailable) })
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[tokio::test]
async fn noop_provider_returns_not_available() {
let p = NoopSamplingProvider;
assert!(!p.is_available());
let result = p.sample(SamplingRequest::new("test")).await;
assert!(matches!(result, Err(SamplingError::NotAvailable)));
}
#[test]
fn sampling_request_builder() {
let req = SamplingRequest::new("hello")
.with_system("sys")
.with_max_tokens(100)
.with_temperature(0.7);
assert_eq!(req.prompt, "hello");
assert_eq!(req.system.as_deref(), Some("sys"));
assert_eq!(req.max_tokens, Some(100));
assert!((req.temperature.unwrap_or(0.0) - 0.7).abs() < f32::EPSILON);
}
#[tokio::test]
async fn noop_provider_is_object_safe() {
let p: &dyn SamplingProvider = &NoopSamplingProvider;
assert!(!p.is_available());
let result = p.sample(SamplingRequest::new("test")).await;
assert!(matches!(result, Err(SamplingError::NotAvailable)));
}
}