use crate::diameter::DiameterMessage;
use crate::dictionary::Dictionary;
use crate::error::Result;
use crate::transport::Codec;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
pub struct DiameterServerConfig {
pub native_tls: Option<native_tls::Identity>,
}
pub struct DiameterServer {
listener: TcpListener,
config: DiameterServerConfig,
}
impl DiameterServer {
pub async fn new(addr: &str, config: DiameterServerConfig) -> Result<DiameterServer> {
let listener = TcpListener::bind(addr).await?;
Ok(DiameterServer { listener, config })
}
pub async fn listen<F, Fut>(&mut self, handler: F, dict: Arc<Dictionary>) -> Result<()>
where
F: Fn(DiameterMessage) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Result<DiameterMessage>> + Send + 'static,
{
loop {
match self.config.native_tls {
Some(ref identity) => {
let acceptor = native_tls::TlsAcceptor::new(identity.clone())?;
let acceptor = tokio_native_tls::TlsAcceptor::from(acceptor);
let (stream, peer_addr) = self.listener.accept().await?;
match acceptor.accept(stream).await {
Ok(stream) => {
Self::handle_peer(
peer_addr,
stream,
handler.clone(),
Arc::clone(&dict),
);
}
Err(e) => {
log::error!("TLS handshake failed: {:?}", e);
}
}
}
None => {
let (stream, peer_addr) = self.listener.accept().await?;
Self::handle_peer(peer_addr, stream, handler.clone(), Arc::clone(&dict));
}
};
}
}
fn handle_peer<F, Fut, S>(peer_addr: SocketAddr, stream: S, handler: F, dict: Arc<Dictionary>)
where
F: Fn(DiameterMessage) -> Fut + Clone + Send + 'static,
Fut: Future<Output = Result<DiameterMessage>> + Send + 'static,
S: AsyncReadExt + AsyncWriteExt + Unpin + Send + 'static,
{
tokio::spawn(async move {
log::info!("[{}] Connection established", peer_addr);
match Self::process_incoming_message(stream, handler, dict).await {
Ok(_) => {
log::info!("[{}] Connection closed", peer_addr);
}
Err(e) => {
log::error!("Fatal error occurred: {:?}", e);
}
}
});
}
async fn process_incoming_message<F, Fut, S>(
mut stream: S,
handler: F,
dict: Arc<Dictionary>,
) -> Result<()>
where
F: Fn(DiameterMessage) -> Fut,
Fut: Future<Output = Result<DiameterMessage>>,
S: AsyncReadExt + AsyncWriteExt + Unpin,
{
loop {
let req = match Codec::decode(&mut stream, Arc::clone(&dict)).await {
Ok(req) => req,
Err(e) => match e {
crate::error::Error::IoError(ref e)
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
return Ok(());
}
_ => {
return Err(e);
}
},
};
let res = handler(req).await?;
Codec::encode(&mut stream, &res).await?;
}
}
}