Skip to main content

oxibonsai_runtime/
builders.rs

1//! Builder patterns for ergonomic OxiBonsai setup.
2//!
3//! Three builders are provided for validating and constructing the main
4//! runtime objects:
5//!
6//! - [`SamplerBuilder`] — validates and creates a [`Sampler`]
7//! - [`ConfigBuilder`] — validates and creates an [`OxiBonsaiConfig`]
8//! - [`EngineBuilder`] — orchestrates config + sampler together
9
10use crate::config::OxiBonsaiConfig;
11use crate::error::{RuntimeError, RuntimeResult};
12use crate::sampling::{Sampler, SamplingParams};
13
14/// Builder for sampling parameters with validation.
15///
16/// # Example
17///
18/// ```
19/// use oxibonsai_runtime::builders::SamplerBuilder;
20///
21/// let sampler = SamplerBuilder::new()
22///     .temperature(0.5)
23///     .top_k(50)
24///     .top_p(0.95)
25///     .repetition_penalty(1.2)
26///     .seed(123)
27///     .build()
28///     .expect("valid parameters");
29///
30/// let params = sampler.params();
31/// assert!((params.temperature - 0.5).abs() < f32::EPSILON);
32/// assert_eq!(params.top_k, 50);
33/// ```
34pub struct SamplerBuilder {
35    temperature: f32,
36    top_k: usize,
37    top_p: f32,
38    repetition_penalty: f32,
39    seed: u64,
40}
41
42impl SamplerBuilder {
43    /// Create a new sampler builder with default values.
44    pub fn new() -> Self {
45        Self {
46            temperature: 0.7,
47            top_k: 40,
48            top_p: 0.9,
49            repetition_penalty: 1.1,
50            seed: 42,
51        }
52    }
53
54    /// Set the temperature for softmax scaling. Must be >= 0.
55    pub fn temperature(mut self, t: f32) -> Self {
56        self.temperature = t;
57        self
58    }
59
60    /// Set top-k filtering. 0 = disabled.
61    pub fn top_k(mut self, k: usize) -> Self {
62        self.top_k = k;
63        self
64    }
65
66    /// Set top-p (nucleus) threshold. Must be in [0.0, 1.0].
67    pub fn top_p(mut self, p: f32) -> Self {
68        self.top_p = p;
69        self
70    }
71
72    /// Set repetition penalty. Must be >= 1.0.
73    pub fn repetition_penalty(mut self, rp: f32) -> Self {
74        self.repetition_penalty = rp;
75        self
76    }
77
78    /// Set the random seed.
79    pub fn seed(mut self, s: u64) -> Self {
80        self.seed = s;
81        self
82    }
83
84    /// Validate parameters and build the [`Sampler`].
85    pub fn build(self) -> RuntimeResult<Sampler> {
86        if self.temperature < 0.0 {
87            return Err(RuntimeError::Config(format!(
88                "temperature must be >= 0.0, got {}",
89                self.temperature
90            )));
91        }
92        if self.top_p < 0.0 || self.top_p > 1.0 {
93            return Err(RuntimeError::Config(format!(
94                "top_p must be in [0.0, 1.0], got {}",
95                self.top_p
96            )));
97        }
98        if self.repetition_penalty < 1.0 {
99            return Err(RuntimeError::Config(format!(
100                "repetition_penalty must be >= 1.0, got {}",
101                self.repetition_penalty
102            )));
103        }
104
105        let params = SamplingParams {
106            temperature: self.temperature,
107            top_k: self.top_k,
108            top_p: self.top_p,
109            repetition_penalty: self.repetition_penalty,
110            max_tokens: SamplingParams::default().max_tokens,
111        };
112
113        Ok(Sampler::new(params, self.seed))
114    }
115}
116
117impl Default for SamplerBuilder {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123/// Builder for the [`OxiBonsaiConfig`].
124pub struct ConfigBuilder {
125    config: OxiBonsaiConfig,
126}
127
128impl ConfigBuilder {
129    /// Create a new config builder with default values.
130    pub fn new() -> Self {
131        Self {
132            config: OxiBonsaiConfig::default(),
133        }
134    }
135
136    /// Set the path to the GGUF model file.
137    pub fn model_path(mut self, path: impl Into<String>) -> Self {
138        self.config.model.model_path = Some(path.into());
139        self
140    }
141
142    /// Set the path to the tokenizer.json file.
143    pub fn tokenizer_path(mut self, path: impl Into<String>) -> Self {
144        self.config.model.tokenizer_path = Some(path.into());
145        self
146    }
147
148    /// Set the maximum sequence length (prompt + generated).
149    pub fn max_seq_len(mut self, len: usize) -> Self {
150        self.config.model.max_seq_len = len;
151        self
152    }
153
154    /// Set the server bind host address.
155    pub fn host(mut self, h: impl Into<String>) -> Self {
156        self.config.server.host = h.into();
157        self
158    }
159
160    /// Set the server bind port.
161    pub fn port(mut self, p: u16) -> Self {
162        self.config.server.port = p;
163        self
164    }
165
166    /// Set the log level filter (e.g. "info", "debug", "warn").
167    pub fn log_level(mut self, level: impl Into<String>) -> Self {
168        self.config.observability.log_level = level.into();
169        self
170    }
171
172    /// Enable or disable JSON-formatted logs.
173    pub fn json_logs(mut self, enabled: bool) -> Self {
174        self.config.observability.json_logs = enabled;
175        self
176    }
177
178    /// Set the sampling temperature.
179    pub fn temperature(mut self, t: f32) -> Self {
180        self.config.sampling.temperature = t;
181        self
182    }
183
184    /// Set the top-k sampling parameter.
185    pub fn top_k(mut self, k: usize) -> Self {
186        self.config.sampling.top_k = k;
187        self
188    }
189
190    /// Set the top-p (nucleus) sampling parameter.
191    pub fn top_p(mut self, p: f32) -> Self {
192        self.config.sampling.top_p = p;
193        self
194    }
195
196    /// Set the repetition penalty.
197    pub fn repetition_penalty(mut self, rp: f32) -> Self {
198        self.config.sampling.repetition_penalty = rp;
199        self
200    }
201
202    /// Set the maximum tokens to generate.
203    pub fn max_tokens(mut self, n: usize) -> Self {
204        self.config.sampling.max_tokens = n;
205        self
206    }
207
208    /// Validate and build the [`OxiBonsaiConfig`].
209    pub fn build(self) -> RuntimeResult<OxiBonsaiConfig> {
210        self.config.validate()?;
211        Ok(self.config)
212    }
213}
214
215impl Default for ConfigBuilder {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221/// Builder for the inference engine (high-level orchestrator).
222///
223/// Validates configuration and sampling parameters together.
224/// Cannot create an actual engine without a GGUF file, but returns
225/// the validated config and sampler ready for engine construction.
226pub struct EngineBuilder {
227    config: Option<OxiBonsaiConfig>,
228    sampler: Option<SamplerBuilder>,
229    kernel_tier: Option<String>,
230}
231
232impl EngineBuilder {
233    /// Create a new engine builder.
234    pub fn new() -> Self {
235        Self {
236            config: None,
237            sampler: None,
238            kernel_tier: None,
239        }
240    }
241
242    /// Set the configuration directly.
243    pub fn config(mut self, config: OxiBonsaiConfig) -> Self {
244        self.config = Some(config);
245        self
246    }
247
248    /// Load configuration from a TOML file.
249    pub fn config_file(mut self, path: &str) -> RuntimeResult<Self> {
250        let config = OxiBonsaiConfig::load(std::path::Path::new(path))?;
251        self.config = Some(config);
252        Ok(self)
253    }
254
255    /// Set a custom sampler builder.
256    pub fn sampler(mut self, builder: SamplerBuilder) -> Self {
257        self.sampler = Some(builder);
258        self
259    }
260
261    /// Set the preferred kernel tier (e.g. "reference", "avx2", "neon").
262    pub fn kernel_tier(mut self, tier: &str) -> Self {
263        self.kernel_tier = Some(tier.to_string());
264        self
265    }
266
267    /// Get the configured kernel tier name, if any.
268    pub fn configured_kernel_tier(&self) -> Option<&str> {
269        self.kernel_tier.as_deref()
270    }
271
272    /// Validate and build the config + sampler pair.
273    ///
274    /// Returns the validated configuration and sampler, ready for
275    /// engine construction once a GGUF file is available.
276    pub fn build(self) -> RuntimeResult<(OxiBonsaiConfig, Sampler)> {
277        let config = self.config.unwrap_or_default();
278        config.validate()?;
279
280        let sampler = match self.sampler {
281            Some(builder) => builder.build()?,
282            None => {
283                // Build sampler from config's sampling parameters
284                SamplerBuilder::new()
285                    .temperature(config.sampling.temperature)
286                    .top_k(config.sampling.top_k)
287                    .top_p(config.sampling.top_p)
288                    .repetition_penalty(config.sampling.repetition_penalty)
289                    .build()?
290            }
291        };
292
293        Ok((config, sampler))
294    }
295}
296
297impl Default for EngineBuilder {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    // ── SamplerBuilder tests ──
308
309    #[test]
310    fn sampler_builder_defaults() {
311        let sampler = SamplerBuilder::new().build();
312        assert!(sampler.is_ok());
313        let sampler = sampler.expect("default build should succeed");
314        let params = sampler.params();
315        assert!((params.temperature - 0.7).abs() < f32::EPSILON);
316        assert_eq!(params.top_k, 40);
317        assert!((params.top_p - 0.9).abs() < f32::EPSILON);
318        assert!((params.repetition_penalty - 1.1).abs() < f32::EPSILON);
319    }
320
321    #[test]
322    fn sampler_builder_chain() {
323        let sampler = SamplerBuilder::new()
324            .temperature(0.5)
325            .top_k(50)
326            .top_p(0.95)
327            .repetition_penalty(1.2)
328            .seed(123)
329            .build();
330        assert!(sampler.is_ok());
331        let sampler = sampler.expect("chained build should succeed");
332        let params = sampler.params();
333        assert!((params.temperature - 0.5).abs() < f32::EPSILON);
334        assert_eq!(params.top_k, 50);
335        assert!((params.top_p - 0.95).abs() < f32::EPSILON);
336        assert!((params.repetition_penalty - 1.2).abs() < f32::EPSILON);
337    }
338
339    #[test]
340    fn sampler_builder_negative_temperature() {
341        let result = SamplerBuilder::new().temperature(-0.1).build();
342        assert!(result.is_err());
343        let err = result.expect_err("negative temperature should fail");
344        assert!(err.to_string().contains("temperature"));
345    }
346
347    #[test]
348    fn sampler_builder_invalid_top_p_high() {
349        let result = SamplerBuilder::new().top_p(1.5).build();
350        assert!(result.is_err());
351        let err = result.expect_err("top_p > 1 should fail");
352        assert!(err.to_string().contains("top_p"));
353    }
354
355    #[test]
356    fn sampler_builder_invalid_top_p_low() {
357        let result = SamplerBuilder::new().top_p(-0.1).build();
358        assert!(result.is_err());
359    }
360
361    #[test]
362    fn sampler_builder_invalid_repetition_penalty() {
363        let result = SamplerBuilder::new().repetition_penalty(0.5).build();
364        assert!(result.is_err());
365        let err = result.expect_err("rep_pen < 1 should fail");
366        assert!(err.to_string().contains("repetition_penalty"));
367    }
368
369    #[test]
370    fn sampler_builder_zero_temperature() {
371        let result = SamplerBuilder::new().temperature(0.0).build();
372        assert!(result.is_ok());
373    }
374
375    #[test]
376    fn sampler_builder_boundary_top_p() {
377        // top_p = 0.0 and 1.0 should both be valid
378        assert!(SamplerBuilder::new().top_p(0.0).build().is_ok());
379        assert!(SamplerBuilder::new().top_p(1.0).build().is_ok());
380    }
381
382    #[test]
383    fn sampler_builder_default_trait() {
384        let builder = SamplerBuilder::default();
385        assert!(builder.build().is_ok());
386    }
387
388    // ── ConfigBuilder tests ──
389
390    #[test]
391    fn config_builder_defaults() {
392        let config = ConfigBuilder::new().build();
393        assert!(config.is_ok());
394        let config = config.expect("default build should succeed");
395        assert_eq!(config.server.host, "0.0.0.0");
396        assert_eq!(config.server.port, 8080);
397        assert_eq!(config.model.max_seq_len, 4096);
398    }
399
400    #[test]
401    fn config_builder_chain() {
402        let model_path = std::env::temp_dir().join("model.gguf");
403        let tokenizer_path = std::env::temp_dir().join("tokenizer.json");
404        let config = ConfigBuilder::new()
405            .model_path(model_path.display().to_string())
406            .tokenizer_path(tokenizer_path.display().to_string())
407            .max_seq_len(8192)
408            .host("127.0.0.1")
409            .port(3000)
410            .log_level("debug")
411            .json_logs(true)
412            .temperature(0.5)
413            .top_k(50)
414            .top_p(0.95)
415            .repetition_penalty(1.2)
416            .max_tokens(1024)
417            .build();
418        assert!(config.is_ok());
419        let config = config.expect("chained build should succeed");
420        assert_eq!(
421            config.model.model_path.as_deref(),
422            Some(model_path.to_str().expect("path is valid UTF-8"))
423        );
424        assert_eq!(
425            config.model.tokenizer_path.as_deref(),
426            Some(tokenizer_path.to_str().expect("path is valid UTF-8"))
427        );
428        assert_eq!(config.model.max_seq_len, 8192);
429        assert_eq!(config.server.host, "127.0.0.1");
430        assert_eq!(config.server.port, 3000);
431        assert_eq!(config.observability.log_level, "debug");
432        assert!(config.observability.json_logs);
433        assert!((config.sampling.temperature - 0.5).abs() < f32::EPSILON);
434        assert_eq!(config.sampling.top_k, 50);
435        assert_eq!(config.sampling.max_tokens, 1024);
436    }
437
438    #[test]
439    fn config_builder_invalid_temperature() {
440        let result = ConfigBuilder::new().temperature(-1.0).build();
441        assert!(result.is_err());
442    }
443
444    #[test]
445    fn config_builder_invalid_top_p() {
446        let result = ConfigBuilder::new().top_p(2.0).build();
447        assert!(result.is_err());
448    }
449
450    #[test]
451    fn config_builder_invalid_max_seq_len() {
452        let result = ConfigBuilder::new().max_seq_len(0).build();
453        assert!(result.is_err());
454    }
455
456    #[test]
457    fn config_builder_default_trait() {
458        let builder = ConfigBuilder::default();
459        assert!(builder.build().is_ok());
460    }
461
462    // ── EngineBuilder tests ──
463
464    #[test]
465    fn engine_builder_defaults() {
466        let result = EngineBuilder::new().build();
467        assert!(result.is_ok());
468        let (config, _sampler) = result.expect("default build should succeed");
469        assert_eq!(config.server.port, 8080);
470    }
471
472    #[test]
473    fn engine_builder_with_config() {
474        let config = ConfigBuilder::new()
475            .port(9090)
476            .build()
477            .expect("config build should succeed");
478        let result = EngineBuilder::new().config(config).build();
479        assert!(result.is_ok());
480        let (config, _sampler) = result.expect("build with config should succeed");
481        assert_eq!(config.server.port, 9090);
482    }
483
484    #[test]
485    fn engine_builder_with_sampler() {
486        let sampler_builder = SamplerBuilder::new().temperature(0.3).seed(99);
487        let result = EngineBuilder::new().sampler(sampler_builder).build();
488        assert!(result.is_ok());
489        let (_config, sampler) = result.expect("build with sampler should succeed");
490        assert!((sampler.params().temperature - 0.3).abs() < f32::EPSILON);
491    }
492
493    #[test]
494    fn engine_builder_with_kernel_tier() {
495        let builder = EngineBuilder::new().kernel_tier("reference");
496        assert_eq!(builder.configured_kernel_tier(), Some("reference"));
497        let result = builder.build();
498        assert!(result.is_ok());
499    }
500
501    #[test]
502    fn engine_builder_invalid_sampler() {
503        let sampler_builder = SamplerBuilder::new().temperature(-1.0);
504        let result = EngineBuilder::new().sampler(sampler_builder).build();
505        assert!(result.is_err());
506    }
507
508    #[test]
509    fn engine_builder_config_file_nonexistent() {
510        let path = std::env::temp_dir().join("nonexistent_oxibonsai_test_12345.toml");
511        let result = EngineBuilder::new().config_file(path.to_str().expect("path is valid UTF-8"));
512        assert!(result.is_err());
513    }
514
515    #[test]
516    fn engine_builder_config_file_valid() {
517        let dir = std::env::temp_dir();
518        let path = dir.join("oxibonsai_builder_test.toml");
519        std::fs::write(
520            &path,
521            r#"
522[server]
523port = 7777
524"#,
525        )
526        .expect("write temp config");
527
528        let path_str = path.to_string_lossy().to_string();
529        let result = EngineBuilder::new()
530            .config_file(&path_str)
531            .expect("should load config file")
532            .build();
533        assert!(result.is_ok());
534        let (config, _) = result.expect("build should succeed");
535        assert_eq!(config.server.port, 7777);
536
537        let _ = std::fs::remove_file(&path);
538    }
539
540    #[test]
541    fn engine_builder_default_trait() {
542        let builder = EngineBuilder::default();
543        assert!(builder.build().is_ok());
544    }
545}