llama_crab/sampling/
chain.rs1use super::LlamaSampler;
4
5#[derive(Debug, Default)]
20pub struct SamplerChain {
21 samplers: Vec<LlamaSampler>,
22 no_perf: bool,
23}
24
25impl SamplerChain {
26 #[must_use]
28 pub const fn new() -> Self {
29 Self {
30 samplers: Vec::new(),
31 no_perf: false,
32 }
33 }
34
35 #[must_use]
37 pub const fn with_no_perf(mut self, yes: bool) -> Self {
38 self.no_perf = yes;
39 self
40 }
41
42 #[must_use]
44 pub fn temp(mut self, t: f32) -> Self {
45 if let Some(s) = LlamaSampler::temp(t) {
46 self.samplers.push(s);
47 }
48 self
49 }
50
51 #[must_use]
53 pub fn top_k(mut self, k: i32) -> Self {
54 if let Some(s) = LlamaSampler::top_k(k) {
55 self.samplers.push(s);
56 }
57 self
58 }
59
60 #[must_use]
62 pub fn top_p(mut self, p: f32, min_keep: usize) -> Self {
63 if let Some(s) = LlamaSampler::top_p(p, min_keep) {
64 self.samplers.push(s);
65 }
66 self
67 }
68
69 #[must_use]
71 pub fn min_p(mut self, p: f32, min_keep: usize) -> Self {
72 if let Some(s) = LlamaSampler::min_p(p, min_keep) {
73 self.samplers.push(s);
74 }
75 self
76 }
77
78 #[must_use]
80 pub fn penalties(mut self, last_n: i32, repeat: f32, freq: f32, present: f32) -> Self {
81 if let Some(s) = LlamaSampler::penalties(last_n, repeat, freq, present) {
82 self.samplers.push(s);
83 }
84 self
85 }
86
87 #[must_use]
89 pub fn greedy(mut self) -> Self {
90 if let Some(s) = LlamaSampler::greedy() {
91 self.samplers.push(s);
92 }
93 self
94 }
95
96 #[must_use]
98 pub fn build(self) -> Option<LlamaSampler> {
99 LlamaSampler::chain(self.samplers, self.no_perf)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn empty_chain_builds() {
109 let chain = SamplerChain::new().build();
110 assert!(chain.is_some());
111 }
112
113 #[test]
114 fn with_no_perf_propagates() {
115 let chain = SamplerChain::new().with_no_perf(true).greedy().build();
116 assert!(chain.is_some());
117 }
118
119 #[test]
120 fn fluent_chain_with_multiple_stages() {
121 let chain = SamplerChain::new()
126 .temp(0.8)
127 .top_p(0.95, 1)
128 .penalties(64, 1.1, 0.0, 0.0)
129 .build();
130 assert!(chain.is_some());
131 }
132}