openai_ergonomic/builders/
embeddings.rs

1//! Embeddings API builders.
2//!
3//! Provides high-level builders for creating `OpenAI` embeddings requests
4//! covering text inputs, tokenized inputs, and configuration options such as
5//! encoding format and dimensionality.
6
7use openai_client_base::models::{
8    create_embedding_request::EncodingFormat, CreateEmbeddingRequest, CreateEmbeddingRequestInput,
9};
10
11use crate::{Builder, Error, Result};
12
13/// Types of input supported by the embeddings endpoint.
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum EmbeddingInput {
16    /// A single string to embed.
17    Text(String),
18    /// Multiple strings to embed in one request.
19    TextArray(Vec<String>),
20    /// A single tokenized input represented as integers.
21    Tokens(Vec<i32>),
22    /// Multiple tokenized inputs.
23    TokensBatch(Vec<Vec<i32>>),
24}
25
26impl EmbeddingInput {
27    fn into_request_input(self) -> CreateEmbeddingRequestInput {
28        match self {
29            Self::Text(value) => CreateEmbeddingRequestInput::new_text(value),
30            Self::TextArray(values) => CreateEmbeddingRequestInput::new_arrayofstrings(values),
31            Self::Tokens(values) => CreateEmbeddingRequestInput::new_arrayofintegers(values),
32            Self::TokensBatch(values) => {
33                CreateEmbeddingRequestInput::new_arrayofintegerarrays(values)
34            }
35        }
36    }
37}
38
39/// Builder for creating embedding requests.
40///
41/// # Examples
42///
43/// ```rust
44/// use openai_ergonomic::{Builder, EmbeddingsBuilder};
45///
46/// let request = EmbeddingsBuilder::new("text-embedding-3-small")
47///     .input_text("hello world")
48///     .dimensions(256)
49///     .build()
50///     .unwrap();
51///
52/// assert_eq!(request.model, "text-embedding-3-small");
53/// assert_eq!(request.dimensions, Some(256));
54/// ```
55#[derive(Debug, Clone)]
56pub struct EmbeddingsBuilder {
57    model: String,
58    input: Option<EmbeddingInput>,
59    encoding_format: Option<EncodingFormat>,
60    dimensions: Option<i32>,
61    user: Option<String>,
62}
63
64impl EmbeddingsBuilder {
65    /// Create a new embeddings builder for the specified model.
66    #[must_use]
67    pub fn new(model: impl Into<String>) -> Self {
68        Self {
69            model: model.into(),
70            input: None,
71            encoding_format: None,
72            dimensions: None,
73            user: None,
74        }
75    }
76
77    /// Provide the request input explicitly.
78    #[must_use]
79    pub fn input(mut self, input: EmbeddingInput) -> Self {
80        self.input = Some(input);
81        self
82    }
83
84    /// Embed a single string input.
85    #[must_use]
86    pub fn input_text(mut self, text: impl Into<String>) -> Self {
87        self.input = Some(EmbeddingInput::Text(text.into()));
88        self
89    }
90
91    /// Embed multiple string inputs in one request.
92    #[must_use]
93    pub fn input_texts<I, S>(mut self, texts: I) -> Self
94    where
95        I: IntoIterator<Item = S>,
96        S: Into<String>,
97    {
98        let collected = texts.into_iter().map(Into::into).collect();
99        self.input = Some(EmbeddingInput::TextArray(collected));
100        self
101    }
102
103    /// Embed a single tokenized input.
104    #[must_use]
105    pub fn input_tokens<I>(mut self, tokens: I) -> Self
106    where
107        I: IntoIterator<Item = i32>,
108    {
109        self.input = Some(EmbeddingInput::Tokens(tokens.into_iter().collect()));
110        self
111    }
112
113    /// Embed multiple tokenized inputs.
114    #[must_use]
115    pub fn input_token_batches<I, J>(mut self, batches: I) -> Self
116    where
117        I: IntoIterator<Item = J>,
118        J: IntoIterator<Item = i32>,
119    {
120        let collected = batches
121            .into_iter()
122            .map(|batch| batch.into_iter().collect())
123            .collect();
124        self.input = Some(EmbeddingInput::TokensBatch(collected));
125        self
126    }
127
128    /// Set the encoding format for the embeddings response.
129    #[must_use]
130    pub fn encoding_format(mut self, format: EncodingFormat) -> Self {
131        self.encoding_format = Some(format);
132        self
133    }
134
135    /// Set the output dimensions for supported models.
136    #[must_use]
137    pub fn dimensions(mut self, dimensions: i32) -> Self {
138        self.dimensions = Some(dimensions);
139        self
140    }
141
142    /// Associate a user identifier with the request.
143    #[must_use]
144    pub fn user(mut self, user: impl Into<String>) -> Self {
145        self.user = Some(user.into());
146        self
147    }
148
149    /// Access the configured model name.
150    #[must_use]
151    pub fn model(&self) -> &str {
152        &self.model
153    }
154
155    /// Access the configured input, if set.
156    #[must_use]
157    pub fn input_ref(&self) -> Option<&EmbeddingInput> {
158        self.input.as_ref()
159    }
160
161    /// Access the configured encoding format, if set.
162    #[must_use]
163    pub fn encoding_format_ref(&self) -> Option<EncodingFormat> {
164        self.encoding_format
165    }
166
167    /// Access the configured dimensions, if set.
168    #[must_use]
169    pub fn dimensions_ref(&self) -> Option<i32> {
170        self.dimensions
171    }
172
173    /// Access the configured user identifier, if set.
174    #[must_use]
175    pub fn user_ref(&self) -> Option<&str> {
176        self.user.as_deref()
177    }
178
179    fn validate(&self) -> Result<()> {
180        if let Some(dimensions) = self.dimensions {
181            if dimensions <= 0 {
182                return Err(Error::InvalidRequest(
183                    "Embedding dimensions must be positive".to_string(),
184                ));
185            }
186        }
187        Ok(())
188    }
189}
190
191impl Builder<CreateEmbeddingRequest> for EmbeddingsBuilder {
192    fn build(self) -> Result<CreateEmbeddingRequest> {
193        self.validate()?;
194
195        let Self {
196            model,
197            input,
198            encoding_format,
199            dimensions,
200            user,
201        } = self;
202
203        let request_input = input
204            .ok_or_else(|| Error::InvalidRequest("Embeddings input is required".to_string()))?
205            .into_request_input();
206
207        let mut request = CreateEmbeddingRequest::new(request_input, model);
208        request.encoding_format = encoding_format;
209        request.dimensions = dimensions;
210        request.user = user;
211
212        Ok(request)
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn builds_text_input_request() {
222        let builder = EmbeddingsBuilder::new("text-embedding-3-small").input_text("hello world");
223        let request = builder.build().expect("builder should succeed");
224
225        assert_eq!(request.model, "text-embedding-3-small");
226        assert!(matches!(
227            *request.input,
228            CreateEmbeddingRequestInput::String(ref value) if value == "hello world"
229        ));
230    }
231
232    #[test]
233    fn builds_multiple_texts_request() {
234        let builder =
235            EmbeddingsBuilder::new("text-embedding-3-large").input_texts(["foo", "bar", "baz"]);
236        let request = builder.build().expect("builder should succeed");
237
238        match *request.input {
239            CreateEmbeddingRequestInput::ArrayOfStrings(values) => {
240                assert_eq!(values, vec!["foo", "bar", "baz"]);
241            }
242            other => panic!("unexpected input variant: {other:?}"),
243        }
244    }
245
246    #[test]
247    fn builds_token_batch_request() {
248        let builder = EmbeddingsBuilder::new("text-embedding-3-small")
249            .input_token_batches([vec![1, 2, 3], vec![4, 5, 6]]);
250        let request = builder.build().expect("builder should succeed");
251
252        match *request.input {
253            CreateEmbeddingRequestInput::ArrayOfIntegerArrays(values) => {
254                assert_eq!(values, vec![vec![1, 2, 3], vec![4, 5, 6]]);
255            }
256            other => panic!("unexpected input variant: {other:?}"),
257        }
258    }
259
260    #[test]
261    fn validates_dimensions_positive() {
262        let builder = EmbeddingsBuilder::new("text-embedding-3-small")
263            .input_text("test")
264            .dimensions(0);
265        let error = builder.build().expect_err("dimensions should be validated");
266        assert!(matches!(error, Error::InvalidRequest(message) if message.contains("positive")));
267    }
268
269    #[test]
270    fn requires_input() {
271        let builder = EmbeddingsBuilder::new("text-embedding-3-small");
272        let error = builder.build().expect_err("input is required");
273        assert!(matches!(error, Error::InvalidRequest(message) if message.contains("input")));
274    }
275
276    #[test]
277    fn propagates_encoding_and_user() {
278        let builder = EmbeddingsBuilder::new("text-embedding-3-small")
279            .input_text("hello")
280            .encoding_format(EncodingFormat::Base64)
281            .dimensions(512)
282            .user("user-123");
283        let request = builder.build().expect("builder should succeed");
284
285        assert_eq!(request.encoding_format, Some(EncodingFormat::Base64));
286        assert_eq!(request.dimensions, Some(512));
287        assert_eq!(request.user.as_deref(), Some("user-123"));
288    }
289}