1use core::{
2 cmp,
3 io::{BorrowedBuf, BorrowedCursor},
4};
5
6use crate::{BufRead, Error, IoBuf, Read, Result, Seek, SeekFrom};
7
8#[derive(Debug)]
17pub struct Take<T> {
18 inner: T,
19 len: u64,
20 limit: u64,
21}
22
23impl<T> Take<T> {
24 pub(crate) fn new(inner: T, limit: u64) -> Self {
25 Take {
26 inner,
27 len: limit,
28 limit,
29 }
30 }
31
32 pub fn limit(&self) -> u64 {
35 self.limit
36 }
37
38 pub fn position(&self) -> u64 {
40 self.len - self.limit
41 }
42
43 pub fn set_limit(&mut self, limit: u64) {
48 self.len = limit;
49 self.limit = limit;
50 }
51
52 pub fn into_inner(self) -> T {
54 self.inner
55 }
56
57 pub fn get_ref(&self) -> &T {
63 &self.inner
64 }
65
66 pub fn get_mut(&mut self) -> &mut T {
72 &mut self.inner
73 }
74}
75
76impl<T: Read> Read for Take<T> {
77 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
78 if self.limit == 0 {
80 return Ok(0);
81 }
82
83 let max = cmp::min(buf.len() as u64, self.limit) as usize;
84 let n = self.inner.read(&mut buf[..max])?;
85 assert!(n as u64 <= self.limit, "number of read bytes exceeds limit");
86 self.limit -= n as u64;
87 Ok(n)
88 }
89
90 fn read_buf(&mut self, mut buf: BorrowedCursor<'_>) -> Result<()> {
91 if self.limit == 0 {
93 return Ok(());
94 }
95
96 if self.limit < buf.capacity() as u64 {
97 let limit = self.limit as usize;
99
100 #[cfg(borrowedbuf_init)]
101 let extra_init = cmp::min(limit, buf.init_mut().len());
102
103 let ibuf = unsafe { &mut buf.as_mut()[..limit] };
105
106 let mut sliced_buf: BorrowedBuf<'_> = ibuf.into();
107
108 #[cfg(borrowedbuf_init)]
109 unsafe {
111 sliced_buf.set_init(extra_init);
112 }
113
114 let mut cursor = sliced_buf.unfilled();
115 let result = self.inner.read_buf(cursor.reborrow());
116
117 #[cfg(borrowedbuf_init)]
118 let new_init = cursor.init_mut().len();
119 let filled = sliced_buf.len();
120
121 #[cfg(borrowedbuf_init)]
124 unsafe {
125 buf.advance_unchecked(filled);
127 buf.set_init(new_init);
129 }
130 #[cfg(not(borrowedbuf_init))]
131 unsafe {
133 buf.advance(filled);
134 }
135
136 self.limit -= filled as u64;
137
138 result
139 } else {
140 let written = buf.written();
141 let result = self.inner.read_buf(buf.reborrow());
142 self.limit -= (buf.written() - written) as u64;
143 result
144 }
145 }
146}
147
148impl<T: BufRead> BufRead for Take<T> {
149 fn fill_buf(&mut self) -> Result<&[u8]> {
150 if self.limit == 0 {
152 return Ok(&[]);
153 }
154
155 let buf = self.inner.fill_buf()?;
156 let cap = cmp::min(buf.len() as u64, self.limit) as usize;
157 Ok(&buf[..cap])
158 }
159
160 fn consume(&mut self, amt: usize) {
161 let amt = cmp::min(amt as u64, self.limit) as usize;
163 self.limit -= amt as u64;
164 self.inner.consume(amt);
165 }
166}
167
168impl<T: Seek> Seek for Take<T> {
169 fn seek(&mut self, pos: SeekFrom) -> Result<u64> {
170 let new_position = match pos {
171 SeekFrom::Start(v) => Some(v),
172 SeekFrom::Current(v) => self.position().checked_add_signed(v),
173 SeekFrom::End(v) => self.len.checked_add_signed(v),
174 };
175 let new_position = match new_position {
176 Some(v) if v <= self.len => v,
177 _ => return Err(Error::InvalidInput),
178 };
179 while new_position != self.position() {
180 if let Some(offset) = new_position.checked_signed_diff(self.position()) {
181 self.inner.seek_relative(offset)?;
182 self.limit = self.limit.wrapping_sub(offset as u64);
183 break;
184 }
185 let offset = if new_position > self.position() {
186 i64::MAX
187 } else {
188 i64::MIN
189 };
190 self.inner.seek_relative(offset)?;
191 self.limit = self.limit.wrapping_sub(offset as u64);
192 }
193 Ok(new_position)
194 }
195
196 fn stream_len(&mut self) -> Result<u64> {
197 Ok(self.len)
198 }
199
200 fn stream_position(&mut self) -> Result<u64> {
201 Ok(self.position())
202 }
203
204 fn seek_relative(&mut self, offset: i64) -> Result<()> {
205 if self
206 .position()
207 .checked_add_signed(offset)
208 .is_none_or(|p| p > self.len)
209 {
210 return Err(Error::InvalidInput);
211 }
212 self.inner.seek_relative(offset)?;
213 self.limit = self.limit.wrapping_sub(offset as u64);
214 Ok(())
215 }
216}
217
218impl<T: IoBuf> IoBuf for Take<T> {
219 fn remaining(&self) -> usize {
220 cmp::min(self.inner.remaining(), self.limit as usize)
221 }
222}