use crate::{Model, ModelParams, MullamaError};
use std::sync::Arc;
#[cfg(feature = "async")]
use crate::async_support::AsyncModel;
#[derive(Debug, Clone)]
pub struct ModelBuilder {
path: Option<String>,
gpu_layers: i32,
context_size: Option<u32>,
use_mmap: bool,
use_mlock: bool,
check_tensors: bool,
vocab_only: bool,
}
impl ModelBuilder {
pub fn new() -> Self {
Self {
path: None,
gpu_layers: 0,
context_size: None,
use_mmap: true,
use_mlock: false,
check_tensors: true,
vocab_only: false,
}
}
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
pub fn gpu_layers(mut self, layers: i32) -> Self {
self.gpu_layers = layers;
self
}
pub fn context_size(mut self, size: u32) -> Self {
self.context_size = Some(size);
self
}
pub fn memory_mapping(mut self, enable: bool) -> Self {
self.use_mmap = enable;
self
}
pub fn memory_locking(mut self, enable: bool) -> Self {
self.use_mlock = enable;
self
}
pub fn tensor_validation(mut self, enable: bool) -> Self {
self.check_tensors = enable;
self
}
pub fn vocabulary_only(mut self, vocab_only: bool) -> Self {
self.vocab_only = vocab_only;
self
}
pub fn preset<F>(self, preset: F) -> Self
where
F: FnOnce(Self) -> Self,
{
preset(self)
}
pub fn build(self) -> Result<Arc<Model>, MullamaError> {
let path = self
.path
.ok_or_else(|| MullamaError::ConfigError("Model path is required".to_string()))?;
let params = ModelParams {
n_gpu_layers: self.gpu_layers,
use_mmap: self.use_mmap,
use_mlock: self.use_mlock,
check_tensors: self.check_tensors,
vocab_only: self.vocab_only,
..Default::default()
};
let model = Model::load_with_params(&path, params)?;
Ok(Arc::new(model))
}
#[cfg(feature = "async")]
pub async fn build_async(self) -> Result<AsyncModel, MullamaError> {
let path = self
.path
.ok_or_else(|| MullamaError::ConfigError("Model path is required".to_string()))?;
let params = ModelParams {
n_gpu_layers: self.gpu_layers,
use_mmap: self.use_mmap,
use_mlock: self.use_mlock,
check_tensors: self.check_tensors,
vocab_only: self.vocab_only,
..Default::default()
};
AsyncModel::load_with_params(path, params).await
}
}
impl Default for ModelBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_builder() {
let builder = ModelBuilder::new()
.path("test.gguf")
.gpu_layers(16)
.context_size(2048);
assert_eq!(builder.path, Some("test.gguf".to_string()));
assert_eq!(builder.gpu_layers, 16);
assert_eq!(builder.context_size, Some(2048));
}
}