use std::future::Ready;
use crate::{client::MqttClient, error::MqttError, MqttPacket};
use futures::Stream;
use mqtt_format::v3::{
packet::{MPacket, MPublish, MPubrel},
qos::MQualityOfService,
};
use tracing::{debug, error, trace};
pub struct Acknowledge;
pub struct PacketStreamBuilder<'client, ACK> {
client: &'client MqttClient,
ack_fn: ACK,
}
pub trait AckHandler: Send {
type Future: std::future::Future<Output = Acknowledge> + Send;
fn handle(&self, packet: MqttPacket) -> Self::Future;
}
impl<FUT, H> AckHandler for H
where
FUT: std::future::Future<Output = Acknowledge> + Send,
H: Send,
H: for<'s> Fn(MqttPacket) -> FUT,
{
type Future = FUT;
fn handle(&self, packet: MqttPacket) -> Self::Future {
(*self)(packet)
}
}
pub struct NoOPAck;
impl AckHandler for NoOPAck {
type Future = Ready<Acknowledge>;
fn handle(&self, _packet: MqttPacket) -> Self::Future {
std::future::ready(Acknowledge)
}
}
impl<'client, HANDLER> PacketStreamBuilder<'client, HANDLER>
where
HANDLER: AckHandler,
{
pub(crate) fn new(client: &'client MqttClient) -> PacketStreamBuilder<'client, NoOPAck> {
PacketStreamBuilder {
client,
ack_fn: NoOPAck,
}
}
pub fn with_custom_ack_fn<NEWHANDLER: AckHandler>(
self,
f: NEWHANDLER,
) -> PacketStreamBuilder<'client, impl AckHandler> {
PacketStreamBuilder {
client: self.client,
ack_fn: f,
}
}
pub fn build(self) -> PacketStream<'client, HANDLER> {
PacketStream {
client: self.client,
ack_fn: self.ack_fn,
}
}
}
pub struct PacketStream<'client, ACK: AckHandler> {
client: &'client MqttClient,
ack_fn: ACK,
}
impl<'client, ACK: AckHandler> PacketStream<'client, ACK> {
pub fn stream(&self) -> impl Stream<Item = Result<MqttPacket, MqttError>> + '_ {
futures::stream::try_unfold((), |()| async {
let client = self.client;
loop {
let next_message = {
let mut mutex = client.client_receiver().lock().await;
let client_stream = match mutex.as_mut() {
Some(cs) => cs,
None => return Err(MqttError::ConnectionClosed),
};
crate::read_one_packet(client_stream).await?
};
let packet = next_message.get_packet();
match packet {
MPacket::Publish(MPublish {
qos, id: Some(id), ..
}) => {
match qos {
MQualityOfService::AtMostOnce => {}
MQualityOfService::AtLeastOnce => {
self.ack_fn.handle(next_message.clone());
let mut mutex = client.client_sender().lock().await;
let client_stream = match mutex.as_mut() {
Some(cs) => cs,
None => return Err(MqttError::ConnectionClosed),
};
MqttClient::acknowledge_packet(client_stream, packet).await?;
}
MQualityOfService::ExactlyOnce => {
if client.received_packets().contains(&id.0) {
debug!(?packet, "Received duplicate packet");
continue;
}
self.ack_fn.handle(next_message.clone());
trace!(?packet, "Inserting packet into received");
client.received_packets().insert(id.0);
let mut mutex = client.client_sender().lock().await;
let client_stream = match mutex.as_mut() {
Some(cs) => cs,
None => return Err(MqttError::ConnectionClosed),
};
MqttClient::acknowledge_packet(client_stream, packet).await?;
}
}
}
MPacket::Pubrel(MPubrel { id }) => {
if client.received_packets().contains(&id.0) {
self.ack_fn.handle(next_message.clone());
let mut mutex = client.client_sender().lock().await;
let client_stream = match mutex.as_mut() {
Some(cs) => cs,
None => return Err(MqttError::ConnectionClosed),
};
MqttClient::acknowledge_packet(client_stream, packet).await?;
} else {
error!("Received a pubrel for a packet we did not expect");
return Ok(None);
}
}
_ => (),
}
return Ok(Some((next_message, ())));
}
})
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use crate::{client::MqttClient, packet_stream::Acknowledge, MqttPacket};
#[allow(unreachable_code, unused, clippy::diverging_sub_expression)]
async fn check_making_stream_builder() {
let client: MqttClient = todo!();
let builder = client
.build_packet_stream()
.with_custom_ack_fn(|packet| async move {
println!("ACKing packet {packet:?}");
Acknowledge
})
.build();
let mut packet_stream = Box::pin(builder.stream());
loop {
while let Some(Ok(packet)) = packet_stream.next().await {
let packet: MqttPacket = packet;
println!("Received: {packet:#?}");
}
}
}
}