Skip to main content

llama_crab/sampling/
chain.rs

1//! [`SamplerChain`] builder.
2
3use super::LlamaSampler;
4
5/// Builder for a sampler chain.
6///
7/// Convenience wrapper around [`LlamaSampler::chain`] with a fluent API.
8///
9/// # Example
10///
11/// ```no_run
12/// use llama_crab::sampling::SamplerChain;
13/// let chain = SamplerChain::new()
14///     .temp(0.8)
15///     .top_p(0.95, 1)
16///     .build();
17/// # let _ = chain;
18/// ```
19#[derive(Debug, Default)]
20pub struct SamplerChain {
21    samplers: Vec<LlamaSampler>,
22    no_perf: bool,
23}
24
25impl SamplerChain {
26    /// Construct a new empty chain.
27    #[must_use]
28    pub const fn new() -> Self {
29        Self {
30            samplers: Vec::new(),
31            no_perf: false,
32        }
33    }
34
35    /// Disable performance counters in the chain.
36    #[must_use]
37    pub const fn with_no_perf(mut self, yes: bool) -> Self {
38        self.no_perf = yes;
39        self
40    }
41
42    /// Add a temperature stage.
43    #[must_use]
44    pub fn temp(mut self, t: f32) -> Self {
45        if let Some(s) = LlamaSampler::temp(t) {
46            self.samplers.push(s);
47        }
48        self
49    }
50
51    /// Add a top-K stage.
52    #[must_use]
53    pub fn top_k(mut self, k: i32) -> Self {
54        if let Some(s) = LlamaSampler::top_k(k) {
55            self.samplers.push(s);
56        }
57        self
58    }
59
60    /// Add a top-P stage.
61    #[must_use]
62    pub fn top_p(mut self, p: f32, min_keep: usize) -> Self {
63        if let Some(s) = LlamaSampler::top_p(p, min_keep) {
64            self.samplers.push(s);
65        }
66        self
67    }
68
69    /// Add a min-P stage.
70    #[must_use]
71    pub fn min_p(mut self, p: f32, min_keep: usize) -> Self {
72        if let Some(s) = LlamaSampler::min_p(p, min_keep) {
73            self.samplers.push(s);
74        }
75        self
76    }
77
78    /// Add a penalties stage.
79    #[must_use]
80    pub fn penalties(mut self, last_n: i32, repeat: f32, freq: f32, present: f32) -> Self {
81        if let Some(s) = LlamaSampler::penalties(last_n, repeat, freq, present) {
82            self.samplers.push(s);
83        }
84        self
85    }
86
87    /// Add a greedy sampler.
88    #[must_use]
89    pub fn greedy(mut self) -> Self {
90        if let Some(s) = LlamaSampler::greedy() {
91            self.samplers.push(s);
92        }
93        self
94    }
95
96    /// Consume the chain and return a single [`LlamaSampler`].
97    #[must_use]
98    pub fn build(self) -> Option<LlamaSampler> {
99        LlamaSampler::chain(self.samplers, self.no_perf)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn empty_chain_builds() {
109        let chain = SamplerChain::new().build();
110        assert!(chain.is_some());
111    }
112
113    #[test]
114    fn with_no_perf_propagates() {
115        let chain = SamplerChain::new().with_no_perf(true).greedy().build();
116        assert!(chain.is_some());
117    }
118
119    #[test]
120    fn fluent_chain_with_multiple_stages() {
121        // Even without a model, the builder should accept stages and
122        // produce a chain. The inner samplers return `None` from
123        // their constructors (no model), so the chain ends up empty —
124        // but the builder itself doesn't panic.
125        let chain = SamplerChain::new()
126            .temp(0.8)
127            .top_p(0.95, 1)
128            .penalties(64, 1.1, 0.0, 0.0)
129            .build();
130        assert!(chain.is_some());
131    }
132}