async_openai/types/
embedding.rs1use base64::engine::{general_purpose, Engine};
2use derive_builder::Builder;
3use serde::{Deserialize, Serialize};
4
5use crate::error::OpenAIError;
6
7#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)]
8#[serde(untagged)]
9pub enum EmbeddingInput {
10 String(String),
11 StringArray(Vec<String>),
12 IntegerArray(Vec<u32>),
14 ArrayOfIntegerArray(Vec<Vec<u32>>),
15}
16
17#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)]
18#[serde(rename_all = "lowercase")]
19pub enum EncodingFormat {
20 #[default]
21 Float,
22 Base64,
23}
24
25#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)]
26#[builder(name = "CreateEmbeddingRequestArgs")]
27#[builder(pattern = "mutable")]
28#[builder(setter(into, strip_option), default)]
29#[builder(derive(Debug))]
30#[builder(build_fn(error = "OpenAIError"))]
31pub struct CreateEmbeddingRequest {
32 pub model: String,
38
39 pub input: EmbeddingInput,
41
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub encoding_format: Option<EncodingFormat>,
45
46 #[serde(skip_serializing_if = "Option::is_none")]
49 pub user: Option<String>,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub dimensions: Option<u32>,
54}
55
56#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
58pub struct Embedding {
59 pub index: u32,
61 pub object: String,
63 pub embedding: Vec<f32>,
66}
67
68#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
69pub struct Base64EmbeddingVector(pub String);
70
71impl From<Base64EmbeddingVector> for Vec<f32> {
72 fn from(value: Base64EmbeddingVector) -> Self {
73 let bytes = general_purpose::STANDARD
74 .decode(value.0)
75 .expect("openai base64 encoding to be valid");
76 let chunks = bytes.chunks_exact(4);
77 chunks
78 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
79 .collect()
80 }
81}
82
83#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
85pub struct Base64Embedding {
86 pub index: u32,
88 pub object: String,
90 pub embedding: Base64EmbeddingVector,
92}
93
94#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)]
95pub struct EmbeddingUsage {
96 pub prompt_tokens: u32,
98 pub total_tokens: u32,
100}
101
102#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
103pub struct CreateEmbeddingResponse {
104 pub object: String,
105 pub model: String,
107 pub data: Vec<Embedding>,
109 pub usage: EmbeddingUsage,
111}
112
113#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
114pub struct CreateBase64EmbeddingResponse {
115 pub object: String,
116 pub model: String,
118 pub data: Vec<Base64Embedding>,
120 pub usage: EmbeddingUsage,
122}