1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
use std::{ffi::OsStr, path::Path};

use windows::{
    core::Result,
    Win32::{
        Foundation::{ERROR_BUFFER_OVERFLOW, E_FAIL, E_INVALIDARG, E_OUTOFMEMORY, MAX_PATH, PWSTR},
        Storage::FileSystem::{
            GetTempFileNameW, GetTempPathW, FILE_ATTRIBUTE_TEMPORARY, FILE_FLAG_DELETE_ON_CLOSE,
        },
        System::Com::{
            IStream,
            StructuredStorage::{STGM_CREATE, STGM_READWRITE, STGM_SHARE_EXCLUSIVE},
            STREAM_SEEK_SET,
        },
        UI::Shell::SHCreateStreamOnFileEx,
    },
};

/// Builder for a read/write implementation of the [`windows`] crate's [`IStream`] interface
/// backed by a temp file on disk. The temp file is created with [`SHCreateStreamOnFileEx`], using
/// [`FILE_ATTRIBUTE_TEMPORARY`] and [`FILE_FLAG_DELETE_ON_CLOSE`] so it will be deleted by the OS
/// as soon as the last reference to the [`IStream`] is dropped.
///
/// # Example
///
/// ```
/// use tempfile_istream::Builder;
///
/// let stream = Builder::new("prefix")
///     .with_content(b"binary content")
///     .build()
///     .expect("creates the stream");
/// ```
pub struct Builder<'a> {
    prefix: &'a str,
    content: Option<&'a [u8]>,
}

impl<'a> Builder<'a> {
    /// Create a new [`Builder`] for an empty [`IStream`] backed by a temp file on disk with the
    /// specified filename prefix. Only the first 3 characters of the [`prefix`] parameter will
    /// be used in the filename, but the entire string must match a valid [`std::path::Path`]
    /// `file_stem` or the call to `build` will fail.
    ///
    /// # Example
    ///
    /// ```
    /// use windows::Win32::System::Com::STREAM_SEEK_END;
    /// use tempfile_istream::Builder;
    ///
    /// let stream = Builder::new("prefix")
    ///     .build()
    ///     .expect("creates an empty stream");
    ///
    /// let end_pos = unsafe {
    ///     stream.Seek(0, STREAM_SEEK_END)
    /// }
    /// .expect("end position");
    ///
    /// assert_eq!(0, end_pos, "stream should be empty");
    /// ```
    ///
    /// # See also
    ///
    /// Parameter
    /// [requirements](https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-gettempfilenamew#parameters)
    /// for the `prefix` argument.
    pub fn new(prefix: &'a str) -> Self {
        Self {
            prefix,
            content: None,
        }
    }

    /// Initialize the stream with a [`u8`] slice of bytes and leave the [`IStream`] cursor at the
    /// beginning of the stream so that a consumer can immediately begin reading it back.
    ///
    /// # Example
    ///
    /// ```
    /// use std::mem;
    /// use tempfile_istream::Builder;
    ///
    /// const CONTENT: &[u8] = b"binary content";
    /// const CONTENT_LEN: usize = CONTENT.len();
    ///
    /// let stream = Builder::new("prefix")
    ///     .with_content(CONTENT)
    ///     .build()
    ///     .expect("creates a stream with content");
    ///
    /// let mut buf = [0_u8; CONTENT_LEN];
    /// let mut read_len = 0;
    /// unsafe {
    ///     stream.Read(
    ///         mem::transmute(buf.as_mut_ptr()),
    ///         buf.len() as u32,
    ///         &mut read_len,
    ///     )
    /// }
    /// .expect("read bytes");
    ///
    /// assert_eq!(buf, CONTENT, "should match the initial content");
    /// ```
    pub fn with_content(self, content: &'a [u8]) -> Self {
        Self {
            content: Some(content),
            ..self
        }
    }

