dynamo_async_openai/types/embedding.rs
1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Based on https://github.com/64bit/async-openai/ by Himanshu Neema
5// Original Copyright (c) 2022 Himanshu Neema
6// Licensed under MIT License (see ATTRIBUTIONS-Rust.md)
7//
8// Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
9// Licensed under Apache 2.0
10
11use base64::engine::{Engine, general_purpose};
12use derive_builder::Builder;
13use serde::{Deserialize, Serialize};
14
15use crate::error::OpenAIError;
16
17#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
18#[serde(untagged)]
19pub enum EmbeddingInput {
20 String(String),
21 StringArray(Vec<String>),
22 // Minimum value is 0, maximum value is 100257 (inclusive).
23 IntegerArray(Vec<u32>),
24 ArrayOfIntegerArray(Vec<Vec<u32>>),
25}
26
27#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum EncodingFormat {
30 #[default]
31 Float,
32 Base64,
33}
34
35#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
36#[builder(name = "CreateEmbeddingRequestArgs")]
37#[builder(pattern = "mutable")]
38#[builder(setter(into, strip_option), default)]
39#[builder(derive(Debug))]
40#[builder(build_fn(error = "OpenAIError"))]
41pub struct CreateEmbeddingRequest {
42 /// ID of the model to use. You can use the
43 /// [List models](https://platform.openai.com/docs/api-reference/models/list)
44 /// API to see all of your available models, or see our
45 /// [Model overview](https://platform.openai.com/docs/models/overview)
46 /// for descriptions of them.
47 pub model: String,
48
49 /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.
50 pub input: EmbeddingInput,
51
52 /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). Defaults to float
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub encoding_format: Option<EncodingFormat>,
55
56 /// A unique identifier representing your end-user, which will help OpenAI
57 /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub user: Option<String>,
60
61 /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.
62 #[serde(skip_serializing_if = "Option::is_none")]
63 pub dimensions: Option<u32>,
64}
65
66/// Represents an embedding vector returned by embedding endpoint.
67#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
68pub struct Embedding {
69 /// The index of the embedding in the list of embeddings.
70 pub index: u32,
71 /// The object type, which is always "embedding".
72 pub object: String,
73 /// The embedding vector, which is a list of floats. The length of vector
74 /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
75 pub embedding: Vec<f32>,
76}
77
78#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
79pub struct Base64EmbeddingVector(pub String);
80
81impl From<Base64EmbeddingVector> for Vec<f32> {
82 fn from(value: Base64EmbeddingVector) -> Self {
83 let bytes = general_purpose::STANDARD
84 .decode(value.0)
85 .expect("openai base64 encoding to be valid");
86 let chunks = bytes.chunks_exact(4);
87 chunks
88 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
89 .collect()
90 }
91}
92
93/// Represents an base64-encoded embedding vector returned by embedding endpoint.
94#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
95pub struct Base64Embedding {
96 /// The index of the embedding in the list of embeddings.
97 pub index: u32,
98 /// The object type, which is always "embedding".
99 pub object: String,
100 /// The embedding vector, encoded in base64.
101 pub embedding: Base64EmbeddingVector,
102}
103
104#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
105pub struct EmbeddingUsage {
106 /// The number of tokens used by the prompt.
107 pub prompt_tokens: u32,
108 /// The total number of tokens used by the request.
109 pub total_tokens: u32,
110}
111
112#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
113pub struct CreateEmbeddingResponse {
114 pub object: String,
115 /// The name of the model used to generate the embedding.
116 pub model: String,
117 /// The list of embeddings generated by the model.
118 pub data: Vec<Embedding>,
119 /// The usage information for the request.
120 pub usage: EmbeddingUsage,
121}
122
123#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
124pub struct CreateBase64EmbeddingResponse {
125 pub object: String,
126 /// The name of the model used to generate the embedding.
127 pub model: String,
128 /// The list of embeddings generated by the model.
129 pub data: Vec<Base64Embedding>,
130 /// The usage information for the request.
131 pub usage: EmbeddingUsage,
132}