1#![forbid(unsafe_code)]
59
60use core::pin::Pin;
61use core::task::{Context, Poll};
62use fixed_buffer::{FixedBuf, MalformedInputError};
63
64mod async_read_write_chain;
65pub use async_read_write_chain::*;
66
67mod async_read_write_take;
68pub use async_read_write_take::*;
69
70#[cfg(test)]
71mod test_utils;
72#[cfg(test)]
73pub use test_utils::*;
74
75pub struct AsyncFixedBuf<const SIZE: usize>(FixedBuf<SIZE>);
84
85impl<const SIZE: usize> AsyncFixedBuf<SIZE> {
86 pub const fn new() -> Self {
92 AsyncFixedBuf(FixedBuf::new())
93 }
94
95 pub fn into_inner(self) -> FixedBuf<SIZE> {
98 self.0
99 }
100
101 pub fn empty(mem: [u8; SIZE]) -> Self {
106 Self(FixedBuf::empty(mem))
107 }
108
109 pub fn filled(mem: [u8; SIZE]) -> Self {
115 Self(FixedBuf::filled(mem))
116 }
117
118 pub async fn copy_once_from<R: tokio::io::AsyncRead + std::marker::Unpin + Send>(
124 &mut self,
125 reader: &mut R,
126 ) -> Result<usize, std::io::Error> {
127 let writable = self.writable();
128 if writable.is_empty() {
129 return Err(std::io::Error::new(
130 std::io::ErrorKind::InvalidData,
131 "no empty space in buffer",
132 ));
133 };
134 let num_read = tokio::io::AsyncReadExt::read(reader, writable).await?;
135 self.wrote(num_read);
136 Ok(num_read)
137 }
138
139 pub async fn read_frame<R, F>(
142 &mut self,
143 reader: &mut R,
144 deframer_fn: F,
145 ) -> Result<Option<&[u8]>, std::io::Error>
146 where
147 R: tokio::io::AsyncRead + std::marker::Unpin + Send,
148 F: Fn(&[u8]) -> Result<Option<(core::ops::Range<usize>, usize)>, MalformedInputError>,
149 {
150 loop {
151 if !self.is_empty() {
152 if let Some(frame_range) = self.deframe(&deframer_fn)? {
153 return Ok(Some(&self.mem()[frame_range]));
154 }
155 }
157 self.shift();
158 let writable = self.writable();
159 if writable.is_empty() {
160 return Err(std::io::Error::new(
161 std::io::ErrorKind::InvalidData,
162 "end of buffer full",
163 ));
164 };
165 let num_read = tokio::io::AsyncReadExt::read(reader, writable).await?;
166 if num_read == 0 {
167 if self.is_empty() {
168 return Ok(None);
169 }
170 return Err(std::io::Error::new(
171 std::io::ErrorKind::UnexpectedEof,
172 "eof after reading part of a frame",
173 ));
174 }
175 self.wrote(num_read);
176 }
177 }
178}
179
180impl<const SIZE: usize> Unpin for AsyncFixedBuf<SIZE> {}
181
182impl<const SIZE: usize> std::ops::Deref for AsyncFixedBuf<SIZE> {
183 type Target = FixedBuf<SIZE>;
184 fn deref(&self) -> &Self::Target {
185 &self.0
186 }
187}
188
189impl<const SIZE: usize> std::ops::DerefMut for AsyncFixedBuf<SIZE> {
190 fn deref_mut(&mut self) -> &mut Self::Target {
191 &mut self.0
192 }
193}
194
195impl<const SIZE: usize> tokio::io::AsyncRead for AsyncFixedBuf<SIZE> {
196 fn poll_read(
197 self: Pin<&mut Self>,
198 _cx: &mut Context<'_>,
199 buf: &mut tokio::io::ReadBuf<'_>,
200 ) -> Poll<Result<(), std::io::Error>> {
201 let num_read = self
202 .get_mut()
203 .0
204 .read_and_copy_bytes(buf.initialize_unfilled());
205 buf.advance(num_read);
206 Poll::Ready(Ok(()))
207 }
208}
209
210impl<const SIZE: usize> tokio::io::AsyncWrite for AsyncFixedBuf<SIZE> {
211 fn poll_write(
212 self: Pin<&mut Self>,
213 _cx: &mut Context<'_>,
214 buf: &[u8],
215 ) -> Poll<Result<usize, std::io::Error>> {
216 Poll::Ready(self.get_mut().0.write_bytes(buf).map_err(|_| {
217 std::io::Error::new(std::io::ErrorKind::InvalidData, "no space in buffer")
218 }))
219 }
220
221 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
222 Poll::Ready(Ok(()))
223 }
224
225 fn poll_shutdown(
226 self: Pin<&mut Self>,
227 _cx: &mut Context<'_>,
228 ) -> Poll<Result<(), std::io::Error>> {
229 Poll::Ready(Ok(()))
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use fixed_buffer::*;
237
238 fn deframe_line_reject_xs(
239 data: &[u8],
240 ) -> Result<Option<(core::ops::Range<usize>, usize)>, MalformedInputError> {
241 if data.contains(&b'x') || data.contains(&b'X') {
242 return Err(MalformedInputError::new(String::from("err1")));
243 }
244 deframe_line(data)
245 }
246
247 #[tokio::test]
248 async fn test_read_frame_empty_to_eof() {
249 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
250 let mut reader = std::io::Cursor::new(b"");
251 assert_eq!(
252 None,
253 buf.read_frame(&mut reader, deframe_line_reject_xs)
254 .await
255 .unwrap()
256 );
257 assert_eq!("", escape_ascii(buf.readable()));
258 }
259
260 #[tokio::test]
261 async fn test_read_frame_empty_to_incomplete() {
262 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
263 let mut reader = std::io::Cursor::new(b"abc");
264 assert_eq!(
265 std::io::ErrorKind::UnexpectedEof,
266 buf.read_frame(&mut reader, deframe_line_reject_xs)
267 .await
268 .unwrap_err()
269 .kind()
270 );
271 assert_eq!("abc", escape_ascii(buf.readable()));
272 }
273
274 #[tokio::test]
275 async fn test_read_frame_empty_to_complete() {
276 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
277 let mut reader = std::io::Cursor::new(b"abc\n");
278 assert_eq!(
279 "abc",
280 escape_ascii(
281 buf.read_frame(&mut reader, deframe_line_reject_xs)
282 .await
283 .unwrap()
284 .unwrap()
285 )
286 );
287 assert_eq!("", escape_ascii(buf.readable()));
288 }
289
290 #[tokio::test]
291 async fn test_read_frame_empty_to_complete_with_leftover() {
292 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
293 let mut reader = std::io::Cursor::new(b"abc\nde");
294 assert_eq!(
295 "abc",
296 escape_ascii(
297 buf.read_frame(&mut reader, deframe_line_reject_xs)
298 .await
299 .unwrap()
300 .unwrap()
301 )
302 );
303 assert_eq!("de", escape_ascii(buf.readable()));
304 }
305
306 #[tokio::test]
307 async fn test_read_frame_empty_to_invalid() {
308 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
309 let mut reader = std::io::Cursor::new(b"x");
310 assert_eq!(
311 std::io::ErrorKind::InvalidData,
312 buf.read_frame(&mut reader, deframe_line_reject_xs)
313 .await
314 .unwrap_err()
315 .kind()
316 );
317 assert_eq!("x", escape_ascii(buf.readable()));
318 }
319
320 #[tokio::test]
321 async fn test_read_frame_incomplete_to_eof() {
322 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
323 buf.write_str("a").unwrap();
324 let mut reader = std::io::Cursor::new(b"");
325 assert_eq!(
326 std::io::ErrorKind::UnexpectedEof,
327 buf.read_frame(&mut reader, deframe_line_reject_xs)
328 .await
329 .unwrap_err()
330 .kind()
331 );
332 assert_eq!("a", escape_ascii(buf.readable()));
333 }
334
335 #[tokio::test]
336 async fn test_read_frame_incomplete_to_incomplete() {
337 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
338 buf.write_str("a").unwrap();
339 let mut reader = std::io::Cursor::new(b"bc");
340 assert_eq!(
341 std::io::ErrorKind::UnexpectedEof,
342 buf.read_frame(&mut reader, deframe_line_reject_xs)
343 .await
344 .unwrap_err()
345 .kind()
346 );
347 assert_eq!("abc", escape_ascii(buf.readable()));
348 }
349
350 #[tokio::test]
351 async fn test_read_frame_incomplete_to_complete() {
352 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
353 buf.write_str("a").unwrap();
354 let mut reader = std::io::Cursor::new(b"bc\n");
355 assert_eq!(
356 "abc",
357 escape_ascii(
358 buf.read_frame(&mut reader, deframe_line_reject_xs)
359 .await
360 .unwrap()
361 .unwrap()
362 )
363 );
364 assert_eq!("", escape_ascii(buf.readable()));
365 }
366
367 #[tokio::test]
368 async fn test_read_frame_incomplete_to_complete_with_leftover() {
369 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
370 buf.write_str("a").unwrap();
371 let mut reader = std::io::Cursor::new(b"bc\nde");
372 assert_eq!(
373 "abc",
374 escape_ascii(
375 buf.read_frame(&mut reader, deframe_line_reject_xs)
376 .await
377 .unwrap()
378 .unwrap()
379 )
380 );
381 assert_eq!("de", escape_ascii(buf.readable()));
382 }
383
384 #[tokio::test]
385 async fn test_read_frame_complete_doesnt_read() {
386 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
387 buf.write_str("abc\n").unwrap();
388 assert_eq!(
389 "abc",
390 escape_ascii(
391 buf.read_frame(&mut FakeAsyncReadWriter::empty(), deframe_line_reject_xs)
392 .await
393 .unwrap()
394 .unwrap()
395 )
396 );
397 assert_eq!("", escape_ascii(buf.readable()));
398 }
399
400 #[tokio::test]
401 async fn test_read_frame_complete_leaves_leftovers() {
402 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
403 buf.write_str("abc\nde").unwrap();
404 assert_eq!(
405 "abc",
406 escape_ascii(
407 buf.read_frame(&mut FakeAsyncReadWriter::empty(), deframe_line_reject_xs)
408 .await
409 .unwrap()
410 .unwrap()
411 )
412 );
413 assert_eq!("de", escape_ascii(buf.readable()));
414 }
415
416 #[tokio::test]
417 async fn test_read_frame_invalid_doesnt_read() {
418 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
419 buf.write_str("x").unwrap();
420 assert_eq!(
421 std::io::ErrorKind::InvalidData,
422 buf.read_frame(&mut FakeAsyncReadWriter::empty(), deframe_line_reject_xs)
423 .await
424 .unwrap_err()
425 .kind()
426 );
427 assert_eq!("x", escape_ascii(buf.readable()));
428 }
429
430 #[tokio::test]
431 async fn test_read_frame_buffer_full() {
432 let mut buf: AsyncFixedBuf<8> = AsyncFixedBuf::new();
433 buf.write_str("abcdefgh").unwrap();
434 let mut reader = std::io::Cursor::new(b"bc\nde");
435 assert_eq!(
436 std::io::ErrorKind::InvalidData,
437 buf.read_frame(&mut reader, deframe_line_reject_xs)
438 .await
439 .unwrap_err()
440 .kind()
441 );
442 assert_eq!("abcdefgh", escape_ascii(buf.readable()));
443 }
444
445 #[tokio::test]
446 async fn test_async_read() {
447 let mut buf: AsyncFixedBuf<16> = AsyncFixedBuf::new();
448 let mut data = ['.' as u8; 16];
449 assert_eq!(
450 0,
451 tokio::io::AsyncReadExt::read(&mut buf, &mut data)
452 .await
453 .unwrap()
454 );
455 assert_eq!("..........", escape_ascii(&data[..10]));
456 buf.write_str("abc").unwrap();
457 assert_eq!(
458 3,
459 tokio::io::AsyncReadExt::read(&mut buf, &mut data)
460 .await
461 .unwrap()
462 );
463 assert_eq!("abc.......", escape_ascii(&data[..10]));
464 assert_eq!(
465 0,
466 tokio::io::AsyncReadExt::read(&mut buf, &mut data)
467 .await
468 .unwrap()
469 );
470 let many_bs = "b".repeat(16);
471 buf.write_str(&many_bs).unwrap();
472 assert_eq!(
473 16,
474 tokio::io::AsyncReadExt::read(&mut buf, &mut data)
475 .await
476 .unwrap()
477 );
478 assert_eq!(many_bs, escape_ascii(&data[..]));
479 assert_eq!(
480 0,
481 tokio::io::AsyncReadExt::read(&mut buf, &mut data)
482 .await
483 .unwrap()
484 );
485 }
486
487 #[tokio::test]
488 async fn test_async_write() {
489 let mut buf: AsyncFixedBuf<16> = AsyncFixedBuf::new();
490 tokio::io::AsyncWriteExt::write_all(&mut buf, b"abc")
491 .await
492 .unwrap();
493 assert_eq!("abc", escape_ascii(buf.readable()));
494 tokio::io::AsyncWriteExt::write_all(&mut buf, b"def")
495 .await
496 .unwrap();
497 assert_eq!("abcdef", escape_ascii(buf.readable()));
498 buf.read_bytes(1);
499 tokio::io::AsyncWriteExt::write_all(&mut buf, b"g")
500 .await
501 .unwrap();
502 assert_eq!("bcdefg", escape_ascii(buf.readable()));
503 tokio::io::AsyncWriteExt::write_all(&mut buf, "h".repeat(8).as_bytes())
504 .await
505 .unwrap();
506 tokio::io::AsyncWriteExt::write_all(&mut buf, b"i")
507 .await
508 .unwrap();
509 assert_eq!(
510 std::io::ErrorKind::InvalidData,
511 tokio::io::AsyncWriteExt::write_all(&mut buf, b"def")
512 .await
513 .unwrap_err()
514 .kind()
515 );
516 }
517}