tempfile_istream/lib.rs
1use std::{ffi::OsStr, mem::MaybeUninit, path::Path};
2
3use windows::{
4    core::{Result, PCWSTR},
5    Win32::{
6        Foundation::{ERROR_BUFFER_OVERFLOW, E_FAIL, E_INVALIDARG, E_OUTOFMEMORY, MAX_PATH},
7        Storage::FileSystem::{
8            GetTempFileNameW, GetTempPathW, FILE_ATTRIBUTE_TEMPORARY, FILE_FLAG_DELETE_ON_CLOSE,
9        },
10        System::Com::{
11            IStream,
12            StructuredStorage::{STGM_CREATE, STGM_READWRITE, STGM_SHARE_EXCLUSIVE},
13            STREAM_SEEK_SET,
14        },
15        UI::Shell::SHCreateStreamOnFileEx,
16    },
17};
18
19/// Builder for a read/write implementation of the [`windows`] crate's [`IStream`] interface
20/// backed by a temp file on disk. The temp file is created with [`SHCreateStreamOnFileEx`], using
21/// [`FILE_ATTRIBUTE_TEMPORARY`] and [`FILE_FLAG_DELETE_ON_CLOSE`] so it will be deleted by the OS
22/// as soon as the last reference to the [`IStream`] is dropped.
23///
24/// # Example
25///
26/// ```
27/// use tempfile_istream::Builder;
28///
29/// let stream = Builder::new("prefix")
30///     .with_content(b"binary content")
31///     .build()
32///     .expect("creates the stream");
33/// ```
34pub struct Builder<'a> {
35    prefix: &'a str,
36    content: Option<&'a [u8]>,
37}
38
39impl<'a> Builder<'a> {
40    /// Create a new [`Builder`] for an empty [`IStream`] backed by a temp file on disk with the
41    /// specified filename prefix. Only the first 3 characters of the `prefix` parameter will
42    /// be used in the filename, but the entire string must match a valid [`std::path::Path`]
43    /// `file_stem` or the call to `build` will fail.
44    ///
45    /// # Example
46    ///
47    /// ```
48    /// use windows::Win32::System::Com::STREAM_SEEK_END;
49    /// use tempfile_istream::Builder;
50    ///
51    /// let stream = Builder::new("prefix")
52    ///     .build()
53    ///     .expect("creates an empty stream");
54    ///
55    /// let end_pos = unsafe {
56    ///     stream.Seek(0, STREAM_SEEK_END)
57    /// }
58    /// .expect("end position");
59    ///
60    /// assert_eq!(0, end_pos, "stream should be empty");
61    /// ```
62    ///
63    /// # See also
64    ///
65    /// Parameter
66    /// [requirements](https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-gettempfilenamew#parameters)
67    /// for the `prefix` argument.
68    pub fn new(prefix: &'a str) -> Self {
69        Self {
70            prefix,
71            content: None,
72        }
73    }
74
75    /// Initialize the stream with a [`u8`] slice of bytes and leave the [`IStream`] cursor at the
76    /// beginning of the stream so that a consumer can immediately begin reading it back.
77    ///
78    /// # Example
79    ///
80    /// ```
81    /// use std::mem;
82    /// use tempfile_istream::Builder;
83    ///
84    /// const CONTENT: &[u8] = b"binary content";
85    /// const CONTENT_LEN: usize = CONTENT.len();
86    ///
87    /// let stream = Builder::new("prefix")
88    ///     .with_content(CONTENT)
89    ///     .build()
90    ///     .expect("creates a stream with content");
91    ///
92    /// let mut buf = [0_u8; CONTENT_LEN];
93    /// let mut read_len = 0;
94    /// unsafe {
95    ///     stream.Read(
96    ///         mem::transmute(buf.as_mut_ptr()),
97    ///         buf.len() as u32,
98    ///         &mut read_len,
99    ///     )
100    ///     .ok()
101    /// }
102    /// .expect("read bytes");
103    ///
104    /// assert_eq!(buf, CONTENT, "should match the initial content");
105    /// ```
106    pub fn with_content(self, content: &'a [u8]) -> Self {
107        Self {
108            content: Some(content),
109            ..self
110        }
111    }
112
113    /// Create the [`IStream`] backed by a temp file. This will perform parameter validation
114    /// on the `prefix` argument and fail with [`E_INVALIDARG`] if it contains anything other
115    /// than a valid [`std::path::Path`] `file_stem`. Only the first 3 characters of the `prefix`
116    /// will be used.
117    ///
118    /// # Example
119    ///
120    /// ```
121    /// use windows::Win32::System::Com::{STREAM_SEEK_CUR, STREAM_SEEK_END};
122    /// use tempfile_istream::Builder;
123    ///
124    /// const CONTENT: &[u8] = b"binary content";
125    ///
126    /// let stream = Builder::new("prefix")
127    ///     .with_content(CONTENT)
128    ///     .build()
129    ///     .expect("creates a non-empty stream");
130    ///
131    /// let cur_pos = unsafe {
132    ///     stream.Seek(0, STREAM_SEEK_CUR)
133    /// }
134    /// .expect("current position");
135    ///
136    /// assert_eq!(0, cur_pos, "current position should be at the beginning");
137    ///
138    /// let end_pos = unsafe {
139    ///     stream.Seek(0, STREAM_SEEK_END)
140    /// }
141    /// .expect("end position");
142    ///
143    /// assert_eq!(end_pos as usize, CONTENT.len(), "end position should match content length")
144    /// ```
145    ///
146    /// # See also
147    ///
148    /// Parameter
149    /// [requirements](https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-gettempfilenamew#parameters)
150    /// for the `prefix` argument.
151    pub fn build(self) -> Result<IStream> {
152        if !self.prefix.is_empty()
153            && Path::new(self.prefix).file_stem() != Some(OsStr::new(self.prefix))
154        {
155            return Err(E_INVALIDARG.into());
156        }
157
158        let stream = unsafe {
159            const FILE_LEN: usize = MAX_PATH as usize;
160            const DIR_LEN: usize = FILE_LEN - 14;
161
162            let mut dir = [MaybeUninit::<u16>::uninit(); DIR_LEN];
163            let mut file = [MaybeUninit::<u16>::uninit(); FILE_LEN];
164
165            match GetTempPathW(
166                &mut *(std::ptr::slice_from_raw_parts_mut(dir.as_mut_ptr(), dir.len())
167                    as *mut [u16]),
168            ) as usize
169            {
170                0 => Err(windows::core::Error::from_win32()),
171                len if len >= dir.len() => E_OUTOFMEMORY.ok(),
172                _ => Ok(()),
173            }?;
174            match GetTempFileNameW(
175                PCWSTR(std::mem::transmute(dir.as_ptr())),
176                self.prefix,
177                0,
178                &mut *(file.as_mut_ptr() as *mut [u16; FILE_LEN]),
179            ) {
180                unique if unique == ERROR_BUFFER_OVERFLOW.0 => Result::Err(E_OUTOFMEMORY.into()),
181                0 => Result::Err(E_FAIL.into()),
182                _ => Ok(()),
183            }?;
184            SHCreateStreamOnFileEx(
185                PCWSTR(std::mem::transmute(file.as_ptr())),
186                (STGM_CREATE | STGM_READWRITE | STGM_SHARE_EXCLUSIVE).0,
187                (FILE_ATTRIBUTE_TEMPORARY | FILE_FLAG_DELETE_ON_CLOSE).0,
188                true,
189                None,
190            )?
191        };
192
193        if let Some(content) = self.content {
194            unsafe {
195                stream
196                    .Write(
197                        std::mem::transmute(content.as_ptr()),
198                        content.len() as u32,
199                        std::ptr::null_mut(),
200                    )
201                    .ok()?;
202                stream.Seek(0, STREAM_SEEK_SET)?;
203            }
204        }
205
206        Ok(stream)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use std::mem;
214    use windows::Win32::System::Com::{IStream, STREAM_SEEK_SET};
215
216    const TEST_PREFIX: &'static str = "test";
217
218    #[test]
219    fn new_tempfile_stream() {
220        Builder::new(TEST_PREFIX).build().expect("create tempfile");
221    }
222
223    #[test]
224    fn with_bytes_and_read() {
225        let text = b"with_bytes_and_read";
226        let stream: IStream = Builder::new(TEST_PREFIX)
227            .with_content(text)
228            .build()
229            .expect("create tempfile")
230            .into();
231        let mut buf = Vec::new();
232        buf.resize(text.len() + 1, 0_u8);
233        let mut read_len = 0;
234        unsafe {
235            stream.Read(
236                mem::transmute(buf.as_mut_ptr()),
237                buf.len() as u32,
238                &mut read_len,
239            )
240        }
241        .ok()
242        .expect("read bytes");
243        assert_eq!(read_len as usize, text.len());
244        assert_eq!(text, &buf[0..text.len()]);
245        assert_eq!(0, buf[text.len()]);
246    }
247
248    #[test]
249    fn write_and_read() {
250        let text = b"write_and_read";
251        let stream: IStream = Builder::new(TEST_PREFIX)
252            .build()
253            .expect("create tempfile")
254            .into();
255        let write_len = unsafe {
256            let mut write_len = mem::MaybeUninit::uninit();
257            stream
258                .Write(
259                    mem::transmute(text.as_ptr()),
260                    text.len() as u32,
261                    write_len.as_mut_ptr(),
262                )
263                .ok()
264                .expect("write bytes");
265            write_len.assume_init() as usize
266        };
267        assert_eq!(write_len, text.len());
268        unsafe { stream.Seek(0, STREAM_SEEK_SET) }.expect("seek to beginning");
269        let mut buf = Vec::new();
270        buf.resize(write_len + 1, 0_u8);
271        let read_len = unsafe {
272            let mut read_len = mem::MaybeUninit::uninit();
273            stream
274                .Read(
275                    mem::transmute(buf.as_mut_ptr()),
276                    buf.len() as u32,
277                    read_len.as_mut_ptr(),
278                )
279                .ok()
280                .expect("read bytes");
281            read_len.assume_init() as usize
282        };
283        assert_eq!(read_len, write_len);
284        assert_eq!(text, &buf[0..write_len]);
285        assert_eq!(0, buf[write_len]);
286    }
287}