use super::*;
pub struct RtmpServer {
pub config: RtmpServerConfig,
connections: Arc<RwLock<HashMap<u64, mpsc::UnboundedSender<OutgoingMessage>>>>,
next_connection_id: Arc<RwLock<u64>>,
stream_registry: Arc<StreamRegistry>,
auth_handler: Arc<dyn AuthHandler>,
}
impl RtmpServer {
#[must_use]
pub fn new(config: RtmpServerConfig, auth_handler: Arc<dyn AuthHandler>) -> Self {
Self {
config,
connections: Arc::new(RwLock::new(HashMap::new())),
next_connection_id: Arc::new(RwLock::new(1)),
stream_registry: Arc::new(StreamRegistry::new()),
auth_handler,
}
}
#[must_use]
pub fn with_default_config() -> Self {
Self::new(RtmpServerConfig::default(), Arc::new(AllowAllAuth))
}
#[must_use]
pub fn stream_registry(&self) -> &Arc<StreamRegistry> {
&self.stream_registry
}
pub async fn run(&self) -> NetResult<()> {
let listener = TcpListener::bind(&self.config.bind_address)
.await
.map_err(|e| {
NetError::connection(format!(
"Failed to bind to {}: {e}",
self.config.bind_address
))
})?;
loop {
let (stream, addr) = listener
.accept()
.await
.map_err(|e| NetError::connection(format!("Accept failed: {e}")))?;
{
let connections = self.connections.read().await;
if connections.len() >= self.config.max_connections {
continue;
}
}
let connection_id = {
let mut next_id = self.next_connection_id.write().await;
let id = *next_id;
*next_id += 1;
id
};
let connections = self.connections.clone();
let config = self.config.clone();
let stream_registry = Arc::clone(&self.stream_registry);
let auth_handler = Arc::clone(&self.auth_handler);
tokio::spawn(async move {
let conn = ServerConnection::new(
connection_id,
stream,
addr,
config,
stream_registry,
auth_handler,
);
let sender = conn.message_sender();
{
let mut conns = connections.write().await;
conns.insert(connection_id, sender);
}
let result = conn.run().await;
{
let mut conns = connections.write().await;
conns.remove(&connection_id);
}
if let Err(e) = result {
eprintln!("Connection {connection_id} error: {e}");
}
});
}
}
pub async fn broadcast(&self, message: RtmpMessage, csid: u32) -> NetResult<()> {
let connections = self.connections.read().await;
for sender in connections.values() {
let _ = sender.send(OutgoingMessage {
message: message.clone(),
chunk_stream_id: csid,
});
}
Ok(())
}
pub async fn send_to_connection(
&self,
connection_id: u64,
message: RtmpMessage,
csid: u32,
) -> NetResult<()> {
let connections = self.connections.read().await;
if let Some(sender) = connections.get(&connection_id) {
sender
.send(OutgoingMessage {
message,
chunk_stream_id: csid,
})
.map_err(|e| NetError::connection(format!("Failed to send message: {e}")))?;
}
Ok(())
}
pub async fn connection_count(&self) -> usize {
let connections = self.connections.read().await;
connections.len()
}
}
pub struct RtmpServerBuilder {
config: RtmpServerConfig,
auth_handler: Option<Arc<dyn AuthHandler>>,
}
impl RtmpServerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: RtmpServerConfig::default(),
auth_handler: None,
}
}
#[must_use]
pub fn bind_address(mut self, address: impl Into<String>) -> Self {
self.config.bind_address = address.into();
self
}
#[must_use]
pub const fn read_timeout(mut self, timeout: Duration) -> Self {
self.config.read_timeout = timeout;
self
}
#[must_use]
pub const fn write_timeout(mut self, timeout: Duration) -> Self {
self.config.write_timeout = timeout;
self
}
#[must_use]
pub const fn chunk_size(mut self, size: u32) -> Self {
self.config.chunk_size = size;
self
}
#[must_use]
pub const fn window_ack_size(mut self, size: u32) -> Self {
self.config.window_ack_size = size;
self
}
#[must_use]
pub const fn max_connections(mut self, max: usize) -> Self {
self.config.max_connections = max;
self
}
#[must_use]
pub fn auth_handler(mut self, handler: Arc<dyn AuthHandler>) -> Self {
self.auth_handler = Some(handler);
self
}
#[must_use]
pub fn build(self) -> RtmpServer {
let auth_handler = self.auth_handler.unwrap_or_else(|| Arc::new(AllowAllAuth));
RtmpServer::new(self.config, auth_handler)
}
}
impl Default for RtmpServerBuilder {
fn default() -> Self {
Self::new()
}
}