use std::{
cell::UnsafeCell,
ptr::copy_nonoverlapping,
sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU32, Ordering},
},
time::Duration,
};
use tokio::sync::Notify;
use crate::{
buffer::{Buf, BufferReader, BufferWriter, linked::LinkedBuffer, slice::BufferSlice},
consts::MAGIC_NUMBER,
error::Error,
protocol::event::{EventType, FallbackDataEvent},
queue::QueueElement,
session::Session,
};
pub const STREAM_OPENED: u32 = 0;
pub const STREAM_CLOSED: u32 = 1;
pub const STREAM_HALF_CLOSED: u32 = 2;
#[derive(Debug, Clone)]
pub struct Stream {
inner: Arc<StreamInner>,
id: u32,
session: Session,
session_id: usize,
}
#[derive(Debug)]
pub struct StreamInner {
recv_buf: UnsafeCell<LinkedBuffer>,
send_buf: UnsafeCell<LinkedBuffer>,
pending_data: Mutex<Vec<BufferSliceWrapper>>,
state: AtomicU32,
close_notify: Notify,
recv_notify: Notify,
in_fallback_state: AtomicBool,
}
unsafe impl Sync for StreamInner {}
impl Stream {
pub(crate) fn new(id: u32, session_id: usize, session: Session) -> Self {
let recv_notify = Notify::new();
let close_notify = Notify::new();
Self {
id,
session_id,
inner: Arc::new(StreamInner {
recv_buf: UnsafeCell::new(LinkedBuffer::new(session.shared.buffer_manager.clone())),
send_buf: UnsafeCell::new(LinkedBuffer::new(session.shared.buffer_manager.clone())),
pending_data: Mutex::new(Vec::new()),
state: AtomicU32::new(STREAM_OPENED),
close_notify,
recv_notify,
in_fallback_state: AtomicBool::new(false),
}),
session,
}
}
#[allow(clippy::mut_from_ref)]
#[inline]
pub fn recv_buf(&self) -> &mut LinkedBuffer {
unsafe { &mut *self.inner.recv_buf.get() }
}
#[allow(clippy::mut_from_ref)]
#[inline]
pub fn send_buf(&self) -> &mut LinkedBuffer {
unsafe { &mut *self.inner.send_buf.get() }
}
pub const fn stream_id(&self) -> u32 {
self.id
}
}
impl Stream {
async fn read_more(&self, min_size: usize, buf: &mut LinkedBuffer) -> Result<(), Error> {
self.move_pending_data(buf);
let recv_len = buf.len();
if recv_len >= min_size {
return Ok(());
}
if recv_len == 0 && self.inner.state.load(Ordering::SeqCst) != STREAM_OPENED {
return Err(Error::EndOfStream);
}
loop {
let recv_notified = self.inner.recv_notify.notified();
let close_notified = self.inner.close_notify.notified();
match futures::future::select(
std::pin::pin!(recv_notified),
std::pin::pin!(close_notified),
)
.await
{
futures::future::Either::Left(_) => {
self.move_pending_data(buf);
if buf.len() >= min_size {
return Ok(());
}
}
futures::future::Either::Right(_) => {
self.move_pending_data(buf);
if buf.len() >= min_size {
return Ok(());
}
if self.inner.state.load(Ordering::SeqCst) == STREAM_HALF_CLOSED {
return Err(Error::EndOfStream);
}
return Err(Error::StreamClosed);
}
}
}
}
fn move_pending_data(&self, buf: &mut LinkedBuffer) {
let mut pending_data = self.inner.pending_data.lock().unwrap();
if pending_data.is_empty() {
return;
}
let pre_len = buf.len();
for data in pending_data.drain(0..) {
if let Some(fallback_slice) = data.fallback_slice {
buf.append_buffer_slice(fallback_slice);
self.inner.in_fallback_state.store(true, Ordering::SeqCst);
continue;
}
let mut offset = data.offset;
loop {
let slice = match self.session.shared.buffer_manager.read_buffer_slice(offset) {
Ok(slice) => slice,
Err(err) => {
tracing::error!("read_buffer_slice error {err}");
break;
}
};
if !slice
.buffer_header
.as_ref()
.map(|h| h.has_next())
.unwrap_or(false)
{
buf.append_buffer_slice(slice);
break;
}
offset = slice.buffer_header.as_ref().unwrap().next_buffer_offset();
buf.append_buffer_slice(slice);
}
}
self.session
.shared
.stats
.in_flow_bytes
.fetch_add((buf.len() - pre_len) as u64, Ordering::SeqCst);
}
pub async fn write_fallback(
&self,
stream_status: u32,
err: Error,
send_buf: &mut LinkedBuffer,
) -> Result<(), Error> {
tracing::warn!(
"session {} stream fallback seqID:{} len:{} reason:{}, send_buf.is_from_share_memory: \
{}",
self.session.shared.name,
self.id,
send_buf.len(),
err,
send_buf.is_from_share_memory()
);
let mut event = FallbackDataEvent([0u8; 16].as_mut_ptr());
event.encode(
send_buf.len() as u32 + 16,
self.session.shared.communication_version,
self.id,
stream_status,
);
let mut data = Vec::with_capacity(send_buf.len() + 16);
data.extend_from_slice(event.as_slice());
let mut slice = send_buf.slice_list().front();
while let Some(s) = slice {
data.extend_from_slice(unsafe {
std::slice::from_raw_parts(s.data, s.write_index - s.read_index)
});
if send_buf
.slice_list()
.write()
.map(|ws| ws == s)
.unwrap_or(false)
{
break;
}
slice = s.next();
}
send_buf.recycle();
self.session.open_circuit_breaker().await;
self.session
.shared
.stats
.fallback_write_count
.fetch_add(1, Ordering::SeqCst);
self.session.wait_for_send(None, data).await
}
fn clean(&self) {
self.session
.on_stream_close(self.id, self.inner.state.load(Ordering::SeqCst));
self.clean_pending_data();
self.recv_buf().recycle();
self.send_buf().recycle();
}
fn clean_pending_data(&self) {
let mut pending_data = self.inner.pending_data.lock().unwrap();
for data in pending_data.drain(0..) {
if let Some(fallback_slice) = data.fallback_slice {
if !fallback_slice.is_from_shm {
unsafe {
_ = Vec::from_raw_parts(
fallback_slice.data,
fallback_slice.cap as usize,
fallback_slice.cap as usize,
)
}
} else {
tracing::warn!(
"fallback slice is from shm, offset:{}",
fallback_slice.offset_in_shm
);
}
continue;
}
match self
.session
.shared
.buffer_manager
.read_buffer_slice(data.offset)
{
Ok(slice) => {
self.session.shared.buffer_manager.recycle_buffers(slice);
}
Err(err) => {
tracing::error!("read_buffer_slice error {}", err);
break;
}
}
}
}
pub fn reset(&self) -> Result<(), Error> {
if self.inner.state.load(Ordering::SeqCst) != STREAM_OPENED {
return Err(Error::StreamClosed);
}
let unread_size = self.recv_buf().len();
if unread_size > 0 {
return Err(Error::StreamHasUnreadData(unread_size));
}
let pending_data_len = self.inner.pending_data.lock().unwrap().len();
if pending_data_len > 0 {
return Err(Error::StreamHasPendingData(pending_data_len));
}
self.inner.in_fallback_state.store(false, Ordering::SeqCst);
Ok(())
}
pub fn release_read_and_reuse(&self) {
let recv_buf = self.recv_buf();
let send_buf = self.send_buf();
recv_buf.release_previous_read_and_reserve();
if recv_buf.is_empty() && recv_buf.slice_list().size() == 1 {
std::mem::swap(recv_buf, send_buf);
}
}
pub fn fill_data_to_read_buffer(&self, buf: BufferSliceWrapper) -> Result<(), Error> {
self.inner.pending_data.lock().unwrap().push(buf);
if self.inner.state.load(Ordering::SeqCst) == STREAM_CLOSED {
self.clean_pending_data();
self.recv_buf().recycle();
return Ok(());
}
self.inner.recv_notify.notify_one();
Ok(())
}
pub fn is_open(&self) -> bool {
self.inner.state.load(Ordering::SeqCst) == STREAM_OPENED
}
pub fn safe_close_notify(&self) {
self.inner.close_notify.notify_waiters();
}
pub fn half_close(&self) {
if self
.inner
.state
.compare_exchange(
STREAM_OPENED,
STREAM_HALF_CLOSED,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_ok()
{
self.safe_close_notify();
}
}
pub const fn session_id(&self) -> usize {
self.session_id
}
pub fn fallback_state(&self) -> bool {
self.inner.in_fallback_state.load(Ordering::SeqCst)
}
pub async fn reuse(&self) {
self.session.put_or_close_stream(self.clone()).await;
}
pub async fn close(&mut self) -> Result<(), Error> {
let old_state = self.inner.state.swap(STREAM_CLOSED, Ordering::Release);
if old_state == STREAM_CLOSED {
return Ok(());
}
self.clean();
if old_state != STREAM_OPENED {
return Ok(());
}
self.safe_close_notify();
if self.session.shared.shutdown.load(Ordering::SeqCst) == 1 {
return Ok(());
}
if self
.session
.shared
.queue_manager
.send_queue
.put(QueueElement {
seq_id: self.id,
offset_in_shm_buf: 0,
status: STREAM_CLOSED,
})
.is_ok()
{
return self.session.wake_up_peer().await;
}
self.session
.shared
.stats
.queue_full_error_count
.fetch_add(1, Ordering::SeqCst);
let mut event = vec![0u8; 12];
unsafe {
let ptr = event.as_mut_ptr();
copy_nonoverlapping(12_u32.to_be_bytes().as_ptr(), ptr, 4);
copy_nonoverlapping(MAGIC_NUMBER.to_be_bytes().as_ptr(), ptr.offset(4), 2);
*ptr.offset(6) = self.session.shared.communication_version;
*ptr.offset(7) = EventType::TYPE_STREAM_CLOSE.inner();
copy_nonoverlapping(self.id.to_be_bytes().as_ptr(), ptr.offset(8), 4);
}
self.session.wait_for_send(None, event).await
}
pub async fn read(&mut self) -> Result<Buf<'_>, Error> {
let buf = self.recv_buf();
if buf.is_empty() {
tracing::debug!("read_bytes seqID:{}", self.id);
self.read_more(1, buf).await?;
}
buf.read_bytes(buf.len())
}
pub async fn read_bytes(&mut self, size: usize) -> Result<Buf<'_>, Error> {
let buf = self.recv_buf();
if buf.len() < size {
tracing::debug!(
"read_bytes seqID:{} len:{} size:{}",
self.id,
buf.len(),
size
);
self.read_more(size, buf).await?;
}
buf.read_bytes(size)
}
pub async fn peek(&mut self, size: usize) -> Result<Buf<'_>, Error> {
let buf = self.recv_buf();
if buf.len() < size {
self.read_more(size, buf).await?;
}
buf.peek(size)
}
pub async fn discard(&mut self, size: usize) -> Result<usize, Error> {
let buf = self.recv_buf();
if buf.len() < size {
self.read_more(size, buf).await?;
}
buf.discard(size)
}
pub fn reserve(&mut self, size: usize) -> Result<&mut [u8], Error> {
self.send_buf().reserve(size)
}
pub fn write_bytes(&mut self, data: &[u8]) -> Result<usize, Error> {
self.send_buf().write_bytes(data)
}
pub async fn flush(&mut self, end_stream: bool) -> Result<(), Error> {
let send_buf = self.send_buf();
if send_buf.is_empty() {
return Ok(());
}
self.session
.shared
.stats
.out_flow_bytes
.fetch_add(send_buf.len() as u64, Ordering::SeqCst);
let state = self.inner.state.load(Ordering::SeqCst);
if state != STREAM_OPENED {
send_buf.recycle();
return Err(Error::StreamClosed);
}
send_buf.done(end_stream);
if !send_buf.is_from_share_memory() {
self.inner.in_fallback_state.store(true, Ordering::SeqCst);
}
if self.inner.in_fallback_state.load(Ordering::SeqCst) {
let ret = self
.write_fallback(state, Error::NoMoreBuffer, send_buf)
.await;
send_buf.clean();
return ret;
}
match self
.session
.shared
.queue_manager
.send_queue
.put(QueueElement {
seq_id: self.id,
offset_in_shm_buf: send_buf.root_buf_offset(),
status: state,
}) {
Ok(_) => {
let ret = self.session.wake_up_peer().await;
send_buf.clean();
return ret;
}
Err(Error::QueueFull) => {}
Err(err) => {
send_buf.recycle();
return Err(err);
}
}
self.session
.shared
.stats
.queue_full_error_count
.fetch_add(1, Ordering::SeqCst);
for _ in 0..10 {
if tokio::time::timeout(
Duration::from_millis(10),
self.inner.close_notify.notified(),
)
.await
.is_ok()
{
send_buf.recycle();
return Err(Error::StreamClosed);
}
match self
.session
.shared
.queue_manager
.send_queue
.put(QueueElement {
seq_id: self.id,
offset_in_shm_buf: send_buf.root_buf_offset(),
status: state,
}) {
Ok(_) => {
let ret = self.session.wake_up_peer().await;
send_buf.clean();
return ret;
}
Err(Error::QueueFull) => continue,
Err(err) => {
send_buf.recycle();
return Err(err);
}
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct BufferSliceWrapper {
pub(crate) fallback_slice: Option<BufferSlice>,
pub(crate) offset: u32,
}