use crate::task::TaskType;
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub model_id: String,
pub dimensions: usize,
pub batch_size: usize,
pub max_input_length: usize,
pub normalize: bool,
pub task_type: TaskType,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
model_id: "ggml-org/embeddinggemma-300M-GGUF".to_owned(),
dimensions: 768,
batch_size: 32,
max_input_length: 2048,
normalize: true,
task_type: TaskType::default(),
}
}
}
#[derive(Debug)]
#[must_use]
pub struct EmbeddingConfigBuilder {
config: EmbeddingConfig,
}
impl EmbeddingConfig {
pub fn builder() -> EmbeddingConfigBuilder {
EmbeddingConfigBuilder {
config: Self::default(),
}
}
}
impl EmbeddingConfigBuilder {
pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
self.config.model_id = model_id.into();
self
}
pub fn dimensions(mut self, dimensions: usize) -> Self {
self.config.dimensions = dimensions;
self
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.config.batch_size = batch_size;
self
}
pub fn max_input_length(mut self, max_input_length: usize) -> Self {
self.config.max_input_length = max_input_length;
self
}
pub fn normalize(mut self, normalize: bool) -> Self {
self.config.normalize = normalize;
self
}
pub fn task_type(mut self, task_type: TaskType) -> Self {
self.config.task_type = task_type;
self
}
pub fn build(self) -> EmbeddingConfig {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_has_sensible_values() {
let config = EmbeddingConfig::default();
assert_eq!(config.model_id, "ggml-org/embeddinggemma-300M-GGUF");
assert_eq!(config.dimensions, 768);
assert_eq!(config.batch_size, 32);
assert_eq!(config.max_input_length, 2048);
assert!(config.normalize);
assert_eq!(config.task_type, TaskType::SearchResult);
}
#[test]
fn builder_sets_dimensions() {
let config = EmbeddingConfig::builder().dimensions(768).build();
assert_eq!(config.dimensions, 768);
}
#[test]
fn builder_sets_batch_size() {
let config = EmbeddingConfig::builder().batch_size(64).build();
assert_eq!(config.batch_size, 64);
}
#[test]
fn builder_sets_max_input_length() {
let config = EmbeddingConfig::builder().max_input_length(1024).build();
assert_eq!(config.max_input_length, 1024);
}
#[test]
fn builder_sets_normalize() {
let config = EmbeddingConfig::builder().normalize(false).build();
assert!(!config.normalize);
}
#[test]
fn builder_sets_model_id() {
let config = EmbeddingConfig::builder().model_id("custom/model").build();
assert_eq!(config.model_id, "custom/model");
}
#[test]
fn builder_sets_task_type() {
let config = EmbeddingConfig::builder()
.task_type(TaskType::CodeRetrieval)
.build();
assert_eq!(config.task_type, TaskType::CodeRetrieval);
}
#[test]
fn builder_chains_all_fields() {
let config = EmbeddingConfig::builder()
.model_id("custom/model")
.dimensions(512)
.batch_size(16)
.max_input_length(256)
.normalize(false)
.task_type(TaskType::Clustering)
.build();
assert_eq!(config.model_id, "custom/model");
assert_eq!(config.dimensions, 512);
assert_eq!(config.batch_size, 16);
assert_eq!(config.max_input_length, 256);
assert!(!config.normalize);
assert_eq!(config.task_type, TaskType::Clustering);
}
}