1use sctp::EndpointConfig;
2pub use actix_web::{App, HttpServer, web};
3use clap::{arg, command, Parser};
4use log::error;
5use opentelemetry::{KeyValue};
6use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider};
7use opentelemetry_stdout::MetricsExporterBuilder;
8use opentelemetry_sdk::{Resource, runtime};
9
10use std::net::{IpAddr, Ipv4Addr};
11use std::str::FromStr;
12use tera::Tera;
13use wg::WaitGroup;
14use std::cell::RefCell;
15use std::collections::HashMap;
16use std::io::{Error, ErrorKind};
17use std::net::{SocketAddr, UdpSocket};
18use std::rc::Rc;
19use std::sync::mpsc::{Receiver, SyncSender};
20use std::sync::{Arc, mpsc};
21use std::time::{Duration, Instant};
22use bytes::{Bytes, BytesMut};
23
24use std::io::Write;
25use actix_web::{HttpRequest, HttpResponse};
26
27use retty::channel::{InboundPipeline, Pipeline};
28use retty::transport::{TaggedBytesMut, TransportContext};
29use sfu::{DtlsHandler, ExceptionHandler, GatewayHandler, InterceptorHandler, SrtpHandler, StunHandler, DataChannelHandler, DemuxerHandler, RTCSessionDescription};
30
31pub mod util;
32pub mod messages;
33pub mod interceptors;
34pub mod types;
35pub mod metrics;
36
37use dtls::config;
38use dtls::extension::extension_use_srtp::SrtpProtectionProfile;
39use sfu::{RTCCertificate, SctpHandler, ServerConfig, ServerStates};
40
41
42
43#[derive(Default, Debug, Copy, Clone, clap::ValueEnum)]
44pub enum Level {
45 Error,
46 Warn,
47 #[default]
48 Info,
49 Debug,
50 Trace,
51}
52
53impl From<Level> for log::LevelFilter {
54 fn from(level: Level) -> Self {
55 match level {
56 Level::Error => log::LevelFilter::Error,
57 Level::Warn => log::LevelFilter::Warn,
58 Level::Info => log::LevelFilter::Info,
59 Level::Debug => log::LevelFilter::Debug,
60 Level::Trace => log::LevelFilter::Trace,
61 }
62 }
63}
64
65#[derive(Parser)]
66#[command(name = "SFU Server")]
67#[command(author = "Rusty Rain <y@ngr.tc>")]
68#[command(version = "0.1.0")]
69#[command(about = "An example of SFU Server", long_about = None)]
70pub struct Cli {
71 #[arg(long, default_value_t = format!("127.0.0.1"))]
72 pub host: String,
73 #[arg(short, long, default_value_t = 8080)]
74 pub signal_port: u16,
75 #[arg(long, default_value_t = 3478)]
76 pub media_port_min: u16,
77 #[arg(long, default_value_t = 3495)]
78 pub media_port_max: u16,
79
80 #[arg(short, long)]
81 pub force_local_loop: bool,
82 #[arg(short, long)]
83 pub debug: bool,
84 #[arg(short, long, default_value_t = Level::Info)]
85 #[clap(value_enum)]
86 pub level: Level,
87}
88
89pub fn init_meter_provider(
90 mut _stop_rx: async_broadcast::Receiver<()>,
91 _wait_group: WaitGroup,
92) -> SdkMeterProvider {
93 let exporter = MetricsExporterBuilder::default()
94 .with_encoder(|writer, data| {
95 Ok(serde_json::to_writer_pretty(writer, &data).unwrap())
96 })
97 .build();
98 let reader = PeriodicReader::builder(exporter, runtime::TokioCurrentThread)
99 .with_interval(Duration::from_secs(30))
100 .build();
101 let meter_provider = SdkMeterProvider::builder()
102 .with_reader(reader)
103 .with_resource(Resource::new(vec![KeyValue::new("chat", "metrics")]))
104 .build();
105
106 meter_provider
107}
108
109#[actix_web::main]
110async fn main() -> anyhow::Result<()> {
111
112 let (_stop_tx, _stop_rx) = crossbeam_channel::bounded::<()>(1);
113 let cli = Cli::parse();
114 if cli.debug {
115 env_logger::Builder::new()
116 .format(|buf, record| {
117 writeln!(
118 buf,
119 "{}:{} [{}] {} - {}",
120 record.file().unwrap_or("unknown"),
121 record.line().unwrap_or(0),
122 record.level(),
123 chrono::Local::now().format("%H:%M:%S.%6f"),
124 record.args()
125 )
126 })
127 .filter(None, cli.level.into())
128 .init();
129 }
130
131 let host_addr = if cli.host == "127.0.0.1" && !cli.force_local_loop {
133 util::select_host_address()
134 } else {
135 IpAddr::from_str(&cli.host)?
136 };
137
138 let _media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect();
139 let (_stop_tx, stop_rx) = crossbeam_channel::bounded::<()>(1);
140 let mut media_port_thread_map = HashMap::new();
141
142
143 let key_pair = rcgen::KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)?;
144 let certificates = vec![RTCCertificate::from_key_pair(key_pair)?];
145 let dtls_handshake_config = Arc::new(
146 config::ConfigBuilder::default()
147 .with_certificates(
148 certificates
149 .iter()
150 .map(|c| c.dtls_certificate.clone())
151 .collect(),
152 )
153 .with_srtp_protection_profiles(vec![SrtpProtectionProfile::Srtp_Aes128_Cm_Hmac_Sha1_80])
154 .with_extended_master_secret(config::ExtendedMasterSecretType::Require)
155 .build(false, None)?,
156 );
157 let sctp_endpoint_config = Arc::new(EndpointConfig::default());
158 let sctp_server_config = Arc::new(sctp::ServerConfig::default());
159 let server_config = Arc::new(
160 ServerConfig::new(certificates)
161 .with_dtls_handshake_config(dtls_handshake_config)
162 .with_sctp_endpoint_config(sctp_endpoint_config)
163 .with_sctp_server_config(sctp_server_config)
164 .with_idle_timeout(Duration::from_secs(30)),
165 );
166 let (_stop_meter_tx, stop_meter_rx) = async_broadcast::broadcast::<()>(1);
167 let wait_group = WaitGroup::new();
168 let meter_provider = init_meter_provider(stop_meter_rx, wait_group.clone());
169 let media_ports: Vec<u16> = (cli.media_port_min..=cli.media_port_max).collect();
170 for port in media_ports {
171 let worker = wait_group.add(1);
172 let stop_rx = stop_rx.clone();
173 let (signaling_tx, signaling_rx) = mpsc::sync_channel(1);
174
175 media_port_thread_map.insert(port, signaling_tx);
176 let socket = UdpSocket::bind(format!("{host_addr}:{port}"))
179 .expect(&format!("binding to {host_addr}:{port}"));
180
181 let server_config = server_config.clone();
182 let meter_provider = meter_provider.clone();
183 std::thread::spawn(move || {
184 if let Err(err) = sync_run(stop_rx, socket, signaling_rx, server_config, meter_provider)
185 {
186 eprintln!("run_sfu got error: {}", err);
187 }
188 worker.done();
189 });
190 }
191
192 let media_port_thread_map = Arc::new(media_port_thread_map);
193 let signal_port = cli.signal_port;
194 let host_addr = if cli.host == "127.0.0.1" && !cli.force_local_loop {
195 util::select_host_address()
196 } else {
197 IpAddr::from_str(&cli.host)?
198 };
199
200 println!("Connect a browser to https://{}:{}", host_addr, signal_port);
201
202 HttpServer::new(move || {
203 let tera = Tera::new("templates/**/*").unwrap();
204 App::new()
205 .service(actix_files::Files::new("/static", "./src/static/"))
206 .app_data(web::Data::new(tera))
207 .app_data(web::Data::new(media_port_thread_map.clone()))
208 .service(web::resource("/").route(web::get().to(index)))
209 .service(web::resource("/{path}/{session_id}/{endpoint_id}")
210 .route(web::get().to(web_request))
211 .route(web::post().to(web_request))
212 )
213 })
214 .bind(SocketAddr::new(IpAddr::from(Ipv4Addr::new(0,0,0,0)), 80))?
215 .run()
216 .await?;
217
218 println!("Wait for Signaling Server and Media Server Gracefully Shutdown...");
219 wait_group.wait();
220
221 Ok(())
222}
223
224pub enum SignalingProtocolMessage {
225 Ok {
226 session_id: u64,
227 endpoint_id: u64,
228 },
229 Err {
230 session_id: u64,
231 endpoint_id: u64,
232 reason: Bytes,
233 },
234 Offer {
235 session_id: u64,
236 endpoint_id: u64,
237 offer_sdp: Bytes,
238 },
239 Answer {
240 session_id: u64,
241 endpoint_id: u64,
242 answer_sdp: Bytes,
243 },
244 Leave {
245 session_id: u64,
246 endpoint_id: u64,
247 },
248}
249
250pub struct SignalingMessage {
251 pub request: SignalingProtocolMessage,
252 pub response_tx: SyncSender<SignalingProtocolMessage>,
253}
254
255pub async fn index(tera: web::Data<Tera>) -> HttpResponse {
256 let rendered = tera.render("chat.html.tera", &tera::Context::new()).unwrap();
257 HttpResponse::Ok().content_type("text/html").body(rendered)
258}
259
260
261fn build_pipeline(local_addr: SocketAddr, server_states: Rc<RefCell<ServerStates>>) -> Rc<Pipeline<TaggedBytesMut, TaggedBytesMut>> {
262 let pipeline: Pipeline<TaggedBytesMut, TaggedBytesMut> = Pipeline::new();
263
264 let demuxer_handler = DemuxerHandler::new();
265 let stun_handler = StunHandler::new();
266 let dtls_handler = DtlsHandler::new(local_addr, Rc::clone(&server_states));
268 let sctp_handler = SctpHandler::new(local_addr, Rc::clone(&server_states));
269 let data_channel_handler = DataChannelHandler::new();
270 let srtp_handler = SrtpHandler::new(Rc::clone(&server_states));
272 let interceptor_handler = InterceptorHandler::new(Rc::clone(&server_states));
273 let gateway_handler = GatewayHandler::new(Rc::clone(&server_states));
275 let exception_handler = ExceptionHandler::new();
276
277 pipeline.add_back(demuxer_handler);
278 pipeline.add_back(stun_handler);
279 pipeline.add_back(dtls_handler);
281 pipeline.add_back(sctp_handler);
282 pipeline.add_back(data_channel_handler);
283 pipeline.add_back(srtp_handler);
285 pipeline.add_back(interceptor_handler);
286 pipeline.add_back(gateway_handler);
288 pipeline.add_back(exception_handler);
289
290 pipeline.finalize()
291}
292
293fn write_socket_output(socket: &UdpSocket, pipeline: &Rc<Pipeline<TaggedBytesMut, TaggedBytesMut>>) -> anyhow::Result<()> {
294 while let Some(transmit) = pipeline.poll_transmit() {
295 socket.send_to(&transmit.message, transmit.transport.peer_addr)?;
296 }
297
298 Ok(())
299}
300
301fn handle_offer_message(
302 server_states: &Rc<RefCell<ServerStates>>,
303 session_id: u64,
304 endpoint_id: u64,
305 offer: Bytes,
306 response_tx: SyncSender<SignalingProtocolMessage>,
307) -> anyhow::Result<()> {
308 let try_handle = || -> anyhow::Result<Bytes> {
309 let offer_str = String::from_utf8(offer.to_vec())?;
310 log::info!(
311 "handle_offer_message: {}/{}/{}",
312 session_id,
313 endpoint_id,
314 offer_str,
315 );
316 let mut server_states = server_states.borrow_mut();
317
318 let offer_sdp = serde_json::from_str::<RTCSessionDescription>(&offer_str)?;
319 let answer = server_states.accept_offer(session_id, endpoint_id, None, offer_sdp)?;
320 let answer_str = serde_json::to_string(&answer)?;
321 log::info!("generate answer sdp: {}", answer_str);
322 Ok(Bytes::from(answer_str))
323 };
324
325 match try_handle() {
326 Ok(answer_sdp) => Ok(response_tx
327 .send(SignalingProtocolMessage::Answer {
328 session_id,
329 endpoint_id,
330 answer_sdp,
331 })
332 .map_err(|_| {
333 Error::new(
334 ErrorKind::Other,
335 "failed to send back signaling message response".to_string(),
336 )
337 })?),
338 Err(err) => Ok(response_tx
339 .send(SignalingProtocolMessage::Err {
340 session_id,
341 endpoint_id,
342 reason: Bytes::from(err.to_string()),
343 })
344 .map_err(|_| {
345 Error::new(
346 ErrorKind::Other,
347 "failed to send back signaling message response".to_string(),
348 )
349 })?),
350 }
351}
352
353fn handle_leave_message(_server_states: &Rc<RefCell<ServerStates>>, session_id: u64, endpoint_id: u64, response_tx: SyncSender<SignalingProtocolMessage>) -> anyhow::Result<()> {
354 let try_handle = || -> anyhow::Result<()> {
355 log::info!("handle_leave_message: {}/{}", session_id, endpoint_id,);
356 Ok(())
357 };
358
359 match try_handle() {
360 Ok(_) => Ok(response_tx
361 .send(SignalingProtocolMessage::Ok {
362 session_id,
363 endpoint_id,
364 })
365 .map_err(|_| {
366 Error::new(
367 ErrorKind::Other,
368 "failed to send back signaling message response".to_string(),
369 )
370 })?),
371 Err(err) => Ok(response_tx
372 .send(SignalingProtocolMessage::Err {
373 session_id,
374 endpoint_id,
375 reason: Bytes::from(err.to_string()),
376 })
377 .map_err(|_| {
378 Error::new(
379 ErrorKind::Other,
380 "failed to send back signaling message response".to_string(),
381 )
382 })?),
383 }
384}
385
386pub fn handle_signaling_message(
387 server_states: &Rc<RefCell<ServerStates>>,
388 signaling_msg: SignalingMessage,
389) -> anyhow::Result<()> {
390 match signaling_msg.request {
391 SignalingProtocolMessage::Offer {
392 session_id,
393 endpoint_id,
394 offer_sdp,
395 } => handle_offer_message(
396 server_states,
397 session_id,
398 endpoint_id,
399 offer_sdp,
400 signaling_msg.response_tx,
401 ),
402 SignalingProtocolMessage::Leave {
403 session_id,
404 endpoint_id,
405 } => handle_leave_message(
406 server_states,
407 session_id,
408 endpoint_id,
409 signaling_msg.response_tx,
410 ),
411 SignalingProtocolMessage::Ok {
412 session_id,
413 endpoint_id,
414 }
415 | SignalingProtocolMessage::Err {
416 session_id,
417 endpoint_id,
418 reason: _,
419 }
420 | SignalingProtocolMessage::Answer {
421 session_id,
422 endpoint_id,
423 answer_sdp: _,
424 } => Ok(signaling_msg
425 .response_tx
426 .send(SignalingProtocolMessage::Err {
427 session_id,
428 endpoint_id,
429 reason: Bytes::from("Invalid Request"),
430 })
431 .map_err(|_| {
432 Error::new(
433 ErrorKind::Other,
434 "failed to send back signaling message response".to_string(),
435 )
436 })?),
437 }
438}
439
440pub fn sync_run(
441 stop_rx: crossbeam_channel::Receiver<()>,
442 socket: UdpSocket,
443 rx: Receiver<SignalingMessage>,
444 server_config: Arc<ServerConfig>,
445 _meter_provider: SdkMeterProvider,
446) -> anyhow::Result<()> {
447 let server_states = Rc::new(RefCell::new(ServerStates::new(
448 server_config,
449 socket.local_addr()?,
450 )?));
451
452 println!("listening {}...", socket.local_addr()?);
453
454 let pipeline = build_pipeline(socket.local_addr()?, server_states.clone());
455
456 let mut buf = vec![0; 2000];
457
458 pipeline.transport_active();
459 loop {
460 match stop_rx.try_recv() {
461 Ok(_) => break,
462 Err(err) => {
463 if err.is_disconnected() {
464 break;
465 }
466 }
467 };
468
469 write_socket_output(&socket, &pipeline)?;
470
471 if let Ok(signal_message) = rx.try_recv() {
473 if let Err(err) = handle_signaling_message(&server_states, signal_message) {
474 error!("handle_signaling_message got error:{}", err);
475 continue;
476 }
477 }
478
479 let mut eto = Instant::now() + Duration::from_millis(100);
481 pipeline.poll_timeout(&mut eto);
482
483 let delay_from_now = eto
484 .checked_duration_since(Instant::now())
485 .unwrap_or(Duration::from_secs(0));
486 if delay_from_now.is_zero() {
487 pipeline.handle_timeout(Instant::now());
488 continue;
489 }
490
491 socket
492 .set_read_timeout(Some(delay_from_now))
493 .expect("setting socket read timeout");
494
495 if let Some(input) = read_socket_input(&socket, &mut buf) {
496 pipeline.read(input);
497 }
498
499 pipeline.handle_timeout(Instant::now());
501 }
502 pipeline.transport_inactive();
503
504 println!(
505 "media server on {} is gracefully down",
506 socket.local_addr()?
507 );
508 Ok(())
509}
510
511
512
513fn read_socket_input(socket: &UdpSocket, buf: &mut [u8]) -> Option<TaggedBytesMut> {
514 match socket.recv_from(buf) {
515 Ok((n, peer_addr)) => {
516 return Some(TaggedBytesMut {
517 now: Instant::now(),
518 transport: TransportContext {
519 local_addr: socket.local_addr().unwrap(),
520 peer_addr,
521 ecn: None,
522 },
523 message: BytesMut::from(&buf[..n]),
524 });
525 }
526
527 Err(e) => match e.kind() {
528 ErrorKind::WouldBlock | ErrorKind::TimedOut => None,
530 _ => panic!("UdpSocket read failed: {e:?}"),
531 },
532 }
533}
534
535
536
537pub async fn web_request(
538 req: HttpRequest,
539 bytes: web::Bytes,
540 path: web::Path<(String, u64, u64)>,
541 tera: web::Data<Tera>,
542 media_port_thread_map: web::Data<Arc<HashMap<u16, SyncSender<SignalingMessage>>>>,
543) -> HttpResponse {
544 let (path, session_id, endpoint_id) = path.into_inner();
545
546 if req.method() == actix_web::http::Method::GET {
547 let rendered = tera.render("chat.html.tera", &tera::Context::new()).unwrap();
548 HttpResponse::Ok().content_type("text/html").body(rendered)
549 } else if req.method() == actix_web::http::Method::POST {
550 let mut sorted_ports: Vec<u16> = media_port_thread_map.keys().copied().collect();
551 sorted_ports.sort();
552 assert!(!sorted_ports.is_empty());
553 let port = sorted_ports[(session_id as usize) % sorted_ports.len()];
554 let tx = media_port_thread_map.get(&port);
555
556 if let Some(tx ) = tx {
557 let offer_sdp = bytes.to_vec();
558
559 let (response_tx, response_rx) = mpsc::sync_channel(1);
560 tx.send(SignalingMessage {
561 request: SignalingProtocolMessage::Offer {
562 session_id,
563 endpoint_id,
564 offer_sdp: Bytes::from(offer_sdp),
565 },
566 response_tx,
567 })
568 .expect("to send SignalingMessage instance");
569
570 let response = response_rx.recv().expect("receive answer offer");
571 match response {
572 SignalingProtocolMessage::Answer {
573 session_id: _,
574 endpoint_id: _,
575 answer_sdp,
576 } => HttpResponse::Ok()
577 .content_type("application/json")
578 .body(answer_sdp),
579 _ => HttpResponse::NotFound().finish(),
580 }
581 } else {
582 HttpResponse::NotAcceptable().finish()
583 }
584 } else {
585 HttpResponse::MethodNotAllowed().finish()
586 }
587}