async_openai/types/embeddings/embedding.rs
1use base64::engine::{general_purpose, Engine};
2use derive_builder::Builder;
3use serde::{Deserialize, Serialize};
4
5use crate::error::OpenAIError;
6
7#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
8#[serde(untagged)]
9pub enum EmbeddingInput {
10 String(String),
11 StringArray(Vec<String>),
12 // Minimum value is 0, maximum value is 100257 (inclusive).
13 IntegerArray(Vec<u32>),
14 ArrayOfIntegerArray(Vec<Vec<u32>>),
15}
16
17#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum EncodingFormat {
20 #[default]
21 Float,
22 Base64,
23}
24
25#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
26#[builder(name = "CreateEmbeddingRequestArgs")]
27#[builder(pattern = "mutable")]
28#[builder(setter(into, strip_option), default)]
29#[builder(derive(Debug))]
30#[builder(build_fn(error = "OpenAIError"))]
31pub struct CreateEmbeddingRequest {
32 /// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list)
33 /// API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models)
34 /// for descriptions of them.
35 pub model: String,
36
37 /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single
38 /// request, pass an array of strings or array of token arrays. The input must not exceed the max
39 /// input tokens for the model (8192 tokens for all embedding models), cannot be an empty string, and
40 /// any array must be 2048 dimensions or less. [Example Python
41 /// code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
42 /// In addition to the per-input token limit, all embedding models enforce a maximum of 300,000
43 /// tokens summed across all inputs in a single request.
44 pub input: EmbeddingInput,
45
46 /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub encoding_format: Option<EncodingFormat>,
49
50 /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
51 /// [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#end-user-ids).
52 #[serde(skip_serializing_if = "Option::is_none")]
53 pub user: Option<String>,
54
55 /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
56 #[serde(skip_serializing_if = "Option::is_none")]
57 pub dimensions: Option<u32>,
58}
59
60/// Represents an embedding vector returned by embedding endpoint.
61#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
62pub struct Embedding {
63 /// The index of the embedding in the list of embeddings.
64 pub index: u32,
65 /// The object type, which is always "embedding".
66 pub object: String,
67 /// The embedding vector, which is a list of floats. The length of vector
68 /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
69 pub embedding: Vec<f32>,
70}
71
72#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
73pub struct Base64EmbeddingVector(pub String);
74
75impl From<Base64EmbeddingVector> for Vec<f32> {
76 fn from(value: Base64EmbeddingVector) -> Self {
77 let bytes = general_purpose::STANDARD
78 .decode(value.0)
79 .expect("openai base64 encoding to be valid");
80 let chunks = bytes.chunks_exact(4);
81 chunks
82 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
83 .collect()
84 }
85}
86
87/// Represents an base64-encoded embedding vector returned by embedding endpoint.
88#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
89pub struct Base64Embedding {
90 /// The index of the embedding in the list of embeddings.
91 pub index: u32,
92 /// The object type, which is always "embedding".
93 pub object: String,
94 /// The embedding vector, encoded in base64.
95 pub embedding: Base64EmbeddingVector,
96}
97
98#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
99pub struct EmbeddingUsage {
100 /// The number of tokens used by the prompt.
101 pub prompt_tokens: u32,
102 /// The total number of tokens used by the request.
103 pub total_tokens: u32,
104}
105
106#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
107pub struct CreateEmbeddingResponse {
108 pub object: String,
109 /// The name of the model used to generate the embedding.
110 pub model: String,
111 /// The list of embeddings generated by the model.
112 pub data: Vec<Embedding>,
113 /// The usage information for the request.
114 pub usage: EmbeddingUsage,
115}
116
117#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
118pub struct CreateBase64EmbeddingResponse {
119 pub object: String,
120 /// The name of the model used to generate the embedding.
121 pub model: String,
122 /// The list of embeddings generated by the model.
123 pub data: Vec<Base64Embedding>,
124 /// The usage information for the request.
125 pub usage: EmbeddingUsage,
126}