data_streams/
std_io.rs

1// SPDX-License-Identifier: Apache-2.0
2
3#![cfg(feature = "std")]
4
5#[cfg(all(feature = "alloc", feature = "utf8"))]
6use alloc::string::String;
7use std::io::{BufRead, BufReader, BufWriter, Cursor, Empty, ErrorKind, Read, Repeat, Seek, Sink, Take, Write};
8use crate::{
9	BufferAccess,
10	DataSink,
11	DataSource,
12	Error,
13	Result,
14	source::default_skip,
15};
16use crate::markers::source::{InfiniteSource, SourceSize};
17
18#[cfg(any(unix, windows, target_os = "wasi"))]
19// Safety: the size is read from the file system metadata.
20unsafe impl SourceSize for &std::fs::File {
21	// Todo: lower bound?
22	
23	fn upper_bound(&self) -> Option<u64> {
24		#[cfg(unix)]
25		let size = std::os::unix::fs::MetadataExt::size;
26		#[cfg(windows)]
27		let size = std::os::windows::fs::MetadataExt::file_size;
28		#[cfg(target_os = "wasi")]
29		let size = std::os::wasi::fs::MetadataExt::size;
30		let pos = (&mut &**self).stream_position().ok()?;
31		self.metadata()
32			.ok()
33			.as_ref()
34			.map(size)
35			.map(|s| s - pos)
36	}
37}
38
39#[cfg(any(unix, windows, target_os = "wasi"))]
40// Safety: the size is read from the file system metadata.
41unsafe impl SourceSize for std::fs::File {
42	fn upper_bound(&self) -> Option<u64> {
43		(&self).upper_bound()
44	}
45}
46
47impl<R: Read + ?Sized> DataSource for BufReader<R> {
48	#[cfg(not(feature = "unstable_specialization"))]
49	fn available(&self) -> usize { self.buffer_count() }
50
51	#[cfg(not(feature = "unstable_specialization"))]
52	fn request(&mut self, count: usize) -> Result<bool> {
53		crate::source::default_request(self, count)
54	}
55
56	fn skip(&mut self, count: usize) -> Result<usize> {
57		Ok(buf_read_skip(self, count))
58	}
59
60	fn read_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
61		buf_read_bytes(self, buf)
62	}
63
64	fn read_exact_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
65		buf_read_exact_bytes(self, buf)
66	}
67}
68
69impl<R: Read + ?Sized> BufferAccess for BufReader<R> {
70	fn buffer_capacity(&self) -> usize { self.capacity() }
71
72	fn buffer(&self) -> &[u8] { self.buffer() }
73
74	fn fill_buffer(&mut self) -> Result<&[u8]> {
75		Ok(self.fill_buf()?)
76	}
77
78	fn drain_buffer(&mut self, count: usize) {
79		self.consume(count);
80	}
81}
82
83// Safety: the bounds are correct if those returned by `R` are correct.
84unsafe impl<R: Read + SourceSize + ?Sized> SourceSize for BufReader<R> {
85	// Todo: include buffer size?
86	
87	fn lower_bound(&self) -> u64 {
88		self.get_ref().lower_bound()
89	}
90
91	fn upper_bound(&self) -> Option<u64> {
92		self.get_ref().upper_bound()
93	}
94}
95
96impl<W: Write + ?Sized> DataSink for BufWriter<W> {
97	fn write_bytes(&mut self, buf: &[u8]) -> Result {
98		self.write_all(buf)?;
99		Ok(())
100	}
101}
102
103impl<T: AsRef<[u8]>> DataSource for Cursor<T> {
104	#[cfg(not(feature = "unstable_specialization"))]
105	fn available(&self) -> usize { self.buffer_count() }
106
107	fn request(&mut self, count: usize) -> Result<bool> {
108		Ok(self.available() >= count)
109	}
110
111	fn skip(&mut self, mut count: usize) -> Result<usize> {
112		count = count.min(self.available());
113		self.consume(count);
114		Ok(count)
115	}
116
117	fn read_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
118		let count = self.read(buf)?;
119		Ok(&buf[..count])
120	}
121
122	fn read_exact_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
123		buf_read_exact_bytes(self, buf)
124	}
125}
126
127impl<T: AsRef<[u8]>> BufferAccess for Cursor<T> {
128	fn buffer_capacity(&self) -> usize { cursor_as_slice(self).len() }
129
130	fn buffer_count(&self) -> usize {
131		self.buffer_capacity()
132			.min(self.position() as usize)
133	}
134
135	fn buffer(&self) -> &[u8] {
136		// See Cursor::fill_buf and Cursor::split
137		let slice = cursor_as_slice(self);
138		let start = self.buffer_count();
139		&slice[start..]
140	}
141
142	fn fill_buffer(&mut self) -> Result<&[u8]> {
143		Ok((*self).buffer()) // Nothing to read
144	}
145
146	fn drain_buffer(&mut self, count: usize) {
147		self.consume(count);
148	}
149}
150
151// Safety: the size is the buffer count.
152unsafe impl<T: AsRef<[u8]>> SourceSize for Cursor<T> {
153	fn lower_bound(&self) -> u64 { self.buffer_count() as u64 }
154	fn upper_bound(&self) -> Option<u64> { Some(self.buffer_count() as u64) }
155}
156
157impl<T> DataSink for Cursor<T> where Self: Write {
158	fn write_bytes(&mut self, buf: &[u8]) -> Result {
159		let count = self.write(buf)?;
160		if count < buf.len() {
161			let remaining = buf.len() - count;
162			Err(Error::Overflow { remaining })
163		} else {
164			Ok(())
165		}
166	}
167}
168
169fn cursor_as_slice<T: AsRef<[u8]>>(cursor: &Cursor<T>) -> &[u8] {
170	cursor.get_ref().as_ref()
171}
172
173impl<T: BufferAccess + BufRead> DataSource for Take<T> {
174	#[cfg(not(feature = "unstable_specialization"))]
175	fn available(&self) -> usize { self.buffer_count() }
176
177	#[cfg(not(feature = "unstable_specialization"))]
178	fn request(&mut self, count: usize) -> Result<bool> {
179		crate::source::default_request(self, count)
180	}
181
182	fn skip(&mut self, count: usize) -> Result<usize> {
183		Ok(buf_read_skip(self, count))
184	}
185
186	fn read_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
187		buf_read_bytes(self, buf)
188	}
189
190	fn read_exact_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
191		buf_read_exact_bytes(self, buf)
192	}
193}
194
195impl<T: BufferAccess + BufRead> BufferAccess for Take<T> {
196	fn buffer_capacity(&self) -> usize { self.get_ref().buffer_capacity() }
197
198	fn buffer_count(&self) -> usize {
199		self.get_ref()
200			.buffer_count()
201			.min(self.limit() as usize)
202	}
203	
204	fn buffer(&self) -> &[u8] {
205		let buf = self.get_ref().buffer();
206		let len = self.buffer_count();
207		&buf[..len]
208	}
209
210	fn fill_buffer(&mut self) -> Result<&[u8]> {
211		Ok(self.fill_buf()?)
212	}
213
214	fn drain_buffer(&mut self, count: usize) {
215		self.consume(count);
216	}
217}
218
219// Safety: the upper bound is correct if `Take` behaves correctly (produces no more bytes than its
220//  limit).
221unsafe impl<T> SourceSize for Take<T> {
222	fn upper_bound(&self) -> Option<u64> {
223		Some(self.limit())
224	}
225}
226
227macro_rules! fixed_stream_impl {
228    (impl $trait:ident for $stream:ident {
229		$($item:item)+
230	}) => {
231		impl $trait for $stream {
232			$($item)+
233		}
234		
235		impl $trait for &$stream {
236			$($item)+
237		}
238	};
239}
240
241fixed_stream_impl! {
242impl DataSource for Empty {
243	fn available(&self) -> usize { 0 }
244
245	fn request(&mut self, _: usize) -> Result<bool> {
246		Ok(false)
247	}
248
249	fn skip(&mut self, _: usize) -> Result<usize> {
250		Ok(0)
251	}
252
253	fn read_bytes<'a>(&mut self, _: &'a mut [u8]) -> Result<&'a [u8]> {
254		Ok(&[])
255	}
256
257	#[cfg(feature = "utf8")]
258	fn read_utf8<'a>(&mut self, _: &'a mut [u8]) -> Result<&'a str> {
259		Ok("")
260	}
261}
262}
263
264// Safety: `Empty` produces no bytes by definition.
265unsafe impl SourceSize for Empty {
266	fn upper_bound(&self) -> Option<u64> { Some(0) }
267}
268// Safety: `Empty` produces no bytes by definition.
269unsafe impl SourceSize for &Empty {
270	fn upper_bound(&self) -> Option<u64> { Some(0) }
271}
272
273impl DataSink for Empty {
274	fn write_bytes(&mut self, _: &[u8]) -> Result { Ok(()) }
275}
276impl DataSink for &Empty {
277	fn write_bytes(&mut self, _: &[u8]) -> Result { Ok(()) }
278}
279
280impl DataSink for Sink {
281	fn write_bytes(&mut self, _: &[u8]) -> Result { Ok(()) }
282}
283impl DataSink for &Sink {
284	fn write_bytes(&mut self, _: &[u8]) -> Result { Ok(()) }
285}
286
287impl DataSource for Repeat {
288	fn available(&self) -> usize { usize::MAX }
289
290	fn request(&mut self, _: usize) -> Result<bool> {
291		Ok(true)
292	}
293
294	fn skip(&mut self, count: usize) -> Result<usize> {
295		Ok(count)
296	}
297
298	fn read_bytes<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8]> {
299		// Safety: Repeat doesn't return an error.
300		unsafe {
301			Read::read(self, buf).unwrap_unchecked();
302		}
303		Ok(buf)
304	}
305
306	#[cfg(feature = "utf8")]
307	fn read_utf8<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a str> {
308		// Safety: Repeat doesn't return an error.
309		match unsafe { self.read_bytes(buf).unwrap_unchecked() } {
310			[] => Ok(""),
311			bytes @ [byte, ..] if byte.is_ascii() => Ok(
312				// Safety: the byte is valid ASCII, which is valid UTF-8.
313				unsafe {
314					core::str::from_utf8_unchecked(bytes)
315				}
316			),
317			bytes =>
318				// Use from_utf8 to convert the byte into a UTF-8 error.
319				// Safety: Unwrap is safe because non-ASCII bytes are not valid UTF-8.
320				Err(unsafe {
321					simdutf8::compat::from_utf8(&bytes[..1]).unwrap_err_unchecked().into()
322				})
323		}
324	}
325
326	#[cfg(feature = "unstable_ascii_char")]
327	fn read_ascii<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [core::ascii::Char]> {
328		// Safety: Repeat doesn't return an error.
329		match unsafe { self.read_bytes(buf).unwrap_unchecked() } {
330			[] => Ok(&[]),
331			bytes @ [byte, ..] if byte.is_ascii() => Ok(
332				// Safety: the byte is valid ASCII.
333				unsafe {
334					bytes.as_ascii_unchecked()
335				}
336			),
337			bytes @ &[byte, ..] => Err(Error::invalid_ascii(byte, 0, bytes.len()))
338		}
339	}
340}
341
342// Safety: the source repeats one byte forever.
343unsafe impl InfiniteSource for Repeat { }
344
345fn buf_read_skip(source: &mut (impl BufferAccess + ?Sized), count: usize) -> usize {
346	let mut skip_count = 0;
347	while skip_count < count {
348		let cur_skip_count = default_skip(&mut *source, count);
349		skip_count += cur_skip_count;
350
351		if cur_skip_count == 0 {
352			break
353		}
354	}
355	skip_count
356}
357
358fn buf_read_bytes<'a>(source: &mut (impl Read + ?Sized), buf: &'a mut [u8]) -> Result<&'a [u8]> {
359	use ErrorKind::Interrupted;
360
361	let mut count = 0;
362	loop {
363		match source.read(buf) {
364			Ok(0) => break Ok(&buf[..count]),
365			Ok(cur_count) => count += cur_count,
366			Err(err) if err.kind() == Interrupted => { }
367			Err(err) => break Err(err.into())
368		}
369	}
370}
371
372fn buf_read_exact_bytes<'a>(source: &mut (impl Read + ?Sized), buf: &'a mut [u8]) -> Result<&'a [u8]> {
373	match source.read_exact(&mut *buf) {
374		Ok(()) => Ok(buf),
375		Err(error) if error.kind() == ErrorKind::UnexpectedEof =>
376			Err(Error::End { required_count: buf.len() }),
377		Err(error) => Err(error.into())
378	}
379}
380
381#[cfg(all(feature = "alloc", feature = "utf8"))]
382#[allow(dead_code)]
383fn buf_read_utf8_to_end<'a>(source: &mut (impl Read + ?Sized), buf: &'a mut String) -> Result<&'a str> {
384	// Safety: this function only modifies the string's bytes if the new bytes are found to be
385	//  valid UTF-8.
386	unsafe {
387		crate::source::append_utf8(buf, |b|
388			Ok(source.read_to_end(b)?)
389		)
390	}
391}