turn-server 4.0.1

A pure rust-implemented turn server.
Documentation
use std::time::{Duration, Instant};

use anyhow::Result;
use tokio::sync::{
    Mutex,
    mpsc::{Sender, channel},
};

use tonic::{
    Request, Response, Status,
    transport::{Channel, Server},
};

#[cfg(feature = "ssl")]
use tonic::transport::{Certificate, ClientTlsConfig, Identity, ServerTlsConfig};

use protos::{
    GetTurnPasswordRequest, PasswordAlgorithm as ProtoPasswordAlgorithm, SessionQueryParams,
    TurnAllocatedEvent, TurnChannelBindEvent, TurnCreatePermissionEvent, TurnDestroyEvent,
    TurnRefreshEvent, TurnServerInfo, TurnSession, TurnSessionStatistics,
    turn_hooks_service_client::TurnHooksServiceClient,
    turn_service_server::{TurnService, TurnServiceServer},
};

use crate::{
    Service,
    codec::{crypto::Password, message::attributes::PasswordAlgorithm},
    config::Config,
    service::session::{Identifier, Session},
    statistics::Statistics,
};

impl From<PasswordAlgorithm> for ProtoPasswordAlgorithm {
    fn from(val: PasswordAlgorithm) -> Self {
        match val {
            PasswordAlgorithm::Md5 => Self::Md5,
            PasswordAlgorithm::Sha256 => Self::Sha256,
        }
    }
}

pub trait IdString {
    type Error;

    fn to_string(&self) -> String;
    fn from_string(s: String) -> Result<Self, Self::Error>
    where
        Self: Sized;
}

impl IdString for Identifier {
    type Error = Status;

    fn to_string(&self) -> String {
        format!("{}/{}", self.source(), self.interface())
    }

    fn from_string(s: String) -> Result<Self, Self::Error> {
        let (source, interface) = s
            .split_once('/')
            .ok_or(Status::invalid_argument("Invalid identifier"))?;

        Ok(Self::new(
            source
                .parse()
                .map_err(|_| Status::invalid_argument("Invalid source address"))?,
            interface
                .parse()
                .map_err(|_| Status::invalid_argument("Invalid interface address"))?,
        ))
    }
}

struct RpcService {
    config: Config,
    service: Service,
    statistics: Statistics,
    uptime: Instant,
}

#[tonic::async_trait]
impl TurnService for RpcService {
    async fn get_info(&self, _: Request<()>) -> Result<Response<TurnServerInfo>, Status> {
        Ok(Response::new(TurnServerInfo {
            software: crate::SOFTWARE.to_string(),
            uptime: self.uptime.elapsed().as_secs(),
            interfaces: self
                .config
                .server
                .get_external_addresses()
                .iter()
                .map(|addr| addr.to_string())
                .collect(),
            port_capacity: self.config.server.port_range.size() as u32,
            port_allocated: self.service.get_session_manager().allocated() as u32,
        }))
    }

    async fn get_session(
        &self,
        request: Request<SessionQueryParams>,
    ) -> Result<Response<TurnSession>, Status> {
        if let Some(Session::Authenticated {
            username,
            allocate_port,
            allocate_channels,
            permissions,
            expires,
            ..
        }) = self
            .service
            .get_session_manager()
            .get_session(&Identifier::from_string(request.into_inner().id)?)
            .get_ref()
        {
            Ok(Response::new(TurnSession {
                username: username.to_string(),
                permissions: permissions.iter().map(|p| *p as i32).collect(),
                channels: allocate_channels.iter().map(|p| *p as i32).collect(),
                port: allocate_port.map(|p| p as i32),
                expires: *expires as i64,
            }))
        } else {
            Err(Status::not_found("Session not found"))
        }
    }

    async fn get_session_statistics(
        &self,
        request: Request<SessionQueryParams>,
    ) -> Result<Response<TurnSessionStatistics>, Status> {
        if let Some(counts) = self
            .statistics
            .get(&Identifier::from_string(request.into_inner().id)?)
        {
            Ok(Response::new(TurnSessionStatistics {
                received_bytes: counts.received_bytes as u64,
                send_bytes: counts.send_bytes as u64,
                received_pkts: counts.received_pkts as u64,
                send_pkts: counts.send_pkts as u64,
                error_pkts: counts.error_pkts as u64,
            }))
        } else {
            Err(Status::not_found("Session not found"))
        }
    }

    async fn destroy_session(
        &self,
        request: Request<SessionQueryParams>,
    ) -> Result<Response<()>, Status> {
        if self
            .service
            .get_session_manager()
            .refresh(&Identifier::from_string(request.into_inner().id)?, 0)
        {
            Ok(Response::new(()))
        } else {
            Err(Status::failed_precondition("Session not found"))
        }
    }
}

