1use super::*;
2use std::cell::RefCell;
3use tracing::{trace, trace_span};
4
5pub struct StreamWriter<'a, F> {
7 pub(super) stream: u32,
9
10 pub(super) file: &'a F,
11
12 pub(super) size: &'a mut u32,
16
17 pub(super) page_allocator: &'a mut PageAllocator,
18
19 pub(super) pages: &'a mut Vec<Page>,
21
22 pub(super) pos: u64,
27}
28
29impl<'a, F> StreamWriter<'a, F> {
30 pub fn len(&self) -> u32 {
32 if *self.size == NIL_STREAM_SIZE {
33 0
34 } else {
35 *self.size
36 }
37 }
38
39 pub fn is_empty(&self) -> bool {
41 *self.size == 0 || *self.size == NIL_STREAM_SIZE
42 }
43
44 pub fn set_contents(&mut self, data: &[u8]) -> std::io::Result<()>
46 where
47 F: ReadAt + WriteAt,
48 {
49 let _span = trace_span!("StreamWriter::set_contents").entered();
50
51 if data.len() as u64 >= NIL_STREAM_SIZE as u64 {
52 return Err(std::io::ErrorKind::InvalidInput.into());
53 }
54 let data_len = data.len() as u32;
55
56 if *self.size > data_len {
59 self.set_len(data_len)?;
60 }
61
62 self.write_core(data, 0)?;
63 self.set_len(data.len() as u32)?;
64 Ok(())
65 }
66
67 pub fn write_at_mut(&mut self, buf: &[u8], offset: u64) -> std::io::Result<usize>
72 where
73 F: ReadAt + WriteAt,
74 {
75 let _span = trace_span!("StreamWriter::write_at_mut").entered();
76
77 self.write_core(buf, offset)?;
78 Ok(buf.len())
79 }
80
81 pub fn write_all_at_mut(&mut self, buf: &[u8], offset: u64) -> std::io::Result<()>
86 where
87 F: ReadAt + WriteAt,
88 {
89 self.write_core(buf, offset)
90 }
91
92 pub fn set_len(&mut self, mut len: u32) -> std::io::Result<()>
104 where
105 F: ReadAt + WriteAt,
106 {
107 use std::cmp::Ordering;
108
109 let _span = trace_span!("StreamWriter::set_len").entered();
110 trace!(new_len = len);
111
112 if *self.size == NIL_STREAM_SIZE {
113 trace!("stream changes from nil to non-nil");
114 *self.size = 0;
115 }
116
117 let page_size = self.page_allocator.page_size;
118
119 match Ord::cmp(&len, self.size) {
120 Ordering::Equal => {
121 trace!(len = self.size, "no change in stream size");
122 }
123
124 Ordering::Less => {
125 trace!(old_len = self.size, new_len = len, "reducing stream size");
127
128 let num_pages_old = num_pages_for_stream_size(*self.size, page_size) as usize;
129 let num_pages_new = num_pages_for_stream_size(len, page_size) as usize;
130 assert!(num_pages_new <= num_pages_old);
131
132 for &page in self.pages[num_pages_new..num_pages_old].iter() {
133 self.page_allocator.fpm_freed.set(page as usize, true);
134 }
135
136 self.pages.truncate(num_pages_new);
137 *self.size = len;
138 }
139
140 Ordering::Greater => {
141 trace!(
143 old_len = self.size,
144 new_len = len,
145 "increasing stream size (zero-filling)"
146 );
147
148 let end_phase = offset_within_page(*self.size, page_size);
149 if end_phase != 0 {
150 let total_zx_bytes = len - *self.size;
152
153 let end_spage = *self.size / page_size;
155 let num_zx_bytes = (u32::from(page_size) - end_phase).min(total_zx_bytes);
156
157 let mut page_buffer = self.page_allocator.alloc_page_buffer();
158 self.read_page(end_spage, &mut page_buffer)?;
159 page_buffer[end_phase as usize..].fill(0);
160 self.cow_page_and_write(end_spage, &page_buffer)?;
161
162 *self.size += num_zx_bytes;
163
164 len -= num_zx_bytes;
165 if len == 0 {
166 return Ok(());
168 }
169 }
170
171 assert!(page_size.is_aligned(*self.size));
174
175 let mut page_buffer = self.page_allocator.alloc_page_buffer();
176 page_buffer.fill(0);
177
178 assert!(page_size.is_aligned(*self.size));
179
180 let num_zx_pages_wanted = (len - *self.size).div_round_up(page_size);
182
183 let (first_page, run_len) = self.page_allocator.alloc_pages(num_zx_pages_wanted);
184 assert!(run_len > 0);
185
186 let old_num_pages = self.pages.len() as u32;
187
188 for i in 0..run_len {
189 self.pages.push(first_page + i);
190 }
191
192 *self.size += len;
194
195 for i in 0..run_len {
200 self.write_page(old_num_pages + i, &page_buffer)?;
201 }
202 }
203 }
204
205 Ok(())
206 }
207
208 pub fn into_random(self) -> RandomStreamWriter<'a, F> {
210 RandomStreamWriter {
211 cell: RefCell::new(self),
212 }
213 }
214}
215
216impl<'a, F: ReadAt> std::io::Seek for StreamWriter<'a, F> {
217 fn seek(&mut self, from: SeekFrom) -> std::io::Result<u64> {
218 let new_pos: i64 = match from {
219 SeekFrom::Start(offset) => offset as i64,
220 SeekFrom::End(signed_offset) => signed_offset + *self.size as i64,
221 SeekFrom::Current(signed_offset) => self.pos as i64 + signed_offset,
222 };
223
224 if new_pos < 0 {
225 return Err(std::io::ErrorKind::InvalidInput.into());
226 }
227
228 self.pos = new_pos as u64;
229 Ok(self.pos)
230 }
231}
232
233impl<'a, F: ReadAt> std::io::Read for StreamWriter<'a, F> {
234 fn read(&mut self, dst: &mut [u8]) -> std::io::Result<usize> {
235 let (n, new_pos) = super::read::read_stream_core(
236 self.stream,
237 self.file,
238 self.page_allocator.page_size,
239 *self.size,
240 self.pages,
241 self.pos,
242 dst,
243 )?;
244 self.pos = new_pos;
245 Ok(n)
246 }
247}
248
249impl<'a, F: ReadAt + WriteAt> std::io::Write for StreamWriter<'a, F> {
250 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
251 self.write_core(buf, self.pos)?;
252 self.pos += buf.len() as u64;
253 Ok(buf.len())
254 }
255
256 fn flush(&mut self) -> std::io::Result<()> {
257 Ok(())
258 }
259}
260
261pub struct RandomStreamWriter<'a, F> {
262 cell: RefCell<StreamWriter<'a, F>>,
263}
264
265impl<'a, F: ReadAt> ReadAt for RandomStreamWriter<'a, F> {
266 fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<()> {
267 let sw = self.cell.borrow();
268 let (n, _new_pos) = super::read::read_stream_core(
269 sw.stream,
270 &sw.file,
271 sw.page_allocator.page_size,
272 *sw.size,
273 sw.pages,
274 offset,
275 buf,
276 )?;
277 if n != buf.len() {
278 return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof));
279 }
280 Ok(())
281 }
282
283 fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
284 let sw = self.cell.borrow();
285 let (n, _new_pos) = super::read::read_stream_core(
286 sw.stream,
287 &sw.file,
288 sw.page_allocator.page_size,
289 *sw.size,
290 sw.pages,
291 offset,
292 buf,
293 )?;
294 Ok(n)
295 }
296}
297
298impl<'a, F: ReadAt + WriteAt> WriteAt for RandomStreamWriter<'a, F> {
299 fn write_at(&self, buf: &[u8], offset: u64) -> std::io::Result<usize> {
300 let mut sw = self.cell.borrow_mut();
301 sw.write_core(buf, offset)?;
302 Ok(buf.len())
303 }
304
305 fn write_all_at(&self, buf: &[u8], offset: u64) -> std::io::Result<()> {
306 let mut sw = self.cell.borrow_mut();
307 sw.write_core(buf, offset)
308 }
309}