chroma_types/
base64_decode.rs1use base64::{engine::general_purpose, Engine as _};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6#[derive(Error, Debug)]
7pub enum Base64DecodeError {
8 #[error("Invalid base64 string: {0}")]
9 InvalidBase64(#[from] base64::DecodeError),
10 #[error("Invalid byte length: {byte_length} bytes cannot be converted to f32 values (must be divisible by 4)")]
11 InvalidByteLength { byte_length: usize },
12 #[error("Failed to convert embedding {embedding_index} to byte array")]
13 EmbeddingConversionFailed { embedding_index: usize },
14}
15
16impl ChromaError for Base64DecodeError {
17 fn code(&self) -> ErrorCodes {
18 ErrorCodes::InvalidArgument
19 }
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone)]
23#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
24#[serde(untagged)]
25pub enum EmbeddingsPayload {
26 JsonArrays(Vec<Vec<f32>>),
27 Base64Binary(Vec<String>),
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32#[serde(untagged)]
33pub enum UpdateEmbeddingsPayload {
34 JsonArrays(Vec<Option<Vec<f32>>>),
35 Base64Binary(Vec<Option<String>>),
36}
37
38pub fn decode_embeddings(
39 embeddings: EmbeddingsPayload,
40) -> Result<Vec<Vec<f32>>, Base64DecodeError> {
41 match embeddings {
42 EmbeddingsPayload::Base64Binary(base64_strings) => {
43 Ok(decode_base64_embeddings(&base64_strings)?)
44 }
45 EmbeddingsPayload::JsonArrays(arrays) => Ok(arrays),
46 }
47}
48
49pub fn maybe_decode_update_embeddings(
50 embeddings: Option<UpdateEmbeddingsPayload>,
51) -> Result<Option<Vec<Option<Vec<f32>>>>, Base64DecodeError> {
52 match embeddings {
53 Some(UpdateEmbeddingsPayload::Base64Binary(base64_data)) => {
54 Ok(Some(decode_base64_update_embeddings(&base64_data)?))
55 }
56 Some(UpdateEmbeddingsPayload::JsonArrays(arrays)) => Ok(Some(arrays)),
57 None => Ok(None),
58 }
59}
60
61pub fn decode_base64_embeddings(
62 base64_strings: &Vec<String>,
63) -> Result<Vec<Vec<f32>>, Base64DecodeError> {
64 let mut result = Vec::with_capacity(base64_strings.len());
65
66 for base64_str in base64_strings {
67 let floats = decode_base64_embedding(base64_str)?;
68
69 result.push(floats);
70 }
71
72 Ok(result)
73}
74
75pub fn decode_base64_update_embeddings(
76 base64_data: &Vec<Option<String>>,
77) -> Result<Vec<Option<Vec<f32>>>, Base64DecodeError> {
78 let mut result = Vec::with_capacity(base64_data.len());
79
80 for base64_str in base64_data {
81 if let Some(base64_str) = base64_str {
82 let floats = decode_base64_embedding(base64_str)?;
83
84 result.push(Some(floats));
85 } else {
86 result.push(None);
87 }
88 }
89
90 Ok(result)
91}
92
93pub fn decode_base64_embedding(base64_str: &String) -> Result<Vec<f32>, Base64DecodeError> {
94 let bytes = general_purpose::STANDARD.decode(base64_str)?;
95
96 let float_count = bytes.len() / 4;
97 if bytes.len() % 4 != 0 {
98 return Err(Base64DecodeError::InvalidByteLength {
99 byte_length: bytes.len(),
100 });
101 }
102
103 let mut floats = Vec::with_capacity(float_count);
104 for (embedding_index, chunk) in bytes.chunks_exact(4).enumerate() {
105 let float_bytes: [u8; 4] = chunk
106 .try_into()
107 .map_err(|_| Base64DecodeError::EmbeddingConversionFailed { embedding_index })?;
108 floats.push(f32::from_le_bytes(float_bytes));
110 }
111
112 Ok(floats)
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 #[cfg(feature = "testing")]
119 use proptest::prelude::*;
120
121 #[test]
122 fn test_invalid_base64_returns_error() {
123 let invalid_base64 = "invalid!@#$".to_string();
124 let result = decode_base64_embedding(&invalid_base64);
125 assert!(matches!(result, Err(Base64DecodeError::InvalidBase64(_))));
126 }
127
128 #[test]
129 fn test_invalid_byte_length_returns_error() {
130 let invalid_length_base64 = "YWJj".to_string(); let result = decode_base64_embedding(&invalid_length_base64);
133 assert!(matches!(
134 result,
135 Err(Base64DecodeError::InvalidByteLength { byte_length: 3 })
136 ));
137 }
138
139 #[test]
140 fn test_get_embeddings_propagates_error() {
141 let invalid_embeddings = EmbeddingsPayload::Base64Binary(vec!["invalid!@#$".to_string()]);
142 let result = decode_embeddings(invalid_embeddings);
143
144 assert!(matches!(result, Err(Base64DecodeError::InvalidBase64(_))));
145 }
146
147 #[test]
148 fn test_valid_base64_decoding() {
149 let valid_base64 = base64::Engine::encode(
151 &base64::engine::general_purpose::STANDARD,
152 1.0f32.to_le_bytes(),
153 );
154 let result = decode_base64_embedding(&valid_base64);
155 assert!(result.is_ok());
156 assert_eq!(result.unwrap(), vec![1.0f32]);
157 }
158
159 #[test]
160 fn test_multiple_embeddings_with_one_invalid() {
161 let valid_base64 = base64::Engine::encode(
162 &base64::engine::general_purpose::STANDARD,
163 1.0f32.to_le_bytes(),
164 );
165 let embeddings =
166 EmbeddingsPayload::Base64Binary(vec![valid_base64, "invalid!@#$".to_string()]);
167
168 let result = decode_embeddings(embeddings);
169 assert!(matches!(result, Err(Base64DecodeError::InvalidBase64(_))));
170 }
171
172 #[test]
173 fn test_decode_base64_embedding() {
174 let valid_base64 = base64::Engine::encode(
175 &base64::engine::general_purpose::STANDARD,
176 1.0f32.to_le_bytes(),
177 );
178 let result = decode_base64_embedding(&valid_base64);
179 assert!(result.is_ok());
180 assert_eq!(result.unwrap(), vec![1.0f32]);
181 }
182
183 #[test]
184 fn test_decode_base64_update_embeddings() {
185 let valid_base64s: Vec<Option<String>> = vec![
186 Some(base64::Engine::encode(
187 &base64::engine::general_purpose::STANDARD,
188 1.0f32.to_le_bytes(),
189 )),
190 Some(base64::Engine::encode(
191 &base64::engine::general_purpose::STANDARD,
192 2.0f32.to_le_bytes(),
193 )),
194 None,
195 Some(base64::Engine::encode(
196 &base64::engine::general_purpose::STANDARD,
197 3.0f32.to_le_bytes(),
198 )),
199 None,
200 ];
201 let result = decode_base64_update_embeddings(&valid_base64s);
202 assert!(result.is_ok());
203 assert_eq!(
204 result.unwrap(),
205 vec![
206 Some(vec![1.0f32]),
207 Some(vec![2.0f32]),
208 None,
209 Some(vec![3.0f32]),
210 None,
211 ]
212 );
213 }
214
215 #[test]
216 fn test_decode_base64_embeddings() {
217 let valid_base64s = vec![
218 base64::Engine::encode(
219 &base64::engine::general_purpose::STANDARD,
220 1.0f32.to_le_bytes(),
221 ),
222 base64::Engine::encode(
223 &base64::engine::general_purpose::STANDARD,
224 2.0f32.to_le_bytes(),
225 ),
226 base64::Engine::encode(
227 &base64::engine::general_purpose::STANDARD,
228 3.0f32.to_le_bytes(),
229 ),
230 ];
231 let result = decode_base64_embeddings(&valid_base64s);
232 assert!(result.is_ok());
233 assert_eq!(
234 result.unwrap(),
235 vec![vec![1.0f32], vec![2.0f32], vec![3.0f32]]
236 );
237 }
238
239 #[cfg(feature = "testing")]
240 fn encode_floats_to_base64(floats: &[f32]) -> String {
241 let mut bytes = Vec::with_capacity(floats.len() * 4);
242 for &f in floats {
243 bytes.extend_from_slice(&f.to_le_bytes());
244 }
245 general_purpose::STANDARD.encode(&bytes)
246 }
247
248 #[cfg(feature = "testing")]
249 fn embeddings_strategy() -> impl Strategy<Value = Vec<Vec<f32>>> {
250 any::<Vec<Vec<f32>>>()
251 }
252
253 #[cfg(feature = "testing")]
254 proptest! {
255 #[test]
256 fn test_decode_base64_embeddings_prop(embeddings in embeddings_strategy()) {
257 let base64_strings = embeddings.iter().map(|e| encode_floats_to_base64(e)).collect();
258 let result = decode_base64_embeddings(&base64_strings).unwrap();
259 for (original, decoded) in embeddings.iter().zip(result.iter()) {
260 prop_assert_eq!(original, decoded);
261 }
262 }
263 }
264}