pub enum HooksEvent {
    Allocated(TurnAllocatedEvent),
    ChannelBind(TurnChannelBindEvent),
    CreatePermission(TurnCreatePermissionEvent),
    Refresh(TurnRefreshEvent),
    Destroy(TurnDestroyEvent),
}

struct RpcHooksServiceInner {
    event_channel: Sender<HooksEvent>,
    client: Mutex<TurnHooksServiceClient<Channel>>,
}

pub struct RpcHooksService(Option<RpcHooksServiceInner>);

impl RpcHooksService {
    pub async fn new(config: &Config) -> Result<Self> {
        if let Some(hooks) = &config.hooks {
            let (event_channel, mut rx) = channel(hooks.max_channel_size);
            let client = {
                let mut builder = Channel::builder(hooks.endpoint.as_str().try_into()?);

                builder = builder.timeout(Duration::from_secs(hooks.timeout as u64));

                #[cfg(feature = "ssl")]
                if let Some(ssl) = &hooks.ssl {
                    builder = builder.tls_config(
                        ClientTlsConfig::new()
                            .ca_certificate(Certificate::from_pem(ssl.certificate_chain.clone()))
                            .domain_name(
                                url::Url::parse(&hooks.endpoint)?.domain().ok_or_else(|| {
                                    anyhow::anyhow!("Invalid hooks server domain")
                                })?,
                            ),
                    )?;
                }

                TurnHooksServiceClient::new(
                    builder
                        .connect_timeout(Duration::from_secs(5))
                        .timeout(Duration::from_secs(1))
                        .connect_lazy(),
                )
            };

            {
                let mut client = client.clone();

                tokio::spawn(async move {
                    while let Some(event) = rx.recv().await {
                        if match event {
                            HooksEvent::Allocated(event) => {
                                client.on_allocated_event(Request::new(event)).await
                            }
                            HooksEvent::ChannelBind(event) => {
                                client.on_channel_bind_event(Request::new(event)).await
                            }
                            HooksEvent::CreatePermission(event) => {
                                client.on_create_permission_event(Request::new(event)).await
                            }
                            HooksEvent::Refresh(event) => {
                                client.on_refresh_event(Request::new(event)).await
                            }
                            HooksEvent::Destroy(event) => {
                                client.on_destroy_event(Request::new(event)).await
                            }
                        }
                        .is_err()
                        {
                            break;
                        }
                    }
                });
            }

            log::info!("create hooks client, endpoint={}", hooks.endpoint);

            Ok(Self(Some(RpcHooksServiceInner {
                client: Mutex::new(client),
                event_channel,
            })))
        } else {
            Ok(Self(None))
        }
    }

    pub fn send_event(&self, event: HooksEvent) {
        if let Some(inner) = &self.0
            && !inner.event_channel.is_closed()
            && let Err(e) = inner.event_channel.try_send(event)
        {
            log::error!("Failed to send event to hooks server: {}", e);
        }
    }

    pub async fn get_password(
        &self,
        realm: &str,
        username: &str,
        algorithm: PasswordAlgorithm,
    ) -> Option<Password> {
        if let Some(inner) = &self.0 {
            let algorithm: ProtoPasswordAlgorithm = algorithm.into();
            let password = inner
                .client
                .lock()
                .await
                .get_password(Request::new(GetTurnPasswordRequest {
                    realm: realm.to_string(),
                    username: username.to_string(),
                    algorithm: algorithm as i32,
                }))
                .await
                .ok()?
                .into_inner()
                .password;

            return Some(match algorithm {
                ProtoPasswordAlgorithm::Md5 => Password::Md5(password.try_into().ok()?),
                ProtoPasswordAlgorithm::Sha256 => Password::Sha256(password.try_into().ok()?),
                ProtoPasswordAlgorithm::Unspecified => unreachable!(),
            });
        }

        None
    }
}

pub async fn start_server(config: Config, service: Service, statistics: Statistics) -> Result<()> {
    if let Some(api) = &config.api {
        let mut builder = Server::builder();

        builder = builder
            .timeout(Duration::from_secs(api.timeout as u64))
            .accept_http1(false);

        #[cfg(feature = "ssl")]
        if let Some(ssl) = &api.ssl {
            builder = builder.tls_config(ServerTlsConfig::new().identity(Identity::from_pem(
                ssl.certificate_chain.clone(),
                ssl.private_key.clone(),
            )))?;
        }

        log::info!("api server listening: listen={}", api.listen);

        builder
            .add_service(TurnServiceServer::new(RpcService {
                config: config.clone(),
                uptime: Instant::now(),
                statistics,
                service,
            }))
            .serve(api.listen)
            .await?;
    } else {
        std::future::pending().await
    }

    Ok(())
}