1use crate::internal_prelude::*;
5use std::borrow::Borrow;
6use std::ffi::{c_void, CString};
7use std::io::{Read, Seek};
8use std::os::raw::{c_char, c_int, c_longlong, c_ulonglong};
9use std::slice::from_raw_parts_mut;
10
11pub trait CloseStream {
16 fn close(&mut self) -> std::result::Result<(), Box<dyn std::fmt::Display>>;
17}
18
19pub trait Stream: Read + CloseStream + Send + 'static {}
23
24impl<T: Read + CloseStream + Send + 'static> Stream for T {}
25
26pub trait SeekableStream: Stream + Seek {}
30impl<T: Stream + Seek> SeekableStream for T {}
31
32pub(crate) struct CustomStreamData<T> {
33 pub(crate) err_msg: CString,
34 pub(crate) stream: T,
35}
36
37impl<T> CustomStreamData<T> {
38 fn conv_err(&mut self, data: &dyn std::fmt::Display) -> *const c_char {
39 let str = format!("{}", data);
40 let cstr = CString::new(str).expect("Display impls must produce valid C strings");
41 self.err_msg = cstr;
42 self.err_msg.as_ptr()
43 }
44}
45
46pub(crate) extern "C" fn stream_read_cb<T: Read>(
47 read: *mut c_ulonglong,
48 requested: c_ulonglong,
49 destination: *mut c_char,
50 userdata: *mut c_void,
51 err_msg: *mut *const c_char,
52) -> c_int {
53 let data = unsafe { &mut *(userdata as *mut CustomStreamData<T>) };
54
55 let dest = unsafe { from_raw_parts_mut(destination as *mut u8, requested as usize) };
56
57 let mut got_so_far = 0;
58
59 while got_so_far < requested {
60 match data.stream.read(&mut dest[got_so_far as usize..]) {
61 Ok(0) => {
62 break;
63 }
64 Ok(d) => {
65 got_so_far += d as u64;
66 }
67 Err(e) => {
68 unsafe { *err_msg = data.conv_err(&e) };
69 return 1;
70 }
71 }
72 }
73
74 unsafe { *read = got_so_far as c_ulonglong };
75 0
76}
77
78pub(crate) extern "C" fn stream_seek_cb<T: Seek>(
79 pos: c_ulonglong,
80 userdata: *mut c_void,
81 err_msg: *mut *const c_char,
82) -> c_int {
83 let data = unsafe { &mut *(userdata as *mut CustomStreamData<T>) };
84
85 match data.stream.seek(std::io::SeekFrom::Start(pos as u64)) {
86 Ok(_) => 0,
87 Err(e) => {
88 unsafe { *err_msg = data.conv_err(&e) };
89 1
90 }
91 }
92}
93
94extern "C" fn close_cb<T: CloseStream>(
95 userdata: *mut c_void,
96 err_msg: *mut *const c_char,
97) -> c_int {
98 let data = unsafe { &mut *(userdata as *mut CustomStreamData<T>) };
99
100 match data.stream.close() {
101 Ok(_) => 0,
102 Err(e) => {
103 unsafe { *err_msg = data.conv_err(&e) };
104 1
105 }
106 }
107}
108
109extern "C" fn destroy_cb<T>(userdata: *mut c_void) {
110 unsafe { Box::from_raw(userdata as *mut CustomStreamData<T>) };
112}
113
114fn drop_cb<T>(ptr: *mut c_void) {
116 unsafe {
117 std::mem::drop(Box::<T>::from_raw(ptr as *mut T));
118 }
119}
120
121fn fillout_read<T: Stream>(dest: &mut syz_CustomStreamDef) {
122 dest.read_cb = Some(stream_read_cb::<T>);
123 dest.close_cb = Some(close_cb::<T>);
124 dest.destroy_cb = Some(destroy_cb::<T>);
125 dest.length = -1;
126}
127
128fn fillout_seekable<T: SeekableStream>(
129 dest: &mut syz_CustomStreamDef,
130 val: &mut T,
131) -> std::io::Result<()> {
132 dest.seek_cb = Some(stream_seek_cb::<T>);
133
134 val.seek(std::io::SeekFrom::End(0))?;
135 dest.length = val.stream_position()? as c_longlong;
136 val.seek(std::io::SeekFrom::Start(0))?;
137 Ok(())
138}
139
140fn fillout_userdata<T>(dest: &mut syz_CustomStreamDef, val: T) {
141 dest.userdata = Box::into_raw(Box::new(CustomStreamData {
142 stream: val,
143 err_msg: Default::default(),
144 })) as *mut c_void;
145}
146
147pub struct CustomStreamDef {
151 def: syz_CustomStreamDef,
152 used: bool,
156}
157
158impl CustomStreamDef {
159 pub fn from_reader<T: Stream>(value: T) -> CustomStreamDef {
161 let mut ret = CustomStreamDef {
162 def: Default::default(),
163 used: false,
164 };
165
166 fillout_read::<T>(&mut ret.def);
167 fillout_userdata(&mut ret.def, value);
168 ret
169 }
170
171 pub fn from_seekable<T: SeekableStream>(mut value: T) -> std::io::Result<CustomStreamDef> {
173 let mut ret = CustomStreamDef {
174 def: Default::default(),
175 used: false,
176 };
177 fillout_read::<T>(&mut ret.def);
178 fillout_seekable(&mut ret.def, &mut value)?;
179 fillout_userdata(&mut ret.def, value);
180 Ok(ret)
181 }
182}
183
184impl Drop for CustomStreamDef {
185 fn drop(&mut self) {
186 let mut err_msg: *const c_char = std::ptr::null();
187 if !self.used {
188 unsafe {
189 if let Some(cb) = self.def.close_cb {
190 cb(self.def.userdata, &mut err_msg as *mut *const c_char);
191 }
192 if let Some(cb) = self.def.destroy_cb {
193 cb(self.def.userdata);
194 }
195 }
196 }
197 }
198}
199
200#[derive(Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
203pub struct StreamHandle {
204 handle: syz_Handle,
205 needs_drop: Option<(std::ptr::NonNull<c_void>, fn(*mut c_void))>,
207}
208
209impl StreamHandle {
210 pub fn from_stream_def(mut def: CustomStreamDef) -> Result<StreamHandle> {
211 let mut h = Default::default();
212 check_error(unsafe {
213 syz_createStreamHandleFromCustomStream(
214 &mut h as *mut syz_Handle,
215 &mut def.def as *mut syz_CustomStreamDef,
216 std::ptr::null_mut(),
217 None,
218 )
219 })?;
220 def.used = true;
221
222 Ok(StreamHandle {
223 handle: h,
224 needs_drop: None,
225 })
226 }
227
228 pub fn from_vec(data: Vec<u8>) -> Result<StreamHandle> {
230 if data.is_empty() {
231 return Err(Error::rust_error("Cannot create streams from empty vecs"));
232 };
233 let mut h = Default::default();
234 check_error(unsafe {
235 let ptr = &data[0] as *const u8 as *const i8;
236 syz_createStreamHandleFromMemory(
237 &mut h as *mut syz_Handle,
238 data.len() as u64,
239 ptr,
240 std::ptr::null_mut(),
241 None,
242 )
243 })?;
244
245 Ok(StreamHandle {
246 handle: h,
247 needs_drop: Some((
248 unsafe {
249 std::ptr::NonNull::new_unchecked(Box::into_raw(Box::new(data)) as *mut c_void)
250 },
251 drop_cb::<Vec<u8>>,
252 )),
253 })
254 }
255
256 pub fn from_stream_params(protocol: &str, path: &str, param: usize) -> Result<StreamHandle> {
257 let mut h = Default::default();
260 let protocol_c = std::ffi::CString::new(protocol)
261 .map_err(|_| Error::rust_error("Unable to convert protocol to a C string"))?;
262 let path_c = std::ffi::CString::new(path)
263 .map_err(|_| Error::rust_error("Unable to convert path to a C string"))?;
264 let protocol_ptr = protocol_c.as_ptr();
265 let path_ptr = path_c.as_ptr();
266 check_error(unsafe {
267 syz_createStreamHandleFromStreamParams(
268 &mut h as *mut syz_Handle,
269 protocol_ptr as *const c_char,
270 path_ptr as *const c_char,
271 std::mem::transmute(param),
272 std::ptr::null_mut(),
273 None,
274 )
275 })?;
276 Ok(StreamHandle {
277 handle: h,
278 needs_drop: None,
279 })
280 }
281
282 pub(crate) fn get_handle(&self) -> syz_Handle {
283 self.handle
284 }
285
286 fn get_userdata(mut self) -> UserdataBox {
287 let ret = if let Some((ud, free_cb)) = self.needs_drop.take() {
289 UserdataBox::from_streaming_userdata(ud, free_cb)
290 } else {
291 UserdataBox::new()
292 };
293 ret
294 }
295
296 pub(crate) fn with_userdata<T>(
301 mut self,
302 mut closure: impl (FnMut(syz_Handle, *mut c_void, extern "C" fn(*mut c_void)) -> Result<T>),
303 ) -> Result<T> {
304 let sh = self.handle;
305 self.handle = 0;
307 let ud = self.get_userdata();
308 ud.consume(move |ud, cb| closure(sh, ud, cb))
309 }
310}
311
312impl Drop for StreamHandle {
313 fn drop(&mut self) {
314 unsafe { syz_handleDecRef(self.handle) };
315 if let Some((ud, cb)) = self.needs_drop {
316 cb(ud.as_ptr());
317 }
318 }
319}
320
321static mut STREAM_ERR_CONSTANT: *const c_char = std::ptr::null();
322
323extern "C" fn stream_open_callback<
324 E,
325 T: 'static + Send + Sync + Fn(&str, &str, usize) -> std::result::Result<CustomStreamDef, E>,
326>(
327 out: *mut syz_CustomStreamDef,
328 protocol: *const c_char,
329 path: *const c_char,
330 param: *mut c_void,
331 userdata: *mut c_void,
332 err_msg: *mut *const c_char,
333) -> c_int {
334 static ONCE: std::sync::Once = std::sync::Once::new();
335 ONCE.call_once(|| {
336 let cstr = std::ffi::CString::new("Unable to create stream").unwrap();
337 unsafe { STREAM_ERR_CONSTANT = cstr.into_raw() };
338 });
339
340 let protocol = unsafe { std::ffi::CStr::from_ptr(protocol) };
341 let path = unsafe { std::ffi::CStr::from_ptr(path) };
342 let protocol = protocol.to_string_lossy();
343 let path = path.to_string_lossy();
344 let param: usize = unsafe { std::mem::transmute(param) };
345
346 let cb: Box<T> = unsafe { Box::from_raw(userdata as *mut T) };
347 let res = cb(protocol.borrow(), path.borrow(), param);
348 Box::into_raw(cb);
350
351 match res {
352 Ok(mut s) => {
353 unsafe { *out = s.def };
354 s.used = true;
355 0
356 }
357 Err(_) => {
358 unsafe { *err_msg = STREAM_ERR_CONSTANT };
359 1
360 }
361 }
362}
363
364pub fn register_stream_protocol<
370 E,
371 T: 'static + Send + Sync + Fn(&str, &str, usize) -> std::result::Result<CustomStreamDef, E>,
372>(
373 protocol: &str,
374 callback: T,
375) -> Result<()> {
376 let protocol_c = std::ffi::CString::new(protocol)
377 .map_err(|_| Error::rust_error("Unable to convert protocol to a C string"))?;
378 let leaked = Box::into_raw(Box::new(callback));
379 let res = check_error(unsafe {
380 syz_registerStreamProtocol(
381 protocol_c.as_ptr(),
382 Some(stream_open_callback::<E, T>),
383 leaked as *mut c_void,
384 )
385 });
386 match res {
387 Ok(_) => Ok(()),
388 Err(e) => {
389 unsafe { Box::from_raw(leaked) };
390 Err(e)
391 }
392 }
393}