1use crate::{
2 as_data_slice_or_err,
3 df::{
4 meta::{align_nbytes, DF_ALIGN},
5 ColumnsDtype, DataFrame, IndexDtype, COLUMNS_NBYTES, INDEX_NBYTES,
6 },
7 toolkit::{
8 array::AFloat,
9 convert::{to_bytes, to_nbytes},
10 },
11};
12use anyhow::Result;
13use bytes::BufMut;
14use core::mem;
15
16fn extract_usize(bytes: &[u8]) -> Result<(&[u8], usize)> {
17 let (target, remain) = bytes.split_at(to_nbytes::<i64>(1));
18 let value = i64::from_le_bytes(target.try_into()?);
19 Ok((remain, value as usize))
20}
21fn extract_ptr(bytes: &[u8], nbytes: usize) -> (&[u8], *const u8) {
22 let (target, remain) = bytes.split_at(nbytes);
23 (remain, target.as_ptr())
24}
25
26fn put_aligned_slice(bytes: &mut Vec<u8>, slice: &[u8]) {
27 bytes.put_slice(slice);
28 let remainder = slice.len() % DF_ALIGN;
29 if remainder != 0 {
30 let padding = DF_ALIGN - remainder;
31 bytes.put_slice(&[0u8; DF_ALIGN][..padding]);
32 }
33}
34
35impl<'a, T: AFloat> DataFrame<'a, T> {
36 pub fn to_bytes(&self) -> Result<Vec<u8>> {
37 let index = self.index();
38 let columns = self.columns();
39 let values = self.values();
40 let index_nbytes = to_nbytes::<IndexDtype>(index.len());
41 let columns_nbytes = to_nbytes::<ColumnsDtype>(columns.len());
42 let values_nbytes = to_nbytes::<T>(values.len());
43
44 let index_aligned_nbytes = align_nbytes(index_nbytes);
45 let columns_aligned_nbytes = align_nbytes(columns_nbytes);
46 let values_aligned_nbytes = align_nbytes(values_nbytes);
47 let total_aligned_nbytes =
48 index_aligned_nbytes + columns_aligned_nbytes + values_aligned_nbytes + 16;
49 let aligned_bytes: Vec<[u8; DF_ALIGN]> =
50 Vec::with_capacity(total_aligned_nbytes / DF_ALIGN);
51 let mut bytes: Vec<u8> = unsafe {
52 Vec::from_raw_parts(aligned_bytes.as_ptr() as *mut u8, 0, total_aligned_nbytes)
53 };
54 mem::forget(aligned_bytes);
55
56 bytes.put_i64_le(index_nbytes as i64);
57 bytes.put_i64_le(columns_nbytes as i64);
58 unsafe {
59 put_aligned_slice(&mut bytes, to_bytes(as_data_slice_or_err!(index)));
60 put_aligned_slice(&mut bytes, to_bytes(as_data_slice_or_err!(columns)));
61 put_aligned_slice(&mut bytes, to_bytes(as_data_slice_or_err!(values)));
62 };
63 Ok(bytes)
64 }
65
66 pub unsafe fn from_bytes(bytes: &'a [u8]) -> Result<Self> {
76 let (bytes, index_nbytes) = extract_usize(bytes)?;
77 let (bytes, columns_nbytes) = extract_usize(bytes)?;
78
79 let index_shape = index_nbytes / INDEX_NBYTES;
80 let columns_shape = columns_nbytes / COLUMNS_NBYTES;
81
82 let (bytes, index_ptr) = extract_ptr(bytes, index_nbytes);
83 let (bytes, columns_ptr) = extract_ptr(bytes, columns_nbytes);
84 let values_nbytes = to_nbytes::<T>(index_shape * columns_shape);
85 let (_, values_ptr) = extract_ptr(bytes, values_nbytes);
86
87 Ok(DataFrame::from_ptr(
88 index_ptr,
89 index_shape,
90 columns_ptr,
91 columns_shape,
92 values_ptr,
93 ))
94 }
95}
96
97#[cfg(test)]
98pub(super) mod tests {
99 use super::*;
100 use crate::toolkit::convert::from_vec;
101
102 pub fn get_test_df<'a>() -> DataFrame<'a, f32> {
103 let mut index_vec: Vec<u8> = vec![0; INDEX_NBYTES];
104 index_vec[0] = 1;
105 let index_vec = unsafe { from_vec(index_vec) };
106 let mut columns_vec: Vec<u8> = vec![0; COLUMNS_NBYTES];
107 columns_vec[0] = 2;
108 let columns_vec = unsafe { from_vec(columns_vec) };
109 let mut values_vec: Vec<u8> = vec![0; mem::size_of::<f32>()];
110 values_vec[0] = 3;
111 let values_vec = unsafe { from_vec(values_vec) };
112
113 DataFrame::<f32>::from_vec(index_vec, columns_vec, values_vec).unwrap()
114 }
115
116 #[test]
117 fn test_bytes_io() {
118 let df = get_test_df();
119 let bytes = df.to_bytes().unwrap();
120 #[rustfmt::skip]
121 {
122 assert_eq!(
123 bytes,
124 [
125 8, 0, 0, 0, 0, 0, 0, 0,
126 32, 0, 0, 0, 0, 0, 0, 0,
127 1, 0, 0, 0, 0, 0, 0, 0,
128 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
129 3, 0, 0, 0, 0, 0, 0, 0,
130 ]
131 );
132 };
133 let loaded = unsafe { DataFrame::<f32>::from_bytes(&bytes).unwrap() };
134 assert_eq!(df.index(), loaded.index());
135 assert_eq!(df.columns(), loaded.columns());
136 assert_eq!(df.values(), loaded.values());
137 }
138}