wax-core 0.1.0

Core inference engine for wax, a small Candle-based local LLM runner
Documentation
use candle_core::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};

use crate::{Result, WaxError};

#[derive(Debug, Clone, Copy)]
pub struct SamplingConfig {
    pub temperature: f64,
    pub top_k: Option<usize>,
    pub top_p: Option<f64>,
    pub repetition_penalty: f32,
    pub repeat_last_n: usize,
    pub seed: u64,
}

impl Default for SamplingConfig {
    fn default() -> Self {
        Self {
            temperature: 0.0,
            top_k: None,
            top_p: None,
            repetition_penalty: 1.0,
            repeat_last_n: 128,
            seed: 299_792_458,
        }
    }
}

impl SamplingConfig {
    pub fn validate(&self) -> Result<()> {
        if !self.temperature.is_finite() || self.temperature < 0.0 {
            return Err(WaxError::InvalidRequest(
                "temperature must be finite and >= 0".to_string(),
            ));
        }
        if matches!(self.top_k, Some(0)) {
            return Err(WaxError::InvalidRequest("top-k must be > 0".to_string()));
        }
        if let Some(top_p) = self.top_p {
            if !top_p.is_finite() || !(0.0..=1.0).contains(&top_p) {
                return Err(WaxError::InvalidRequest(
                    "top-p must be finite and between 0 and 1".to_string(),
                ));
            }
        }
        if !self.repetition_penalty.is_finite() || self.repetition_penalty <= 0.0 {
            return Err(WaxError::InvalidRequest(
                "repetition penalty must be finite and > 0".to_string(),
            ));
        }
        Ok(())
    }

    pub fn processor(&self) -> Result<LogitsProcessor> {
        self.validate()?;
        Ok(LogitsProcessor::from_sampling(self.seed, self.sampling()))
    }

    fn sampling(&self) -> Sampling {
        if self.temperature <= 0.0 {
            return Sampling::ArgMax;
        }

        match (self.top_k, self.top_p) {
            (None, None) => Sampling::All {
                temperature: self.temperature,
            },
            (Some(k), None) => Sampling::TopK {
                k,
                temperature: self.temperature,
            },
            (None, Some(p)) => Sampling::TopP {
                p,
                temperature: self.temperature,
            },
            (Some(k), Some(p)) => Sampling::TopKThenTopP {
                k,
                p,
                temperature: self.temperature,
            },
        }
    }
}

pub struct Sampler {
    config: SamplingConfig,
    processor: LogitsProcessor,
}

impl Sampler {
    pub fn new(config: SamplingConfig) -> Result<Self> {
        Ok(Self {
            config,
            processor: config.processor()?,
        })
    }

    pub fn sample(&mut self, logits: &Tensor, tokens: &[u32]) -> Result<u32> {
        let logits = if self.config.repetition_penalty == 1.0 {
            logits.clone()
        } else {
            let start_at = tokens.len().saturating_sub(self.config.repeat_last_n);
            candle_transformers::utils::apply_repeat_penalty(
                logits,
                self.config.repetition_penalty,
                &tokens[start_at..],
            )?
        };

        Ok(self.processor.sample(&logits)?)
    }
}

#[cfg(test)]
mod tests {
    use candle_core::{Device, Tensor};

    use super::{Sampler, SamplingConfig};

    #[test]
    fn greedy_selects_argmax() {
        let logits = Tensor::new(&[0.1f32, 4.0, 0.2], &Device::Cpu).unwrap();
        let mut sampler = Sampler::new(SamplingConfig {
            temperature: 0.0,
            ..SamplingConfig::default()
        })
        .unwrap();

        let token = sampler.sample(&logits, &[]).unwrap();

        assert_eq!(token, 1);
    }

    #[test]
    fn seeded_sampling_is_deterministic() {
        let logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &Device::Cpu).unwrap();
        let config = SamplingConfig {
            temperature: 0.8,
            top_k: Some(3),
            top_p: Some(0.9),
            seed: 42,
            ..SamplingConfig::default()
        };
        let mut left = Sampler::new(config).unwrap();
        let mut right = Sampler::new(config).unwrap();

        let left_token = left.sample(&logits, &[]).unwrap();
        let right_token = right.sample(&logits, &[]).unwrap();

        assert_eq!(left_token, right_token);
    }

    #[test]
    fn rejects_invalid_top_k() {
        let err = SamplingConfig {
            top_k: Some(0),
            ..SamplingConfig::default()
        }
        .validate()
        .unwrap_err();

        assert!(err.to_string().contains("top-k"));
    }
}