use std::marker::PhantomData;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::envelope::Envelope;
use crate::error::RelayError;
use crate::tracker::CompletionTracker;
pub struct Subscription<T> {
rx: mpsc::Receiver<Envelope>,
current_tracker: Option<Arc<CompletionTracker>>,
current_msg_id: Option<u64>,
tracked: bool,
_marker: PhantomData<T>,
}
impl<T: 'static + Send + Sync> Subscription<T> {
pub(crate) fn new(rx: mpsc::Receiver<Envelope>) -> Self {
Self {
rx,
current_tracker: None,
current_msg_id: None,
tracked: false,
_marker: PhantomData,
}
}
pub(crate) fn new_tracked(rx: mpsc::Receiver<Envelope>) -> Self {
Self {
rx,
current_tracker: None,
current_msg_id: None,
tracked: true,
_marker: PhantomData,
}
}
pub fn current_tracker(&self) -> Option<Arc<CompletionTracker>> {
self.current_tracker.clone()
}
pub fn current_msg_id(&self) -> Option<u64> {
self.current_msg_id
}
pub fn clear_tracker(&mut self) {
self.current_tracker = None;
self.current_msg_id = None;
}
pub async fn recv(&mut self) -> Option<Arc<T>> {
self.current_tracker = None;
self.current_msg_id = None;
loop {
match self.rx.recv().await {
Some(env) => {
if let Some(value) = env.downcast::<T>() {
self.current_tracker = env.tracker();
self.current_msg_id = Some(env.msg_id());
return Some(value);
}
if self.tracked {
if let Some(tracker) = env.tracker() {
tracker.complete_one();
}
}
}
None => return None, }
}
}
pub fn try_recv(&mut self) -> Option<Arc<T>> {
self.current_tracker = None;
self.current_msg_id = None;
loop {
match self.rx.try_recv() {
Ok(env) => {
if let Some(value) = env.downcast::<T>() {
self.current_tracker = env.tracker();
self.current_msg_id = Some(env.msg_id());
return Some(value);
}
if self.tracked {
if let Some(tracker) = env.tracker() {
tracker.complete_one();
}
}
}
Err(_) => return None,
}
}
}
}
impl<T> Drop for Subscription<T> {
fn drop(&mut self) {
if self.tracked {
if let Some(tracker) = self.current_tracker.take() {
let error = RelayError::new(
self.current_msg_id.unwrap_or(0),
std::io::Error::new(
std::io::ErrorKind::Interrupted,
"subscription dropped while processing message",
),
"subscription",
);
tracker.fail(error);
}
}
}
}
impl<T> std::fmt::Debug for Subscription<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscription")
.field("type", &std::any::type_name::<T>())
.field("tracked", &self.tracked)
.field("has_current_tracker", &self.current_tracker.is_some())
.finish()
}
}