use std::io::Result;
use std::io::IoSlice;
use std::task::{Poll, ready};
use std::marker::PhantomData;
use super::min_len;
use super::super::{Stream, RoleHelper};
use super::super::state::{WriteState, HeadStore};
use crate::frame::FrameHead;
use crate::frame::{Fin, OpCode, PayloadLen};
pub fn write_some<F, IO, Role, Guard>(
stream: &mut Stream<IO, Role, Guard>,
mut write: F,
buf: &[u8],
) -> Poll<Result<usize>>
where
F: FnMut(&mut IO, &[IoSlice]) -> Poll<Result<usize>>,
Role: RoleHelper,
{
match stream.write_state {
WriteState::WriteZero => Poll::Ready(Ok(0)),
WriteState::WriteHead(mut head_store) => {
let frame_len = buf.len();
if head_store.is_empty() {
WriteFrameHead::<Role>::write_data_frame(&mut head_store, &mut stream.role, buf);
}
let iovec = [IoSlice::new(head_store.read()), IoSlice::new(buf)];
let write_n = ready!(write(&mut stream.io, &iovec))?;
let head_len = head_store.rd_left();
if write_n == 0 {
stream.write_state = WriteState::WriteZero;
return Poll::Ready(Ok(0));
}
if write_n < head_len {
head_store.advance_rd_pos(write_n);
stream.write_state = WriteState::WriteHead(head_store);
return Poll::Ready(Ok(0));
}
let write_n = write_n - head_len;
if write_n == frame_len {
stream.write_state = WriteState::new();
} else {
stream.write_state = WriteState::WriteData((frame_len - write_n) as u64);
}
Poll::Ready(Ok(write_n))
}
WriteState::WriteData(next) => {
let len = min_len(buf.len(), next);
let write_n = ready!(write(&mut stream.io, &[IoSlice::new(&buf[..len])]))?;
if write_n == 0 {
stream.write_state = WriteState::WriteZero;
return Poll::Ready(Ok(0));
}
if next == write_n as u64 {
stream.write_state = WriteState::new()
} else {
stream.write_state = WriteState::WriteData(next - write_n as u64)
}
Poll::Ready(Ok(write_n))
}
}
}
struct WriteFrameHead<Role: RoleHelper> {
_marker: PhantomData<Role>,
}
trait WriteFrameHeadTrait<R> {
fn write_data_frame(_: &mut HeadStore, _: &mut R, _: &[u8]) {}
}
impl<Role: RoleHelper> WriteFrameHeadTrait<Role> for WriteFrameHead<Role> {
#[inline]
default fn write_data_frame(store: &mut HeadStore, role: &mut Role, buf: &[u8]) {
let head = FrameHead::new(
Fin::Y,
OpCode::Binary,
role.mask_key(),
PayloadLen::from_num(buf.len() as u64),
);
let n = unsafe { head.encode_unchecked(store.as_mut()) };
store.set_wr_pos(n);
}
}
cfg_if::cfg_if! {
if #[cfg(feature = "unsafe_auto_mask_write")] {
use crate::role::AutoMaskClientRole;
use crate::bleed::const_cast;
use crate::frame::{Mask, new_mask_key, apply_mask4};
}
}
#[cfg(feature = "unsafe_auto_mask_write")]
impl<Role: AutoMaskClientRole> WriteFrameHeadTrait<Role> for WriteFrameHead<Role> {
#[inline]
fn write_data_frame(store: &mut HeadStore, role: &mut Role, buf: &[u8]) {
let key = if Role::UPDATE_MASK_KEY {
let key = new_mask_key();
role.set_mask_key(key);
key
} else {
role.mask_key().to_key()
};
unsafe {
let buf = const_cast(buf);
let buf = std::hint::black_box(buf);
apply_mask4(key, buf);
}
let head = FrameHead::new(
Fin::Y,
OpCode::Binary,
Mask::Key(key),
PayloadLen::from_num(buf.len() as u64),
);
let n = unsafe { head.encode_unchecked(store.as_mut()) };
store.set_wr_pos(n);
}
}
#[cfg(all(test, feature = "unsafe_auto_mask_write"))]
mod test {
use super::*;
use crate::bleed::Store;
use crate::frame::mask::*;
use crate::role::*;
fn auto_mask<R: RoleHelper>(role: &mut R, buf: &[u8]) {
let mut store = Store::new();
WriteFrameHead::<R>::write_data_frame(&mut store, role, buf)
}
#[test]
fn auto_mask_active() {
for i in 0..4096 {
let mut buf: Vec<u8> = std::iter::repeat(rand::random::<u8>()).take(i).collect();
let buf2 = buf.clone();
assert_eq!(buf.len(), i);
let mut role = StandardClient::new();
for _ in 0..8 {
auto_mask(&mut role, &buf2);
let key = role.mask_key().to_key();
apply_mask4(key, &mut buf);
assert_eq!(buf, buf2);
}
}
}
#[test]
fn auto_mask_active2() {
for i in 0..4096 {
let mut buf: Vec<u8> = std::iter::repeat(rand::random::<u8>()).take(i).collect();
let buf2 = buf.clone();
assert_eq!(buf.len(), i);
let mut role = FixedMaskClient::new();
let key = role.mask_key().to_key();
for _ in 0..8 {
auto_mask(&mut role, &buf2);
assert_eq!(key, role.mask_key().to_key());
apply_mask4(key, &mut buf);
assert_eq!(buf, buf2);
}
}
}
#[test]
fn auto_mask_inactive() {
for i in 0..4096 {
let buf: Vec<u8> = std::iter::repeat(rand::random::<u8>()).take(i).collect();
let buf2 = buf.clone();
assert_eq!(buf.len(), i);
let mut client = Client::new();
let mut server = Server::new();
for _ in 0..8 {
auto_mask(&mut client, &buf2);
assert_eq!(buf, buf2);
}
for _ in 0..8 {
auto_mask(&mut server, &buf2);
assert_eq!(buf, buf2);
}
}
}
}