use async_trait::async_trait;
use std::time::Duration;
use tungstenite::Message;
use super::traits::{HandshakeReceiver, HandshakeSender, Handshaker};
use crate::context::ConnectionContext;
use crate::error::HandshakeError;
pub struct SendMessageHandshaker {
message: Message,
wait_response: bool,
timeout: Duration,
}
impl SendMessageHandshaker {
#[must_use]
pub const fn new(message: Message) -> Self {
Self {
message,
wait_response: false,
timeout: Duration::from_secs(10),
}
}
#[must_use]
pub fn text(text: impl Into<String>) -> Self {
Self::new(Message::Text(text.into().into()))
}
#[must_use]
pub fn binary(data: impl Into<Vec<u8>>) -> Self {
Self::new(Message::Binary(data.into().into()))
}
#[must_use]
pub const fn with_response(mut self) -> Self {
self.wait_response = true;
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
}
#[async_trait]
impl Handshaker for SendMessageHandshaker {
async fn handshake(
&self,
sender: &mut dyn HandshakeSender,
receiver: &mut dyn HandshakeReceiver,
_context: &ConnectionContext,
) -> Result<(), HandshakeError> {
sender.send_msg(self.message.clone()).await?;
if self.wait_response {
let result = tokio::time::timeout(self.timeout, receiver.recv_msg()).await;
match result {
Ok(Ok(Some(_msg))) => Ok(()),
Ok(Ok(None)) => Err(HandshakeError::Failed(
"Connection closed during handshake".into(),
)),
Ok(Err(e)) => Err(e),
Err(_) => Err(HandshakeError::Timeout(self.timeout)),
}
} else {
Ok(())
}
}
fn name(&self) -> &'static str {
"send_message"
}
fn timeout(&self) -> Option<Duration> {
Some(self.timeout)
}
}
pub struct AuthHandshaker {
token: String,
format: AuthFormat,
timeout: Duration,
}
#[derive(Debug, Clone)]
pub enum AuthFormat {
Plain,
Json,
Custom(String),
}
impl AuthHandshaker {
#[must_use]
pub fn new(token: impl Into<String>) -> Self {
Self {
token: token.into(),
format: AuthFormat::Plain,
timeout: Duration::from_secs(10),
}
}
#[must_use]
pub fn json(mut self) -> Self {
self.format = AuthFormat::Json;
self
}
#[must_use]
pub fn custom_format(mut self, format: impl Into<String>) -> Self {
self.format = AuthFormat::Custom(format.into());
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
fn format_message(&self) -> String {
match &self.format {
AuthFormat::Plain => self.token.clone(),
AuthFormat::Json => format!(r#"{{"type":"auth","token":"{}"}}"#, self.token),
AuthFormat::Custom(fmt) => fmt.replace("{}", &self.token),
}
}
}
#[async_trait]
impl Handshaker for AuthHandshaker {
async fn handshake(
&self,
sender: &mut dyn HandshakeSender,
receiver: &mut dyn HandshakeReceiver,
_context: &ConnectionContext,
) -> Result<(), HandshakeError> {
let message = self.format_message();
tracing::debug!("Sending auth message");
sender.send_msg(Message::Text(message.into())).await?;
let result = tokio::time::timeout(self.timeout, receiver.recv_msg()).await;
match result {
Ok(Ok(Some(msg))) => {
let text = match &msg {
Message::Text(t) => Some(t.to_string()),
_ => None,
};
if let Some(t) = text {
let lower = t.to_lowercase();
if lower.contains("error")
|| lower.contains("unauthorized")
|| lower.contains("denied")
{
return Err(HandshakeError::AuthFailed(t));
}
}
tracing::debug!("Auth successful");
Ok(())
}
Ok(Ok(None)) => Err(HandshakeError::Failed(
"Connection closed during auth".into(),
)),
Ok(Err(e)) => Err(e),
Err(_) => Err(HandshakeError::Timeout(self.timeout)),
}
}
fn is_retryable(&self, error: &HandshakeError) -> bool {
!matches!(error, HandshakeError::AuthFailed(_))
}
fn name(&self) -> &'static str {
"auth"
}
fn timeout(&self) -> Option<Duration> {
Some(self.timeout)
}
}
pub struct SubscribeHandshaker {
channels: Vec<String>,
format: SubscribeFormat,
timeout: Duration,
wait_confirmation: bool,
}
#[derive(Debug, Clone)]
pub enum SubscribeFormat {
Json,
Custom(String),
}
impl SubscribeHandshaker {
#[must_use]
pub fn new(channels: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
channels: channels.into_iter().map(Into::into).collect(),
format: SubscribeFormat::Json,
timeout: Duration::from_secs(5),
wait_confirmation: false,
}
}
#[must_use]
pub fn custom_format(mut self, format: impl Into<String>) -> Self {
self.format = SubscribeFormat::Custom(format.into());
self
}
#[must_use]
pub const fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub const fn wait_confirmation(mut self) -> Self {
self.wait_confirmation = true;
self
}
fn format_message(&self, channel: &str) -> String {
match &self.format {
SubscribeFormat::Json => {
format!(r#"{{"type":"subscribe","channel":"{channel}"}}"#)
}
SubscribeFormat::Custom(fmt) => fmt.replace("{}", channel),
}
}
}
#[async_trait]
impl Handshaker for SubscribeHandshaker {
async fn handshake(
&self,
sender: &mut dyn HandshakeSender,
receiver: &mut dyn HandshakeReceiver,
_context: &ConnectionContext,
) -> Result<(), HandshakeError> {
for channel in &self.channels {
let message = self.format_message(channel);
tracing::debug!(channel = %channel, "Subscribing to channel");
sender.send_msg(Message::Text(message.into())).await?;
if self.wait_confirmation {
let result = tokio::time::timeout(self.timeout, receiver.recv_msg()).await;
match result {
Ok(Ok(Some(_))) => {}
Ok(Ok(None)) => {
return Err(HandshakeError::Failed(format!(
"Connection closed while subscribing to {channel}"
)))
}
Ok(Err(e)) => return Err(e),
Err(_) => return Err(HandshakeError::Timeout(self.timeout)),
}
}
}
tracing::debug!(count = self.channels.len(), "Subscriptions complete");
Ok(())
}
fn name(&self) -> &'static str {
"subscribe"
}
fn timeout(&self) -> Option<Duration> {
let count = u32::try_from(self.channels.len()).unwrap_or(u32::MAX);
Some(self.timeout.saturating_mul(count))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_format_plain() {
let auth = AuthHandshaker::new("my-token");
assert_eq!(auth.format_message(), "my-token");
}
#[test]
fn test_auth_format_json() {
let auth = AuthHandshaker::new("my-token").json();
assert_eq!(
auth.format_message(),
r#"{"type":"auth","token":"my-token"}"#
);
}
#[test]
fn test_auth_format_custom() {
let auth = AuthHandshaker::new("my-token").custom_format("AUTH {}");
assert_eq!(auth.format_message(), "AUTH my-token");
}
#[test]
fn test_subscribe_format() {
let sub = SubscribeHandshaker::new(vec!["channel1"]);
assert_eq!(
sub.format_message("test"),
r#"{"type":"subscribe","channel":"test"}"#
);
}
}