pub mod adapter;
use async_trait::async_trait;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use sh_layer3::generate_short_id;
use std::collections::HashMap;
use std::sync::Arc;
use crate::types::Layer4Result;
#[async_trait]
pub trait Channel: Send + Sync {
fn id(&self) -> &str;
fn channel_type(&self) -> ChannelType;
async fn send(&self, message: &OutboundMessage) -> Layer4Result<()>;
async fn try_receive(&self) -> Layer4Result<Option<InboundMessage>>;
fn is_connected(&self) -> bool;
async fn close(&self) -> Layer4Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ChannelType {
Cli,
Http,
WebSocket,
Mqtt,
Custom,
}
impl std::fmt::Display for ChannelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cli => write!(f, "cli"),
Self::Http => write!(f, "http"),
Self::WebSocket => write!(f, "websocket"),
Self::Mqtt => write!(f, "mqtt"),
Self::Custom => write!(f, "custom"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundMessage {
pub message_id: String,
pub channel_id: String,
pub user_id: String,
pub session_id: Option<String>,
pub content: String,
pub message_type: MessageType,
pub metadata: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl InboundMessage {
pub fn new(
channel_id: impl Into<String>,
user_id: impl Into<String>,
content: impl Into<String>,
) -> Self {
Self {
message_id: generate_short_id(),
channel_id: channel_id.into(),
user_id: user_id.into(),
session_id: None,
content: content.into(),
message_type: MessageType::Text,
metadata: serde_json::Value::Null,
timestamp: chrono::Utc::now(),
}
}
pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutboundMessage {
pub message_id: String,
pub content: String,
pub message_type: MessageType,
pub target: MessageTarget,
pub metadata: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl OutboundMessage {
pub fn new(content: impl Into<String>, target: MessageTarget) -> Self {
Self {
message_id: generate_short_id(),
content: content.into(),
message_type: MessageType::Text,
target,
metadata: serde_json::Value::Null,
timestamp: chrono::Utc::now(),
}
}
pub fn broadcast(content: impl Into<String>) -> Self {
Self::new(content, MessageTarget::All)
}
pub fn to_channel(channel_id: impl Into<String>, content: impl Into<String>) -> Self {
Self::new(content, MessageTarget::Channel(channel_id.into()))
}
pub fn to_user(user_id: impl Into<String>, content: impl Into<String>) -> Self {
Self::new(content, MessageTarget::User(user_id.into()))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageTarget {
All,
Channel(String),
User(String),
Session(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageType {
Text,
Json,
Binary,
Command,
Event,
Error,
}
pub struct ChannelGateway {
channels: RwLock<HashMap<String, Box<dyn Channel>>>,
router: MessageRouter,
message_queue: RwLock<Vec<InboundMessage>>,
}
impl ChannelGateway {
pub fn new() -> Self {
Self {
channels: RwLock::new(HashMap::new()),
router: MessageRouter::new(),
message_queue: RwLock::new(Vec::new()),
}
}
pub async fn register_channel(&self, channel: Box<dyn Channel>) -> Layer4Result<()> {
let id = channel.id().to_string();
let channel_type = channel.channel_type();
self.channels.write().insert(id.clone(), channel);
self.router.register_channel(&id, channel_type);
tracing::info!("Registered channel: {} ({})", id, channel_type);
Ok(())
}
pub async fn unregister_channel(&self, channel_id: &str) -> Layer4Result<bool> {
let channel = self.channels.write().remove(channel_id);
if let Some(channel) = channel {
channel.close().await?;
self.router.unregister_channel(channel_id);
tracing::info!("Unregistered channel: {}", channel_id);
Ok(true)
} else {
Ok(false)
}
}
pub fn get_channel(&self, _channel_id: &str) -> Option<Arc<dyn Channel>> {
None
}
pub fn list_channels(&self) -> Vec<(String, ChannelType)> {
self.channels
.read()
.iter()
.map(|(id, ch)| (id.clone(), ch.channel_type()))
.collect()
}
#[allow(clippy::await_holding_lock)]
pub async fn broadcast(&self, message: &OutboundMessage) -> Layer4Result<()> {
let channels = self.channels.read();
for (id, channel) in channels.iter() {
if let Err(e) = channel.send(message).await {
tracing::error!("Failed to send to channel {}: {}", id, e);
}
}
Ok(())
}
#[allow(clippy::await_holding_lock)]
pub async fn send_to(
&self,
target: &MessageTarget,
message: &OutboundMessage,
) -> Layer4Result<()> {
match target {
MessageTarget::All => self.broadcast(message).await,
MessageTarget::Channel(channel_id) => {
let channels = self.channels.read();
if let Some(channel) = channels.get(channel_id) {
channel.send(message).await?;
}
Ok(())
}
MessageTarget::User(user_id) => {
let channel_id = self.router.find_user_channel(user_id);
if let Some(cid) = channel_id {
let channels = self.channels.read();
if let Some(channel) = channels.get(&cid) {
channel.send(message).await?;
}
}
Ok(())
}
MessageTarget::Session(session_id) => {
let channel_id = self.router.find_session_channel(session_id);
if let Some(cid) = channel_id {
let channels = self.channels.read();
if let Some(channel) = channels.get(&cid) {
channel.send(message).await?;
}
}
Ok(())
}
}
}
#[allow(clippy::await_holding_lock)]
pub async fn receive(&self) -> Layer4Result<Option<InboundMessage>> {
if let Some(msg) = self.message_queue.write().pop() {
return Ok(Some(msg));
}
let channels = self.channels.read();
for (_, channel) in channels.iter() {
if let Some(msg) = channel.try_receive().await? {
self.router
.update_user_channel(&msg.user_id, &msg.channel_id);
if let Some(ref session_id) = msg.session_id {
self.router
.update_session_channel(session_id, &msg.channel_id);
}
return Ok(Some(msg));
}
}
Ok(None)
}
#[allow(clippy::await_holding_lock)]
pub async fn receive_all(&self) -> Layer4Result<Vec<InboundMessage>> {
let mut messages = Vec::new();
messages.append(&mut self.message_queue.write());
let channels = self.channels.read();
for (_, channel) in channels.iter() {
while let Some(msg) = channel.try_receive().await? {
messages.push(msg);
}
}
Ok(messages)
}
pub fn channel_count(&self) -> usize {
self.channels.read().len()
}
#[allow(clippy::await_holding_lock)]
pub async fn close_all(&self) -> Layer4Result<()> {
let mut channels = self.channels.write();
for (id, channel) in channels.drain() {
if let Err(e) = channel.close().await {
tracing::error!("Failed to close channel {}: {}", id, e);
}
}
Ok(())
}
}
impl Default for ChannelGateway {
fn default() -> Self {
Self::new()
}
}
pub struct MessageRouter {
user_channels: RwLock<HashMap<String, String>>,
session_channels: RwLock<HashMap<String, String>>,
channel_registry: RwLock<HashMap<String, ChannelType>>,
}
impl MessageRouter {
pub fn new() -> Self {
Self {
user_channels: RwLock::new(HashMap::new()),
session_channels: RwLock::new(HashMap::new()),
channel_registry: RwLock::new(HashMap::new()),
}
}
pub fn register_channel(&self, channel_id: &str, channel_type: ChannelType) {
self.channel_registry
.write()
.insert(channel_id.to_string(), channel_type);
}
pub fn unregister_channel(&self, channel_id: &str) {
self.channel_registry.write().remove(channel_id);
self.user_channels.write().retain(|_, v| v != channel_id);
self.session_channels.write().retain(|_, v| v != channel_id);
}
pub fn update_user_channel(&self, user_id: &str, channel_id: &str) {
self.user_channels
.write()
.insert(user_id.to_string(), channel_id.to_string());
}
pub fn update_session_channel(&self, session_id: &str, channel_id: &str) {
self.session_channels
.write()
.insert(session_id.to_string(), channel_id.to_string());
}
pub fn find_user_channel(&self, user_id: &str) -> Option<String> {
self.user_channels.read().get(user_id).cloned()
}
pub fn find_session_channel(&self, session_id: &str) -> Option<String> {
self.session_channels.read().get(session_id).cloned()
}
}
impl Default for MessageRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inbound_message_creation() {
let msg = InboundMessage::new("cli-1", "user-1", "Hello");
assert_eq!(msg.channel_id, "cli-1");
assert_eq!(msg.user_id, "user-1");
assert_eq!(msg.content, "Hello");
}
#[test]
fn test_outbound_message_broadcast() {
let msg = OutboundMessage::broadcast("Hello all");
assert!(matches!(msg.target, MessageTarget::All));
}
#[test]
fn test_channel_gateway_creation() {
let gateway = ChannelGateway::new();
assert_eq!(gateway.channel_count(), 0);
}
#[test]
fn test_message_router() {
let router = MessageRouter::new();
router.update_user_channel("user-1", "cli-1");
let channel = router.find_user_channel("user-1");
assert_eq!(channel, Some("cli-1".to_string()));
}
}