use crate::config::OxiBonsaiConfig;
use crate::error::{RuntimeError, RuntimeResult};
use crate::sampling::{Sampler, SamplingParams};
pub struct SamplerBuilder {
temperature: f32,
top_k: usize,
top_p: f32,
repetition_penalty: f32,
seed: u64,
}
impl SamplerBuilder {
pub fn new() -> Self {
Self {
temperature: 0.7,
top_k: 40,
top_p: 0.9,
repetition_penalty: 1.1,
seed: 42,
}
}
pub fn temperature(mut self, t: f32) -> Self {
self.temperature = t;
self
}
pub fn top_k(mut self, k: usize) -> Self {
self.top_k = k;
self
}
pub fn top_p(mut self, p: f32) -> Self {
self.top_p = p;
self
}
pub fn repetition_penalty(mut self, rp: f32) -> Self {
self.repetition_penalty = rp;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = s;
self
}
pub fn build(self) -> RuntimeResult<Sampler> {
if self.temperature < 0.0 {
return Err(RuntimeError::Config(format!(
"temperature must be >= 0.0, got {}",
self.temperature
)));
}
if self.top_p < 0.0 || self.top_p > 1.0 {
return Err(RuntimeError::Config(format!(
"top_p must be in [0.0, 1.0], got {}",
self.top_p
)));
}
if self.repetition_penalty < 1.0 {
return Err(RuntimeError::Config(format!(
"repetition_penalty must be >= 1.0, got {}",
self.repetition_penalty
)));
}
let params = SamplingParams {
temperature: self.temperature,
top_k: self.top_k,
top_p: self.top_p,
repetition_penalty: self.repetition_penalty,
max_tokens: SamplingParams::default().max_tokens,
};
Ok(Sampler::new(params, self.seed))
}
}
impl Default for SamplerBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ConfigBuilder {
config: OxiBonsaiConfig,
}
impl ConfigBuilder {
pub fn new() -> Self {
Self {
config: OxiBonsaiConfig::default(),
}
}
pub fn model_path(mut self, path: impl Into<String>) -> Self {
self.config.model.model_path = Some(path.into());
self
}
pub fn tokenizer_path(mut self, path: impl Into<String>) -> Self {
self.config.model.tokenizer_path = Some(path.into());
self
}
pub fn max_seq_len(mut self, len: usize) -> Self {
self.config.model.max_seq_len = len;
self
}
pub fn host(mut self, h: impl Into<String>) -> Self {
self.config.server.host = h.into();
self
}
pub fn port(mut self, p: u16) -> Self {
self.config.server.port = p;
self
}
pub fn log_level(mut self, level: impl Into<String>) -> Self {
self.config.observability.log_level = level.into();
self
}
pub fn json_logs(mut self, enabled: bool) -> Self {
self.config.observability.json_logs = enabled;
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.config.sampling.temperature = t;
self
}
pub fn top_k(mut self, k: usize) -> Self {
self.config.sampling.top_k = k;
self
}
pub fn top_p(mut self, p: f32) -> Self {
self.config.sampling.top_p = p;
self
}
pub fn repetition_penalty(mut self, rp: f32) -> Self {
self.config.sampling.repetition_penalty = rp;
self
}
pub fn max_tokens(mut self, n: usize) -> Self {
self.config.sampling.max_tokens = n;
self
}
pub fn build(self) -> RuntimeResult<OxiBonsaiConfig> {
self.config.validate()?;
Ok(self.config)
}
}
impl Default for ConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct EngineBuilder {
config: Option<OxiBonsaiConfig>,
sampler: Option<SamplerBuilder>,
kernel_tier: Option<String>,
}
impl EngineBuilder {
pub fn new() -> Self {
Self {
config: None,
sampler: None,
kernel_tier: None,
}
}
pub fn config(mut self, config: OxiBonsaiConfig) -> Self {
self.config = Some(config);
self
}
pub fn config_file(mut self, path: &str) -> RuntimeResult<Self> {
let config = OxiBonsaiConfig::load(std::path::Path::new(path))?;
self.config = Some(config);
Ok(self)
}
pub fn sampler(mut self, builder: SamplerBuilder) -> Self {
self.sampler = Some(builder);
self
}
pub fn kernel_tier(mut self, tier: &str) -> Self {
self.kernel_tier = Some(tier.to_string());
self
}
pub fn configured_kernel_tier(&self) -> Option<&str> {
self.kernel_tier.as_deref()
}
pub fn build(self) -> RuntimeResult<(OxiBonsaiConfig, Sampler)> {
let config = self.config.unwrap_or_default();
config.validate()?;
let sampler = match self.sampler {
Some(builder) => builder.build()?,
None => {
SamplerBuilder::new()
.temperature(config.sampling.temperature)
.top_k(config.sampling.top_k)
.top_p(config.sampling.top_p)
.repetition_penalty(config.sampling.repetition_penalty)
.build()?
}
};
Ok((config, sampler))
}
}
impl Default for EngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sampler_builder_defaults() {
let sampler = SamplerBuilder::new().build();
assert!(sampler.is_ok());
let sampler = sampler.expect("default build should succeed");
let params = sampler.params();
assert!((params.temperature - 0.7).abs() < f32::EPSILON);
assert_eq!(params.top_k, 40);
assert!((params.top_p - 0.9).abs() < f32::EPSILON);
assert!((params.repetition_penalty - 1.1).abs() < f32::EPSILON);
}
#[test]
fn sampler_builder_chain() {
let sampler = SamplerBuilder::new()
.temperature(0.5)
.top_k(50)
.top_p(0.95)
.repetition_penalty(1.2)
.seed(123)
.build();
assert!(sampler.is_ok());
let sampler = sampler.expect("chained build should succeed");
let params = sampler.params();
assert!((params.temperature - 0.5).abs() < f32::EPSILON);
assert_eq!(params.top_k, 50);
assert!((params.top_p - 0.95).abs() < f32::EPSILON);
assert!((params.repetition_penalty - 1.2).abs() < f32::EPSILON);
}
#[test]
fn sampler_builder_negative_temperature() {
let result = SamplerBuilder::new().temperature(-0.1).build();
assert!(result.is_err());
let err = result.expect_err("negative temperature should fail");
assert!(err.to_string().contains("temperature"));
}
#[test]
fn sampler_builder_invalid_top_p_high() {
let result = SamplerBuilder::new().top_p(1.5).build();
assert!(result.is_err());
let err = result.expect_err("top_p > 1 should fail");
assert!(err.to_string().contains("top_p"));
}
#[test]
fn sampler_builder_invalid_top_p_low() {
let result = SamplerBuilder::new().top_p(-0.1).build();
assert!(result.is_err());
}
#[test]
fn sampler_builder_invalid_repetition_penalty() {
let result = SamplerBuilder::new().repetition_penalty(0.5).build();
assert!(result.is_err());
let err = result.expect_err("rep_pen < 1 should fail");
assert!(err.to_string().contains("repetition_penalty"));
}
#[test]
fn sampler_builder_zero_temperature() {
let result = SamplerBuilder::new().temperature(0.0).build();
assert!(result.is_ok());
}
#[test]
fn sampler_builder_boundary_top_p() {
assert!(SamplerBuilder::new().top_p(0.0).build().is_ok());
assert!(SamplerBuilder::new().top_p(1.0).build().is_ok());
}
#[test]
fn sampler_builder_default_trait() {
let builder = SamplerBuilder::default();
assert!(builder.build().is_ok());
}
#[test]
fn config_builder_defaults() {
let config = ConfigBuilder::new().build();
assert!(config.is_ok());
let config = config.expect("default build should succeed");
assert_eq!(config.server.host, "0.0.0.0");
assert_eq!(config.server.port, 8080);
assert_eq!(config.model.max_seq_len, 4096);
}
#[test]
fn config_builder_chain() {
let model_path = std::env::temp_dir().join("model.gguf");
let tokenizer_path = std::env::temp_dir().join("tokenizer.json");
let config = ConfigBuilder::new()
.model_path(model_path.display().to_string())
.tokenizer_path(tokenizer_path.display().to_string())
.max_seq_len(8192)
.host("127.0.0.1")
.port(3000)
.log_level("debug")
.json_logs(true)
.temperature(0.5)
.top_k(50)
.top_p(0.95)
.repetition_penalty(1.2)
.max_tokens(1024)
.build();
assert!(config.is_ok());
let config = config.expect("chained build should succeed");
assert_eq!(
config.model.model_path.as_deref(),
Some(model_path.to_str().expect("path is valid UTF-8"))
);
assert_eq!(
config.model.tokenizer_path.as_deref(),
Some(tokenizer_path.to_str().expect("path is valid UTF-8"))
);
assert_eq!(config.model.max_seq_len, 8192);
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 3000);
assert_eq!(config.observability.log_level, "debug");
assert!(config.observability.json_logs);
assert!((config.sampling.temperature - 0.5).abs() < f32::EPSILON);
assert_eq!(config.sampling.top_k, 50);
assert_eq!(config.sampling.max_tokens, 1024);
}
#[test]
fn config_builder_invalid_temperature() {
let result = ConfigBuilder::new().temperature(-1.0).build();
assert!(result.is_err());
}
#[test]
fn config_builder_invalid_top_p() {
let result = ConfigBuilder::new().top_p(2.0).build();
assert!(result.is_err());
}
#[test]
fn config_builder_invalid_max_seq_len() {
let result = ConfigBuilder::new().max_seq_len(0).build();
assert!(result.is_err());
}
#[test]
fn config_builder_default_trait() {
let builder = ConfigBuilder::default();
assert!(builder.build().is_ok());
}
#[test]
fn engine_builder_defaults() {
let result = EngineBuilder::new().build();
assert!(result.is_ok());
let (config, _sampler) = result.expect("default build should succeed");
assert_eq!(config.server.port, 8080);
}
#[test]
fn engine_builder_with_config() {
let config = ConfigBuilder::new()
.port(9090)
.build()
.expect("config build should succeed");
let result = EngineBuilder::new().config(config).build();
assert!(result.is_ok());
let (config, _sampler) = result.expect("build with config should succeed");
assert_eq!(config.server.port, 9090);
}
#[test]
fn engine_builder_with_sampler() {
let sampler_builder = SamplerBuilder::new().temperature(0.3).seed(99);
let result = EngineBuilder::new().sampler(sampler_builder).build();
assert!(result.is_ok());
let (_config, sampler) = result.expect("build with sampler should succeed");
assert!((sampler.params().temperature - 0.3).abs() < f32::EPSILON);
}
#[test]
fn engine_builder_with_kernel_tier() {
let builder = EngineBuilder::new().kernel_tier("reference");
assert_eq!(builder.configured_kernel_tier(), Some("reference"));
let result = builder.build();
assert!(result.is_ok());
}
#[test]
fn engine_builder_invalid_sampler() {
let sampler_builder = SamplerBuilder::new().temperature(-1.0);
let result = EngineBuilder::new().sampler(sampler_builder).build();
assert!(result.is_err());
}
#[test]
fn engine_builder_config_file_nonexistent() {
let path = std::env::temp_dir().join("nonexistent_oxibonsai_test_12345.toml");
let result = EngineBuilder::new().config_file(path.to_str().expect("path is valid UTF-8"));
assert!(result.is_err());
}
#[test]
fn engine_builder_config_file_valid() {
let dir = std::env::temp_dir();
let path = dir.join("oxibonsai_builder_test.toml");
std::fs::write(
&path,
r#"
[server]
port = 7777
"#,
)
.expect("write temp config");
let path_str = path.to_string_lossy().to_string();
let result = EngineBuilder::new()
.config_file(&path_str)
.expect("should load config file")
.build();
assert!(result.is_ok());
let (config, _) = result.expect("build should succeed");
assert_eq!(config.server.port, 7777);
let _ = std::fs::remove_file(&path);
}
#[test]
fn engine_builder_default_trait() {
let builder = EngineBuilder::default();
assert!(builder.build().is_ok());
}
}