use crate::diameter::DiameterMessage;
use crate::dictionary::Dictionary;
use crate::error::{Error, Result};
use crate::transport::Codec;
use std::collections::HashMap;
use std::future::Future;
use std::ops::DerefMut;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::oneshot::Receiver;
use tokio::sync::oneshot::Sender;
use tokio::sync::Mutex;
pub struct DiameterClientConfig {
pub use_tls: bool,
pub verify_cert: bool,
}
pub struct DiameterClient {
config: DiameterClientConfig,
address: String,
writer: Option<Arc<Mutex<dyn AsyncWrite + Send + Unpin>>>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
seq_num: u32,
}
impl DiameterClient {
pub fn new(addr: &str, config: DiameterClientConfig) -> DiameterClient {
DiameterClient {
config,
address: addr.into(),
writer: None,
msg_caches: Arc::new(Mutex::new(HashMap::new())),
seq_num: 0,
}
}
pub async fn connect(&mut self) -> Result<ClientHandler> {
let stream = TcpStream::connect(self.address.clone()).await?;
if self.config.use_tls {
let tls_connector = tokio_native_tls::TlsConnector::from(
native_tls::TlsConnector::builder()
.danger_accept_invalid_certs(!self.config.verify_cert)
.build()?,
);
let tls_stream = tls_connector.connect(&self.address.clone(), stream).await?;
let (reader, writer) = tokio::io::split(tls_stream);
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);
let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler {
reader: Box::new(reader),
msg_caches,
})
} else {
let (reader, writer) = tokio::io::split(stream);
let writer = Arc::new(Mutex::new(writer));
self.writer = Some(writer);
let msg_caches = Arc::clone(&self.msg_caches);
Ok(ClientHandler {
reader: Box::new(reader),
msg_caches,
})
}
}
pub async fn handle(handler: &mut ClientHandler, dictionary: Arc<Dictionary>) {
loop {
match Codec::decode(&mut handler.reader, Arc::clone(&dictionary)).await {
Ok(res) => {
if let Err(e) = Self::process_decoded_msg(handler.msg_caches.clone(), res).await
{
log::error!("Failed to process response; error: {:?}", e);
return;
}
}
Err(e) => {
log::error!("Failed to read message from socket; error: {:?}", e);
return;
}
}
}
}
async fn process_decoded_msg(
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
res: DiameterMessage,
) -> Result<()> {
let hop_by_hop = res.get_hop_by_hop_id();
let sender_opt = {
let mut msg_caches = msg_caches.lock().await;
msg_caches.remove(&hop_by_hop)
};
match sender_opt {
Some(sender) => {
sender.send(res).map_err(|e| {
Error::ClientError(format!("Failed to send response; error: {:?}", e))
})?;
}
None => {
Err(Error::ClientError(format!(
"No request found for hop_by_hop_id {}",
hop_by_hop
)))?;
}
};
Ok(())
}
pub async fn send_message(&mut self, req: DiameterMessage) -> Result<ResponseFuture> {
if let Some(writer) = &self.writer {
let (tx, rx) = oneshot::channel();
let hop_by_hop = req.get_hop_by_hop_id();
{
let mut msg_caches = self.msg_caches.lock().await;
msg_caches.insert(hop_by_hop, tx);
}
let mut writer = writer.lock().await;
Codec::encode(&mut writer.deref_mut(), &req).await?;
Ok(ResponseFuture { receiver: rx })
} else {
Err(Error::ClientError("Not connected".into()))
}
}
pub fn get_next_seq_num(&mut self) -> u32 {
self.seq_num += 1;
self.seq_num
}
}
pub struct ClientHandler {
reader: Box<dyn AsyncRead + Send + Unpin>,
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
}
#[derive(Debug)]
pub struct ResponseFuture {
pub receiver: Receiver<DiameterMessage>,
}
impl Future for ResponseFuture {
type Output = Result<DiameterMessage>;
fn poll(
mut self: Pin<&mut Self>,
ctx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
match Pin::new(&mut self.receiver).poll(ctx) {
std::task::Poll::Ready(result) => match result {
Ok(response) => std::task::Poll::Ready(Ok(response)),
Err(_) => std::task::Poll::Ready(Err(Error::ClientError(
"Response channel closed".into(),
))),
},
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}