use std::collections::HashMap;
use std::sync::Arc;
use futures::stream;
use tokio::sync::RwLock;
use crate::channels::{Channel, IncomingMessage, MessageStream, OutgoingResponse, StatusUpdate};
use crate::error::ChannelError;
pub struct ChannelManager {
channels: Arc<RwLock<HashMap<String, Box<dyn Channel>>>>,
}
impl ChannelManager {
pub fn new() -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn add(&mut self, channel: Box<dyn Channel>) {
let name = channel.name().to_string();
if let Ok(mut channels) = self.channels.try_write() {
channels.insert(name.clone(), channel);
tracing::debug!("Added channel: {}", name);
} else {
tracing::error!("Failed to add channel: {} (lock contention)", name);
}
}
pub async fn start_all(&self) -> Result<MessageStream, ChannelError> {
let channels = self.channels.read().await;
let mut streams = Vec::new();
for (name, channel) in channels.iter() {
match channel.start().await {
Ok(stream) => {
tracing::info!("Started channel: {}", name);
streams.push(stream);
}
Err(e) => {
tracing::error!("Failed to start channel {}: {}", name, e);
}
}
}
if streams.is_empty() {
return Err(ChannelError::StartupFailed {
name: "all".to_string(),
reason: "No channels started successfully".to_string(),
});
}
let merged = stream::select_all(streams);
Ok(Box::pin(merged))
}
pub async fn respond(
&self,
msg: &IncomingMessage,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let channels = self.channels.read().await;
if let Some(channel) = channels.get(&msg.channel) {
channel.respond(msg, response).await
} else {
Err(ChannelError::SendFailed {
name: msg.channel.clone(),
reason: "Channel not found".to_string(),
})
}
}
pub async fn send_status(
&self,
channel_name: &str,
status: StatusUpdate,
metadata: &serde_json::Value,
) -> Result<(), ChannelError> {
let channels = self.channels.read().await;
if let Some(channel) = channels.get(channel_name) {
channel.send_status(status, metadata).await
} else {
Ok(())
}
}
pub async fn broadcast(
&self,
channel_name: &str,
user_id: &str,
response: OutgoingResponse,
) -> Result<(), ChannelError> {
let channels = self.channels.read().await;
if let Some(channel) = channels.get(channel_name) {
channel.broadcast(user_id, response).await
} else {
Err(ChannelError::SendFailed {
name: channel_name.to_string(),
reason: "Channel not found".to_string(),
})
}
}
pub async fn broadcast_all(
&self,
user_id: &str,
response: OutgoingResponse,
) -> Vec<(String, Result<(), ChannelError>)> {
let channels = self.channels.read().await;
let mut results = Vec::new();
for (name, channel) in channels.iter() {
let result = channel.broadcast(user_id, response.clone()).await;
results.push((name.clone(), result));
}
results
}
pub async fn health_check_all(&self) -> HashMap<String, Result<(), ChannelError>> {
let channels = self.channels.read().await;
let mut results = HashMap::new();
for (name, channel) in channels.iter() {
results.insert(name.clone(), channel.health_check().await);
}
results
}
pub async fn shutdown_all(&self) -> Result<(), ChannelError> {
let channels = self.channels.read().await;
for (name, channel) in channels.iter() {
if let Err(e) = channel.shutdown().await {
tracing::error!("Error shutting down channel {}: {}", name, e);
}
}
Ok(())
}
pub async fn channel_names(&self) -> Vec<String> {
self.channels.read().await.keys().cloned().collect()
}
}
impl Default for ChannelManager {
fn default() -> Self {
Self::new()
}
}