async_openai/types/
embedding.rs

1use base64::engine::{general_purpose, Engine};
2use derive_builder::Builder;
3use serde::{Deserialize, Serialize};
4
5use crate::error::OpenAIError;
6
7#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
8#[serde(untagged)]
9pub enum EmbeddingInput {
10    String(String),
11    StringArray(Vec<String>),
12    // Minimum value is 0, maximum value is 100257 (inclusive).
13    IntegerArray(Vec<u32>),
14    ArrayOfIntegerArray(Vec<Vec<u32>>),
15}
16
17#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum EncodingFormat {
20    #[default]
21    Float,
22    Base64,
23}
24
25#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
26#[builder(name = "CreateEmbeddingRequestArgs")]
27#[builder(pattern = "mutable")]
28#[builder(setter(into, strip_option), default)]
29#[builder(derive(Debug))]
30#[builder(build_fn(error = "OpenAIError"))]
31pub struct CreateEmbeddingRequest {
32    /// ID of the model to use. You can use the
33    /// [List models](https://platform.openai.com/docs/api-reference/models/list)
34    /// API to see all of your available models, or see our
35    /// [Model overview](https://platform.openai.com/docs/models/overview)
36    /// for descriptions of them.
37    pub model: String,
38
39    ///  Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
40    pub input: EmbeddingInput,
41
42    /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). Defaults to float
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub encoding_format: Option<EncodingFormat>,
45
46    /// A unique identifier representing your end-user, which will help OpenAI
47    ///  to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub user: Option<String>,
50
51    /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub dimensions: Option<u32>,
54}
55
56/// Represents an embedding vector returned by embedding endpoint.
57#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
58pub struct Embedding {
59    /// The index of the embedding in the list of embeddings.
60    pub index: u32,
61    /// The object type, which is always "embedding".
62    pub object: String,
63    /// The embedding vector, which is a list of floats. The length of vector
64    /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
65    pub embedding: Vec<f32>,
66}
67
68#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
69pub struct Base64EmbeddingVector(pub String);
70
71impl From<Base64EmbeddingVector> for Vec<f32> {
72    fn from(value: Base64EmbeddingVector) -> Self {
73        let bytes = general_purpose::STANDARD
74            .decode(value.0)
75            .expect("openai base64 encoding to be valid");
76        let chunks = bytes.chunks_exact(4);
77        chunks
78            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
79            .collect()
80    }
81}
82
83/// Represents an base64-encoded embedding vector returned by embedding endpoint.
84#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
85pub struct Base64Embedding {
86    /// The index of the embedding in the list of embeddings.
87    pub index: u32,
88    /// The object type, which is always "embedding".
89    pub object: String,
90    /// The embedding vector, encoded in base64.
91    pub embedding: Base64EmbeddingVector,
92}
93
94#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
95pub struct EmbeddingUsage {
96    /// The number of tokens used by the prompt.
97    pub prompt_tokens: u32,
98    /// The total number of tokens used by the request.
99    pub total_tokens: u32,
100}
101
102#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
103pub struct CreateEmbeddingResponse {
104    pub object: String,
105    /// The name of the model used to generate the embedding.
106    pub model: String,
107    /// The list of embeddings generated by the model.
108    pub data: Vec<Embedding>,
109    /// The usage information for the request.
110    pub usage: EmbeddingUsage,
111}
112
113#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
114pub struct CreateBase64EmbeddingResponse {
115    pub object: String,
116    /// The name of the model used to generate the embedding.
117    pub model: String,
118    /// The list of embeddings generated by the model.
119    pub data: Vec<Base64Embedding>,
120    /// The usage information for the request.
121    pub usage: EmbeddingUsage,
122}