use super::tls::{TlsConfig, TlsConnection};
use super::websocket::WebSocketConnection;
use super::{connection::TransportSender, sip_addr::SipAddr, tcp::TcpConnection, SipConnection};
use crate::resolver::SipResolver;
use crate::sip::{Host, HostWithPort, Transport};
use crate::transaction::key::TransactionKey;
use crate::transport::connection::TransportReceiver;
use crate::{transport::TransportEvent, Result};
use async_trait::async_trait;
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::select;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
#[async_trait]
pub trait DomainResolver: Send + Sync {
async fn resolve(&self, target: &SipAddr) -> Result<SipAddr>;
}
#[async_trait]
pub trait TransportWhitelist: Send + Sync {
async fn allow(&self, ip: IpAddr) -> bool;
}
#[async_trait]
impl<F, Fut> TransportWhitelist for F
where
F: Send + Sync + Fn(IpAddr) -> Fut,
Fut: std::future::Future<Output = bool> + Send,
{
async fn allow(&self, ip: IpAddr) -> bool {
(self)(ip).await
}
}
pub(crate) type TransportWhitelistRef = Arc<dyn TransportWhitelist>;
pub struct DefaultDomainResolver {
resolver: SipResolver,
}
impl DefaultDomainResolver {
pub fn new() -> Self {
Self {
resolver: SipResolver::default(),
}
}
pub async fn resolve_with_lookup(&self, target: &SipAddr) -> Result<SipAddr> {
let domain = match &target.addr.host {
Host::Domain(domain) => domain,
_ => {
return Err(crate::Error::DnsResolutionError(target.addr.to_string()));
}
};
let secure = match target.r#type {
Some(Transport::Tls) | Some(Transport::Wss) | Some(Transport::TlsSctp) => true,
_ => false,
};
let addrs = self
.resolver
.lookup(
domain,
target.addr.port.clone(),
target.r#type.clone(),
secure,
)
.await
.map_err(|e| crate::Error::DnsResolutionError(format!("{}: {}", target.addr, e)))?;
if let Some(first) = addrs.first() {
return Ok(SipAddr {
r#type: Some(first.transport),
addr: HostWithPort {
host: Host::IpAddr(first.addr.ip()),
port: Some(first.addr.port().into()),
},
});
}
Err(crate::Error::DnsResolutionError(target.addr.to_string()))
}
}
impl Default for DefaultDomainResolver {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl DomainResolver for DefaultDomainResolver {
async fn resolve(&self, target: &SipAddr) -> Result<SipAddr> {
return self.resolve_with_lookup(target).await;
}
}
pub struct TransportLayerInner {
pub(crate) cancel_token: CancellationToken,
listens: Arc<RwLock<Vec<SipConnection>>>, connections: Arc<RwLock<HashMap<SipAddr, SipConnection>>>, pub(crate) transport_tx: TransportSender,
pub(crate) transport_rx: Mutex<Option<TransportReceiver>>,
pub domain_resolver: Box<dyn DomainResolver>,
whitelist: RwLock<Option<TransportWhitelistRef>>,
tls_config: RwLock<Option<TlsConfig>>,
}
pub(crate) type TransportLayerInnerRef = Arc<TransportLayerInner>;
#[derive(Clone)]
pub struct TransportLayer {
pub outbound: Option<SipAddr>,
pub inner: TransportLayerInnerRef,
}
impl TransportLayer {
pub fn new_with_domain_resolver(
cancel_token: CancellationToken,
domain_resolver: Box<dyn DomainResolver>,
) -> Self {
let (transport_tx, transport_rx) = mpsc::unbounded_channel();
let inner = TransportLayerInner {
cancel_token,
listens: Arc::new(RwLock::new(Vec::new())),
connections: Arc::new(RwLock::new(HashMap::new())),
transport_tx,
transport_rx: Mutex::new(Some(transport_rx)),
domain_resolver,
whitelist: RwLock::new(None),
tls_config: RwLock::new(None),
};
Self {
outbound: None,
inner: Arc::new(inner),
}
}
pub fn new(cancel_token: CancellationToken) -> Self {
let domain_resolver = Box::new(DefaultDomainResolver::default());
Self::new_with_domain_resolver(cancel_token, domain_resolver)
}
pub fn add_transport(&self, transport: SipConnection) {
self.inner.add_listener(transport)
}
pub fn del_transport(&self, addr: &SipAddr) {
self.inner.del_listener(addr)
}
pub fn add_connection(&self, connection: SipConnection) {
self.inner.add_connection(connection);
}
pub fn del_connection(&self, addr: &SipAddr) {
self.inner.del_connection(addr)
}
pub async fn lookup(
&self,
target: &SipAddr,
key: Option<&TransactionKey>,
) -> Result<(SipConnection, SipAddr)> {
self.inner.lookup(target, self.outbound.as_ref(), key).await
}
pub async fn serve_listens(&self) -> Result<()> {
let listens = self.inner.listens.read().clone();
for transport in listens {
let addr = transport.get_addr().clone();
match TransportLayerInner::serve_listener(self.inner.clone(), transport).await {
Ok(()) => {}
Err(e) => {
warn!(error = ?e, %addr, "Failed to serve listener");
}
}
}
Ok(())
}
pub fn get_addrs(&self) -> Vec<SipAddr> {
let mut addrs: Vec<SipAddr> = self
.inner
.listens
.read()
.iter()
.map(|t| t.get_addr().to_owned())
.collect();
let connections = self.inner.connections.read();
for conn in connections.values() {
match conn {
SipConnection::Tcp(tcp) => {
addrs.push(tcp.inner.local_addr.clone());
}
SipConnection::Tls(_) => {}
_ => {}
}
}
addrs
}
pub fn set_whitelist<T>(&self, whitelist: T)
where
T: TransportWhitelist + 'static,
{
self.inner.set_whitelist(Some(Arc::new(whitelist)));
}
pub fn clear_whitelist(&self) {
self.inner.set_whitelist(None);
}
pub fn set_tls_config(&self, tls_config: TlsConfig) {
self.inner.set_tls_config(Some(tls_config));
}
pub fn clear_tls_config(&self) {
self.inner.set_tls_config(None);
}
}
impl TransportLayerInner {
pub(super) fn set_whitelist(&self, whitelist: Option<TransportWhitelistRef>) {
*self.whitelist.write() = whitelist;
}
pub(crate) async fn is_whitelisted(&self, ip: IpAddr) -> bool {
let whitelist = self.whitelist.read().clone();
match whitelist {
Some(whitelist) => whitelist.allow(ip).await,
None => true,
}
}
fn set_tls_config(&self, tls_config: Option<TlsConfig>) {
*self.tls_config.write() = tls_config;
}
fn tls_config(&self) -> Option<TlsConfig> {
self.tls_config.read().clone()
}
pub fn add_listener(&self, connection: SipConnection) {
self.listens.write().push(connection);
}
pub(super) fn del_listener(&self, addr: &SipAddr) {
self.listens.write().retain(|t| t.get_addr() != addr);
}
pub(super) fn add_connection(&self, connection: SipConnection) {
let mut connections = self.connections.write();
connections.insert(connection.get_addr().to_owned(), connection.clone());
drop(connections);
self.serve_connection(connection);
}
pub(super) fn del_connection(&self, addr: &SipAddr) {
self.connections.write().remove(addr);
}
async fn lookup(
&self,
destination: &SipAddr,
outbound: Option<&SipAddr>,
key: Option<&TransactionKey>,
) -> Result<(SipConnection, SipAddr)> {
let target = outbound.unwrap_or(destination);
let tls_config = self.tls_config();
let original_domain = match &target.addr.host {
Host::Domain(domain) => Some(domain.to_string()),
_ => None,
};
let target = if matches!(target.addr.host, Host::Domain(_)) {
&self.domain_resolver.resolve(target).await?
} else {
target
};
debug!(?key, src = %destination, %target, "lookup target");
{
let connections = self.connections.read();
if let Some(transport) = connections.get(&target) {
return Ok((transport.clone(), target.clone()));
}
}
match target.r#type {
Some(Transport::Tcp | Transport::Tls | Transport::Ws | Transport::Wss) => {
let sip_connection = match target.r#type {
Some(Transport::Tcp) => {
let connection =
TcpConnection::connect(target, Some(self.cancel_token.child_token()))
.await?;
SipConnection::Tcp(connection)
}
Some(Transport::Tls) => {
let mut effective_config = tls_config.clone().unwrap_or_default();
if effective_config.sni_hostname.is_none() {
effective_config.sni_hostname = original_domain;
}
let connection = TlsConnection::connect(
target,
Some(&effective_config),
None,
Some(self.cancel_token.child_token()),
)
.await?;
SipConnection::Tls(connection)
}
Some(Transport::Ws | Transport::Wss) => {
let connection = WebSocketConnection::connect(
target,
Some(self.cancel_token.child_token()),
)
.await?;
SipConnection::WebSocket(connection)
}
_ => {
return Err(crate::Error::TransportLayerError(
format!("unsupported transport type: {:?}", target.r#type),
target.to_owned(),
));
}
};
self.add_connection(sip_connection.clone());
return Ok((sip_connection, target.clone()));
}
_ => {}
}
let listens = self.listens.read();
let mut first_udp = None;
for transport in listens.iter() {
let addr = transport.get_addr();
if addr.r#type == Some(Transport::Udp) && first_udp.is_none() {
first_udp = Some(transport.clone());
}
if addr == target {
return Ok((transport.clone(), target.clone()));
}
}
if let Some(transport) = first_udp {
return Ok((transport, target.clone()));
}
Err(crate::Error::TransportLayerError(
format!("unsupported transport type: {:?}", target.r#type),
target.to_owned(),
))
}
pub(super) async fn serve_listener(self: Arc<Self>, transport: SipConnection) -> Result<()> {
let sender = self.transport_tx.clone();
match transport {
SipConnection::Udp(transport) => {
let transport_layer_inner = self.clone();
tokio::spawn(async move {
transport
.serve_loop_with_whitelist(sender, Some(transport_layer_inner))
.await
});
Ok(())
}
SipConnection::TcpListener(connection) => connection.serve_listener(self.clone()).await,
#[cfg(feature = "rustls")]
SipConnection::TlsListener(connection) => connection.serve_listener(self.clone()).await,
#[cfg(feature = "websocket")]
SipConnection::WebSocketListener(connection) => {
connection.serve_listener(self.clone()).await
}
_ => {
warn!(
"serve_listener: unsupported transport type: {:?}",
transport.get_addr()
);
Ok(())
}
}
}
pub fn serve_connection(&self, transport: SipConnection) {
let sub_token = self.cancel_token.child_token();
let sender_clone = self.transport_tx.clone();
info!(addr=%transport.get_addr(), "serve_connection: starting serve_loop");
tokio::spawn(async move {
match sender_clone.send(TransportEvent::New(transport.clone())) {
Ok(()) => {
info!(addr=%transport.get_addr(), "serve_connection: New event sent");
}
Err(e) => {
warn!(addr=%transport.get_addr(), error = ?e, "Error sending new connection event");
return;
}
}
select! {
_ = sub_token.cancelled() => { }
result = async {
transport.serve_loop(sender_clone.clone()).await
} => {
if let Err(e) = result {
warn!(addr=%transport.get_addr(), error = %e, "serve_loop error");
}
}
}
info!(addr=%transport.get_addr(), "transport serve_loop exited");
transport.close().await.ok();
sender_clone.send(TransportEvent::Closed(transport)).ok();
});
}
}
impl Drop for TransportLayer {
fn drop(&mut self) {
self.inner.cancel_token.cancel();
}
}
#[cfg(test)]
mod tests {
use crate::resolver::SipResolver;
use crate::sip::uri::ParamsExt;
use crate::sip::{Host, HostWithPort, Transport};
use crate::{
transport::{udp::UdpConnection, SipAddr},
Result,
};
#[tokio::test]
async fn test_lookup() -> Result<()> {
let mut tl = super::TransportLayer::new(tokio_util::sync::CancellationToken::new());
let first_uri = SipAddr {
r#type: Some(Transport::Udp),
addr: HostWithPort {
host: Host::IpAddr("127.0.0.1".parse()?),
port: Some(5060.into()),
},
};
assert!(tl.lookup(&first_uri, None).await.is_err());
let udp_peer = UdpConnection::create_connection(
"127.0.0.1:0".parse()?,
None,
Some(tl.inner.cancel_token.child_token()),
)
.await?;
let udp_peer_addr = udp_peer.get_addr().to_owned();
tl.add_transport(udp_peer.into());
let (target, _) = tl.lookup(&first_uri, None).await?;
assert_eq!(target.get_addr(), &udp_peer_addr);
let outbound_peer = UdpConnection::create_connection(
"127.0.0.1:0".parse()?,
None,
Some(tl.inner.cancel_token.child_token()),
)
.await?;
let outbound = outbound_peer.get_addr().to_owned();
tl.add_transport(outbound_peer.into());
tl.outbound = Some(outbound.clone());
let (target, _) = tl.lookup(&first_uri, None).await?;
assert_eq!(target.get_addr(), &outbound);
Ok(())
}
#[tokio::test]
async fn test_sip_dns_lookup() -> Result<()> {
let check_list = vec![
(
"sip:bob@127.0.0.1:5061;transport=udp",
("bob", "127.0.0.1", 5061, Transport::Udp),
),
(
"sip:bob@127.0.0.1:5062;transport=tcp",
("bob", "127.0.0.1", 5062, Transport::Tcp),
),
(
"sip:bob@localhost:5063;transport=tls",
("bob", "127.0.0.1", 5063, Transport::Tls),
),
];
let resolver = SipResolver::default();
for item in check_list {
let uri = crate::sip::uri::Uri::try_from(item.0)?;
let domain = match &uri.host_with_port.host {
crate::sip::Host::Domain(d) => d.clone(),
crate::sip::Host::IpAddr(ip) => crate::sip::Domain::from(ip.to_string()),
};
let secure = match uri.scheme {
Some(crate::sip::Scheme::Sips) => true,
_ => false,
};
let transport_param = uri.transport().map(|t| *t);
let targets = resolver
.lookup(&domain, uri.host_with_port.port, transport_param, secure)
.await;
assert!(targets.is_ok(), "Failed to resolve {}", item.0);
let targets = targets.unwrap();
assert!(!targets.is_empty());
let target = &targets[0];
assert_eq!(uri.user().unwrap(), item.1 .0);
assert_eq!(target.transport, item.1 .3);
assert_eq!(target.addr.ip().to_string(), item.1 .1);
assert_eq!(target.addr.port(), item.1 .2);
}
Ok(())
}
#[tokio::test]
async fn test_serve_listens() -> Result<()> {
let tl = super::TransportLayer::new(tokio_util::sync::CancellationToken::new());
let udp_conn = UdpConnection::create_connection(
"127.0.0.1:0".parse()?,
None,
Some(tl.inner.cancel_token.child_token()),
)
.await?;
let addr = udp_conn.get_addr().clone();
tl.add_transport(udp_conn.into());
tl.serve_listens().await?;
let addrs = tl.get_addrs();
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0], addr);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
drop(tl);
Ok(())
}
}