use std::fmt;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ConnectInfo {
pub api_version: String,
pub connection_id: u64,
pub is_reconnection: bool,
}
#[derive(Debug, Clone)]
pub enum DisconnectInfo {
ServerClosed,
NetworkError(String),
Timeout,
Shutdown,
AuthFailed,
HeartbeatTimeout,
}
#[derive(Debug, Clone)]
pub struct SubscriptionInfo {
pub channel: String,
pub symbols: Vec<String>,
pub accepted: bool,
pub reason: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ChecksumInfo {
pub symbol: String,
pub expected: u32,
pub computed: u32,
}
pub type ConnectHook = Arc<dyn Fn(&ConnectInfo) + Send + Sync>;
pub type DisconnectHook = Arc<dyn Fn(&DisconnectInfo) + Send + Sync>;
pub type ReconnectAttemptHook = Arc<dyn Fn(u32, Duration) + Send + Sync>;
pub type SubscriptionHook = Arc<dyn Fn(&SubscriptionInfo) + Send + Sync>;
pub type ChecksumHook = Arc<dyn Fn(&ChecksumInfo) + Send + Sync>;
pub type MessageHook = Arc<dyn Fn(usize) + Send + Sync>;
pub type ErrorHook = Arc<dyn Fn(&str) + Send + Sync>;
pub struct Hooks {
pub(crate) on_connect: Option<ConnectHook>,
pub(crate) on_disconnect: Option<DisconnectHook>,
pub(crate) on_reconnect_attempt: Option<ReconnectAttemptHook>,
pub(crate) on_subscription: Option<SubscriptionHook>,
pub(crate) on_checksum_mismatch: Option<ChecksumHook>,
pub(crate) on_message: Option<MessageHook>,
pub(crate) on_error: Option<ErrorHook>,
}
impl Default for Hooks {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for Hooks {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Hooks")
.field("on_connect", &self.on_connect.as_ref().map(|_| "..."))
.field("on_disconnect", &self.on_disconnect.as_ref().map(|_| "..."))
.field("on_reconnect_attempt", &self.on_reconnect_attempt.as_ref().map(|_| "..."))
.field("on_subscription", &self.on_subscription.as_ref().map(|_| "..."))
.field("on_checksum_mismatch", &self.on_checksum_mismatch.as_ref().map(|_| "..."))
.field("on_message", &self.on_message.as_ref().map(|_| "..."))
.field("on_error", &self.on_error.as_ref().map(|_| "..."))
.finish()
}
}
impl Clone for Hooks {
fn clone(&self) -> Self {
Self {
on_connect: self.on_connect.clone(),
on_disconnect: self.on_disconnect.clone(),
on_reconnect_attempt: self.on_reconnect_attempt.clone(),
on_subscription: self.on_subscription.clone(),
on_checksum_mismatch: self.on_checksum_mismatch.clone(),
on_message: self.on_message.clone(),
on_error: self.on_error.clone(),
}
}
}
impl Hooks {
pub fn new() -> Self {
Self {
on_connect: None,
on_disconnect: None,
on_reconnect_attempt: None,
on_subscription: None,
on_checksum_mismatch: None,
on_message: None,
on_error: None,
}
}
pub fn on_connect<F>(mut self, f: F) -> Self
where
F: Fn(&ConnectInfo) + Send + Sync + 'static,
{
self.on_connect = Some(Arc::new(f));
self
}
pub fn on_disconnect<F>(mut self, f: F) -> Self
where
F: Fn(&DisconnectInfo) + Send + Sync + 'static,
{
self.on_disconnect = Some(Arc::new(f));
self
}
pub fn on_reconnect_attempt<F>(mut self, f: F) -> Self
where
F: Fn(u32, Duration) + Send + Sync + 'static,
{
self.on_reconnect_attempt = Some(Arc::new(f));
self
}
pub fn on_subscription<F>(mut self, f: F) -> Self
where
F: Fn(&SubscriptionInfo) + Send + Sync + 'static,
{
self.on_subscription = Some(Arc::new(f));
self
}
pub fn on_checksum_mismatch<F>(mut self, f: F) -> Self
where
F: Fn(&ChecksumInfo) + Send + Sync + 'static,
{
self.on_checksum_mismatch = Some(Arc::new(f));
self
}
pub fn on_message<F>(mut self, f: F) -> Self
where
F: Fn(usize) + Send + Sync + 'static,
{
self.on_message = Some(Arc::new(f));
self
}
pub fn on_error<F>(mut self, f: F) -> Self
where
F: Fn(&str) + Send + Sync + 'static,
{
self.on_error = Some(Arc::new(f));
self
}
#[allow(dead_code)]
pub(crate) fn invoke_connect(&self, info: &ConnectInfo) {
if let Some(ref hook) = self.on_connect {
hook(info);
}
}
#[allow(dead_code)]
pub(crate) fn invoke_disconnect(&self, info: &DisconnectInfo) {
if let Some(ref hook) = self.on_disconnect {
hook(info);
}
}
#[allow(dead_code)]
pub(crate) fn invoke_reconnect_attempt(&self, attempt: u32, delay: Duration) {
if let Some(ref hook) = self.on_reconnect_attempt {
hook(attempt, delay);
}
}
#[allow(dead_code)]
pub(crate) fn invoke_subscription(&self, info: &SubscriptionInfo) {
if let Some(ref hook) = self.on_subscription {
hook(info);
}
}
#[allow(dead_code)]
pub(crate) fn invoke_checksum_mismatch(&self, info: &ChecksumInfo) {
if let Some(ref hook) = self.on_checksum_mismatch {
hook(info);
}
}
#[allow(dead_code)]
pub(crate) fn invoke_message(&self, size: usize) {
if let Some(ref hook) = self.on_message {
hook(size);
}
}
#[allow(dead_code)]
pub(crate) fn invoke_error(&self, msg: &str) {
if let Some(ref hook) = self.on_error {
hook(msg);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_hooks_builder() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let hooks = Hooks::new()
.on_connect(move |_| {
counter_clone.fetch_add(1, Ordering::SeqCst);
});
let info = ConnectInfo {
api_version: "v2".to_string(),
connection_id: 123,
is_reconnection: false,
};
hooks.invoke_connect(&info);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn test_hooks_clone() {
let hooks = Hooks::new()
.on_connect(|_| {})
.on_disconnect(|_| {});
let cloned = hooks.clone();
assert!(cloned.on_connect.is_some());
assert!(cloned.on_disconnect.is_some());
}
#[test]
fn test_hooks_default() {
let hooks = Hooks::default();
hooks.invoke_connect(&ConnectInfo {
api_version: "v2".to_string(),
connection_id: 0,
is_reconnection: false,
});
}
}