    /// Create the [`IStream`] backed by a temp file. This will perform parameter validation
    /// on the `prefix` argument and fail with [`E_INVALIDARG`] if it contains anything other
    /// than a valid [`std::path::Path`] `file_stem`. Only the first 3 characters of the `prefix`
    /// will be used.
    ///
    /// # Example
    ///
    /// ```
    /// use windows::Win32::System::Com::{STREAM_SEEK_CUR, STREAM_SEEK_END};
    /// use tempfile_istream::Builder;
    ///
    /// const CONTENT: &[u8] = b"binary content";
    ///
    /// let stream = Builder::new("prefix")
    ///     .with_content(CONTENT)
    ///     .build()
    ///     .expect("creates a non-empty stream");
    ///
    /// let cur_pos = unsafe {
    ///     stream.Seek(0, STREAM_SEEK_CUR)
    /// }
    /// .expect("current position");
    ///
    /// assert_eq!(0, cur_pos, "current position should be at the beginning");
    ///
    /// let end_pos = unsafe {
    ///     stream.Seek(0, STREAM_SEEK_END)
    /// }
    /// .expect("end position");
    ///
    /// assert_eq!(end_pos as usize, CONTENT.len(), "end position should match content length")
    /// ```
    ///
    /// # See also
    ///
    /// Parameter
    /// [requirements](https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-gettempfilenamew#parameters)
    /// for the `prefix` argument.
    pub fn build(self) -> Result<IStream> {
        if !self.prefix.is_empty()
            && Path::new(self.prefix).file_stem() != Some(OsStr::new(self.prefix))
        {
            return Err(E_INVALIDARG.into());
        }

        const FILE_LEN: usize = (MAX_PATH + 1) as usize;
        const DIR_LEN: usize = FILE_LEN - 14;
        let mut dir = [0_u16; DIR_LEN];
        let mut file = [0_u16; FILE_LEN];

        let stream = unsafe {
            match GetTempPathW(dir.len() as u32, PWSTR(dir.as_mut_ptr())) {
                len if len as usize > dir.len() => Result::Err(E_OUTOFMEMORY.into()),
                _ => Ok(()),
            }?;
            match GetTempFileNameW(
                PWSTR(dir.as_mut_ptr()),
                self.prefix,
                0,
                PWSTR(file.as_mut_ptr()),
            ) {
                unique if unique == ERROR_BUFFER_OVERFLOW.0 => Result::Err(E_OUTOFMEMORY.into()),
                0 => Result::Err(E_FAIL.into()),
                _ => Ok(()),
            }?;

            SHCreateStreamOnFileEx(
                PWSTR(file.as_mut_ptr()),
                (STGM_CREATE | STGM_READWRITE | STGM_SHARE_EXCLUSIVE).0,
                (FILE_ATTRIBUTE_TEMPORARY | FILE_FLAG_DELETE_ON_CLOSE).0,
                true,
                None,
            )?
        };

        if let Some(content) = self.content {
            unsafe {
                stream.Write(std::mem::transmute(content.as_ptr()), content.len() as u32)?;
                stream.Seek(0, STREAM_SEEK_SET)?;
            }
        }

        Ok(stream)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::mem;
    use windows::Win32::System::Com::{IStream, STREAM_SEEK_SET};

    const TEST_PREFIX: &'static str = "test";

    #[test]
    fn new_tempfile_stream() {
        Builder::new(TEST_PREFIX).build().expect("create tempfile");
    }

    #[test]
    fn with_bytes_and_read() {
        let text = b"with_bytes_and_read";
        let stream: IStream = Builder::new(TEST_PREFIX)
            .with_content(text)
            .build()
            .expect("create tempfile")
            .into();
        let mut buf = Vec::new();
        buf.resize(text.len() + 1, 0_u8);
        let mut read_len = 0;
        unsafe {
            stream.Read(
                mem::transmute(buf.as_mut_ptr()),
                buf.len() as u32,
                &mut read_len,
            )
        }
        .expect("read bytes");
        assert_eq!(read_len as usize, text.len());
        assert_eq!(text, &buf[0..text.len()]);
        assert_eq!(0, buf[text.len()]);
    }

    #[test]
    fn write_and_read() {
        let text = b"write_and_read";
        let stream: IStream = Builder::new(TEST_PREFIX)
            .build()
            .expect("create tempfile")
            .into();
        let write_len = unsafe { stream.Write(mem::transmute(text.as_ptr()), text.len() as u32) }
            .expect("write bytes") as usize;
        assert_eq!(write_len, text.len());
        unsafe { stream.Seek(0, STREAM_SEEK_SET) }.expect("seek to beginning");
        let mut buf = Vec::new();
        buf.resize(write_len + 1, 0_u8);
        let mut read_len = 0;
        unsafe {
            stream.Read(
                mem::transmute(buf.as_mut_ptr()),
                buf.len() as u32,
                &mut read_len,
            )
        }
        .expect("read bytes");
        assert_eq!(read_len as usize, write_len);
        assert_eq!(text, &buf[0..write_len]);
        assert_eq!(0, buf[write_len]);
    }
}