use bytes::Bytes;
use parking_lot::Mutex;
use std::future::Future;
use std::sync::{Arc, Weak};
use std::task::{Context, Poll};
use tokio::sync::{mpsc, oneshot};
use crate::error::{Result, Error};
use super::channel_state::{self, ChannelState, ChannelSendData};
use super::client::Client;
use super::client_state::ClientState;
#[derive(Clone)]
pub struct Channel {
pub(super) client_st: Weak<Mutex<ClientState>>,
pub(super) channel_st: Weak<Mutex<ChannelState>>,
}
impl Channel {
fn upgrade_client(&self) -> Result<Arc<Mutex<ClientState>>> {
self.client_st.upgrade().ok_or(Error::ClientClosed)
}
fn upgrade_channel(&self) -> Result<Arc<Mutex<ChannelState>>> {
self.channel_st.upgrade().ok_or(Error::ChannelClosed)
}
pub fn client(&self) -> Client {
Client { client_st: self.client_st.clone() }
}
pub fn send_request(&self, req: ChannelReq) -> Result<()> {
let st = self.upgrade_client()?;
let channel_st = self.upgrade_channel()?;
channel_state::send_request(&mut st.lock(), &mut channel_st.lock(), req)?;
Ok(())
}
pub async fn send_data(&self, data: Bytes, data_type: DataType) -> Result<()> {
self.send_channel_data(ChannelSendData::Data(data, data_type))?.await
}
pub async fn send_eof(&self) -> Result<()> {
match self.try_send_eof().await {
Ok(_) => Ok(()),
Err(Error::ChannelClosed) => Ok(()),
Err(err) => Err(err),
}
}
async fn try_send_eof(&self) -> Result<()> {
self.send_channel_data(ChannelSendData::Eof)?.await
}
pub fn close(&self) -> Result<()> {
let st = self.upgrade_client()?;
if let Ok(channel_st) = self.upgrade_channel() {
channel_state::close(&mut st.lock(), &mut channel_st.lock());
}
Ok(())
}
fn send_channel_data(&self, data: ChannelSendData) -> Result<impl Future<Output = Result<()>>> {
let st = self.upgrade_client()?;
let channel_st = self.upgrade_channel()?;
let fut = channel_state::send_data(&mut st.lock(), &mut channel_st.lock(), data)?;
Ok(fut)
}
}
#[derive(Debug)]
pub struct ChannelReceiver {
pub(super) event_rx: mpsc::Receiver<ChannelEvent>,
}
impl ChannelReceiver {
pub async fn recv(&mut self) -> Option<ChannelEvent> {
self.event_rx.recv().await
}
pub fn poll_recv(&mut self, cx: &mut Context) -> Poll<Option<ChannelEvent>> {
self.event_rx.poll_recv(cx)
}
}
#[non_exhaustive]
pub enum ChannelEvent {
Request(ChannelReq),
Data(Bytes, DataType),
Eof,
}
#[derive(Debug)]
pub struct ChannelReq {
pub request_type: String,
pub payload: Bytes,
pub reply_tx: Option<oneshot::Sender<ChannelReply>>,
}
#[derive(Debug)]
pub enum ChannelReply {
Success,
Failure,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum DataType {
Standard,
Extended(u32),
}
pub const DATA_STANDARD: DataType = DataType::Standard;
pub const DATA_STDERR: DataType = DataType::Extended(1);
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct ChannelConfig {
pub recv_window_max: usize,
pub recv_packet_len_max: usize,
}
impl Default for ChannelConfig {
fn default() -> Self {
ChannelConfig {
recv_window_max: 500_000,
recv_packet_len_max: 100_000,
}
}
}
impl ChannelConfig {
pub fn with<F: FnOnce(&mut Self)>(mut self, f: F) -> Self {
f(&mut self);
self
}
pub(super) fn recv_window_max(&self) -> usize {
self.recv_window_max.clamp(1000, u32::MAX as usize)
}
pub(super) fn recv_packet_len_max(&self) -> usize {
self.recv_packet_len_max.clamp(200, u32::MAX as usize)
}
}