Skip to main content

erio_embedding/
config.rs

1//! Configuration for embedding engines.
2
3use crate::task::TaskType;
4
5/// Configuration for an embedding engine.
6#[derive(Debug, Clone)]
7pub struct EmbeddingConfig {
8    /// `HuggingFace` model identifier (e.g. "ggml-org/embeddinggemma-300M-GGUF").
9    pub model_id: String,
10    /// Output vector dimensions.
11    pub dimensions: usize,
12    /// Maximum number of texts to embed in a single batch.
13    pub batch_size: usize,
14    /// Maximum input text length in tokens.
15    pub max_input_length: usize,
16    /// Whether to L2-normalize output vectors to unit length.
17    pub normalize: bool,
18    /// Task type for prompt formatting.
19    pub task_type: TaskType,
20}
21
22impl Default for EmbeddingConfig {
23    fn default() -> Self {
24        Self {
25            model_id: "ggml-org/embeddinggemma-300M-GGUF".to_owned(),
26            dimensions: 768,
27            batch_size: 32,
28            max_input_length: 2048,
29            normalize: true,
30            task_type: TaskType::default(),
31        }
32    }
33}
34
35/// Builder for `EmbeddingConfig`.
36#[derive(Debug)]
37#[must_use]
38pub struct EmbeddingConfigBuilder {
39    config: EmbeddingConfig,
40}
41
42impl EmbeddingConfig {
43    /// Creates a builder with default values.
44    pub fn builder() -> EmbeddingConfigBuilder {
45        EmbeddingConfigBuilder {
46            config: Self::default(),
47        }
48    }
49}
50
51impl EmbeddingConfigBuilder {
52    pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
53        self.config.model_id = model_id.into();
54        self
55    }
56
57    pub fn dimensions(mut self, dimensions: usize) -> Self {
58        self.config.dimensions = dimensions;
59        self
60    }
61
62    pub fn batch_size(mut self, batch_size: usize) -> Self {
63        self.config.batch_size = batch_size;
64        self
65    }
66
67    pub fn max_input_length(mut self, max_input_length: usize) -> Self {
68        self.config.max_input_length = max_input_length;
69        self
70    }
71
72    pub fn normalize(mut self, normalize: bool) -> Self {
73        self.config.normalize = normalize;
74        self
75    }
76
77    pub fn task_type(mut self, task_type: TaskType) -> Self {
78        self.config.task_type = task_type;
79        self
80    }
81
82    pub fn build(self) -> EmbeddingConfig {
83        self.config
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90
91    #[test]
92    fn default_config_has_sensible_values() {
93        let config = EmbeddingConfig::default();
94        assert_eq!(config.model_id, "ggml-org/embeddinggemma-300M-GGUF");
95        assert_eq!(config.dimensions, 768);
96        assert_eq!(config.batch_size, 32);
97        assert_eq!(config.max_input_length, 2048);
98        assert!(config.normalize);
99        assert_eq!(config.task_type, TaskType::SearchResult);
100    }
101
102    #[test]
103    fn builder_sets_dimensions() {
104        let config = EmbeddingConfig::builder().dimensions(768).build();
105        assert_eq!(config.dimensions, 768);
106    }
107
108    #[test]
109    fn builder_sets_batch_size() {
110        let config = EmbeddingConfig::builder().batch_size(64).build();
111        assert_eq!(config.batch_size, 64);
112    }
113
114    #[test]
115    fn builder_sets_max_input_length() {
116        let config = EmbeddingConfig::builder().max_input_length(1024).build();
117        assert_eq!(config.max_input_length, 1024);
118    }
119
120    #[test]
121    fn builder_sets_normalize() {
122        let config = EmbeddingConfig::builder().normalize(false).build();
123        assert!(!config.normalize);
124    }
125
126    #[test]
127    fn builder_sets_model_id() {
128        let config = EmbeddingConfig::builder().model_id("custom/model").build();
129        assert_eq!(config.model_id, "custom/model");
130    }
131
132    #[test]
133    fn builder_sets_task_type() {
134        let config = EmbeddingConfig::builder()
135            .task_type(TaskType::CodeRetrieval)
136            .build();
137        assert_eq!(config.task_type, TaskType::CodeRetrieval);
138    }
139
140    #[test]
141    fn builder_chains_all_fields() {
142        let config = EmbeddingConfig::builder()
143            .model_id("custom/model")
144            .dimensions(512)
145            .batch_size(16)
146            .max_input_length(256)
147            .normalize(false)
148            .task_type(TaskType::Clustering)
149            .build();
150        assert_eq!(config.model_id, "custom/model");
151        assert_eq!(config.dimensions, 512);
152        assert_eq!(config.batch_size, 16);
153        assert_eq!(config.max_input_length, 256);
154        assert!(!config.normalize);
155        assert_eq!(config.task_type, TaskType::Clustering);
156    }
157}