dynamo_async_openai/
embedding.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use crate::{
12    Client,
13    config::Config,
14    error::OpenAIError,
15    types::{CreateBase64EmbeddingResponse, CreateEmbeddingRequest, CreateEmbeddingResponse},
16};
17
18#[cfg(not(feature = "byot"))]
19use crate::types::EncodingFormat;
20
21/// Get a vector representation of a given input that can be easily
22/// consumed by machine learning models and algorithms.
23///
24/// Related guide: [Embeddings](https://platform.openai.com/docs/guides/embeddings/what-are-embeddings)
25pub struct Embeddings<'c, C: Config> {
26    client: &'c Client<C>,
27}
28
29impl<'c, C: Config> Embeddings<'c, C> {
30    pub fn new(client: &'c Client<C>) -> Self {
31        Self { client }
32    }
33
34    /// Creates an embedding vector representing the input text.
35    ///
36    /// byot: In serialized `request` you must ensure "encoding_format" is not "base64"
37    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
38    pub async fn create(
39        &self,
40        request: CreateEmbeddingRequest,
41    ) -> Result<CreateEmbeddingResponse, OpenAIError> {
42        #[cfg(not(feature = "byot"))]
43        {
44            if matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
45                return Err(OpenAIError::InvalidArgument(
46                    "When encoding_format is base64, use Embeddings::create_base64".into(),
47                ));
48            }
49        }
50        self.client.post("/embeddings", request).await
51    }
52
53    /// Creates an embedding vector representing the input text.
54    ///
55    /// The response will contain the embedding in base64 format.
56    ///
57    /// byot: In serialized `request` you must ensure "encoding_format" is "base64"
58    #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
59    pub async fn create_base64(
60        &self,
61        request: CreateEmbeddingRequest,
62    ) -> Result<CreateBase64EmbeddingResponse, OpenAIError> {
63        #[cfg(not(feature = "byot"))]
64        {
65            if !matches!(request.encoding_format, Some(EncodingFormat::Base64)) {
66                return Err(OpenAIError::InvalidArgument(
67                    "When encoding_format is not base64, use Embeddings::create".into(),
68                ));
69            }
70        }
71        self.client.post("/embeddings", request).await
72    }
73}