tract_data/
blob.rs

1use num_traits::Zero;
2
3use crate::{TractError, TractResult};
4use std::alloc::*;
5use std::fmt::Display;
6use std::hash::Hash;
7use std::ptr::null_mut;
8
9#[derive(Eq)]
10pub struct Blob {
11    layout: std::alloc::Layout,
12    data: *mut u8,
13}
14
15impl Default for Blob {
16    #[inline]
17    fn default() -> Blob {
18        Blob::from_bytes(&[]).unwrap()
19    }
20}
21
22impl Clone for Blob {
23    #[inline]
24    fn clone(&self) -> Self {
25        Blob::from_bytes_alignment(self, self.layout.align()).unwrap()
26    }
27}
28
29impl Drop for Blob {
30    #[inline]
31    fn drop(&mut self) {
32        if !self.data.is_null() {
33            unsafe { dealloc(self.data, self.layout) }
34        }
35    }
36}
37
38impl PartialEq for Blob {
39    #[inline]
40    fn eq(&self, other: &Self) -> bool {
41        self.layout == other.layout && self.as_bytes() == other.as_bytes()
42    }
43}
44
45impl Hash for Blob {
46    #[inline]
47    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
48        self.layout.align().hash(state);
49        self.as_bytes().hash(state);
50    }
51}
52
53impl Blob {
54    #[inline]
55    pub unsafe fn new_for_size_and_align(size: usize, align: usize) -> Blob {
56        unsafe { Self::for_layout(Layout::from_size_align_unchecked(size, align)) }
57    }
58
59    #[inline]
60    pub unsafe fn ensure_size_and_align(&mut self, size: usize, align: usize) {
61        if size > self.layout.size() || align > self.layout.align() {
62            if !self.data.is_null() {
63                unsafe { std::alloc::dealloc(self.data as _, self.layout) };
64            }
65            self.layout = unsafe { Layout::from_size_align_unchecked(size, align) };
66            self.data = unsafe { std::alloc::alloc(self.layout) };
67            assert!(!self.data.is_null());
68        }
69    }
70
71    #[inline]
72    pub unsafe fn for_layout(layout: Layout) -> Blob {
73        let mut data = null_mut();
74        if layout.size() > 0 {
75            data = unsafe { alloc(layout) };
76            assert!(!data.is_null(), "failed to allocate {layout:?}");
77        }
78        Blob { layout, data }
79    }
80
81    #[inline]
82    pub fn from_bytes(s: &[u8]) -> TractResult<Blob> {
83        Self::from_bytes_alignment(s, 128)
84    }
85
86    #[inline]
87    pub fn as_bytes(&self) -> &[u8] {
88        if self.data.is_null() {
89            &[]
90        } else {
91            unsafe { std::slice::from_raw_parts(self.data, self.layout.size()) }
92        }
93    }
94
95    #[inline]
96    pub fn as_bytes_mut(&mut self) -> &mut [u8] {
97        if self.data.is_null() {
98            &mut []
99        } else {
100            unsafe { std::slice::from_raw_parts_mut(self.data, self.layout.size()) }
101        }
102    }
103
104    #[inline]
105    pub fn from_bytes_alignment(s: &[u8], alignment: usize) -> TractResult<Blob> {
106        unsafe {
107            let layout = Layout::from_size_align(s.len(), alignment)?;
108            let blob = Self::for_layout(layout);
109            if s.len() > 0 {
110                std::ptr::copy_nonoverlapping(s.as_ptr(), blob.data, s.len());
111            }
112            Ok(blob)
113        }
114    }
115
116    #[inline]
117    pub fn layout(&self) -> &Layout {
118        &self.layout
119    }
120}
121
122impl std::ops::Deref for Blob {
123    type Target = [u8];
124    #[inline]
125    fn deref(&self) -> &[u8] {
126        self.as_bytes()
127    }
128}
129
130impl std::ops::DerefMut for Blob {
131    #[inline]
132    fn deref_mut(&mut self) -> &mut [u8] {
133        self.as_bytes_mut()
134    }
135}
136
137impl std::fmt::Display for Blob {
138    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
139        assert!(self.data.is_null() == self.layout.size().is_zero());
140        write!(
141            fmt,
142            "Blob of {} bytes (align @{}): {} {}",
143            self.len(),
144            self.layout.align(),
145            String::from_utf8(
146                self.iter()
147                    .take(20)
148                    .copied()
149                    .flat_map(std::ascii::escape_default)
150                    .collect::<Vec<u8>>()
151            )
152            .unwrap(),
153            if self.len() >= 20 { "[...]" } else { "" }
154        )
155    }
156}
157
158impl std::fmt::Debug for Blob {
159    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
160        <Self as Display>::fmt(self, fmt)
161    }
162}
163
164impl TryFrom<&[u8]> for Blob {
165    type Error = TractError;
166    #[inline]
167    fn try_from(s: &[u8]) -> Result<Blob, Self::Error> {
168        Blob::from_bytes(s)
169    }
170}
171
172unsafe impl Send for Blob {}
173unsafe impl Sync for Blob {}