use crate::alloc::Allocator;
use crate::config::Config;
use crate::error::SmuxError;
use crate::frame::{Cmd, Frame, RawHeader, UpdHeader, HEADER_SIZE, UPD_SIZE};
use crate::shaper::{ClassId, ShaperQueue, WriteRequest};
use crate::stream::{Stream, StreamInner};
use bytes::Bytes;
use std::collections::HashMap;
use std::io;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, Mutex, Notify};
use tokio::time::{sleep, Duration as TokioDuration};
const DEFAULT_ACCEPT_BACKLOG: usize = 1024;
const MIN_SHAPER_NOTIFY_SIZE: usize = 16;
const OPEN_CLOSE_TIMEOUT: Duration = Duration::from_secs(30);
pub struct Session {
conn_read: Arc<Mutex<Box<dyn AsyncRead + Send + Unpin>>>,
conn_write: Arc<Mutex<Box<dyn AsyncWrite + Send + Unpin>>>,
pub config: Config,
next_stream_id: Mutex<u32>,
is_client: bool,
bucket: AtomicI32,
bucket_notify: Arc<Notify>,
streams: Arc<Mutex<HashMap<u32, Arc<StreamInner>>>>,
die: Arc<AtomicBool>,
socket_read_error: Arc<Mutex<Option<io::Error>>>,
socket_write_error: Arc<Mutex<Option<io::Error>>>,
ch_socket_read_error: Arc<Notify>,
ch_socket_write_error: Arc<Notify>,
proto_error: Arc<Mutex<Option<SmuxError>>>,
ch_proto_error: Arc<Notify>,
ch_accepts: Arc<Mutex<mpsc::Receiver<Arc<StreamInner>>>>,
ch_accepts_tx: mpsc::Sender<Arc<StreamInner>>,
data_ready: AtomicBool,
go_away: AtomicBool,
deadline: Arc<Mutex<Option<Instant>>>,
request_id: AtomicU32,
shaper: Arc<Mutex<ShaperQueue>>,
ch_shaper_pending: Arc<Notify>,
allocator: Arc<Allocator>,
local_addr: Option<String>,
remote_addr: Option<String>,
}
pub trait AsyncReadWrite: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin> AsyncReadWrite for T {}
impl Session {
pub(crate) async fn new(
config: Config,
conn: Box<dyn AsyncReadWrite + Send + Unpin>,
is_client: bool,
) -> Result<Arc<Self>, SmuxError> {
config.verify().map_err(|_e| SmuxError::Io(io::ErrorKind::InvalidInput))?;
let (accepts_tx, accepts_rx) = mpsc::channel(DEFAULT_ACCEPT_BACKLOG);
let (read_half, write_half) = tokio::io::split(conn);
let session = Arc::new(Self {
conn_read: Arc::new(Mutex::new(Box::new(read_half))),
conn_write: Arc::new(Mutex::new(Box::new(write_half))),
config: config.clone(),
next_stream_id: Mutex::new(if is_client { 1 } else { 0 }),
is_client,
bucket: AtomicI32::new(config.max_receive_buffer as i32),
bucket_notify: Arc::new(Notify::new()),
streams: Arc::new(Mutex::new(HashMap::new())),
die: Arc::new(AtomicBool::new(false)),
socket_read_error: Arc::new(Mutex::new(None)),
socket_write_error: Arc::new(Mutex::new(None)),
ch_socket_read_error: Arc::new(Notify::new()),
ch_socket_write_error: Arc::new(Notify::new()),
proto_error: Arc::new(Mutex::new(None)),
ch_proto_error: Arc::new(Notify::new()),
ch_accepts: Arc::new(Mutex::new(accepts_rx)),
ch_accepts_tx: accepts_tx,
data_ready: AtomicBool::new(false),
go_away: AtomicBool::new(false),
deadline: Arc::new(Mutex::new(None)),
request_id: AtomicU32::new(0),
shaper: Arc::new(Mutex::new(ShaperQueue::new())),
ch_shaper_pending: Arc::new(Notify::new()),
allocator: Arc::new(Allocator::new()),
local_addr: None,
remote_addr: None,
});
let session_clone = Arc::clone(&session);
tokio::spawn(async move {
session_clone.recv_loop().await;
});
let session_clone = Arc::clone(&session);
tokio::spawn(async move {
session_clone.send_loop().await;
});
if !config.keep_alive_disabled {
let session_clone = Arc::clone(&session);
tokio::spawn(async move {
session_clone.keepalive().await;
});
}
session.ch_shaper_pending.notify_one();
Ok(session)
}
pub async fn open_stream(self: &Arc<Self>) -> Result<Stream, SmuxError> {
if self.is_closed() {
return Err(SmuxError::Io(io::ErrorKind::BrokenPipe));
}
let sid = {
let mut next_id = self.next_stream_id.lock().await;
if self.go_away.load(Ordering::Relaxed) {
return Err(SmuxError::GoAway);
}
*next_id += 2;
let sid = *next_id;
if sid % 2 == 0 && !self.is_client {
self.go_away.store(true, Ordering::Relaxed);
return Err(SmuxError::GoAway);
}
sid
};
let stream = Arc::new(StreamInner::new(sid, self.config.max_frame_size, Arc::clone(self)));
let frame = Frame::new(self.config.version, Cmd::Syn, sid);
self.write_control_frame(frame).await?;
{
let mut streams = self.streams.lock().await;
streams.insert(sid, stream.clone());
}
Ok(Stream::new(stream))
}
pub async fn accept_stream(self: &Arc<Self>) -> Result<Stream, SmuxError> {
let deadline_opt = *self.deadline.lock().await;
loop {
if let Some(err) = self.socket_read_error.lock().await.as_ref() {
return Err(SmuxError::Io(err.kind()));
}
if let Some(err) = self.proto_error.lock().await.as_ref() {
return Err(err.clone());
}
if self.is_closed() {
return Err(SmuxError::Io(io::ErrorKind::BrokenPipe));
}
if let Some(deadline) = deadline_opt {
let now = Instant::now();
if now >= deadline {
return Err(SmuxError::Timeout);
}
let timeout_duration = deadline.duration_since(now);
let mut rx = self.ch_accepts.lock().await;
tokio::select! {
stream = rx.recv() => {
if let Some(stream) = stream {
return Ok(crate::stream::Stream::new(stream));
}
continue;
}
_ = sleep(TokioDuration::from_secs(timeout_duration.as_secs())) => {
return Err(SmuxError::Timeout);
}
_ = self.ch_socket_read_error.notified() => {
continue;
}
_ = self.ch_proto_error.notified() => {
continue;
}
}
} else {
let mut rx = self.ch_accepts.lock().await;
tokio::select! {
stream = rx.recv() => {
if let Some(stream) = stream {
return Ok(crate::stream::Stream::new(stream));
}
continue;
}
_ = self.ch_socket_read_error.notified() => {
continue;
}
_ = self.ch_proto_error.notified() => {
continue;
}
}
};
}
}
pub async fn close(&self) -> io::Result<()> {
if self.die.swap(true, Ordering::Relaxed) {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
let streams = self.streams.lock().await;
for stream in streams.values() {
let _ = stream.close().await;
}
Ok(())
}
pub fn is_closed(&self) -> bool {
self.die.load(Ordering::Relaxed)
}
pub async fn num_streams(&self) -> usize {
let streams = self.streams.lock().await;
streams.len()
}
pub async fn set_deadline(&self, deadline: Option<Instant>) {
*self.deadline.lock().await = deadline;
}
pub fn local_addr(&self) -> Option<&str> {
self.local_addr.as_deref()
}
pub fn remote_addr(&self) -> Option<&str> {
self.remote_addr.as_deref()
}
pub fn set_local_addr(&mut self, addr: String) {
self.local_addr = Some(addr);
}
pub fn set_remote_addr(&mut self, addr: String) {
self.remote_addr = Some(addr);
}
pub async fn poll_wait(&self, streams: &[&crate::stream::Stream]) -> Result<usize, SmuxError> {
use tokio::time::{sleep, Duration};
if streams.is_empty() {
return Err(SmuxError::Io(io::ErrorKind::InvalidInput));
}
loop {
if self.is_closed() {
return Err(SmuxError::Io(io::ErrorKind::BrokenPipe));
}
for (idx, stream) in streams.iter().enumerate() {
let inner = &stream.inner;
if inner.has_buffered_data().await {
return Ok(idx);
}
if inner.is_fin().await {
return Ok(idx);
}
}
let (tx, mut rx) = tokio::sync::mpsc::channel(streams.len());
let mut handles = Vec::new();
for (idx, stream) in streams.iter().enumerate() {
let notifier = stream.inner.get_read_event_notifier();
let tx_clone = tx.clone();
let handle = tokio::spawn(async move {
notifier.notified().await;
let _ = tx_clone.send(idx).await;
});
handles.push(handle);
}
drop(tx);
let result = tokio::select! {
Some(idx) = rx.recv() => {
Ok(idx)
}
_ = sleep(Duration::from_millis(50)) => {
continue;
}
_ = self.ch_socket_read_error.notified() => {
Err(SmuxError::Io(io::ErrorKind::BrokenPipe))
}
};
for handle in handles {
handle.abort();
}
if let Ok(idx) = result {
return Ok(idx);
} else if let Err(e) = result {
return Err(e);
}
}
}
pub(crate) fn return_tokens(&self, n: usize) {
let new_bucket = self.bucket.fetch_add(n as i32, Ordering::Relaxed) + n as i32;
if new_bucket > 0 {
self.bucket_notify.notify_one();
}
}
fn notify_bucket(&self) {
self.bucket_notify.notify_one();
}
pub(crate) async fn stream_closed(&self, sid: u32) {
let mut streams = self.streams.lock().await;
if let Some(stream) = streams.remove(&sid) {
let n = stream.recycle_tokens().await;
if n > 0 {
self.return_tokens(n);
}
}
}
async fn write_control_frame(&self, frame: Frame) -> Result<usize, SmuxError> {
self.write_frame_internal(frame, Some(OPEN_CLOSE_TIMEOUT), ClassId::Ctrl).await
}
pub(crate) async fn write_frame(&self, frame: Frame, is_control: bool) -> Result<usize, SmuxError> {
let class = if is_control {
ClassId::Ctrl
} else {
ClassId::Data
};
self.write_frame_internal(frame, None, class).await
}
async fn write_frame_internal(
&self,
frame: Frame,
_timeout: Option<Duration>,
class: ClassId,
) -> Result<usize, SmuxError> {
let seq = self.request_id.fetch_add(1, Ordering::Relaxed);
let data_len = frame.data.len();
let req = WriteRequest {
class,
frame,
seq,
};
{
let mut shaper = self.shaper.lock().await;
shaper.push(req).await;
let len = shaper.len().await;
if class == ClassId::Ctrl || len >= MIN_SHAPER_NOTIFY_SIZE {
self.ch_shaper_pending.notify_one();
}
}
Ok(data_len)
}
async fn recv_loop(self: Arc<Self>) {
let mut header_buf = vec![0u8; HEADER_SIZE];
loop {
while self.bucket.load(Ordering::Relaxed) <= 0 && !self.is_closed() {
self.bucket_notify.notified().await;
}
if self.is_closed() {
return;
}
let read_result = {
let mut conn_read = self.conn_read.lock().await;
conn_read.read_exact(&mut header_buf).await
};
match read_result {
Ok(_) => {}
Err(e) => {
*self.socket_read_error.lock().await = Some(e);
self.ch_socket_read_error.notify_waiters();
return;
}
}
self.data_ready.store(true, Ordering::Relaxed);
let header = match RawHeader::from_bytes(&header_buf) {
Some(h) => h,
None => {
*self.proto_error.lock().await = Some(SmuxError::InvalidProtocol);
self.ch_proto_error.notify_waiters();
return;
}
};
if header.ver != self.config.version {
*self.proto_error.lock().await = Some(SmuxError::InvalidProtocol);
self.ch_proto_error.notify_waiters();
return;
}
let sid = header.sid;
match header.cmd {
Cmd::Nop => {
}
Cmd::Syn => {
let mut streams = self.streams.lock().await;
if !streams.contains_key(&sid) {
let stream = Arc::new(StreamInner::new(
sid,
self.config.max_frame_size,
Arc::clone(&self),
));
streams.insert(sid, stream.clone());
let _ = self.ch_accepts_tx.send(stream).await;
}
}
Cmd::Fin => {
let streams = self.streams.lock().await;
if let Some(stream) = streams.get(&sid) {
stream.fin();
}
}
Cmd::Psh => {
if header.length == 0 {
continue;
}
let data_len = header.length as usize;
let data = match self.allocator.get(data_len).await {
Some(mut buf) => {
buf.resize(data_len, 0);
let read_result = {
let mut conn_read = self.conn_read.lock().await;
conn_read.read_exact(&mut buf).await
};
match read_result {
Ok(_) => Bytes::from(buf),
Err(e) => {
*self.socket_read_error.lock().await = Some(e);
self.ch_socket_read_error.notify_waiters();
return;
}
}
}
None => continue,
};
let streams = self.streams.lock().await;
if let Some(stream) = streams.get(&sid) {
stream.push_bytes(data.clone()).await;
let written = data.len();
self.bucket.fetch_sub(written as i32, Ordering::Relaxed);
}
}
Cmd::Upd => {
let mut upd_buf = vec![0u8; UPD_SIZE];
let read_result = {
let mut conn_read = self.conn_read.lock().await;
conn_read.read_exact(&mut upd_buf).await
};
match read_result {
Ok(_) => {}
Err(e) => {
*self.socket_read_error.lock().await = Some(e);
self.ch_socket_read_error.notify_waiters();
return;
}
}
let upd = match UpdHeader::from_bytes(&upd_buf) {
Some(u) => u,
None => continue,
};
let streams = self.streams.lock().await;
if let Some(stream) = streams.get(&sid) {
stream.update(upd.consumed, upd.window).await;
}
}
}
}
}
async fn send_loop(self: Arc<Self>) {
loop {
if self.is_closed() {
return;
}
tokio::select! {
_ = self.ch_shaper_pending.notified() => {
}
_ = sleep(TokioDuration::from_millis(10)) => {
}
}
loop {
let req = {
let mut shaper = self.shaper.lock().await;
shaper.pop().await
};
if let Some(req) = req {
let encoded = req.frame.encode();
let write_result = {
let mut conn_write = self.conn_write.lock().await;
conn_write.write_all(&encoded).await
};
match write_result {
Ok(_) => {
}
Err(e) => {
*self.socket_write_error.lock().await = Some(e);
self.ch_socket_write_error.notify_waiters();
return;
}
}
} else {
break;
}
}
}
}
async fn keepalive(self: Arc<Self>) {
let ping_interval = self.config.keep_alive_interval;
let timeout_interval = self.config.keep_alive_timeout;
loop {
sleep(TokioDuration::from_secs(ping_interval.as_secs())).await;
if self.is_closed() {
return;
}
let frame = Frame::new(self.config.version, Cmd::Nop, 0);
let _ = self.write_frame_internal(frame, Some(ping_interval), ClassId::Ctrl).await;
self.notify_bucket();
sleep(TokioDuration::from_secs(
(timeout_interval - ping_interval).as_secs(),
))
.await;
if !self.data_ready.swap(false, Ordering::Relaxed) {
if self.bucket.load(Ordering::Relaxed) > 0 {
let _ = self.close().await;
return;
}
}
}
}
}