oxibonsai_runtime/
builders.rs1use crate::config::OxiBonsaiConfig;
11use crate::error::{RuntimeError, RuntimeResult};
12use crate::sampling::{Sampler, SamplingParams};
13
14pub struct SamplerBuilder {
35 temperature: f32,
36 top_k: usize,
37 top_p: f32,
38 repetition_penalty: f32,
39 seed: u64,
40}
41
42impl SamplerBuilder {
43 pub fn new() -> Self {
45 Self {
46 temperature: 0.7,
47 top_k: 40,
48 top_p: 0.9,
49 repetition_penalty: 1.1,
50 seed: 42,
51 }
52 }
53
54 pub fn temperature(mut self, t: f32) -> Self {
56 self.temperature = t;
57 self
58 }
59
60 pub fn top_k(mut self, k: usize) -> Self {
62 self.top_k = k;
63 self
64 }
65
66 pub fn top_p(mut self, p: f32) -> Self {
68 self.top_p = p;
69 self
70 }
71
72 pub fn repetition_penalty(mut self, rp: f32) -> Self {
74 self.repetition_penalty = rp;
75 self
76 }
77
78 pub fn seed(mut self, s: u64) -> Self {
80 self.seed = s;
81 self
82 }
83
84 pub fn build(self) -> RuntimeResult<Sampler> {
86 if self.temperature < 0.0 {
87 return Err(RuntimeError::Config(format!(
88 "temperature must be >= 0.0, got {}",
89 self.temperature
90 )));
91 }
92 if self.top_p < 0.0 || self.top_p > 1.0 {
93 return Err(RuntimeError::Config(format!(
94 "top_p must be in [0.0, 1.0], got {}",
95 self.top_p
96 )));
97 }
98 if self.repetition_penalty < 1.0 {
99 return Err(RuntimeError::Config(format!(
100 "repetition_penalty must be >= 1.0, got {}",
101 self.repetition_penalty
102 )));
103 }
104
105 let params = SamplingParams {
106 temperature: self.temperature,
107 top_k: self.top_k,
108 top_p: self.top_p,
109 repetition_penalty: self.repetition_penalty,
110 max_tokens: SamplingParams::default().max_tokens,
111 };
112
113 Ok(Sampler::new(params, self.seed))
114 }
115}
116
117impl Default for SamplerBuilder {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123pub struct ConfigBuilder {
125 config: OxiBonsaiConfig,
126}
127
128impl ConfigBuilder {
129 pub fn new() -> Self {
131 Self {
132 config: OxiBonsaiConfig::default(),
133 }
134 }
135
136 pub fn model_path(mut self, path: impl Into<String>) -> Self {
138 self.config.model.model_path = Some(path.into());
139 self
140 }
141
142 pub fn tokenizer_path(mut self, path: impl Into<String>) -> Self {
144 self.config.model.tokenizer_path = Some(path.into());
145 self
146 }
147
148 pub fn max_seq_len(mut self, len: usize) -> Self {
150 self.config.model.max_seq_len = len;
151 self
152 }
153
154 pub fn host(mut self, h: impl Into<String>) -> Self {
156 self.config.server.host = h.into();
157 self
158 }
159
160 pub fn port(mut self, p: u16) -> Self {
162 self.config.server.port = p;
163 self
164 }
165
166 pub fn log_level(mut self, level: impl Into<String>) -> Self {
168 self.config.observability.log_level = level.into();
169 self
170 }
171
172 pub fn json_logs(mut self, enabled: bool) -> Self {
174 self.config.observability.json_logs = enabled;
175 self
176 }
177
178 pub fn temperature(mut self, t: f32) -> Self {
180 self.config.sampling.temperature = t;
181 self
182 }
183
184 pub fn top_k(mut self, k: usize) -> Self {
186 self.config.sampling.top_k = k;
187 self
188 }
189
190 pub fn top_p(mut self, p: f32) -> Self {
192 self.config.sampling.top_p = p;
193 self
194 }
195
196 pub fn repetition_penalty(mut self, rp: f32) -> Self {
198 self.config.sampling.repetition_penalty = rp;
199 self
200 }
201
202 pub fn max_tokens(mut self, n: usize) -> Self {
204 self.config.sampling.max_tokens = n;
205 self
206 }
207
208 pub fn build(self) -> RuntimeResult<OxiBonsaiConfig> {
210 self.config.validate()?;
211 Ok(self.config)
212 }
213}
214
215impl Default for ConfigBuilder {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221pub struct EngineBuilder {
227 config: Option<OxiBonsaiConfig>,
228 sampler: Option<SamplerBuilder>,
229 kernel_tier: Option<String>,
230}
231
232impl EngineBuilder {
233 pub fn new() -> Self {
235 Self {
236 config: None,
237 sampler: None,
238 kernel_tier: None,
239 }
240 }
241
242 pub fn config(mut self, config: OxiBonsaiConfig) -> Self {
244 self.config = Some(config);
245 self
246 }
247
248 pub fn config_file(mut self, path: &str) -> RuntimeResult<Self> {
250 let config = OxiBonsaiConfig::load(std::path::Path::new(path))?;
251 self.config = Some(config);
252 Ok(self)
253 }
254
255 pub fn sampler(mut self, builder: SamplerBuilder) -> Self {
257 self.sampler = Some(builder);
258 self
259 }
260
261 pub fn kernel_tier(mut self, tier: &str) -> Self {
263 self.kernel_tier = Some(tier.to_string());
264 self
265 }
266
267 pub fn configured_kernel_tier(&self) -> Option<&str> {
269 self.kernel_tier.as_deref()
270 }
271
272 pub fn build(self) -> RuntimeResult<(OxiBonsaiConfig, Sampler)> {
277 let config = self.config.unwrap_or_default();
278 config.validate()?;
279
280 let sampler = match self.sampler {
281 Some(builder) => builder.build()?,
282 None => {
283 SamplerBuilder::new()
285 .temperature(config.sampling.temperature)
286 .top_k(config.sampling.top_k)
287 .top_p(config.sampling.top_p)
288 .repetition_penalty(config.sampling.repetition_penalty)
289 .build()?
290 }
291 };
292
293 Ok((config, sampler))
294 }
295}
296
297impl Default for EngineBuilder {
298 fn default() -> Self {
299 Self::new()
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
310 fn sampler_builder_defaults() {
311 let sampler = SamplerBuilder::new().build();
312 assert!(sampler.is_ok());
313 let sampler = sampler.expect("default build should succeed");
314 let params = sampler.params();
315 assert!((params.temperature - 0.7).abs() < f32::EPSILON);
316 assert_eq!(params.top_k, 40);
317 assert!((params.top_p - 0.9).abs() < f32::EPSILON);
318 assert!((params.repetition_penalty - 1.1).abs() < f32::EPSILON);
319 }
320
321 #[test]
322 fn sampler_builder_chain() {
323 let sampler = SamplerBuilder::new()
324 .temperature(0.5)
325 .top_k(50)
326 .top_p(0.95)
327 .repetition_penalty(1.2)
328 .seed(123)
329 .build();
330 assert!(sampler.is_ok());
331 let sampler = sampler.expect("chained build should succeed");
332 let params = sampler.params();
333 assert!((params.temperature - 0.5).abs() < f32::EPSILON);
334 assert_eq!(params.top_k, 50);
335 assert!((params.top_p - 0.95).abs() < f32::EPSILON);
336 assert!((params.repetition_penalty - 1.2).abs() < f32::EPSILON);
337 }
338
339 #[test]
340 fn sampler_builder_negative_temperature() {
341 let result = SamplerBuilder::new().temperature(-0.1).build();
342 assert!(result.is_err());
343 let err = result.expect_err("negative temperature should fail");
344 assert!(err.to_string().contains("temperature"));
345 }
346
347 #[test]
348 fn sampler_builder_invalid_top_p_high() {
349 let result = SamplerBuilder::new().top_p(1.5).build();
350 assert!(result.is_err());
351 let err = result.expect_err("top_p > 1 should fail");
352 assert!(err.to_string().contains("top_p"));
353 }
354
355 #[test]
356 fn sampler_builder_invalid_top_p_low() {
357 let result = SamplerBuilder::new().top_p(-0.1).build();
358 assert!(result.is_err());
359 }
360
361 #[test]
362 fn sampler_builder_invalid_repetition_penalty() {
363 let result = SamplerBuilder::new().repetition_penalty(0.5).build();
364 assert!(result.is_err());
365 let err = result.expect_err("rep_pen < 1 should fail");
366 assert!(err.to_string().contains("repetition_penalty"));
367 }
368
369 #[test]
370 fn sampler_builder_zero_temperature() {
371 let result = SamplerBuilder::new().temperature(0.0).build();
372 assert!(result.is_ok());
373 }
374
375 #[test]
376 fn sampler_builder_boundary_top_p() {
377 assert!(SamplerBuilder::new().top_p(0.0).build().is_ok());
379 assert!(SamplerBuilder::new().top_p(1.0).build().is_ok());
380 }
381
382 #[test]
383 fn sampler_builder_default_trait() {
384 let builder = SamplerBuilder::default();
385 assert!(builder.build().is_ok());
386 }
387
388 #[test]
391 fn config_builder_defaults() {
392 let config = ConfigBuilder::new().build();
393 assert!(config.is_ok());
394 let config = config.expect("default build should succeed");
395 assert_eq!(config.server.host, "0.0.0.0");
396 assert_eq!(config.server.port, 8080);
397 assert_eq!(config.model.max_seq_len, 4096);
398 }
399
400 #[test]
401 fn config_builder_chain() {
402 let model_path = std::env::temp_dir().join("model.gguf");
403 let tokenizer_path = std::env::temp_dir().join("tokenizer.json");
404 let config = ConfigBuilder::new()
405 .model_path(model_path.display().to_string())
406 .tokenizer_path(tokenizer_path.display().to_string())
407 .max_seq_len(8192)
408 .host("127.0.0.1")
409 .port(3000)
410 .log_level("debug")
411 .json_logs(true)
412 .temperature(0.5)
413 .top_k(50)
414 .top_p(0.95)
415 .repetition_penalty(1.2)
416 .max_tokens(1024)
417 .build();
418 assert!(config.is_ok());
419 let config = config.expect("chained build should succeed");
420 assert_eq!(
421 config.model.model_path.as_deref(),
422 Some(model_path.to_str().expect("path is valid UTF-8"))
423 );
424 assert_eq!(
425 config.model.tokenizer_path.as_deref(),
426 Some(tokenizer_path.to_str().expect("path is valid UTF-8"))
427 );
428 assert_eq!(config.model.max_seq_len, 8192);
429 assert_eq!(config.server.host, "127.0.0.1");
430 assert_eq!(config.server.port, 3000);
431 assert_eq!(config.observability.log_level, "debug");
432 assert!(config.observability.json_logs);
433 assert!((config.sampling.temperature - 0.5).abs() < f32::EPSILON);
434 assert_eq!(config.sampling.top_k, 50);
435 assert_eq!(config.sampling.max_tokens, 1024);
436 }
437
438 #[test]
439 fn config_builder_invalid_temperature() {
440 let result = ConfigBuilder::new().temperature(-1.0).build();
441 assert!(result.is_err());
442 }
443
444 #[test]
445 fn config_builder_invalid_top_p() {
446 let result = ConfigBuilder::new().top_p(2.0).build();
447 assert!(result.is_err());
448 }
449
450 #[test]
451 fn config_builder_invalid_max_seq_len() {
452 let result = ConfigBuilder::new().max_seq_len(0).build();
453 assert!(result.is_err());
454 }
455
456 #[test]
457 fn config_builder_default_trait() {
458 let builder = ConfigBuilder::default();
459 assert!(builder.build().is_ok());
460 }
461
462 #[test]
465 fn engine_builder_defaults() {
466 let result = EngineBuilder::new().build();
467 assert!(result.is_ok());
468 let (config, _sampler) = result.expect("default build should succeed");
469 assert_eq!(config.server.port, 8080);
470 }
471
472 #[test]
473 fn engine_builder_with_config() {
474 let config = ConfigBuilder::new()
475 .port(9090)
476 .build()
477 .expect("config build should succeed");
478 let result = EngineBuilder::new().config(config).build();
479 assert!(result.is_ok());
480 let (config, _sampler) = result.expect("build with config should succeed");
481 assert_eq!(config.server.port, 9090);
482 }
483
484 #[test]
485 fn engine_builder_with_sampler() {
486 let sampler_builder = SamplerBuilder::new().temperature(0.3).seed(99);
487 let result = EngineBuilder::new().sampler(sampler_builder).build();
488 assert!(result.is_ok());
489 let (_config, sampler) = result.expect("build with sampler should succeed");
490 assert!((sampler.params().temperature - 0.3).abs() < f32::EPSILON);
491 }
492
493 #[test]
494 fn engine_builder_with_kernel_tier() {
495 let builder = EngineBuilder::new().kernel_tier("reference");
496 assert_eq!(builder.configured_kernel_tier(), Some("reference"));
497 let result = builder.build();
498 assert!(result.is_ok());
499 }
500
501 #[test]
502 fn engine_builder_invalid_sampler() {
503 let sampler_builder = SamplerBuilder::new().temperature(-1.0);
504 let result = EngineBuilder::new().sampler(sampler_builder).build();
505 assert!(result.is_err());
506 }
507
508 #[test]
509 fn engine_builder_config_file_nonexistent() {
510 let path = std::env::temp_dir().join("nonexistent_oxibonsai_test_12345.toml");
511 let result = EngineBuilder::new().config_file(path.to_str().expect("path is valid UTF-8"));
512 assert!(result.is_err());
513 }
514
515 #[test]
516 fn engine_builder_config_file_valid() {
517 let dir = std::env::temp_dir();
518 let path = dir.join("oxibonsai_builder_test.toml");
519 std::fs::write(
520 &path,
521 r#"
522[server]
523port = 7777
524"#,
525 )
526 .expect("write temp config");
527
528 let path_str = path.to_string_lossy().to_string();
529 let result = EngineBuilder::new()
530 .config_file(&path_str)
531 .expect("should load config file")
532 .build();
533 assert!(result.is_ok());
534 let (config, _) = result.expect("build should succeed");
535 assert_eq!(config.server.port, 7777);
536
537 let _ = std::fs::remove_file(&path);
538 }
539
540 #[test]
541 fn engine_builder_default_trait() {
542 let builder = EngineBuilder::default();
543 assert!(builder.build().is_ok());
544 }
545}