1extern crate alloc;
2
3use alloc::vec::Vec;
4use corevm_host::OutputStream;
5use jam_types::SEGMENT_LEN;
6
7#[derive(Default, Debug)]
13pub struct OutputBuffers {
14 buffers: [Vec<u8>; OutputStream::COUNT],
15}
16
17impl OutputBuffers {
18 #[must_use]
22 pub fn pre_allocate(&mut self, i: OutputStream, len: usize) -> &mut [u8] {
23 if len == 0 {
24 return &mut [];
25 }
26 let buf = self.get_mut(i);
27 let offset = buf.len();
28 buf.resize(offset + len, 0_u8);
29 &mut buf[offset..]
30 }
31
32 pub fn export_segments<F, E>(&mut self, mut export: F) -> Result<(), E>
39 where
40 F: FnMut(&RawSegment) -> Result<(), E>,
41 {
42 let mut output = SegmentedOutput::new();
43 for i in OutputStream::ALL {
44 output.write(self.get(i), &mut export)?;
45 }
46 output.finish(&mut export)?;
47 Ok(())
48 }
49
50 pub fn stream_len(&self) -> [u32; OutputStream::COUNT] {
52 [
53 self.buffers[0].len() as u32,
54 self.buffers[1].len() as u32,
55 self.buffers[2].len() as u32,
56 self.buffers[3].len() as u32,
57 ]
58 }
59
60 pub fn segment_count(&self) -> usize {
62 self.total_len().div_ceil(SEGMENT_LEN)
63 }
64
65 pub fn segment_count_after(&self, i: OutputStream, len: usize) -> Option<usize> {
68 self.total_len_after(i, len).map(|len| len.div_ceil(SEGMENT_LEN))
69 }
70
71 pub fn total_len(&self) -> usize {
75 self.buffers.iter().map(|buf| buf.len()).sum()
76 }
77
78 fn total_len_after(&self, i: OutputStream, len: usize) -> Option<usize> {
80 let buf = self.get(i);
81 let old_len = buf.len();
82 let new_len = old_len.checked_add(len)?;
83 (self.total_len() - old_len).checked_add(new_len)
84 }
85
86 fn get(&self, i: OutputStream) -> &[u8] {
88 &self.buffers[i as usize - 1]
89 }
90
91 fn get_mut(&mut self, i: OutputStream) -> &mut Vec<u8> {
93 &mut self.buffers[i as usize - 1]
94 }
95}
96
97struct SegmentedOutput {
104 segment: RawSegment,
105 offset: usize,
106}
107
108impl SegmentedOutput {
109 fn new() -> Self {
110 Self { segment: [0; SEGMENT_LEN], offset: 0 }
111 }
112
113 fn write<F, E>(&mut self, mut bytes: &[u8], export: &mut F) -> Result<(), E>
114 where
115 F: FnMut(&RawSegment) -> Result<(), E>,
116 {
117 while !bytes.is_empty() {
118 let n = (SEGMENT_LEN - self.offset).min(bytes.len());
119 self.segment[self.offset..self.offset + n].copy_from_slice(&bytes[..n]);
120 self.offset += n;
121 bytes = &bytes[n..];
122 if self.offset == SEGMENT_LEN {
123 export(&self.segment)?;
124 self.offset = 0;
125 }
126 }
127 Ok(())
128 }
129
130 fn finish<F, E>(&mut self, export: &mut F) -> Result<(), E>
131 where
132 F: FnMut(&RawSegment) -> Result<(), E>,
133 {
134 if self.offset == 0 {
135 return Ok(());
137 }
138 for i in self.offset..SEGMENT_LEN {
140 self.segment[i] = 0;
141 }
142 export(&self.segment)?;
143 self.offset = 0;
144 Ok(())
145 }
146}
147
148type RawSegment = [u8; SEGMENT_LEN];
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153 use core::convert::Infallible;
154 use jam_types::Segment;
155 use rand::{seq::IndexedRandom, Rng, RngCore};
156
157 #[test]
158 fn output_buffers_work() {
159 let mut rng = rand::rng();
160 for _ in 0..1000 {
161 let mut guest_buffers = OutputBuffers::default();
162 let mut stream_len = [0; OutputStream::COUNT];
163 for stream in OutputStream::ALL {
164 let num_bytes: u32 = rng.random_range(0..=100);
165 let slice = guest_buffers.pre_allocate(stream, num_bytes as usize);
166 rng.fill_bytes(slice);
167 stream_len[stream as usize - 1] = num_bytes;
168 }
169 let mut segments: Vec<Segment> = Vec::new();
170 guest_buffers
171 .export_segments(|s| {
172 segments.push(s.to_vec().try_into().unwrap());
173 Ok::<(), Infallible>(())
174 })
175 .unwrap();
176 let mut host_buffers = corevm_host::OutputBuffers::default();
177 host_buffers.append_from_segments(&segments[..], &stream_len);
178 for i in OutputStream::ALL {
179 let guest_bytes = guest_buffers.get(i);
180 assert_eq!(guest_bytes, host_buffers.get(i));
181 }
182 }
183 }
184
185 #[test]
186 fn segment_count_after_works() {
187 let mut rng = rand::rng();
188 for _ in 0..1000 {
189 let num_appends = rng.random_range(0..=20);
190 let mut guest_buffers = OutputBuffers::default();
191 for _ in 0..num_appends {
192 let stream = *OutputStream::ALL.choose(&mut rng).unwrap();
193 let num_bytes: usize = rng.random_range(0..=100);
194 let expected_total_len = guest_buffers.total_len_after(stream, num_bytes);
195 let expected_segment_count = guest_buffers.segment_count_after(stream, num_bytes);
196 let slice = guest_buffers.pre_allocate(stream, num_bytes);
197 rng.fill_bytes(slice);
198 assert_eq!(expected_total_len, Some(guest_buffers.total_len()));
199 let mut actual_segment_count = 0;
200 guest_buffers
201 .export_segments(|_| {
202 actual_segment_count += 1;
203 Ok::<(), Infallible>(())
204 })
205 .unwrap();
206 assert_eq!(
207 expected_segment_count,
208 Some(actual_segment_count),
209 "expected segment count = {expected_segment_count:?}, \
210 actual segment count = {actual_segment_count}, \
211 total len = {expected_total_len:?}"
212 );
213 }
214 }
215 }
216}