openai_ergonomic/builders/
embeddings.rs1use openai_client_base::models::{
8 create_embedding_request::EncodingFormat, CreateEmbeddingRequest, CreateEmbeddingRequestInput,
9};
10
11use crate::{Builder, Error, Result};
12
13#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum EmbeddingInput {
16 Text(String),
18 TextArray(Vec<String>),
20 Tokens(Vec<i32>),
22 TokensBatch(Vec<Vec<i32>>),
24}
25
26impl EmbeddingInput {
27 fn into_request_input(self) -> CreateEmbeddingRequestInput {
28 match self {
29 Self::Text(value) => CreateEmbeddingRequestInput::new_text(value),
30 Self::TextArray(values) => CreateEmbeddingRequestInput::new_arrayofstrings(values),
31 Self::Tokens(values) => CreateEmbeddingRequestInput::new_arrayofintegers(values),
32 Self::TokensBatch(values) => {
33 CreateEmbeddingRequestInput::new_arrayofintegerarrays(values)
34 }
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
56pub struct EmbeddingsBuilder {
57 model: String,
58 input: Option<EmbeddingInput>,
59 encoding_format: Option<EncodingFormat>,
60 dimensions: Option<i32>,
61 user: Option<String>,
62}
63
64impl EmbeddingsBuilder {
65 #[must_use]
67 pub fn new(model: impl Into<String>) -> Self {
68 Self {
69 model: model.into(),
70 input: None,
71 encoding_format: None,
72 dimensions: None,
73 user: None,
74 }
75 }
76
77 #[must_use]
79 pub fn input(mut self, input: EmbeddingInput) -> Self {
80 self.input = Some(input);
81 self
82 }
83
84 #[must_use]
86 pub fn input_text(mut self, text: impl Into<String>) -> Self {
87 self.input = Some(EmbeddingInput::Text(text.into()));
88 self
89 }
90
91 #[must_use]
93 pub fn input_texts<I, S>(mut self, texts: I) -> Self
94 where
95 I: IntoIterator<Item = S>,
96 S: Into<String>,
97 {
98 let collected = texts.into_iter().map(Into::into).collect();
99 self.input = Some(EmbeddingInput::TextArray(collected));
100 self
101 }
102
103 #[must_use]
105 pub fn input_tokens<I>(mut self, tokens: I) -> Self
106 where
107 I: IntoIterator<Item = i32>,
108 {
109 self.input = Some(EmbeddingInput::Tokens(tokens.into_iter().collect()));
110 self
111 }
112
113 #[must_use]
115 pub fn input_token_batches<I, J>(mut self, batches: I) -> Self
116 where
117 I: IntoIterator<Item = J>,
118 J: IntoIterator<Item = i32>,
119 {
120 let collected = batches
121 .into_iter()
122 .map(|batch| batch.into_iter().collect())
123 .collect();
124 self.input = Some(EmbeddingInput::TokensBatch(collected));
125 self
126 }
127
128 #[must_use]
130 pub fn encoding_format(mut self, format: EncodingFormat) -> Self {
131 self.encoding_format = Some(format);
132 self
133 }
134
135 #[must_use]
137 pub fn dimensions(mut self, dimensions: i32) -> Self {
138 self.dimensions = Some(dimensions);
139 self
140 }
141
142 #[must_use]
144 pub fn user(mut self, user: impl Into<String>) -> Self {
145 self.user = Some(user.into());
146 self
147 }
148
149 #[must_use]
151 pub fn model(&self) -> &str {
152 &self.model
153 }
154
155 #[must_use]
157 pub fn input_ref(&self) -> Option<&EmbeddingInput> {
158 self.input.as_ref()
159 }
160
161 #[must_use]
163 pub fn encoding_format_ref(&self) -> Option<EncodingFormat> {
164 self.encoding_format
165 }
166
167 #[must_use]
169 pub fn dimensions_ref(&self) -> Option<i32> {
170 self.dimensions
171 }
172
173 #[must_use]
175 pub fn user_ref(&self) -> Option<&str> {
176 self.user.as_deref()
177 }
178
179 fn validate(&self) -> Result<()> {
180 if let Some(dimensions) = self.dimensions {
181 if dimensions <= 0 {
182 return Err(Error::InvalidRequest(
183 "Embedding dimensions must be positive".to_string(),
184 ));
185 }
186 }
187 Ok(())
188 }
189}
190
191impl Builder<CreateEmbeddingRequest> for EmbeddingsBuilder {
192 fn build(self) -> Result<CreateEmbeddingRequest> {
193 self.validate()?;
194
195 let Self {
196 model,
197 input,
198 encoding_format,
199 dimensions,
200 user,
201 } = self;
202
203 let request_input = input
204 .ok_or_else(|| Error::InvalidRequest("Embeddings input is required".to_string()))?
205 .into_request_input();
206
207 let mut request = CreateEmbeddingRequest::new(request_input, model);
208 request.encoding_format = encoding_format;
209 request.dimensions = dimensions;
210 request.user = user;
211
212 Ok(request)
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn builds_text_input_request() {
222 let builder = EmbeddingsBuilder::new("text-embedding-3-small").input_text("hello world");
223 let request = builder.build().expect("builder should succeed");
224
225 assert_eq!(request.model, "text-embedding-3-small");
226 assert!(matches!(
227 *request.input,
228 CreateEmbeddingRequestInput::String(ref value) if value == "hello world"
229 ));
230 }
231
232 #[test]
233 fn builds_multiple_texts_request() {
234 let builder =
235 EmbeddingsBuilder::new("text-embedding-3-large").input_texts(["foo", "bar", "baz"]);
236 let request = builder.build().expect("builder should succeed");
237
238 match *request.input {
239 CreateEmbeddingRequestInput::ArrayOfStrings(values) => {
240 assert_eq!(values, vec!["foo", "bar", "baz"]);
241 }
242 other => panic!("unexpected input variant: {other:?}"),
243 }
244 }
245
246 #[test]
247 fn builds_token_batch_request() {
248 let builder = EmbeddingsBuilder::new("text-embedding-3-small")
249 .input_token_batches([vec![1, 2, 3], vec![4, 5, 6]]);
250 let request = builder.build().expect("builder should succeed");
251
252 match *request.input {
253 CreateEmbeddingRequestInput::ArrayOfIntegerArrays(values) => {
254 assert_eq!(values, vec![vec![1, 2, 3], vec![4, 5, 6]]);
255 }
256 other => panic!("unexpected input variant: {other:?}"),
257 }
258 }
259
260 #[test]
261 fn validates_dimensions_positive() {
262 let builder = EmbeddingsBuilder::new("text-embedding-3-small")
263 .input_text("test")
264 .dimensions(0);
265 let error = builder.build().expect_err("dimensions should be validated");
266 assert!(matches!(error, Error::InvalidRequest(message) if message.contains("positive")));
267 }
268
269 #[test]
270 fn requires_input() {
271 let builder = EmbeddingsBuilder::new("text-embedding-3-small");
272 let error = builder.build().expect_err("input is required");
273 assert!(matches!(error, Error::InvalidRequest(message) if message.contains("input")));
274 }
275
276 #[test]
277 fn propagates_encoding_and_user() {
278 let builder = EmbeddingsBuilder::new("text-embedding-3-small")
279 .input_text("hello")
280 .encoding_format(EncodingFormat::Base64)
281 .dimensions(512)
282 .user("user-123");
283 let request = builder.build().expect("builder should succeed");
284
285 assert_eq!(request.encoding_format, Some(EncodingFormat::Base64));
286 assert_eq!(request.dimensions, Some(512));
287 assert_eq!(request.user.as_deref(), Some("user-123"));
288 }
289}