use crate::io::AsyncWrite;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug)]
pub(crate) struct SplitByUtf8BoundaryIfWindows<W> {
inner: W,
}
impl<W> SplitByUtf8BoundaryIfWindows<W> {
pub(crate) fn new(inner: W) -> Self {
Self { inner }
}
}
const MAX_BYTES_PER_CHAR: usize = 4;
const MAGIC_CONST: usize = 8;
impl<W> crate::io::AsyncWrite for SplitByUtf8BoundaryIfWindows<W>
where
W: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let mut call_inner = move |buf| Pin::new(&mut self.inner).poll_write(cx, buf);
if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF
{
return call_inner(buf);
}
buf = &buf[..crate::io::blocking::MAX_BUF];
let have_to_fix_up = match std::str::from_utf8(&buf[..MAX_BYTES_PER_CHAR * MAGIC_CONST]) {
Ok(_) => true,
Err(err) => {
let incomplete_bytes = MAX_BYTES_PER_CHAR * MAGIC_CONST - err.valid_up_to();
incomplete_bytes < MAX_BYTES_PER_CHAR
}
};
if have_to_fix_up {
let trailing_incomplete_char_size = buf
.iter()
.rev()
.take(MAX_BYTES_PER_CHAR)
.position(|byte| *byte < 0b1000_0000 || *byte >= 0b1100_0000)
.unwrap_or(0)
+ 1;
buf = &buf[..buf.len() - trailing_incomplete_char_size];
}
call_inner(buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
#[cfg(not(loom))]
mod tests {
use crate::io::AsyncWriteExt;
use std::io;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
const MAX_BUF: usize = 16 * 1024;
struct TextMockWriter;
impl crate::io::AsyncWrite for TextMockWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
assert!(buf.len() <= MAX_BUF);
assert!(std::str::from_utf8(buf).is_ok());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
struct LoggingMockWriter {
write_history: Vec<usize>,
}
impl LoggingMockWriter {
fn new() -> Self {
LoggingMockWriter {
write_history: Vec::new(),
}
}
}
impl crate::io::AsyncWrite for LoggingMockWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
assert!(buf.len() <= MAX_BUF);
self.write_history.push(buf.len());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn test_splitter() {
let data = str::repeat("█", MAX_BUF);
let mut wr = super::SplitByUtf8BoundaryIfWindows::new(TextMockWriter);
let fut = async move {
wr.write_all(data.as_bytes()).await.unwrap();
};
crate::runtime::Builder::new_current_thread()
.build()
.unwrap()
.block_on(fut);
}
#[test]
fn test_pseudo_text() {
let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR;
let mut data: Vec<u8> = str::repeat("a", checked_count).into();
data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1));
let mut writer = LoggingMockWriter::new();
let mut splitter = super::SplitByUtf8BoundaryIfWindows::new(&mut writer);
crate::runtime::Builder::new_current_thread()
.build()
.unwrap()
.block_on(async {
splitter.write_all(&data).await.unwrap();
});
assert!(writer.write_history.len() <= 2);
assert_eq!(
writer.write_history.iter().copied().sum::<usize>(),
data.len()
);
assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1);
}
}