use crate::disk::Storage;
use crate::link::local::{LinkError, LinkRx, LinkTx};
use crate::link::network;
use crate::link::network::Network;
use crate::local::Link;
use crate::protocol::{self, Connect, LastWill, Packet, Protocol};
use crate::router::{Event, FilterIdx, Notification};
use crate::{ConnectionId, ConnectionSettings, Offset};
use flume::{Receiver, RecvError, SendError, Sender, TrySendError};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use std::time::Duration;
use std::{fs, io};
use tokio::time::error::Elapsed;
use tokio::{select, time};
use tracing::{error, info, trace};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O")]
Io(#[from] io::Error),
#[error("Zero keep alive")]
ZeroKeepAlive,
#[error("Not connect packet")]
NotConnectPacket(Packet),
#[error("Network {0}")]
Network(#[from] network::Error),
#[error("Timeout")]
Timeout(#[from] Elapsed),
#[error("Channel send error")]
Send(#[from] SendError<(ConnectionId, Event)>),
#[error("Channel recv error")]
Recv(#[from] RecvError),
#[error("Persistent session requires valid client id")]
InvalidClientId,
#[error("Unexpected router message")]
NotConnectionAck,
#[error("ConnAck error {0}")]
ConnectionAck(String),
#[error("Channel try send error")]
TrySend(#[from] TrySendError<(ConnectionId, Event)>),
#[error("Link error = {0}")]
Link(#[from] LinkError),
}
pub struct PersistanceLink<P: Protocol> {
pub(crate) client_id: String,
pub(crate) connection_id: ConnectionId,
pub(crate) network: Network<P>,
link_tx: LinkTx,
link_rx: LinkRx,
notifications: VecDeque<Notification>,
disk_handler: DiskHandler<P>,
network_update_rx: Receiver<Network<P>>,
connack: Notification,
inflight_publishes: VecDeque<Notification>,
}
pub(super) struct DiskHandler<P: Protocol> {
storage: Storage,
protocol: P,
}
impl<P: Protocol> DiskHandler<P> {
fn new(client_id: &str, protocol: P) -> Result<Self, Error> {
let path = format!("/tmp/rumqttd/{}", &client_id);
fs::create_dir_all(&path)?;
Ok(DiskHandler {
storage: Storage::new(path, 1024, 30)?,
protocol,
})
}
pub fn write(
&mut self,
notifications: &mut VecDeque<Notification>,
) -> HashMap<FilterIdx, Offset> {
let mut stored_filter_offset_map: HashMap<FilterIdx, Offset> = HashMap::new();
for notif in notifications.drain(..) {
let packet_or_unscheduled = notif.clone().into();
if let Some(packet) = packet_or_unscheduled {
if let Err(e) = self.protocol.write(packet, self.storage.writer()) {
error!("Failed to write to storage: {e}");
continue;
}
if let Err(e) = self.storage.flush_on_overflow() {
error!("Failed to flush storage: {e}");
continue;
}
match ¬if {
Notification::Forward(forward) => {
stored_filter_offset_map
.entry(forward.filter_idx)
.and_modify(|cursor| {
if forward.next_cursor > *cursor {
*cursor = forward.next_cursor
}
})
.or_insert(forward.next_cursor);
}
_ => continue,
}
}
}
stored_filter_offset_map
}
pub fn read(&mut self, buffer: &mut VecDeque<Packet>) {
if let Err(e) = self.storage.reload_on_eof() {
error!("Failed to reload storage: {e}");
}
loop {
match self
.protocol
.read_mut(self.storage.reader(), 10240) {
Ok(packet) => {
buffer.push_back(packet);
let connection_buffer_length = buffer.len();
if connection_buffer_length >= 100 {
return
}
}
Err(protocol::Error::InsufficientBytes(_)) => {
if let Err(e) = self.storage.reload() {
error!("Failed to reload storage: {e}");
return
}
if self.storage.reader().is_empty() {
return
}
},
Err(e) => {
error!("Failed to read from storage: {e}");
return
}
}
}
}
}
impl<P: Protocol> PersistanceLink<P> {
pub async fn new(
config: Arc<ConnectionSettings>,
router_tx: Sender<(ConnectionId, Event)>,
tenant_id: Option<String>,
connect: Connect,
lastwill: Option<LastWill>,
mut network: Network<P>,
) -> Result<(Sender<Network<P>>, PersistanceLink<P>), Error> {
let dynamic_filters = config.dynamic_filters;
if connect.keep_alive == 0 {
return Err(Error::ZeroKeepAlive);
}
let client_id = connect.client_id.clone();
let clean_session = connect.clean_session;
if !clean_session && client_id.is_empty() {
return Err(Error::InvalidClientId);
}
let (link_tx, link_rx, notification) = Link::new(
tenant_id,
&client_id,
router_tx,
clean_session,
lastwill,
dynamic_filters,
true,
)?;
let id = link_rx.id();
network.write(notification.clone()).await?;
let protocol = network.protocol.clone();
let (network_update_tx, network_update_rx) = flume::bounded(1);
Ok((
network_update_tx,
PersistanceLink {
client_id: client_id.to_string(),
connection_id: id,
network,
link_tx,
link_rx,
notifications: VecDeque::with_capacity(100),
disk_handler: DiskHandler::new(&client_id, protocol)?,
network_update_rx,
connack: notification,
inflight_publishes: VecDeque::with_capacity(100),
},
))
}
pub async fn peek_first_connect(
config: Arc<ConnectionSettings>,
network: &mut Network<P>,
) -> Result<(Connect, Option<LastWill>), Error> {
let connection_timeout_ms = config.connection_timeout_ms.into();
let packet = time::timeout(Duration::from_millis(connection_timeout_ms), async {
let packet = network.read().await?;
Ok::<_, io::Error>(packet)
})
.await??;
let (connect, lastwill) = match packet {
Packet::Connect(connect, _, lastwill, ..) => (connect, lastwill),
packet => return Err(Error::NotConnectPacket(packet)),
};
Ok((connect, lastwill))
}
async fn disconnected(&mut self) -> Result<State, Error> {
info!(state = ?State::Disconnected, "Disconnected from persistent connection");
loop {
select! {
network = self.network_update_rx.recv_async() => {
self.network = network?;
self.network.write(self.connack.clone()).await?;
return Ok(State::Normal)
},
o = self.link_rx.exchange(&mut self.notifications) => {
o?;
self.write_to_disconnected_client().await?;
}
}
}
}
async fn write_to_disconnected_client(&mut self) -> Result<(), Error> {
for notif in self.notifications.drain(..) {
match notif {
Notification::Forward(_) | Notification::ForwardWithProperties(_, _) => {
self.inflight_publishes.push_back(notif)
}
_ => continue,
}
}
if !self.inflight_publishes.is_empty() {
let acked_offsets = self.disk_handler.write(&mut self.inflight_publishes);
if let Err(e) = self.link_tx.ack(acked_offsets).await {
error!("Failed to inform router of read progress: {e}")
};
}
Ok(())
}
async fn write_to_active_client(&mut self) -> Result<(), Error> {
let mut unpersisted_messages = VecDeque::new();
for notif in self.notifications.drain(..) {
match notif {
Notification::Forward(_) | Notification::ForwardWithProperties(_, _) => {
self.inflight_publishes.push_back(notif)
}
Notification::AckDone => {
continue;
}
_ => unpersisted_messages.push_back(notif),
}
}
let unscheduled = self.network.writev(&mut unpersisted_messages).await?;
if unscheduled {
self.link_rx.wake().await?;
};
if !self.inflight_publishes.is_empty() {
let acked_offsets = self.disk_handler.write(&mut self.inflight_publishes);
if let Err(e) = self.link_tx.ack(acked_offsets).await {
error!("Failed to inform router of read progress: {e}")
};
}
let mut buffer = VecDeque::new();
self.disk_handler.read(&mut buffer);
let unscheduled = self.network.writev(&mut buffer).await?;
if unscheduled {
self.link_rx.wake().await?;
};
Ok(())
}
async fn read_from_client(&mut self, packet: Packet) -> Result<(), Error> {
let len = {
let mut buffer = self.link_tx.buffer();
buffer.push_back(packet);
self.network.readv(&mut buffer)?;
buffer.len()
};
trace!("Packets read from network, count = {}", len);
self.link_tx.notify().await?;
Ok(())
}
async fn run(&mut self) -> Result<State, Error> {
info!(state = ?State::Normal, "Persistent connection is running in normal mode");
loop {
select! {
o = self.network.read() => {
match o {
Ok(packet) => self.read_from_client(packet).await?,
Err(e) => {
println!("some error while reading from the network? {e:?}");
match e.kind() {
io::ErrorKind::ConnectionAborted | io::ErrorKind::ConnectionReset => return Ok(State::Disconnected),
_ => return Err(e.into())
}
}
};
}
o = self.link_rx.exchange(&mut self.notifications) => {
o?;
self.write_to_active_client().await?;
}
}
}
}
pub async fn start(&mut self) -> Result<(), Error> {
let mut state = State::Normal;
loop {
let next = match state {
State::Normal => self.run().await?,
State::Disconnected => self.disconnected().await?,
};
state = next;
}
}
}
#[derive(Debug)]
pub enum State {
Normal,
Disconnected,
}