synthizer/
custom_streams.rs

1//! Implement the infrastructure for custom streams.
2//!
3//! We actually get to do this on top of `Read` and `Seek` directly.
4use crate::internal_prelude::*;
5use std::borrow::Borrow;
6use std::ffi::{c_void, CString};
7use std::io::{Read, Seek};
8use std::os::raw::{c_char, c_int, c_longlong, c_ulonglong};
9use std::slice::from_raw_parts_mut;
10
11/// A trait which custom streams must implement in order to support closing.
12///
13/// Rust's stdlib has no concept of closing in it, but simply dropping values
14/// leads to panics.  This trait is therefore required to implement closing.
15pub trait CloseStream {
16    fn close(&mut self) -> std::result::Result<(), Box<dyn std::fmt::Display>>;
17}
18
19/// Marker trait for types which implement non-seekable streams.
20///
21/// A blanket impl is provided for anything implementing the supertraits.
22pub trait Stream: Read + CloseStream + Send + 'static {}
23
24impl<T: Read + CloseStream + Send + 'static> Stream for T {}
25
26/// A [Stream], but one which also implements [Seek].
27///
28/// blanket impls are provided for anything implementing [Stream] and [Seek]
29pub trait SeekableStream: Stream + Seek {}
30impl<T: Stream + Seek> SeekableStream for T {}
31
32pub(crate) struct CustomStreamData<T> {
33    pub(crate) err_msg: CString,
34    pub(crate) stream: T,
35}
36
37impl<T> CustomStreamData<T> {
38    fn conv_err(&mut self, data: &dyn std::fmt::Display) -> *const c_char {
39        let str = format!("{}", data);
40        let cstr = CString::new(str).expect("Display impls must produce valid C strings");
41        self.err_msg = cstr;
42        self.err_msg.as_ptr()
43    }
44}
45
46pub(crate) extern "C" fn stream_read_cb<T: Read>(
47    read: *mut c_ulonglong,
48    requested: c_ulonglong,
49    destination: *mut c_char,
50    userdata: *mut c_void,
51    err_msg: *mut *const c_char,
52) -> c_int {
53    let data = unsafe { &mut *(userdata as *mut CustomStreamData<T>) };
54
55    let dest = unsafe { from_raw_parts_mut(destination as *mut u8, requested as usize) };
56
57    let mut got_so_far = 0;
58
59    while got_so_far < requested {
60        match data.stream.read(&mut dest[got_so_far as usize..]) {
61            Ok(0) => {
62                break;
63            }
64            Ok(d) => {
65                got_so_far += d as u64;
66            }
67            Err(e) => {
68                unsafe { *err_msg = data.conv_err(&e) };
69                return 1;
70            }
71        }
72    }
73
74    unsafe { *read = got_so_far as c_ulonglong };
75    0
76}
77
78pub(crate) extern "C" fn stream_seek_cb<T: Seek>(
79    pos: c_ulonglong,
80    userdata: *mut c_void,
81    err_msg: *mut *const c_char,
82) -> c_int {
83    let data = unsafe { &mut *(userdata as *mut CustomStreamData<T>) };
84
85    match data.stream.seek(std::io::SeekFrom::Start(pos as u64)) {
86        Ok(_) => 0,
87        Err(e) => {
88            unsafe { *err_msg = data.conv_err(&e) };
89            1
90        }
91    }
92}
93
94extern "C" fn close_cb<T: CloseStream>(
95    userdata: *mut c_void,
96    err_msg: *mut *const c_char,
97) -> c_int {
98    let data = unsafe { &mut *(userdata as *mut CustomStreamData<T>) };
99
100    match data.stream.close() {
101        Ok(_) => 0,
102        Err(e) => {
103            unsafe { *err_msg = data.conv_err(&e) };
104            1
105        }
106    }
107}
108
109extern "C" fn destroy_cb<T>(userdata: *mut c_void) {
110    // Build a box and immediately drop it.
111    unsafe { Box::from_raw(userdata as *mut CustomStreamData<T>) };
112}
113
114/// Used as part of [StreamHandle] consumption.
115fn drop_cb<T>(ptr: *mut c_void) {
116    unsafe {
117        std::mem::drop(Box::<T>::from_raw(ptr as *mut T));
118    }
119}
120
121fn fillout_read<T: Stream>(dest: &mut syz_CustomStreamDef) {
122    dest.read_cb = Some(stream_read_cb::<T>);
123    dest.close_cb = Some(close_cb::<T>);
124    dest.destroy_cb = Some(destroy_cb::<T>);
125    dest.length = -1;
126}
127
128fn fillout_seekable<T: SeekableStream>(
129    dest: &mut syz_CustomStreamDef,
130    val: &mut T,
131) -> std::io::Result<()> {
132    dest.seek_cb = Some(stream_seek_cb::<T>);
133
134    val.seek(std::io::SeekFrom::End(0))?;
135    dest.length = val.stream_position()? as c_longlong;
136    val.seek(std::io::SeekFrom::Start(0))?;
137    Ok(())
138}
139
140fn fillout_userdata<T>(dest: &mut syz_CustomStreamDef, val: T) {
141    dest.userdata = Box::into_raw(Box::new(CustomStreamData {
142        stream: val,
143        err_msg: Default::default(),
144    })) as *mut c_void;
145}
146
147/// A definition for a custom stream.  This can come from a variety of places
148/// and is consumed by e.g. [StreamingGenerator::from_stream_handle], or
149/// returned as a result of the callback passed to [register_stream_protocol].
150pub struct CustomStreamDef {
151    def: syz_CustomStreamDef,
152    /// Has this handle been used yet? If not, the drop impl needs to do cleanup
153    /// so that the user's failure to consume the value with [CustomStreamDef]
154    /// does not leak their value.
155    used: bool,
156}
157
158impl CustomStreamDef {
159    /// Convert a [Read] to a stream.
160    pub fn from_reader<T: Stream>(value: T) -> CustomStreamDef {
161        let mut ret = CustomStreamDef {
162            def: Default::default(),
163            used: false,
164        };
165
166        fillout_read::<T>(&mut ret.def);
167        fillout_userdata(&mut ret.def, value);
168        ret
169    }
170
171    /// Build a stream from something seekable.
172    pub fn from_seekable<T: SeekableStream>(mut value: T) -> std::io::Result<CustomStreamDef> {
173        let mut ret = CustomStreamDef {
174            def: Default::default(),
175            used: false,
176        };
177        fillout_read::<T>(&mut ret.def);
178        fillout_seekable(&mut ret.def, &mut value)?;
179        fillout_userdata(&mut ret.def, value);
180        Ok(ret)
181    }
182}
183
184impl Drop for CustomStreamDef {
185    fn drop(&mut self) {
186        let mut err_msg: *const c_char = std::ptr::null();
187        if !self.used {
188            unsafe {
189                if let Some(cb) = self.def.close_cb {
190                    cb(self.def.userdata, &mut err_msg as *mut *const c_char);
191                }
192                if let Some(cb) = self.def.destroy_cb {
193                    cb(self.def.userdata);
194                }
195            }
196        }
197    }
198}
199
200/// A `StreamHandle` binds Synthizer custom streams, as well as other kinds of
201/// streaming functionality.
202#[derive(Debug, Eq, Ord, PartialEq, PartialOrd, Hash)]
203pub struct StreamHandle {
204    handle: syz_Handle,
205    // If set, this stream will move the given value into Synthizer userdata for freeing later.
206    needs_drop: Option<(std::ptr::NonNull<c_void>, fn(*mut c_void))>,
207}
208
209impl StreamHandle {
210    pub fn from_stream_def(mut def: CustomStreamDef) -> Result<StreamHandle> {
211        let mut h = Default::default();
212        check_error(unsafe {
213            syz_createStreamHandleFromCustomStream(
214                &mut h as *mut syz_Handle,
215                &mut def.def as *mut syz_CustomStreamDef,
216                std::ptr::null_mut(),
217                None,
218            )
219        })?;
220        def.used = true;
221
222        Ok(StreamHandle {
223            handle: h,
224            needs_drop: None,
225        })
226    }
227
228    /// Create a stream handle which is backed by memory.
229    pub fn from_vec(data: Vec<u8>) -> Result<StreamHandle> {
230        if data.is_empty() {
231            return Err(Error::rust_error("Cannot create streams from empty vecs"));
232        };
233        let mut h = Default::default();
234        check_error(unsafe {
235            let ptr = &data[0] as *const u8 as *const i8;
236            syz_createStreamHandleFromMemory(
237                &mut h as *mut syz_Handle,
238                data.len() as u64,
239                ptr,
240                std::ptr::null_mut(),
241                None,
242            )
243        })?;
244
245        Ok(StreamHandle {
246            handle: h,
247            needs_drop: Some((
248                unsafe {
249                    std::ptr::NonNull::new_unchecked(Box::into_raw(Box::new(data)) as *mut c_void)
250                },
251                drop_cb::<Vec<u8>>,
252            )),
253        })
254    }
255
256    pub fn from_stream_params(protocol: &str, path: &str, param: usize) -> Result<StreamHandle> {
257        // The below transmute uses the fact that `usize` is the size of a
258        // pointer on all common platforms.
259        let mut h = Default::default();
260        let protocol_c = std::ffi::CString::new(protocol)
261            .map_err(|_| Error::rust_error("Unable to convert protocol to a C string"))?;
262        let path_c = std::ffi::CString::new(path)
263            .map_err(|_| Error::rust_error("Unable to convert path to a C string"))?;
264        let protocol_ptr = protocol_c.as_ptr();
265        let path_ptr = path_c.as_ptr();
266        check_error(unsafe {
267            syz_createStreamHandleFromStreamParams(
268                &mut h as *mut syz_Handle,
269                protocol_ptr as *const c_char,
270                path_ptr as *const c_char,
271                std::mem::transmute(param),
272                std::ptr::null_mut(),
273                None,
274            )
275        })?;
276        Ok(StreamHandle {
277            handle: h,
278            needs_drop: None,
279        })
280    }
281
282    pub(crate) fn get_handle(&self) -> syz_Handle {
283        self.handle
284    }
285
286    fn get_userdata(mut self) -> UserdataBox {
287        // Be sure to take here so that Drop doesn't try to double free.
288        let ret = if let Some((ud, free_cb)) = self.needs_drop.take() {
289            UserdataBox::from_streaming_userdata(ud, free_cb)
290        } else {
291            UserdataBox::new()
292        };
293        ret
294    }
295
296    /// Wrap getting userdata and also make sure to free the handle once the
297    /// closure ends, regardless of if it succeeded.
298    // The closure gets the stream handle, as well as the userdata pointer and
299    // free callback.
300    pub(crate) fn with_userdata<T>(
301        mut self,
302        mut closure: impl (FnMut(syz_Handle, *mut c_void, extern "C" fn(*mut c_void)) -> Result<T>),
303    ) -> Result<T> {
304        let sh = self.handle;
305        // Take the handle.
306        self.handle = 0;
307        let ud = self.get_userdata();
308        ud.consume(move |ud, cb| closure(sh, ud, cb))
309    }
310}
311
312impl Drop for StreamHandle {
313    fn drop(&mut self) {
314        unsafe { syz_handleDecRef(self.handle) };
315        if let Some((ud, cb)) = self.needs_drop {
316            cb(ud.as_ptr());
317        }
318    }
319}
320
321static mut STREAM_ERR_CONSTANT: *const c_char = std::ptr::null();
322
323extern "C" fn stream_open_callback<
324    E,
325    T: 'static + Send + Sync + Fn(&str, &str, usize) -> std::result::Result<CustomStreamDef, E>,
326>(
327    out: *mut syz_CustomStreamDef,
328    protocol: *const c_char,
329    path: *const c_char,
330    param: *mut c_void,
331    userdata: *mut c_void,
332    err_msg: *mut *const c_char,
333) -> c_int {
334    static ONCE: std::sync::Once = std::sync::Once::new();
335    ONCE.call_once(|| {
336        let cstr = std::ffi::CString::new("Unable to create stream").unwrap();
337        unsafe { STREAM_ERR_CONSTANT = cstr.into_raw() };
338    });
339
340    let protocol = unsafe { std::ffi::CStr::from_ptr(protocol) };
341    let path = unsafe { std::ffi::CStr::from_ptr(path) };
342    let protocol = protocol.to_string_lossy();
343    let path = path.to_string_lossy();
344    let param: usize = unsafe { std::mem::transmute(param) };
345
346    let cb: Box<T> = unsafe { Box::from_raw(userdata as *mut T) };
347    let res = cb(protocol.borrow(), path.borrow(), param);
348    // Be sure not to drop the callback.
349    Box::into_raw(cb);
350
351    match res {
352        Ok(mut s) => {
353            unsafe { *out = s.def };
354            s.used = true;
355            0
356        }
357        Err(_) => {
358            unsafe { *err_msg = STREAM_ERR_CONSTANT };
359            1
360        }
361    }
362}
363
364/// register a custom protocol.
365///
366/// The callback here must return a [CustomStreamDef] which represents the
367/// custom stream.  Synthizer is also not safely reentrant, and the callback
368/// must not call back into Synthizer.
369pub fn register_stream_protocol<
370    E,
371    T: 'static + Send + Sync + Fn(&str, &str, usize) -> std::result::Result<CustomStreamDef, E>,
372>(
373    protocol: &str,
374    callback: T,
375) -> Result<()> {
376    let protocol_c = std::ffi::CString::new(protocol)
377        .map_err(|_| Error::rust_error("Unable to convert protocol to a C string"))?;
378    let leaked = Box::into_raw(Box::new(callback));
379    let res = check_error(unsafe {
380        syz_registerStreamProtocol(
381            protocol_c.as_ptr(),
382            Some(stream_open_callback::<E, T>),
383            leaked as *mut c_void,
384        )
385    });
386    match res {
387        Ok(_) => Ok(()),
388        Err(e) => {
389            unsafe { Box::from_raw(leaked) };
390            Err(e)
391        }
392    }
393}