sfu 0.0.3

WebRTC Selective Forwarding Unit (SFU) in Rust with Sans-IO
Documentation
extern crate num_cpus;

use std::cell::RefCell;
use std::collections::HashMap;
use std::io::Write;
use std::net::SocketAddr;
use std::rc::Rc;
use std::str::FromStr;
use std::sync::Arc;

use async_broadcast::broadcast;
use clap::Parser;
use dtls::extension::extension_use_srtp::SrtpProtectionProfile;
use log::{error, info};
use retty::bootstrap::BootstrapUdpServer;
use retty::channel::Pipeline;
use retty::executor::LocalExecutorBuilder;
use retty::transport::TaggedBytesMut;
use waitgroup::WaitGroup;

use sfu::{
    DataChannelHandler, DemuxerHandler, DtlsHandler, ExceptionHandler, GatewayHandler,
    InterceptorHandler, RTCCertificate, SctpHandler, ServerConfig, ServerStates, SrtpHandler,
    StunHandler,
};

mod async_signal;

use async_signal::{handle_signaling_message, SignalingMessage, SignalingServer};

#[derive(Default, Debug, Copy, Clone, clap::ValueEnum)]
enum Level {
    Error,
    Warn,
    #[default]
    Info,
    Debug,
    Trace,
}

impl From<Level> for log::LevelFilter {
    fn from(level: Level) -> Self {
        match level {
            Level::Error => log::LevelFilter::Error,
            Level::Warn => log::LevelFilter::Warn,
            Level::Info => log::LevelFilter::Info,
            Level::Debug => log::LevelFilter::Debug,
            Level::Trace => log::LevelFilter::Trace,
        }
    }
}

#[derive(Parser)]
#[command(name = "SFU Server")]
#[command(author = "Rusty Rain <y@ngr.tc>")]
#[command(version = "0.1.0")]
#[command(about = "An example of SFU Server", long_about = None)]
struct Cli {
    #[arg(long, default_value_t = format!("127.0.0.1"))]
    host: String,
    #[arg(short, long, default_value_t = 8080)]
    signal_port: u16,
    #[arg(long, default_value_t = 3478)]
    media_port_min: u16,
    #[arg(long, default_value_t = 3495)]
    media_port_max: u16,

    #[arg(short, long)]
    debug: bool,
    #[arg(short, long, default_value_t = Level::Info)]
    #[clap(value_enum)]
    level: Level,
}

