dynamo_async_openai/types/
embedding.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use base64::engine::{Engine, general_purpose};
12use derive_builder::Builder;
13use serde::{Deserialize, Serialize};
14
15use crate::error::OpenAIError;
16
17#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
18#[serde(untagged)]
19pub enum EmbeddingInput {
20    String(String),
21    StringArray(Vec<String>),
22    // Minimum value is 0, maximum value is 100257 (inclusive).
23    IntegerArray(Vec<u32>),
24    ArrayOfIntegerArray(Vec<Vec<u32>>),
25}
26
27#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum EncodingFormat {
30    #[default]
31    Float,
32    Base64,
33}
34
35#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
36#[builder(name = "CreateEmbeddingRequestArgs")]
37#[builder(pattern = "mutable")]
38#[builder(setter(into, strip_option), default)]
39#[builder(derive(Debug))]
40#[builder(build_fn(error = "OpenAIError"))]
41pub struct CreateEmbeddingRequest {
42    /// ID of the model to use. You can use the
43    /// [List models](https://platform.openai.com/docs/api-reference/models/list)
44    /// API to see all of your available models, or see our
45    /// [Model overview](https://platform.openai.com/docs/models/overview)
46    /// for descriptions of them.
47    pub model: String,
48
49    ///  Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
50    pub input: EmbeddingInput,
51
52    /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). Defaults to float
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub encoding_format: Option<EncodingFormat>,
55
56    /// A unique identifier representing your end-user, which will help OpenAI
57    ///  to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub user: Option<String>,
60
61    /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
62    #[serde(skip_serializing_if = "Option::is_none")]
63    pub dimensions: Option<u32>,
64}
65
66/// Represents an embedding vector returned by embedding endpoint.
67#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
68pub struct Embedding {
69    /// The index of the embedding in the list of embeddings.
70    pub index: u32,
71    /// The object type, which is always "embedding".
72    pub object: String,
73    /// The embedding vector, which is a list of floats. The length of vector
74    /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
75    pub embedding: Vec<f32>,
76}
77
78#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
79pub struct Base64EmbeddingVector(pub String);
80
81impl From<Base64EmbeddingVector> for Vec<f32> {
82    fn from(value: Base64EmbeddingVector) -> Self {
83        let bytes = general_purpose::STANDARD
84            .decode(value.0)
85            .expect("openai base64 encoding to be valid");
86        let chunks = bytes.chunks_exact(4);
87        chunks
88            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
89            .collect()
90    }
91}
92
93/// Represents an base64-encoded embedding vector returned by embedding endpoint.
94#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
95pub struct Base64Embedding {
96    /// The index of the embedding in the list of embeddings.
97    pub index: u32,
98    /// The object type, which is always "embedding".
99    pub object: String,
100    /// The embedding vector, encoded in base64.
101    pub embedding: Base64EmbeddingVector,
102}
103
104#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
105pub struct EmbeddingUsage {
106    /// The number of tokens used by the prompt.
107    pub prompt_tokens: u32,
108    /// The total number of tokens used by the request.
109    pub total_tokens: u32,
110}
111
112#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
113pub struct CreateEmbeddingResponse {
114    pub object: String,
115    /// The name of the model used to generate the embedding.
116    pub model: String,
117    /// The list of embeddings generated by the model.
118    pub data: Vec<Embedding>,
119    /// The usage information for the request.
120    pub usage: EmbeddingUsage,
121}
122
123#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
124pub struct CreateBase64EmbeddingResponse {
125    pub object: String,
126    /// The name of the model used to generate the embedding.
127    pub model: String,
128    /// The list of embeddings generated by the model.
129    pub data: Vec<Base64Embedding>,
130    /// The usage information for the request.
131    pub usage: EmbeddingUsage,
132}