use crate::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub trait SamplingOps<R: Runtime> {
fn apply_sampling_penalties(
&self,
logits: &Tensor<R>,
token_ids: &Tensor<R>,
token_counts: &Tensor<R>,
repeat_penalty: f32,
frequency_penalty: f32,
presence_penalty: f32,
) -> Result<()>;
fn sample_token(
&self,
logits: &Tensor<R>,
temperature: f32,
top_k: usize,
top_p: f32,
min_p: f32,
) -> Result<u32>;
#[allow(clippy::too_many_arguments)]
fn logits_to_token(
&self,
logits: &Tensor<R>,
token_ids: &Tensor<R>,
token_counts: &Tensor<R>,
num_unique: usize,
repeat_penalty: f32,
frequency_penalty: f32,
presence_penalty: f32,
temperature: f32,
top_k: usize,
top_p: f32,
min_p: f32,
seed: Option<u64>,
) -> Result<Tensor<R>>;
}