1pub mod json_mode;
27
28pub use ferrum_interfaces::sampler::{
30 GreedySampler, LogitsProcessor, LogitsProcessorChain, MultiSampler, MultinomialSampler,
31 ProcessorPriority, RepetitionPenaltyProcessor, Sampler, SamplingConfig, SamplingConfigBuilder,
32 SamplingContext, SamplingStats, TemperatureProcessor, TopKProcessor, TopPProcessor,
33};
34
35pub use ferrum_types::{Result, SamplingParams, TokenId};
37
38use rand::RngCore;
39use std::collections::HashMap;
40
41#[derive(Debug, Clone, Default)]
43pub struct DefaultSamplerFactory;
44
45impl DefaultSamplerFactory {
46 pub fn new() -> Self {
48 Self
49 }
50
51 pub fn build_config(&self, params: &SamplingParams) -> SamplingConfig {
53 SamplingConfig::from_params(params)
54 }
55
56 pub fn create_sampler(&self, params: &SamplingParams) -> Box<dyn Sampler + Send + Sync> {
60 if params.temperature == 0.0 {
61 Box::new(GreedySampler)
62 } else {
63 Box::new(MultinomialSampler)
64 }
65 }
66
67 pub fn build_pipeline(&self, params: &SamplingParams) -> SamplingPipeline {
69 let config = self.build_config(params);
70 SamplingPipeline { config }
71 }
72}
73
74pub struct SamplingPipeline {
79 config: SamplingConfig,
80}
81
82impl SamplingPipeline {
83 pub fn new(params: &SamplingParams) -> Self {
85 let config = SamplingConfig::from_params(params);
86 Self { config }
87 }
88
89 pub fn config(&self) -> &SamplingConfig {
91 &self.config
92 }
93
94 pub fn sample_next(
104 &self,
105 step: usize,
106 logits: &mut [f32],
107 previous_tokens: &[TokenId],
108 token_frequencies: &HashMap<TokenId, usize>,
109 sampling_params: &SamplingParams,
110 rng: &mut dyn RngCore,
111 ) -> Result<TokenId> {
112 let vocab_size = logits.len();
113 let ctx = SamplingContext::new(
114 step,
115 sampling_params,
116 logits,
117 previous_tokens,
118 token_frequencies,
119 vocab_size,
120 );
121 self.config.sample(ctx, rng)
122 }
123
124 pub fn sample_simple(&self, logits: &mut [f32], rng: &mut dyn RngCore) -> Result<TokenId> {
126 let params = SamplingParams::default();
127 let empty_tokens = Vec::new();
128 let empty_freqs = HashMap::new();
129 self.sample_next(0, logits, &empty_tokens, &empty_freqs, ¶ms, rng)
130 }
131}
132
133pub fn build_sampling_config(params: &SamplingParams) -> SamplingConfig {
141 SamplingConfig::from_params(params)
142}
143
144pub fn sampler_from_params(params: &SamplingParams) -> Box<dyn Sampler + Send + Sync> {
148 DefaultSamplerFactory::new().create_sampler(params)
149}
150
151pub fn pipeline_from_params(params: &SamplingParams) -> SamplingPipeline {
153 DefaultSamplerFactory::new().build_pipeline(params)
154}
155
156pub fn greedy_sampler() -> Box<dyn Sampler + Send + Sync> {
158 Box::new(GreedySampler)
159}
160
161pub fn multinomial_sampler() -> Box<dyn Sampler + Send + Sync> {
163 Box::new(MultinomialSampler)
164}
165
166#[cfg(test)]
171mod tests {
172 use super::*;
173 use rand::rngs::StdRng;
174 use rand::SeedableRng;
175
176 #[test]
177 fn test_factory_creates_greedy_for_zero_temp() {
178 let factory = DefaultSamplerFactory::new();
179 let params = SamplingParams {
180 temperature: 0.0,
181 ..Default::default()
182 };
183 let sampler = factory.create_sampler(¶ms);
184 assert!(sampler.is_deterministic());
185 }
186
187 #[test]
188 fn test_factory_creates_multinomial_for_nonzero_temp() {
189 let factory = DefaultSamplerFactory::new();
190 let params = SamplingParams {
191 temperature: 1.0,
192 ..Default::default()
193 };
194 let sampler = factory.create_sampler(¶ms);
195 assert!(!sampler.is_deterministic());
196 }
197
198 #[test]
199 fn test_build_sampling_config() {
200 let params = SamplingParams {
201 temperature: 0.8,
202 top_k: Some(50),
203 top_p: 0.95,
204 repetition_penalty: 1.1,
205 ..Default::default()
206 };
207 let config = build_sampling_config(¶ms);
208 assert_eq!(config.processor_chain.processor_names().len(), 4);
211 }
212
213 #[test]
214 fn test_pipeline_sample_simple() {
215 let params = SamplingParams::greedy();
216 let pipeline = pipeline_from_params(¶ms);
217 let mut rng = StdRng::seed_from_u64(42);
218
219 let mut logits = vec![1.0, 5.0, 2.0, 0.5];
220 let token = pipeline.sample_simple(&mut logits, &mut rng).unwrap();
221
222 assert_eq!(token.get(), 1);
224 }
225
226 #[test]
227 fn test_greedy_sampler_deterministic() {
228 let sampler = greedy_sampler();
229 assert!(sampler.is_deterministic());
230 assert_eq!(sampler.name(), "greedy");
231 }
232
233 #[test]
234 fn test_multinomial_sampler_stochastic() {
235 let sampler = multinomial_sampler();
236 assert!(!sampler.is_deterministic());
237 assert_eq!(sampler.name(), "multinomial");
238 }
239
240 #[test]
241 fn test_pipeline_with_context() {
242 let params = SamplingParams {
243 temperature: 1.0,
244 repetition_penalty: 1.2,
245 ..Default::default()
246 };
247 let pipeline = SamplingPipeline::new(¶ms);
248 let mut rng = StdRng::seed_from_u64(42);
249
250 let mut logits = vec![1.0, 2.0, 3.0, 2.0];
251 let previous_tokens = vec![TokenId::new(2)]; let mut freqs = HashMap::new();
253 freqs.insert(TokenId::new(2), 1);
254
255 let token = pipeline
256 .sample_next(0, &mut logits, &previous_tokens, &freqs, ¶ms, &mut rng)
257 .unwrap();
258
259 assert!(token.get() < 4);
261 }
262}