1use std::task::{Poll, ready};
2
3use bytes::{Bytes, BytesMut};
4
5use crate::{Error, Result};
6
7#[derive(Clone, Debug)]
12#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
13pub struct Frame {
14 pub size: u64,
15}
16
17impl Frame {
18 pub fn produce(self) -> FrameProducer {
20 FrameProducer::new(self)
21 }
22}
23
24impl From<usize> for Frame {
25 fn from(size: usize) -> Self {
26 Self { size: size as u64 }
27 }
28}
29
30impl From<u64> for Frame {
31 fn from(size: u64) -> Self {
32 Self { size }
33 }
34}
35
36impl From<u32> for Frame {
37 fn from(size: u32) -> Self {
38 Self { size: size as u64 }
39 }
40}
41
42impl From<u16> for Frame {
43 fn from(size: u16) -> Self {
44 Self { size: size as u64 }
45 }
46}
47
48#[derive(Default, Debug)]
49struct FrameState {
50 chunks: Vec<Bytes>,
52
53 remaining: u64,
55
56 abort: Option<Error>,
58}
59
60impl FrameState {
61 fn write_chunk(&mut self, chunk: Bytes) -> Result<()> {
62 if let Some(err) = &self.abort {
63 return Err(err.clone());
64 }
65
66 self.remaining = self.remaining.checked_sub(chunk.len() as u64).ok_or(Error::WrongSize)?;
67 self.chunks.push(chunk);
68 Ok(())
69 }
70
71 fn poll_read_chunk(&self, index: usize) -> Poll<Result<Option<Bytes>>> {
72 if let Some(chunk) = self.chunks.get(index).cloned() {
73 Poll::Ready(Ok(Some(chunk)))
74 } else if self.remaining == 0 {
75 Poll::Ready(Ok(None))
76 } else if let Some(err) = &self.abort {
77 Poll::Ready(Err(err.clone()))
78 } else {
79 Poll::Pending
80 }
81 }
82
83 fn poll_read_chunks(&self, index: usize) -> Poll<Result<Vec<Bytes>>> {
84 if index >= self.chunks.len() && self.remaining == 0 {
85 Poll::Ready(Ok(Vec::new()))
86 } else if self.remaining == 0 {
87 Poll::Ready(Ok(self.chunks[index..].to_vec()))
88 } else if let Some(err) = &self.abort {
89 Poll::Ready(Err(err.clone()))
90 } else {
91 Poll::Pending
92 }
93 }
94
95 fn poll_read_all(&self, index: usize) -> Poll<Result<Bytes>> {
96 let chunks = ready!(self.poll_read_all_chunks(index)?);
97
98 Poll::Ready(Ok(match chunks.len() {
99 0 => Bytes::new(),
100 1 => chunks[0].clone(),
101 _ => {
102 let size = chunks.iter().map(Bytes::len).sum();
103 let mut buf = BytesMut::with_capacity(size);
104 for chunk in chunks {
105 buf.extend_from_slice(chunk.as_ref());
106 }
107 buf.freeze()
108 }
109 }))
110 }
111
112 fn poll_read_all_chunks(&self, index: usize) -> Poll<Result<&[Bytes]>> {
113 if self.remaining > 0 {
114 Poll::Pending
115 } else if let Some(err) = &self.abort {
116 Poll::Ready(Err(err.clone()))
117 } else if index < self.chunks.len() {
118 Poll::Ready(Ok(&self.chunks[index..]))
119 } else {
120 Poll::Ready(Ok(&[]))
121 }
122 }
123}
124
125pub struct FrameProducer {
130 pub info: Frame,
132
133 state: conducer::Producer<FrameState>,
135}
136
137impl FrameProducer {
138 pub fn new(info: Frame) -> Self {
140 let state = FrameState {
141 chunks: Vec::new(),
142 remaining: info.size,
143 abort: None,
144 };
145 Self {
146 info,
147 state: conducer::Producer::new(state),
148 }
149 }
150
151 pub fn write<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
155 let chunk = chunk.into();
156 let mut state = self.modify()?;
157 state.write_chunk(chunk)
158 }
159
160 #[deprecated(note = "use write(chunk) instead")]
164 pub fn write_chunk<B: Into<Bytes>>(&mut self, chunk: B) -> Result<()> {
165 self.write(chunk)
166 }
167
168 pub fn finish(&mut self) -> Result<()> {
172 let state = self.modify()?;
173 if state.remaining != 0 {
174 return Err(Error::WrongSize);
175 }
176 Ok(())
177 }
178
179 pub fn abort(&mut self, err: Error) -> Result<()> {
181 let mut guard = self.modify()?;
182 guard.abort = Some(err);
183 guard.close();
184 Ok(())
185 }
186
187 pub fn consume(&self) -> FrameConsumer {
189 FrameConsumer {
190 info: self.info.clone(),
191 state: self.state.consume(),
192 index: 0,
193 }
194 }
195
196 pub async fn unused(&self) -> Result<()> {
198 self.state
199 .unused()
200 .await
201 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
202 }
203
204 fn modify(&mut self) -> Result<conducer::Mut<'_, FrameState>> {
205 self.state
206 .write()
207 .map_err(|r| r.abort.clone().unwrap_or(Error::Dropped))
208 }
209}
210
211impl Clone for FrameProducer {
212 fn clone(&self) -> Self {
213 Self {
214 info: self.info.clone(),
215 state: self.state.clone(),
216 }
217 }
218}
219
220impl From<Frame> for FrameProducer {
221 fn from(info: Frame) -> Self {
222 FrameProducer::new(info)
223 }
224}
225
226#[derive(Clone)]
228pub struct FrameConsumer {
229 pub info: Frame,
231
232 state: conducer::Consumer<FrameState>,
234
235 index: usize,
238}
239
240impl FrameConsumer {
241 fn poll<F, R>(&self, waiter: &conducer::Waiter, f: F) -> Poll<Result<R>>
243 where
244 F: Fn(&conducer::Ref<'_, FrameState>) -> Poll<Result<R>>,
245 {
246 Poll::Ready(match ready!(self.state.poll(waiter, f)) {
247 Ok(res) => res,
248 Err(state) => Err(state.abort.clone().unwrap_or(Error::Dropped)),
250 })
251 }
252
253 pub fn poll_read_all(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Bytes>> {
255 let data = ready!(self.poll(waiter, |state| state.poll_read_all(self.index))?);
256 self.index = usize::MAX;
257 Poll::Ready(Ok(data))
258 }
259
260 pub async fn read_all(&mut self) -> Result<Bytes> {
262 conducer::wait(|waiter| self.poll_read_all(waiter)).await
263 }
264
265 pub fn poll_read_all_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
267 let chunks = ready!(self.poll(waiter, |state| {
268 state
270 .poll_read_all_chunks(self.index)
271 .map(|res| res.map(|chunks| chunks.to_vec()))
272 })?);
273 self.index += chunks.len();
274
275 Poll::Ready(Ok(chunks))
276 }
277
278 pub fn poll_read_chunk(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Option<Bytes>>> {
280 let Some(chunk) = ready!(self.poll(waiter, |state| state.poll_read_chunk(self.index))?) else {
281 return Poll::Ready(Ok(None));
282 };
283 self.index += 1;
284 Poll::Ready(Ok(Some(chunk)))
285 }
286
287 pub async fn read_chunk(&mut self) -> Result<Option<Bytes>> {
289 conducer::wait(|waiter| self.poll_read_chunk(waiter)).await
290 }
291
292 pub fn poll_read_chunks(&mut self, waiter: &conducer::Waiter) -> Poll<Result<Vec<Bytes>>> {
294 let chunks = ready!(self.poll(waiter, |state| state.poll_read_chunks(self.index))?);
295 self.index += chunks.len();
296 Poll::Ready(Ok(chunks))
297 }
298
299 pub async fn read_chunks(&mut self) -> Result<Vec<Bytes>> {
301 conducer::wait(|waiter| self.poll_read_chunks(waiter)).await
302 }
303}
304
305#[cfg(test)]
306mod test {
307 use super::*;
308 use futures::FutureExt;
309
310 #[test]
311 fn single_chunk_roundtrip() {
312 let mut producer = Frame { size: 5 }.produce();
313 producer.write(Bytes::from_static(b"hello")).unwrap();
314 producer.finish().unwrap();
315
316 let mut consumer = producer.consume();
317 let data = consumer.read_all().now_or_never().unwrap().unwrap();
318 assert_eq!(data, Bytes::from_static(b"hello"));
319 }
320
321 #[test]
322 fn multi_chunk_read_all() {
323 let mut producer = Frame { size: 10 }.produce();
324 producer.write(Bytes::from_static(b"hello")).unwrap();
325 producer.write(Bytes::from_static(b"world")).unwrap();
326 producer.finish().unwrap();
327
328 let mut consumer = producer.consume();
329 let data = consumer.read_all().now_or_never().unwrap().unwrap();
330 assert_eq!(data, Bytes::from_static(b"helloworld"));
331 }
332
333 #[test]
334 fn read_chunk_sequential() {
335 let mut producer = Frame { size: 10 }.produce();
336 producer.write(Bytes::from_static(b"hello")).unwrap();
337 producer.write(Bytes::from_static(b"world")).unwrap();
338 producer.finish().unwrap();
339
340 let mut consumer = producer.consume();
341 let c1 = consumer.read_chunk().now_or_never().unwrap().unwrap();
342 assert_eq!(c1, Some(Bytes::from_static(b"hello")));
343 let c2 = consumer.read_chunk().now_or_never().unwrap().unwrap();
344 assert_eq!(c2, Some(Bytes::from_static(b"world")));
345 let c3 = consumer.read_chunk().now_or_never().unwrap().unwrap();
346 assert_eq!(c3, None);
347 }
348
349 #[test]
350 fn read_all_chunks() {
351 let mut producer = Frame { size: 10 }.produce();
352 producer.write(Bytes::from_static(b"hello")).unwrap();
353 producer.write(Bytes::from_static(b"world")).unwrap();
354 producer.finish().unwrap();
355
356 let mut consumer = producer.consume();
357 let chunks = consumer.read_chunks().now_or_never().unwrap().unwrap();
358 assert_eq!(chunks.len(), 2);
359 assert_eq!(chunks[0], Bytes::from_static(b"hello"));
360 assert_eq!(chunks[1], Bytes::from_static(b"world"));
361 }
362
363 #[test]
364 fn finish_checks_remaining() {
365 let mut producer = Frame { size: 5 }.produce();
366 producer.write(Bytes::from_static(b"hi")).unwrap();
367 let err = producer.finish().unwrap_err();
368 assert!(matches!(err, Error::WrongSize));
369 }
370
371 #[test]
372 fn write_too_many_bytes() {
373 let mut producer = Frame { size: 3 }.produce();
374 let err = producer.write(Bytes::from_static(b"toolong")).unwrap_err();
375 assert!(matches!(err, Error::WrongSize));
376 }
377
378 #[test]
379 fn abort_propagates() {
380 let mut producer = Frame { size: 5 }.produce();
381 let mut consumer = producer.consume();
382 producer.abort(Error::Cancel).unwrap();
383
384 let err = consumer.read_all().now_or_never().unwrap().unwrap_err();
385 assert!(matches!(err, Error::Cancel));
386 }
387
388 #[test]
389 fn empty_frame() {
390 let mut producer = Frame { size: 0 }.produce();
391 producer.finish().unwrap();
392
393 let mut consumer = producer.consume();
394 let data = consumer.read_all().now_or_never().unwrap().unwrap();
395 assert_eq!(data, Bytes::new());
396 }
397
398 #[tokio::test]
399 async fn pending_then_ready() {
400 let mut producer = Frame { size: 5 }.produce();
401 let mut consumer = producer.consume();
402
403 assert!(consumer.read_all().now_or_never().is_none());
405
406 producer.write(Bytes::from_static(b"hello")).unwrap();
407 producer.finish().unwrap();
408
409 let data = consumer.read_all().now_or_never().unwrap().unwrap();
410 assert_eq!(data, Bytes::from_static(b"hello"));
411 }
412}