use super::{SyncTransport, TransportError, InMemoryTransport, WebSocketTransport};
use super::leptos_ws_pro_transport::LeptosWsProTransport;
use super::compatibility_layer::CompatibilityTransport;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TransportType {
WebSocket,
LeptosWsPro, Http,
WebRTC,
Memory, }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiTransportConfig {
pub primary: TransportType,
pub fallbacks: Vec<TransportType>,
pub auto_switch: bool,
pub timeout_ms: u64,
}
impl Default for MultiTransportConfig {
fn default() -> Self {
Self {
primary: TransportType::LeptosWsPro, fallbacks: vec![TransportType::WebSocket, TransportType::Http, TransportType::Memory],
auto_switch: true,
timeout_ms: 5000,
}
}
}
#[derive(Clone)]
pub enum TransportEnum {
WebSocket(WebSocketTransport),
LeptosWsPro(LeptosWsProTransport),
Compatibility(CompatibilityTransport),
InMemory(InMemoryTransport),
}
impl SyncTransport for TransportEnum {
type Error = TransportError;
fn send<'a>(&'a self, data: &'a [u8]) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Self::Error>> + Send + 'a>> {
Box::pin(async move {
match self {
TransportEnum::WebSocket(ws) => ws.send(data).await,
TransportEnum::LeptosWsPro(leptos_ws) => leptos_ws.send(data).await.map_err(|e| e.into()),
TransportEnum::Compatibility(compat) => compat.send(data).await.map_err(|e| e.into()),
TransportEnum::InMemory(mem) => mem.send(data).await,
}
})
}
fn receive(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<Vec<u8>>, Self::Error>> + Send + '_>> {
Box::pin(async move {
match self {
TransportEnum::WebSocket(ws) => ws.receive().await,
TransportEnum::LeptosWsPro(leptos_ws) => leptos_ws.receive().await.map_err(|e| e.into()),
TransportEnum::Compatibility(compat) => compat.receive().await.map_err(|e| e.into()),
TransportEnum::InMemory(mem) => mem.receive().await,
}
})
}
fn is_connected(&self) -> bool {
match self {
TransportEnum::WebSocket(ws) => ws.is_connected(),
TransportEnum::LeptosWsPro(leptos_ws) => leptos_ws.is_connected(),
TransportEnum::Compatibility(compat) => compat.is_connected(),
TransportEnum::InMemory(mem) => mem.is_connected(),
}
}
}
pub struct MultiTransport {
config: MultiTransportConfig,
transports: HashMap<TransportType, TransportEnum>,
current_transport: Arc<RwLock<TransportType>>,
}
impl MultiTransport {
pub fn new(config: MultiTransportConfig) -> Self {
let primary = config.primary.clone();
Self {
config,
transports: HashMap::new(),
current_transport: Arc::new(RwLock::new(primary)),
}
}
pub fn register_transport(&mut self, transport_type: TransportType, transport: TransportEnum) {
self.transports.insert(transport_type, transport);
}
pub async fn current_transport(&self) -> TransportType {
self.current_transport.read().await.clone()
}
pub async fn switch_transport(&self, transport_type: TransportType) -> Result<(), TransportError> {
if !self.transports.contains_key(&transport_type) {
return Err(TransportError::ConnectionFailed(format!("Transport {:?} not registered", transport_type)));
}
let mut current = self.current_transport.write().await;
*current = transport_type;
Ok(())
}
pub fn available_transports(&self) -> Vec<TransportType> {
self.transports.keys().cloned().collect()
}
pub fn config(&self) -> &MultiTransportConfig {
&self.config
}
pub fn has_transport(&self, transport_type: &TransportType) -> bool {
self.transports.contains_key(transport_type)
}
pub fn transport_count(&self) -> usize {
self.transports.len()
}
}
impl SyncTransport for MultiTransport {
type Error = TransportError;
fn send<'a>(&'a self, data: &'a [u8]) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<(), Self::Error>> + Send + 'a>> {
Box::pin(async move {
let current_type = self.current_transport.read().await.clone();
if let Some(transport) = self.transports.get(¤t_type) {
transport.send(data).await
} else {
Err(TransportError::SendFailed(format!("No transport available for {:?}", current_type)))
}
})
}
fn receive(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<Vec<u8>>, Self::Error>> + Send + '_>> {
Box::pin(async move {
let current_type = self.current_transport.read().await.clone();
if let Some(transport) = self.transports.get(¤t_type) {
transport.receive().await
} else {
Err(TransportError::ReceiveFailed(format!("No transport available for {:?}", current_type)))
}
})
}
fn is_connected(&self) -> bool {
let current_type = self.current_transport.try_read().unwrap_or_else(|_| {
panic!("Failed to acquire read lock")
});
if let Some(transport) = self.transports.get(&*current_type) {
transport.is_connected()
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_multi_transport_creation() {
let config = MultiTransportConfig::default();
let multi_transport = MultiTransport::new(config);
assert_eq!(multi_transport.current_transport().await, TransportType::LeptosWsPro);
assert!(multi_transport.available_transports().is_empty());
}
#[tokio::test]
async fn test_register_and_switch_transports() {
let config = MultiTransportConfig {
primary: TransportType::WebSocket,
fallbacks: vec![TransportType::Memory],
auto_switch: true,
timeout_ms: 5000,
};
let mut multi_transport = MultiTransport::new(config);
let ws_transport = TransportEnum::WebSocket(WebSocketTransport::new("ws://test".to_string()));
let memory_transport = TransportEnum::InMemory(InMemoryTransport::new());
multi_transport.register_transport(TransportType::WebSocket, ws_transport);
multi_transport.register_transport(TransportType::Memory, memory_transport);
let available = multi_transport.available_transports();
assert_eq!(available.len(), 2);
assert!(available.contains(&TransportType::WebSocket));
assert!(available.contains(&TransportType::Memory));
multi_transport.switch_transport(TransportType::Memory).await.unwrap();
assert_eq!(multi_transport.current_transport().await, TransportType::Memory);
}
#[tokio::test]
async fn test_transport_operations() {
let config = MultiTransportConfig::default();
let mut multi_transport = MultiTransport::new(config);
let mock_transport = TransportEnum::InMemory(InMemoryTransport::new());
multi_transport.register_transport(TransportType::LeptosWsPro, mock_transport);
multi_transport.send(b"test data").await.unwrap();
let data = multi_transport.receive().await.unwrap();
assert_eq!(data.len(), 1);
assert_eq!(data[0], b"test data");
}
#[tokio::test]
async fn test_transport_failure_handling() {
let config = MultiTransportConfig {
primary: TransportType::WebSocket,
fallbacks: vec![TransportType::Memory],
auto_switch: true,
timeout_ms: 5000,
};
let mut multi_transport = MultiTransport::new(config);
let failing_transport = TransportEnum::WebSocket(WebSocketTransport::new("ws://invalid-url".to_string()));
multi_transport.register_transport(TransportType::WebSocket, failing_transport);
assert!(!multi_transport.is_connected());
let working_transport = TransportEnum::InMemory(InMemoryTransport::new());
multi_transport.register_transport(TransportType::Memory, working_transport);
multi_transport.switch_transport(TransportType::Memory).await.unwrap();
assert!(multi_transport.is_connected());
assert!(multi_transport.send(b"test").await.is_ok());
assert!(multi_transport.receive().await.is_ok());
}
#[tokio::test]
async fn test_switch_to_unregistered_transport() {
let config = MultiTransportConfig::default();
let multi_transport = MultiTransport::new(config);
let result = multi_transport.switch_transport(TransportType::WebRTC).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_operations_without_registered_transport() {
let config = MultiTransportConfig::default();
let multi_transport = MultiTransport::new(config);
assert!(multi_transport.send(b"test").await.is_err());
assert!(multi_transport.receive().await.is_err());
assert!(!multi_transport.is_connected());
}
#[tokio::test]
async fn test_multi_transport_utility_methods() {
let config = MultiTransportConfig::default();
let mut multi_transport = MultiTransport::new(config);
assert_eq!(multi_transport.transport_count(), 0);
assert!(!multi_transport.has_transport(&TransportType::WebSocket));
let transport = TransportEnum::InMemory(InMemoryTransport::new());
multi_transport.register_transport(TransportType::WebSocket, transport);
assert_eq!(multi_transport.transport_count(), 1);
assert!(multi_transport.has_transport(&TransportType::WebSocket));
assert!(!multi_transport.has_transport(&TransportType::Memory));
let config = multi_transport.config();
assert_eq!(config.primary, TransportType::LeptosWsPro);
assert!(config.auto_switch);
}
}