fn main() -> anyhow::Result<()> {
    let cli = Cli::parse();
    if cli.debug {
        env_logger::Builder::new()
            .format(|buf, record| {
                writeln!(
                    buf,
                    "{}:{} [{}] {} - {}",
                    record.file().unwrap_or("unknown"),
                    record.line().unwrap_or(0),
                    record.level(),
                    chrono::Local::now().format("%H:%M:%S.%6f"),
                    record.args()
                )
            })
            .filter(None, cli.level.into())
            .init();
    }

    println!(
        "listening {}:{}(signal)/[{}-{}](media)...",
        cli.host, cli.signal_port, cli.media_port_min, cli.media_port_max
    );

    let media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect();
    let (stop_tx, mut stop_rx) = broadcast::<()>(1);
    let mut media_port_thread_map = HashMap::new();

    let key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?;
    let certificates = vec![RTCCertificate::from_key_pair(key_pair)?];
    let dtls_handshake_config = Arc::new(
        dtls::config::ConfigBuilder::default()
            .with_certificates(
                certificates
                    .iter()
                    .map(|c| c.dtls_certificate.clone())
                    .collect(),
            )
            .with_srtp_protection_profiles(vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80])
            .with_extended_master_secret(dtls::config::ExtendedMasterSecretType::Require)
            .build(false, None)?,
    );
    let sctp_endpoint_config = Arc::new(sctp::EndpointConfig::default());
    let sctp_server_config = Arc::new(sctp::ServerConfig::default());
    let server_config = Arc::new(
        ServerConfig::new(certificates)
            .with_dtls_handshake_config(dtls_handshake_config)
            .with_sctp_endpoint_config(sctp_endpoint_config)
            .with_sctp_server_config(sctp_server_config),
    );
    let wait_group = WaitGroup::new();
    let core_num = num_cpus::get();

    for port in media_ports {
        let worker = wait_group.worker();
        let host = cli.host.clone();
        let mut stop_rx = stop_rx.clone();
        let (signaling_tx, signaling_rx) = smol::channel::unbounded::<SignalingMessage>();
        media_port_thread_map.insert(port, signaling_tx);

        let server_config = server_config.clone();
        LocalExecutorBuilder::new()
            .name(format!("media_port_{}", port).as_str())
            .core_id(core_affinity::CoreId {
                id: (port as usize) % core_num,
            })
            .spawn(move || async move {
                let _worker = worker;
                let local_addr = SocketAddr::from_str(&format!("{}:{}", host, port)).unwrap();
                let server_states = Rc::new(RefCell::new(ServerStates::new(server_config, local_addr).unwrap()));

                info!("listening {}:{}...", host, port);

                let server_states_moved = server_states.clone();
                let mut bootstrap = BootstrapUdpServer::new();
                bootstrap.pipeline(Box::new(
                    move || {
                        let pipeline: Pipeline<TaggedBytesMut, TaggedBytesMut> = Pipeline::new();

                        let demuxer_handler = DemuxerHandler::new();
                        let stun_handler = StunHandler::new();
                        // DTLS
                        let dtls_handler = DtlsHandler::new(local_addr, Rc::clone(&server_states_moved));
                        let sctp_handler = SctpHandler::new(local_addr, Rc::clone(&server_states_moved));
                        let data_channel_handler = DataChannelHandler::new();
                        // SRTP
                        let srtp_handler = SrtpHandler::new(Rc::clone(&server_states_moved));
                        let interceptor_handler = InterceptorHandler::new(Rc::clone(&server_states_moved));
                        // Gateway
                        let gateway_handler = GatewayHandler::new(Rc::clone(&server_states_moved));
                        let exception_handler = ExceptionHandler::new();

                        pipeline.add_back(demuxer_handler);
                        pipeline.add_back(stun_handler);
                        // DTLS
                        pipeline.add_back(dtls_handler);
                        pipeline.add_back(sctp_handler);
                        pipeline.add_back(data_channel_handler);
                        // SRTP
                        pipeline.add_back(srtp_handler);
                        pipeline.add_back(interceptor_handler);
                        // Gateway
                        pipeline.add_back(gateway_handler);
                        pipeline.add_back(exception_handler);

                        pipeline.finalize()
                    },
                ));

                if let Err(err) = bootstrap.bind(format!("{}:{}", host, port)).await {
                    error!("bootstrap binding error: {}", err);
                    return;
                }

                loop {
                    tokio::select! {
                        _ = stop_rx.recv() => {
                            info!("media server on {}:{} receives stop signal", host, port);
                            break;
                        }
                        recv = signaling_rx.recv() => {
                            match recv {
                                Ok(signaling_msg) => {
                                    if let Err(err) = handle_signaling_message(&server_states, signaling_msg) {
                                        error!("handle_signaling_message error: {}", err);
                                    }
                                }
                                Err(err) => {
                                    error!("signal_rx recv error: {}", err);
                                    break;
                                }
                            }
                        }
                    }
                }

                bootstrap.graceful_stop().await;
                info!("media server on {}:{} is gracefully down", host, port);
            })?;
    }

    let signaling_addr = SocketAddr::from_str(&format!("{}:{}", cli.host, cli.signal_port))?;
    let signaling_stop_rx = stop_rx.clone();
    let signaling_handle = std::thread::spawn(move || {
        let rt = tokio::runtime::Builder::new_current_thread()
            .enable_io()
            .enable_time()
            .build()
            .unwrap();

        rt.block_on(async {
            let signaling_server = SignalingServer::new(signaling_addr, media_port_thread_map);
            let mut done_rx = signaling_server.run(signaling_stop_rx).await;
            let _ = done_rx.recv().await;
            wait_group.wait().await;
            info!("signaling server is gracefully down");
        })
    });

    LocalExecutorBuilder::default().run(async move {
        println!("Press Ctrl-C to stop");
        std::thread::spawn(move || {
            let mut stop_tx = Some(stop_tx);
            ctrlc::set_handler(move || {
                if let Some(stop_tx) = stop_tx.take() {
                    let _ = stop_tx.try_broadcast(());
                }
            })
            .expect("Error setting Ctrl-C handler");
        });
        let _ = stop_rx.recv().await;
        println!("Wait for Signaling Sever and Media Server Gracefully Shutdown...");
    });

    let _ = signaling_handle.join();

    Ok(())
}