chroma_types/
base64_decode.rs

1use base64::{engine::general_purpose, Engine as _};
2use chroma_error::{ChromaError, ErrorCodes};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6#[derive(Error, Debug)]
7pub enum Base64DecodeError {
8    #[error("Invalid base64 string: {0}")]
9    InvalidBase64(#[from] base64::DecodeError),
10    #[error("Invalid byte length: {byte_length} bytes cannot be converted to f32 values (must be divisible by 4)")]
11    InvalidByteLength { byte_length: usize },
12    #[error("Failed to convert embedding {embedding_index} to byte array")]
13    EmbeddingConversionFailed { embedding_index: usize },
14}
15
16impl ChromaError for Base64DecodeError {
17    fn code(&self) -> ErrorCodes {
18        ErrorCodes::InvalidArgument
19    }
20}
21
22#[derive(Serialize, Deserialize, Debug, Clone)]
23#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
24#[serde(untagged)]
25pub enum EmbeddingsPayload {
26    JsonArrays(Vec<Vec<f32>>),
27    Base64Binary(Vec<String>),
28}
29
30#[derive(Serialize, Deserialize, Debug, Clone)]
31#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
32#[serde(untagged)]
33pub enum UpdateEmbeddingsPayload {
34    JsonArrays(Vec<Option<Vec<f32>>>),
35    Base64Binary(Vec<Option<String>>),
36}
37
38pub fn decode_embeddings(
39    embeddings: EmbeddingsPayload,
40) -> Result<Vec<Vec<f32>>, Base64DecodeError> {
41    match embeddings {
42        EmbeddingsPayload::Base64Binary(base64_strings) => {
43            Ok(decode_base64_embeddings(&base64_strings)?)
44        }
45        EmbeddingsPayload::JsonArrays(arrays) => Ok(arrays),
46    }
47}
48
49pub fn maybe_decode_update_embeddings(
50    embeddings: Option<UpdateEmbeddingsPayload>,
51) -> Result<Option<Vec<Option<Vec<f32>>>>, Base64DecodeError> {
52    match embeddings {
53        Some(UpdateEmbeddingsPayload::Base64Binary(base64_data)) => {
54            Ok(Some(decode_base64_update_embeddings(&base64_data)?))
55        }
56        Some(UpdateEmbeddingsPayload::JsonArrays(arrays)) => Ok(Some(arrays)),
57        None => Ok(None),
58    }
59}
60
61pub fn decode_base64_embeddings(
62    base64_strings: &Vec<String>,
63) -> Result<Vec<Vec<f32>>, Base64DecodeError> {
64    let mut result = Vec::with_capacity(base64_strings.len());
65
66    for base64_str in base64_strings {
67        let floats = decode_base64_embedding(base64_str)?;
68
69        result.push(floats);
70    }
71
72    Ok(result)
73}
74
75pub fn decode_base64_update_embeddings(
76    base64_data: &Vec<Option<String>>,
77) -> Result<Vec<Option<Vec<f32>>>, Base64DecodeError> {
78    let mut result = Vec::with_capacity(base64_data.len());
79
80    for base64_str in base64_data {
81        if let Some(base64_str) = base64_str {
82            let floats = decode_base64_embedding(base64_str)?;
83
84            result.push(Some(floats));
85        } else {
86            result.push(None);
87        }
88    }
89
90    Ok(result)
91}
92
93pub fn decode_base64_embedding(base64_str: &String) -> Result<Vec<f32>, Base64DecodeError> {
94    let bytes = general_purpose::STANDARD.decode(base64_str)?;
95
96    let float_count = bytes.len() / 4;
97    if bytes.len() % 4 != 0 {
98        return Err(Base64DecodeError::InvalidByteLength {
99            byte_length: bytes.len(),
100        });
101    }
102
103    let mut floats = Vec::with_capacity(float_count);
104    for (embedding_index, chunk) in bytes.chunks_exact(4).enumerate() {
105        let float_bytes: [u8; 4] = chunk
106            .try_into()
107            .map_err(|_| Base64DecodeError::EmbeddingConversionFailed { embedding_index })?;
108        // handles little endian encoding
109        floats.push(f32::from_le_bytes(float_bytes));
110    }
111
112    Ok(floats)
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    #[cfg(feature = "testing")]
119    use proptest::prelude::*;
120
121    #[test]
122    fn test_invalid_base64_returns_error() {
123        let invalid_base64 = "invalid!@#$".to_string();
124        let result = decode_base64_embedding(&invalid_base64);
125        assert!(matches!(result, Err(Base64DecodeError::InvalidBase64(_))));
126    }
127
128    #[test]
129    fn test_invalid_byte_length_returns_error() {
130        // This is valid base64 but encodes 3 bytes (not divisible by 4)
131        let invalid_length_base64 = "YWJj".to_string(); // "abc" = 3 bytes
132        let result = decode_base64_embedding(&invalid_length_base64);
133        assert!(matches!(
134            result,
135            Err(Base64DecodeError::InvalidByteLength { byte_length: 3 })
136        ));
137    }
138
139    #[test]
140    fn test_get_embeddings_propagates_error() {
141        let invalid_embeddings = EmbeddingsPayload::Base64Binary(vec!["invalid!@#$".to_string()]);
142        let result = decode_embeddings(invalid_embeddings);
143
144        assert!(matches!(result, Err(Base64DecodeError::InvalidBase64(_))));
145    }
146
147    #[test]
148    fn test_valid_base64_decoding() {
149        // Valid base64 encoding 4 bytes (1 f32)
150        let valid_base64 = base64::Engine::encode(
151            &base64::engine::general_purpose::STANDARD,
152            1.0f32.to_le_bytes(),
153        );
154        let result = decode_base64_embedding(&valid_base64);
155        assert!(result.is_ok());
156        assert_eq!(result.unwrap(), vec![1.0f32]);
157    }
158
159    #[test]
160    fn test_multiple_embeddings_with_one_invalid() {
161        let valid_base64 = base64::Engine::encode(
162            &base64::engine::general_purpose::STANDARD,
163            1.0f32.to_le_bytes(),
164        );
165        let embeddings =
166            EmbeddingsPayload::Base64Binary(vec![valid_base64, "invalid!@#$".to_string()]);
167
168        let result = decode_embeddings(embeddings);
169        assert!(matches!(result, Err(Base64DecodeError::InvalidBase64(_))));
170    }
171
172    #[test]
173    fn test_decode_base64_embedding() {
174        let valid_base64 = base64::Engine::encode(
175            &base64::engine::general_purpose::STANDARD,
176            1.0f32.to_le_bytes(),
177        );
178        let result = decode_base64_embedding(&valid_base64);
179        assert!(result.is_ok());
180        assert_eq!(result.unwrap(), vec![1.0f32]);
181    }
182
183    #[test]
184    fn test_decode_base64_update_embeddings() {
185        let valid_base64s: Vec<Option<String>> = vec![
186            Some(base64::Engine::encode(
187                &base64::engine::general_purpose::STANDARD,
188                1.0f32.to_le_bytes(),
189            )),
190            Some(base64::Engine::encode(
191                &base64::engine::general_purpose::STANDARD,
192                2.0f32.to_le_bytes(),
193            )),
194            None,
195            Some(base64::Engine::encode(
196                &base64::engine::general_purpose::STANDARD,
197                3.0f32.to_le_bytes(),
198            )),
199            None,
200        ];
201        let result = decode_base64_update_embeddings(&valid_base64s);
202        assert!(result.is_ok());
203        assert_eq!(
204            result.unwrap(),
205            vec![
206                Some(vec![1.0f32]),
207                Some(vec![2.0f32]),
208                None,
209                Some(vec![3.0f32]),
210                None,
211            ]
212        );
213    }
214
215    #[test]
216    fn test_decode_base64_embeddings() {
217        let valid_base64s = vec![
218            base64::Engine::encode(
219                &base64::engine::general_purpose::STANDARD,
220                1.0f32.to_le_bytes(),
221            ),
222            base64::Engine::encode(
223                &base64::engine::general_purpose::STANDARD,
224                2.0f32.to_le_bytes(),
225            ),
226            base64::Engine::encode(
227                &base64::engine::general_purpose::STANDARD,
228                3.0f32.to_le_bytes(),
229            ),
230        ];
231        let result = decode_base64_embeddings(&valid_base64s);
232        assert!(result.is_ok());
233        assert_eq!(
234            result.unwrap(),
235            vec![vec![1.0f32], vec![2.0f32], vec![3.0f32]]
236        );
237    }
238
239    #[cfg(feature = "testing")]
240    fn encode_floats_to_base64(floats: &[f32]) -> String {
241        let mut bytes = Vec::with_capacity(floats.len() * 4);
242        for &f in floats {
243            bytes.extend_from_slice(&f.to_le_bytes());
244        }
245        general_purpose::STANDARD.encode(&bytes)
246    }
247
248    #[cfg(feature = "testing")]
249    fn embeddings_strategy() -> impl Strategy<Value = Vec<Vec<f32>>> {
250        any::<Vec<Vec<f32>>>()
251    }
252
253    #[cfg(feature = "testing")]
254    proptest! {
255        #[test]
256        fn test_decode_base64_embeddings_prop(embeddings in embeddings_strategy()) {
257            let base64_strings = embeddings.iter().map(|e| encode_floats_to_base64(e)).collect();
258            let result = decode_base64_embeddings(&base64_strings).unwrap();
259            for (original, decoded) in embeddings.iter().zip(result.iter()) {
260                prop_assert_eq!(original, decoded);
261            }
262        }
263    }
264}