decthings_api/tensor/
tensor_impl.rs

1use super::{DecthingsElementAudio, DecthingsElementImage, DecthingsElementVideo};
2use byte_slice_cast::{AsByteSlice, AsMutByteSlice, AsSliceOf, FromByteSlice, ToMutByteSlice};
3use ndarray::{Array, ArrayView, CowArray, IxDyn};
4
5#[derive(Debug)]
6pub enum DeserializeDecthingsTensorError {
7    UnexpectedEndOfBytes,
8    InvalidBytes(String),
9}
10
11const TYPE_SPEC_F32: u8 = 1;
12const TYPE_SPEC_F64: u8 = 2;
13const TYPE_SPEC_I8: u8 = 3;
14const TYPE_SPEC_I16: u8 = 4;
15const TYPE_SPEC_I32: u8 = 5;
16const TYPE_SPEC_I64: u8 = 6;
17const TYPE_SPEC_U8: u8 = 7;
18const TYPE_SPEC_U16: u8 = 8;
19const TYPE_SPEC_U32: u8 = 9;
20const TYPE_SPEC_U64: u8 = 10;
21const TYPE_SPEC_STRING: u8 = 11;
22const TYPE_SPEC_BINARY: u8 = 12;
23const TYPE_SPEC_BOOLEAN: u8 = 13;
24const TYPE_SPEC_IMAGE: u8 = 14;
25const TYPE_SPEC_AUDIO: u8 = 15;
26const TYPE_SPEC_VIDEO: u8 = 16;
27
28#[derive(Debug, Clone)]
29pub enum DecthingsTensor<'a> {
30    F32(CowArray<'a, f32, IxDyn>),
31    F64(CowArray<'a, f64, IxDyn>),
32    I8(CowArray<'a, i8, IxDyn>),
33    I16(CowArray<'a, i16, IxDyn>),
34    I32(CowArray<'a, i32, IxDyn>),
35    I64(CowArray<'a, i64, IxDyn>),
36    U8(CowArray<'a, u8, IxDyn>),
37    U16(CowArray<'a, u16, IxDyn>),
38    U32(CowArray<'a, u32, IxDyn>),
39    U64(CowArray<'a, u64, IxDyn>),
40    String(CowArray<'a, &'a str, IxDyn>),
41    Binary(CowArray<'a, &'a [u8], IxDyn>),
42    Boolean(CowArray<'a, bool, IxDyn>),
43    Image(CowArray<'a, DecthingsElementImage<'a>, IxDyn>),
44    Audio(CowArray<'a, DecthingsElementAudio<'a>, IxDyn>),
45    Video(CowArray<'a, DecthingsElementVideo<'a>, IxDyn>),
46}
47
48impl<'a> DecthingsTensor<'a> {
49    pub fn view(&'a self) -> Self {
50        match self {
51            Self::F32(inner) => DecthingsTensor::F32(inner.view().into()),
52            Self::F64(inner) => DecthingsTensor::F64(inner.view().into()),
53            Self::I8(inner) => DecthingsTensor::I8(inner.view().into()),
54            Self::I16(inner) => DecthingsTensor::I16(inner.view().into()),
55            Self::I32(inner) => DecthingsTensor::I32(inner.view().into()),
56            Self::I64(inner) => DecthingsTensor::I64(inner.view().into()),
57            Self::U8(inner) => DecthingsTensor::U8(inner.view().into()),
58            Self::U16(inner) => DecthingsTensor::U16(inner.view().into()),
59            Self::U32(inner) => DecthingsTensor::U32(inner.view().into()),
60            Self::U64(inner) => DecthingsTensor::U64(inner.view().into()),
61            Self::String(inner) => DecthingsTensor::String(inner.view().into()),
62            Self::Binary(inner) => DecthingsTensor::Binary(inner.view().into()),
63            Self::Boolean(inner) => DecthingsTensor::Boolean(inner.view().into()),
64            Self::Image(inner) => DecthingsTensor::Image(inner.view().into()),
65            Self::Audio(inner) => DecthingsTensor::Audio(inner.view().into()),
66            Self::Video(inner) => DecthingsTensor::Video(inner.view().into()),
67        }
68    }
69
70    pub fn typ(&self) -> super::DecthingsElementType {
71        match self {
72            Self::F32(_) => super::DecthingsElementType::F32,
73            Self::F64(_) => super::DecthingsElementType::F64,
74            Self::I8(_) => super::DecthingsElementType::I8,
75            Self::I16(_) => super::DecthingsElementType::I16,
76            Self::I32(_) => super::DecthingsElementType::I32,
77            Self::I64(_) => super::DecthingsElementType::I64,
78            Self::U8(_) => super::DecthingsElementType::U8,
79            Self::U16(_) => super::DecthingsElementType::U16,
80            Self::U32(_) => super::DecthingsElementType::U32,
81            Self::U64(_) => super::DecthingsElementType::U64,
82            Self::String(_) => super::DecthingsElementType::String,
83            Self::Binary(_) => super::DecthingsElementType::Binary,
84            Self::Boolean(_) => super::DecthingsElementType::Boolean,
85            Self::Image(_) => super::DecthingsElementType::Image,
86            Self::Audio(_) => super::DecthingsElementType::Audio,
87            Self::Video(_) => super::DecthingsElementType::Video,
88        }
89    }
90
91    pub fn shape(&self) -> &[usize] {
92        match self {
93            Self::F32(inner) => inner.shape(),
94            Self::F64(inner) => inner.shape(),
95            Self::I8(inner) => inner.shape(),
96            Self::I16(inner) => inner.shape(),
97            Self::I32(inner) => inner.shape(),
98            Self::I64(inner) => inner.shape(),
99            Self::U8(inner) => inner.shape(),
100            Self::U16(inner) => inner.shape(),
101            Self::U32(inner) => inner.shape(),
102            Self::U64(inner) => inner.shape(),
103            Self::String(inner) => inner.shape(),
104            Self::Binary(inner) => inner.shape(),
105            Self::Boolean(inner) => inner.shape(),
106            Self::Image(inner) => inner.shape(),
107            Self::Audio(inner) => inner.shape(),
108            Self::Video(inner) => inner.shape(),
109        }
110    }
111
112    pub fn len(&self) -> usize {
113        self.shape().iter().product()
114    }
115
116    pub fn is_empty(&self) -> bool {
117        self.len() == 0
118    }
119
120    /// If this is a numeric type (f32, f64, u8, u16, u32, u64, i8, i16, i32 or i64), casts it to a float
121    /// array.
122    ///
123    /// Returns an array that is either owned or not. If the data was of type f64, the returned
124    /// array is borrowed. Otherwise a new array is created, so the returned array is owned.
125    pub fn as_f64(&'a self) -> Option<CowArray<'a, f64, IxDyn>> {
126        match self {
127            Self::F32(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
128            Self::F64(val) => Some(val.into()),
129            Self::I8(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
130            Self::I16(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
131            Self::I32(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
132            Self::I64(val) => Some(CowArray::from(val.map(|x| (*x) as f64))),
133            Self::U8(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
134            Self::U16(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
135            Self::U32(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
136            Self::U64(val) => Some(CowArray::from(val.map(|x| (*x) as f64))),
137            _ => None,
138        }
139    }
140
141    /// If this is a numeric type (f32, f64, u8, u16, u32, u64, i8, i16, i32 or i64) with length 1, casts the
142    /// single element to an f64 and returns it.
143    pub fn as_f64_item(&self) -> Option<f64> {
144        if self.len() != 1 {
145            return None;
146        }
147        match self {
148            Self::F32(val) => Some((*val.first().unwrap()).into()),
149            Self::F64(val) => Some(*val.first().unwrap()),
150            Self::I8(val) => Some((*val.first().unwrap()).into()),
151            Self::I16(val) => Some((*val.first().unwrap()).into()),
152            Self::I32(val) => Some((*val.first().unwrap()).into()),
153            Self::I64(val) => Some(*val.first().unwrap() as f64),
154            Self::U8(val) => Some((*val.first().unwrap()).into()),
155            Self::U16(val) => Some((*val.first().unwrap()).into()),
156            Self::U32(val) => Some((*val.first().unwrap()).into()),
157            Self::U64(val) => Some(*val.first().unwrap() as f64),
158            _ => None,
159        }
160    }
161
162    /// If this is a numeric type (f32, f64, u8, u16, u32, u64, i8, i16, i32 or i64), cast it to an i64
163    /// array.
164    ///
165    /// Returns an array that is either owned or not. If the data was of type i64, the returned
166    /// array is borrowed. Otherwise a new array is created, so the returned array is owned.
167    pub fn as_i64(&'a self) -> Option<CowArray<'a, i64, IxDyn>> {
168        match self {
169            Self::F32(val) => Some(CowArray::from(val.map(|x| *x as i64))),
170            Self::F64(val) => Some(CowArray::from(val.map(|x| *x as i64))),
171            Self::I8(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
172            Self::I16(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
173            Self::I32(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
174            Self::I64(val) => Some(val.into()),
175            Self::U8(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
176            Self::U16(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
177            Self::U32(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
178            Self::U64(val) => Some(CowArray::from(
179                val.map(|x| (*x).try_into().unwrap_or(i64::MAX)),
180            )),
181            _ => None,
182        }
183    }
184
185    /// If this is a numeric type (f32, f64, u8, u16, u32, u64, i8, i16, i32 or i64) with length 1, casts the
186    /// single element to an i64 and returns it.
187    pub fn as_i64_item(&self) -> Option<i64> {
188        if self.len() != 1 {
189            return None;
190        }
191        match self {
192            Self::F32(val) => Some(*val.first().unwrap() as i64),
193            Self::F64(val) => Some(*val.first().unwrap() as i64),
194            Self::I8(val) => Some((*val.first().unwrap()).into()),
195            Self::I16(val) => Some((*val.first().unwrap()).into()),
196            Self::I32(val) => Some((*val.first().unwrap()).into()),
197            Self::I64(val) => Some(*val.first().unwrap()),
198            Self::U8(val) => Some((*val.first().unwrap()).into()),
199            Self::U16(val) => Some((*val.first().unwrap()).into()),
200            Self::U32(val) => Some((*val.first().unwrap()).into()),
201            Self::U64(val) => Some((*val.first().unwrap()).try_into().unwrap_or(i64::MAX)),
202            _ => None,
203        }
204    }
205
206    /// If this is a numeric type (f32, f64, u8, u16, u32, u64, i8, i16, i32 or i64), casts it to a u64.
207    /// If the value is i8, i16, i32 or i64 and negative, None is returned.
208    pub fn as_u64(&'a self) -> Option<CowArray<'a, u64, IxDyn>> {
209        match self {
210            Self::F32(val) => Some(CowArray::from(val.map(|x| *x as u64))),
211            Self::F64(val) => Some(CowArray::from(val.map(|x| *x as u64))),
212            Self::U8(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
213            Self::U16(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
214            Self::U32(val) => Some(CowArray::from(val.map(|x| (*x).into()))),
215            Self::U64(val) => Some(val.into()),
216            Self::I8(val) => Some(CowArray::from(val.map(|x| (*x).try_into().unwrap_or(0)))),
217            Self::I16(val) => Some(CowArray::from(val.map(|x| (*x).try_into().unwrap_or(0)))),
218            Self::I32(val) => Some(CowArray::from(val.map(|x| (*x).try_into().unwrap_or(0)))),
219            Self::I64(val) => Some(CowArray::from(val.map(|x| (*x).try_into().unwrap_or(0)))),
220            _ => None,
221        }
222    }
223
224    /// If this is a numeric type (f32, f64, u8, u16, u32, u64, i8, i16, i32 or i64) with length 1, casts the
225    /// single element to an u64 and returns it.
226    pub fn as_u64_item(&self) -> Option<u64> {
227        if self.len() != 1 {
228            return None;
229        }
230        match self {
231            Self::F32(val) => Some(*val.first().unwrap() as u64),
232            Self::F64(val) => Some(*val.first().unwrap() as u64),
233            Self::I8(val) => Some((*val.first().unwrap()).try_into().unwrap_or(0)),
234            Self::I16(val) => Some((*val.first().unwrap()).try_into().unwrap_or(0)),
235            Self::I32(val) => Some((*val.first().unwrap()).try_into().unwrap_or(0)),
236            Self::I64(val) => Some((*val.first().unwrap()).try_into().unwrap_or(0)),
237            Self::U8(val) => Some((*val.first().unwrap()).into()),
238            Self::U16(val) => Some((*val.first().unwrap()).into()),
239            Self::U32(val) => Some((*val.first().unwrap()).into()),
240            Self::U64(val) => Some(*val.first().unwrap()),
241            _ => None,
242        }
243    }
244
245    /// If this is a string type with length 1, returns the string.
246    pub fn as_str_item(&self) -> Option<&str> {
247        if self.len() != 1 {
248            return None;
249        }
250        match self {
251            Self::String(val) => Some(val.first().unwrap()),
252            _ => None,
253        }
254    }
255
256    /// If this is a binary type with length 1, returns the binary.
257    pub fn as_binary_item(&self) -> Option<&[u8]> {
258        if self.len() != 1 {
259            return None;
260        }
261        match self {
262            Self::Binary(val) => Some(val.first().unwrap()),
263            _ => None,
264        }
265    }
266
267    /// If this is a boolean type with length 1, returns the boolean.
268    pub fn as_boolean_item(&self) -> Option<bool> {
269        if self.len() != 1 {
270            return None;
271        }
272        match self {
273            Self::Boolean(val) => Some(*val.first().unwrap()),
274            _ => None,
275        }
276    }
277
278    /// If this is an image type with length 1, returns the image.
279    pub fn as_image_item(&self) -> Option<&DecthingsElementImage> {
280        if self.len() != 1 {
281            return None;
282        }
283        match self {
284            Self::Image(val) => Some(val.first().unwrap()),
285            _ => None,
286        }
287    }
288
289    /// If this is a audio type with length 1, returns the audio.
290    pub fn as_audio_item(&self) -> Option<&DecthingsElementAudio> {
291        if self.len() != 1 {
292            return None;
293        }
294        match self {
295            Self::Audio(val) => Some(val.first().unwrap()),
296            _ => None,
297        }
298    }
299
300    /// If this is a video type with length 1, returns the video.
301    pub fn as_video_item(&self) -> Option<&DecthingsElementVideo> {
302        if self.len() != 1 {
303            return None;
304        }
305        match self {
306            Self::Video(val) => Some(val.first().unwrap()),
307            _ => None,
308        }
309    }
310
311    pub(crate) fn serialized_len(&self) -> usize {
312        let size_from_elements = match self {
313            Self::F32(inner) => inner.len() * std::mem::size_of::<f32>(),
314            Self::F64(inner) => inner.len() * std::mem::size_of::<f64>(),
315            Self::I8(inner) => inner.len() * std::mem::size_of::<i8>(),
316            Self::I16(inner) => inner.len() * std::mem::size_of::<i16>(),
317            Self::I32(inner) => inner.len() * std::mem::size_of::<i32>(),
318            Self::I64(inner) => inner.len() * std::mem::size_of::<i64>(),
319            Self::U8(inner) => inner.len() * std::mem::size_of::<u8>(),
320            Self::U16(inner) => inner.len() * std::mem::size_of::<u16>(),
321            Self::U32(inner) => inner.len() * std::mem::size_of::<u32>(),
322            Self::U64(inner) => inner.len() * std::mem::size_of::<u64>(),
323            Self::String(inner) => inner
324                .iter()
325                .map(|x| {
326                    crate::varint::get_varint_u64_len(x.len().try_into().unwrap()) as usize
327                        + x.len()
328                })
329                .sum::<usize>(),
330            Self::Binary(inner) => inner
331                .iter()
332                .map(|x| {
333                    crate::varint::get_varint_u64_len(x.len().try_into().unwrap()) as usize
334                        + x.len()
335                })
336                .sum::<usize>(),
337            Self::Boolean(inner) => inner.len() * std::mem::size_of::<u8>(),
338            Self::Image(inner) => inner
339                .iter()
340                .map(|x| {
341                    let len: u64 = x.data.len().try_into().unwrap();
342                    crate::varint::get_varint_u64_len(1 + x.format().len() as u64 + len) as usize
343                        + 1
344                        + x.format().len()
345                        + x.data.len()
346                })
347                .sum::<usize>(),
348            Self::Audio(inner) => inner
349                .iter()
350                .map(|x| {
351                    let len: u64 = x.data.len().try_into().unwrap();
352                    crate::varint::get_varint_u64_len(1 + x.format().len() as u64 + len) as usize
353                        + 1
354                        + x.format().len()
355                        + x.data.len()
356                })
357                .sum::<usize>(),
358            Self::Video(inner) => inner
359                .iter()
360                .map(|x| {
361                    let len: u64 = x.data.len().try_into().unwrap();
362                    crate::varint::get_varint_u64_len(1 + x.format().len() as u64 + len) as usize
363                        + 1
364                        + x.format().len()
365                        + x.data.len()
366                })
367                .sum::<usize>(),
368        };
369
370        let shape = self.shape();
371        let size_from_shape = 1 + shape
372            .iter()
373            .map(|x| crate::varint::get_varint_u64_len(*x as u64) as usize)
374            .sum::<usize>();
375
376        1 + size_from_shape + size_from_elements
377    }
378
379    pub(crate) fn serialize_append(&self, res: &mut Vec<u8>) {
380        let first_byte = match self {
381            Self::F32(_) => TYPE_SPEC_F32,
382            Self::F64(_) => TYPE_SPEC_F64,
383            Self::I8(_) => TYPE_SPEC_I8,
384            Self::I16(_) => TYPE_SPEC_I16,
385            Self::I32(_) => TYPE_SPEC_I32,
386            Self::I64(_) => TYPE_SPEC_I64,
387            Self::U8(_) => TYPE_SPEC_U8,
388            Self::U16(_) => TYPE_SPEC_U16,
389            Self::U32(_) => TYPE_SPEC_U32,
390            Self::U64(_) => TYPE_SPEC_U64,
391            Self::String(_) => TYPE_SPEC_STRING,
392            Self::Binary(_) => TYPE_SPEC_BINARY,
393            Self::Boolean(_) => TYPE_SPEC_BOOLEAN,
394            Self::Image(_) => TYPE_SPEC_IMAGE,
395            Self::Audio(_) => TYPE_SPEC_AUDIO,
396            Self::Video(_) => TYPE_SPEC_VIDEO,
397        };
398
399        res.push(first_byte);
400
401        let shape = self.shape();
402        res.push(
403            shape
404                .len()
405                .try_into()
406                .expect("The data cannot contain more than 255 dimensions."),
407        );
408
409        let mut written_bytes = 2;
410
411        for dim in shape {
412            crate::varint::append_varint_u64((*dim).try_into().unwrap(), res);
413            written_bytes += crate::varint::get_varint_u64_len((*dim).try_into().unwrap()) as usize;
414        }
415
416        #[cfg(not(target_endian = "little"))]
417        use byteorder::{LittleEndian, WriteBytesExt};
418
419        match self {
420            Self::F32(inner) => {
421                #[cfg(target_endian = "little")]
422                res.extend_from_slice(
423                    inner
424                        .as_standard_layout()
425                        .as_slice()
426                        .unwrap()
427                        .as_byte_slice(),
428                );
429                #[cfg(not(target_endian = "little"))]
430                for &val in inner.as_standard_layout().as_slice().unwrap() {
431                    res.write_f32::<LittleEndian>(val).unwrap();
432                }
433            }
434            Self::F64(inner) => {
435                #[cfg(target_endian = "little")]
436                res.extend_from_slice(
437                    inner
438                        .as_standard_layout()
439                        .as_slice()
440                        .unwrap()
441                        .as_byte_slice(),
442                );
443                #[cfg(not(target_endian = "little"))]
444                for &val in inner.as_standard_layout().as_slice().unwrap() {
445                    res.write_f64::<LittleEndian>(val).unwrap();
446                }
447            }
448            Self::I8(inner) => {
449                res.extend_from_slice(
450                    inner
451                        .as_standard_layout()
452                        .as_slice()
453                        .unwrap()
454                        .as_byte_slice(),
455                );
456            }
457            Self::I16(inner) => {
458                #[cfg(target_endian = "little")]
459                res.extend_from_slice(
460                    inner
461                        .as_standard_layout()
462                        .as_slice()
463                        .unwrap()
464                        .as_byte_slice(),
465                );
466                #[cfg(not(target_endian = "little"))]
467                for &val in inner.as_standard_layout().as_slice().unwrap() {
468                    res.write_i16::<LittleEndian>(val).unwrap();
469                }
470            }
471            Self::I32(inner) => {
472                #[cfg(target_endian = "little")]
473                res.extend_from_slice(
474                    inner
475                        .as_standard_layout()
476                        .as_slice()
477                        .unwrap()
478                        .as_byte_slice(),
479                );
480                #[cfg(not(target_endian = "little"))]
481                for &val in inner.as_standard_layout().as_slice().unwrap() {
482                    res.write_i32::<LittleEndian>(val).unwrap();
483                }
484            }
485            Self::I64(inner) => {
486                #[cfg(target_endian = "little")]
487                res.extend_from_slice(
488                    inner
489                        .as_standard_layout()
490                        .as_slice()
491                        .unwrap()
492                        .as_byte_slice(),
493                );
494                #[cfg(not(target_endian = "little"))]
495                for &val in inner.as_standard_layout().as_slice().unwrap() {
496                    res.write_i64::<LittleEndian>(val).unwrap();
497                }
498            }
499            Self::U8(inner) => {
500                res.extend_from_slice(
501                    inner
502                        .as_standard_layout()
503                        .as_slice()
504                        .unwrap()
505                        .as_byte_slice(),
506                );
507            }
508            Self::U16(inner) => {
509                #[cfg(target_endian = "little")]
510                res.extend_from_slice(
511                    inner
512                        .as_standard_layout()
513                        .as_slice()
514                        .unwrap()
515                        .as_byte_slice(),
516                );
517                #[cfg(not(target_endian = "little"))]
518                for &val in inner.as_standard_layout().as_slice().unwrap() {
519                    res.write_u16::<LittleEndian>(val).unwrap();
520                }
521            }
522            Self::U32(inner) => {
523                #[cfg(target_endian = "little")]
524                res.extend_from_slice(
525                    inner
526                        .as_standard_layout()
527                        .as_slice()
528                        .unwrap()
529                        .as_byte_slice(),
530                );
531                #[cfg(not(target_endian = "little"))]
532                for &val in inner.as_standard_layout().as_slice().unwrap() {
533                    res.write_u32::<LittleEndian>(val).unwrap();
534                }
535            }
536            Self::U64(inner) => {
537                #[cfg(target_endian = "little")]
538                res.extend_from_slice(
539                    inner
540                        .as_standard_layout()
541                        .as_slice()
542                        .unwrap()
543                        .as_byte_slice(),
544                );
545                #[cfg(not(target_endian = "little"))]
546                for &val in inner.as_standard_layout().as_slice().unwrap() {
547                    res.write_u64::<LittleEndian>(val).unwrap();
548                }
549            }
550            Self::String(inner) => {
551                crate::varint::append_varint_u64(
552                    (self.serialized_len() - written_bytes) as u64,
553                    res,
554                );
555                inner.iter().for_each(|x| {
556                    crate::varint::append_varint_u64(x.len().try_into().unwrap(), res);
557                    res.extend_from_slice(x.as_bytes())
558                })
559            }
560            Self::Binary(inner) => {
561                crate::varint::append_varint_u64(
562                    (self.serialized_len() - written_bytes) as u64,
563                    res,
564                );
565                inner.iter().for_each(|x| {
566                    crate::varint::append_varint_u64(x.len().try_into().unwrap(), res);
567                    res.extend_from_slice(x)
568                })
569            }
570            Self::Boolean(inner) => res.extend_from_slice(
571                inner
572                    .map(|x| if *x { 1u8 } else { 0 })
573                    .as_standard_layout()
574                    .as_slice()
575                    .unwrap(),
576            ),
577            Self::Image(inner) => {
578                crate::varint::append_varint_u64(
579                    (self.serialized_len() - written_bytes) as u64,
580                    res,
581                );
582                inner.iter().for_each(|x| {
583                    let len: u64 = x.data.len().try_into().unwrap();
584                    crate::varint::append_varint_u64(1 + x.format().len() as u64 + len, res);
585                    res.push(x.format().len().try_into().unwrap());
586                    res.extend_from_slice(x.format().as_bytes());
587                    res.extend_from_slice(&x.data)
588                })
589            }
590            Self::Audio(inner) => {
591                crate::varint::append_varint_u64(
592                    (self.serialized_len() - written_bytes) as u64,
593                    res,
594                );
595                inner.iter().for_each(|x| {
596                    let len: u64 = x.data.len().try_into().unwrap();
597                    crate::varint::append_varint_u64(1 + x.format().len() as u64 + len, res);
598                    res.push(x.format().len().try_into().unwrap());
599                    res.extend_from_slice(x.format().as_bytes());
600                    res.extend_from_slice(&x.data)
601                })
602            }
603            Self::Video(inner) => {
604                crate::varint::append_varint_u64(
605                    (self.serialized_len() - written_bytes) as u64,
606                    res,
607                );
608                inner.iter().for_each(|x| {
609                    let len: u64 = x.data.len().try_into().unwrap();
610                    crate::varint::append_varint_u64(1 + x.format().len() as u64 + len, res);
611                    res.push(x.format().len().try_into().unwrap());
612                    res.extend_from_slice(x.format().as_bytes());
613                    res.extend_from_slice(&x.data)
614                })
615            }
616        }
617    }
618
619    pub fn serialize(&self) -> Vec<u8> {
620        let mut res = Vec::with_capacity(self.serialized_len());
621        self.serialize_append(&mut res);
622        res
623    }
624}
625
626#[derive(Clone)]
627pub struct OwnedDecthingsTensor {
628    pub(crate) data: bytes::Bytes,
629}
630
631impl OwnedDecthingsTensor {
632    pub fn from_bytes(data: bytes::Bytes) -> Result<Self, DeserializeDecthingsTensorError> {
633        let Some(&first_byte) = data.first() else {
634            return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
635        };
636
637        let Some(&num_dims) = data.get(1) else {
638            return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
639        };
640
641        let mut shape: Vec<usize> = Vec::with_capacity(num_dims.into());
642        let mut pos = 2;
643
644        for _ in 0..num_dims {
645            if data.len() < pos + 1 {
646                return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
647            };
648            if data.len()
649                < pos + crate::varint::get_serialized_varint_u64_len(&data[pos..]) as usize
650            {
651                return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
652            }
653            let (dim, varint_len) = crate::varint::deserialize_varint_u64(&data[pos..]);
654            pos += varint_len as usize;
655            shape.push(dim.try_into().unwrap());
656        }
657
658        let numel = shape.iter().fold(1usize, |a, b| a * (*b));
659
660        let element_size = match first_byte {
661            TYPE_SPEC_F32 | TYPE_SPEC_I32 | TYPE_SPEC_U32 => Some(4),
662            TYPE_SPEC_F64 | TYPE_SPEC_I64 | TYPE_SPEC_U64 => Some(8),
663            TYPE_SPEC_BOOLEAN | TYPE_SPEC_I8 | TYPE_SPEC_U8 => Some(1),
664            TYPE_SPEC_I16 | TYPE_SPEC_U16 => Some(2),
665            TYPE_SPEC_STRING | TYPE_SPEC_BINARY | TYPE_SPEC_IMAGE | TYPE_SPEC_AUDIO
666            | TYPE_SPEC_VIDEO => None,
667            _ => {
668                return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
669                    "Unexected first byte {first_byte}"
670                )));
671            }
672        };
673
674        match element_size {
675            Some(element_size) => {
676                if data.len() < pos + numel * element_size {
677                    return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
678                }
679                pos += numel * element_size;
680            }
681            None => {
682                if data.len() < pos + 1 {
683                    return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
684                };
685                if data.len()
686                    < pos + crate::varint::get_serialized_varint_u64_len(&data[pos..]) as usize
687                {
688                    return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
689                }
690                pos += crate::varint::get_serialized_varint_u64_len(&data[pos..]) as usize;
691
692                for _ in 0..numel {
693                    if data.len() < pos + 1 {
694                        return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
695                    };
696                    if data.len()
697                        < pos + crate::varint::get_serialized_varint_u64_len(&data[pos..]) as usize
698                    {
699                        return Err(DeserializeDecthingsTensorError::UnexpectedEndOfBytes);
700                    }
701                    let (len, varint_len) = crate::varint::deserialize_varint_u64(&data[pos..]);
702                    let len: usize = len.try_into().unwrap();
703                    pos += varint_len as usize;
704                    if matches!(first_byte, TYPE_SPEC_STRING) {
705                        if let Err(e) = std::str::from_utf8(&data[pos..pos + len]) {
706                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
707                                "The string was not UTF-8: {e:?}"
708                            )));
709                        }
710                    }
711                    if matches!(first_byte, TYPE_SPEC_IMAGE) {
712                        if len < 1 {
713                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
714                                "Unexpected end of bytes while parsing image format"
715                            )));
716                        }
717                        let format_length = data[pos] as usize;
718                        if len < 1 + format_length {
719                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
720                                "Unexpected end of bytes while parsing image format"
721                            )));
722                        }
723                        if let Err(e) = std::str::from_utf8(&data[pos + 1..pos + 1 + format_length])
724                        {
725                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
726                                "The image format was not UTF-8: {e:?}"
727                            )));
728                        }
729                    }
730                    if matches!(first_byte, TYPE_SPEC_AUDIO) {
731                        if len < 1 {
732                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
733                                "Unexpected end of bytes while parsing audio format"
734                            )));
735                        }
736                        let format_length = data[pos] as usize;
737                        if len < 1 + format_length {
738                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
739                                "Unexpected end of bytes while parsing audio format"
740                            )));
741                        }
742                        if let Err(e) = std::str::from_utf8(&data[pos + 1..pos + 1 + format_length])
743                        {
744                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
745                                "The audio format was not UTF-8: {e:?}"
746                            )));
747                        }
748                    }
749                    if matches!(first_byte, TYPE_SPEC_VIDEO) {
750                        if len < 1 {
751                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
752                                "Unexpected end of bytes while parsing video format"
753                            )));
754                        }
755                        let format_length = data[pos] as usize;
756                        if len < 1 + format_length {
757                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
758                                "Unexpected end of bytes while parsing video format"
759                            )));
760                        }
761                        if let Err(e) = std::str::from_utf8(&data[pos + 1..pos + 1 + format_length])
762                        {
763                            return Err(DeserializeDecthingsTensorError::InvalidBytes(format!(
764                                "The video format was not UTF-8: {e:?}"
765                            )));
766                        }
767                    }
768                    pos += len;
769                }
770            }
771        }
772
773        Ok(Self {
774            data: data.slice(0..pos),
775        })
776    }
777
778    pub fn byte_size(&self) -> usize {
779        self.data.len()
780    }
781
782    pub fn tensor(&self) -> DecthingsTensor<'_> {
783        let first_byte = self.data[0];
784        let num_dims = self.data[1];
785
786        let mut shape: Vec<usize> = Vec::with_capacity(num_dims.into());
787        let mut pos = 2;
788
789        for _ in 0..num_dims {
790            let (dim, varint_len) = crate::varint::deserialize_varint_u64(&self.data[pos..]);
791            pos += varint_len as usize;
792            shape.push(dim.try_into().unwrap());
793        }
794
795        let numel = shape.iter().fold(1usize, |a, b| a * (*b));
796
797        fn sized_into_tensor<'a, T: Clone + Default + FromByteSlice + ToMutByteSlice + 'a>(
798            shape: &[usize],
799            data: &'a [u8],
800            pos: usize,
801            numel: usize,
802            f: impl FnOnce(CowArray<'a, T, IxDyn>) -> DecthingsTensor<'a>,
803        ) -> DecthingsTensor<'a> {
804            let slice = &data[pos..pos + numel * std::mem::size_of::<T>()];
805
806            #[cfg(target_endian = "little")]
807            if let Ok(val) = slice.as_slice_of::<T>() {
808                return f(ArrayView::from(val)
809                    .into_shape(IxDyn(shape))
810                    .unwrap()
811                    .into());
812            }
813
814            // We are either big-endian, or got an alignment error, in which case we need to copy.
815
816            #[cfg(target_endian = "little")]
817            {
818                let mut res: Vec<T> = vec![T::default(); numel];
819                res.as_mut_byte_slice().copy_from_slice(slice);
820                f(Array::from(res).into_shape(IxDyn(shape)).unwrap().into())
821            }
822
823            #[cfg(not(target_endian = "little"))]
824            {
825                let mut res: Vec<T> = Vec::with_capacity(numel);
826                let cursor = std::io::Cursor::new(slice);
827                for _ in 0..numel {
828                    res.push();
829                }
830                f(Array::from(res).into_shape(IxDyn(shape)).unwrap().into())
831            }
832        }
833
834        match first_byte {
835            TYPE_SPEC_F32 => {
836                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::F32);
837            }
838            TYPE_SPEC_F64 => {
839                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::F64);
840            }
841            TYPE_SPEC_I8 => {
842                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::I8);
843            }
844            TYPE_SPEC_I16 => {
845                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::I16);
846            }
847            TYPE_SPEC_I32 => {
848                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::I32);
849            }
850            TYPE_SPEC_I64 => {
851                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::I64);
852            }
853            TYPE_SPEC_U8 => {
854                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::U8);
855            }
856            TYPE_SPEC_U16 => {
857                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::U16);
858            }
859            TYPE_SPEC_U32 => {
860                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::U32);
861            }
862            TYPE_SPEC_U64 => {
863                return sized_into_tensor(&shape, &self.data, pos, numel, DecthingsTensor::U64);
864            }
865            TYPE_SPEC_STRING => {
866                pos += crate::varint::get_serialized_varint_u64_len(&self.data[pos..]) as usize;
867
868                let mut strings = Vec::with_capacity(numel);
869                for _ in 0..numel {
870                    let (len, varint_len) =
871                        crate::varint::deserialize_varint_u64(&self.data[pos..]);
872                    let len: usize = len.try_into().unwrap();
873                    pos += varint_len as usize;
874                    strings.push(std::str::from_utf8(&self.data[pos..pos + len]).unwrap());
875                    pos += len;
876                }
877                return DecthingsTensor::String(
878                    Array::from_vec(strings)
879                        .into_shape(IxDyn(&shape))
880                        .unwrap()
881                        .into(),
882                );
883            }
884            TYPE_SPEC_BINARY => {
885                pos += crate::varint::get_serialized_varint_u64_len(&self.data[pos..]) as usize;
886
887                let mut binaries = Vec::with_capacity(numel);
888                for _ in 0..numel {
889                    let (len, varint_len) =
890                        crate::varint::deserialize_varint_u64(&self.data[pos..]);
891                    let len: usize = len.try_into().unwrap();
892                    pos += varint_len as usize;
893                    binaries.push(&self.data[pos..pos + len]);
894                    pos += len;
895                }
896                return DecthingsTensor::Binary(
897                    Array::from_vec(binaries)
898                        .into_shape(IxDyn(&shape))
899                        .unwrap()
900                        .into(),
901                );
902            }
903            TYPE_SPEC_BOOLEAN => {
904                let converted = self.data[pos..pos + numel * std::mem::size_of::<u8>()]
905                    .as_slice_of::<u8>()
906                    .unwrap();
907                return DecthingsTensor::Boolean(
908                    ArrayView::from(converted)
909                        .into_shape(IxDyn(&shape))
910                        .unwrap()
911                        .map(|&x| x != 0)
912                        .into(),
913                );
914            }
915            TYPE_SPEC_IMAGE => {
916                pos += crate::varint::get_serialized_varint_u64_len(&self.data[pos..]) as usize;
917
918                let mut images = Vec::with_capacity(numel);
919                for _ in 0..numel {
920                    let (len, varint_len) =
921                        crate::varint::deserialize_varint_u64(&self.data[pos..]);
922                    let len: usize = len.try_into().unwrap();
923                    pos += varint_len as usize;
924                    let format_length = self.data[pos] as usize;
925                    let format =
926                        std::str::from_utf8(&self.data[pos + 1..pos + 1 + format_length]).unwrap();
927                    images.push(
928                        DecthingsElementImage::new(
929                            format,
930                            &self.data[pos + 1 + format_length..pos + len],
931                        )
932                        .unwrap(),
933                    );
934                    pos += len;
935                }
936                return DecthingsTensor::Image(
937                    Array::from_vec(images)
938                        .into_shape(IxDyn(&shape))
939                        .unwrap()
940                        .into(),
941                );
942            }
943            TYPE_SPEC_AUDIO => {
944                pos += crate::varint::get_serialized_varint_u64_len(&self.data[pos..]) as usize;
945
946                let mut audios = Vec::with_capacity(numel);
947                for _ in 0..numel {
948                    let (len, varint_len) =
949                        crate::varint::deserialize_varint_u64(&self.data[pos..]);
950                    let len: usize = len.try_into().unwrap();
951                    pos += varint_len as usize;
952                    let format_length = self.data[pos] as usize;
953                    let format =
954                        std::str::from_utf8(&self.data[pos + 1..pos + 1 + format_length]).unwrap();
955                    audios.push(
956                        DecthingsElementAudio::new(
957                            format,
958                            &self.data[pos + 1 + format_length..pos + len],
959                        )
960                        .unwrap(),
961                    );
962                    pos += len;
963                }
964                return DecthingsTensor::Audio(
965                    Array::from_vec(audios)
966                        .into_shape(IxDyn(&shape))
967                        .unwrap()
968                        .into(),
969                );
970            }
971            TYPE_SPEC_VIDEO => {
972                pos += crate::varint::get_serialized_varint_u64_len(&self.data[pos..]) as usize;
973
974                let mut videos = Vec::with_capacity(numel);
975                for _ in 0..numel {
976                    let (len, varint_len) =
977                        crate::varint::deserialize_varint_u64(&self.data[pos..]);
978                    let len: usize = len.try_into().unwrap();
979                    pos += varint_len as usize;
980                    let format_length = self.data[pos] as usize;
981                    let format =
982                        std::str::from_utf8(&self.data[pos + 1..pos + 1 + format_length]).unwrap();
983                    videos.push(
984                        DecthingsElementVideo::new(
985                            format,
986                            &self.data[pos + 1 + format_length..pos + len],
987                        )
988                        .unwrap(),
989                    );
990                    pos += len;
991                }
992                DecthingsTensor::Video(
993                    Array::from_vec(videos)
994                        .into_shape(IxDyn(&shape))
995                        .unwrap()
996                        .into(),
997                )
998            }
999            _ => {
1000                unreachable!()
1001            }
1002        }
1003    }
1004
1005    pub fn serialize(&self) -> bytes::Bytes {
1006        self.data.clone()
1007    }
1008}
1009
1010impl std::fmt::Debug for OwnedDecthingsTensor {
1011    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1012        write!(f, "OwnedDecthingsTensor({:?})", self.tensor())
1013    }
1014}
1015
1016impl<'a, T: AsRef<OwnedDecthingsTensor>> From<&'a T> for DecthingsTensor<'a> {
1017    fn from(value: &'a T) -> Self {
1018        value.as_ref().tensor()
1019    }
1020}
1021
1022impl<'a> From<DecthingsTensor<'a>> for OwnedDecthingsTensor {
1023    fn from(value: DecthingsTensor<'a>) -> Self {
1024        OwnedDecthingsTensor {
1025            data: value.serialize().into(),
1026        }
1027    }
1028}