1use crate::task::TaskType;
4
5#[derive(Debug, Clone)]
7pub struct EmbeddingConfig {
8 pub model_id: String,
10 pub dimensions: usize,
12 pub batch_size: usize,
14 pub max_input_length: usize,
16 pub normalize: bool,
18 pub task_type: TaskType,
20}
21
22impl Default for EmbeddingConfig {
23 fn default() -> Self {
24 Self {
25 model_id: "ggml-org/embeddinggemma-300M-GGUF".to_owned(),
26 dimensions: 768,
27 batch_size: 32,
28 max_input_length: 2048,
29 normalize: true,
30 task_type: TaskType::default(),
31 }
32 }
33}
34
35#[derive(Debug)]
37#[must_use]
38pub struct EmbeddingConfigBuilder {
39 config: EmbeddingConfig,
40}
41
42impl EmbeddingConfig {
43 pub fn builder() -> EmbeddingConfigBuilder {
45 EmbeddingConfigBuilder {
46 config: Self::default(),
47 }
48 }
49}
50
51impl EmbeddingConfigBuilder {
52 pub fn model_id(mut self, model_id: impl Into<String>) -> Self {
53 self.config.model_id = model_id.into();
54 self
55 }
56
57 pub fn dimensions(mut self, dimensions: usize) -> Self {
58 self.config.dimensions = dimensions;
59 self
60 }
61
62 pub fn batch_size(mut self, batch_size: usize) -> Self {
63 self.config.batch_size = batch_size;
64 self
65 }
66
67 pub fn max_input_length(mut self, max_input_length: usize) -> Self {
68 self.config.max_input_length = max_input_length;
69 self
70 }
71
72 pub fn normalize(mut self, normalize: bool) -> Self {
73 self.config.normalize = normalize;
74 self
75 }
76
77 pub fn task_type(mut self, task_type: TaskType) -> Self {
78 self.config.task_type = task_type;
79 self
80 }
81
82 pub fn build(self) -> EmbeddingConfig {
83 self.config
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90
91 #[test]
92 fn default_config_has_sensible_values() {
93 let config = EmbeddingConfig::default();
94 assert_eq!(config.model_id, "ggml-org/embeddinggemma-300M-GGUF");
95 assert_eq!(config.dimensions, 768);
96 assert_eq!(config.batch_size, 32);
97 assert_eq!(config.max_input_length, 2048);
98 assert!(config.normalize);
99 assert_eq!(config.task_type, TaskType::SearchResult);
100 }
101
102 #[test]
103 fn builder_sets_dimensions() {
104 let config = EmbeddingConfig::builder().dimensions(768).build();
105 assert_eq!(config.dimensions, 768);
106 }
107
108 #[test]
109 fn builder_sets_batch_size() {
110 let config = EmbeddingConfig::builder().batch_size(64).build();
111 assert_eq!(config.batch_size, 64);
112 }
113
114 #[test]
115 fn builder_sets_max_input_length() {
116 let config = EmbeddingConfig::builder().max_input_length(1024).build();
117 assert_eq!(config.max_input_length, 1024);
118 }
119
120 #[test]
121 fn builder_sets_normalize() {
122 let config = EmbeddingConfig::builder().normalize(false).build();
123 assert!(!config.normalize);
124 }
125
126 #[test]
127 fn builder_sets_model_id() {
128 let config = EmbeddingConfig::builder().model_id("custom/model").build();
129 assert_eq!(config.model_id, "custom/model");
130 }
131
132 #[test]
133 fn builder_sets_task_type() {
134 let config = EmbeddingConfig::builder()
135 .task_type(TaskType::CodeRetrieval)
136 .build();
137 assert_eq!(config.task_type, TaskType::CodeRetrieval);
138 }
139
140 #[test]
141 fn builder_chains_all_fields() {
142 let config = EmbeddingConfig::builder()
143 .model_id("custom/model")
144 .dimensions(512)
145 .batch_size(16)
146 .max_input_length(256)
147 .normalize(false)
148 .task_type(TaskType::Clustering)
149 .build();
150 assert_eq!(config.model_id, "custom/model");
151 assert_eq!(config.dimensions, 512);
152 assert_eq!(config.batch_size, 16);
153 assert_eq!(config.max_input_length, 256);
154 assert!(!config.normalize);
155 assert_eq!(config.task_type, TaskType::Clustering);
156 }
157}