use prost::Message;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::time;
use tokio_native_tls::TlsStream;
use crate::client::Registry;
use crate::error::Error;
use crate::payload;
use crate::{proto::common::ProtoMessage, transport::Transport};
pub async fn dispatch_loop(
mut rx: mpsc::UnboundedReceiver<Vec<u8>>,
registry: Registry,
event_handler: Option<Arc<dyn Fn(ProtoMessage) + Send + Sync>>,
) {
while let Some(frame) = rx.recv().await {
match ProtoMessage::decode(frame.as_slice()) {
Err(e) => {
tracing::error!("failed to decode ProtoMessage envelope: {e}");
continue;
}
Ok(envelope) => {
let msg_id = envelope.client_msg_id.clone();
match msg_id {
Some(id) if !id.is_empty() => {
let mut reg = registry.lock().await;
if let Some(tx) = reg.remove(&id) {
let _ = tx.send(envelope);
} else {
tracing::warn!("received response for unknown clientMsgId: {id}");
}
}
_ => {
if let Some(ref handler) = event_handler {
handler(envelope);
} else {
tracing::debug!(
"unhandled event payloadType={:?}",
envelope.payload_type
);
}
}
}
}
}
}
}
pub async fn keepalive(transport: Arc<Transport>) {
let heartbeat = ProtoMessage {
payload_type: payload::HEARTBEAT,
payload: None,
client_msg_id: None,
};
let mut frame = Vec::new();
if heartbeat.encode(&mut frame).is_err() {
return;
}
let mut interval = time::interval(Duration::from_secs(10));
loop {
interval.tick().await;
if let Err(e) = transport.send(&frame).await {
tracing::error!("keepalive send failed: {e}");
return;
}
tracing::debug!("heartbeat sent");
}
}
pub async fn read_loop(
mut reader: tokio::io::ReadHalf<TlsStream<TcpStream>>,
tx: mpsc::UnboundedSender<Vec<u8>>,
) -> Result<(), Error> {
let mut len_buf = [0u8; 4];
loop {
reader.read_exact(&mut len_buf).await?;
let payload_len = u32::from_be_bytes(len_buf) as usize;
let mut payload = vec![0u8; payload_len];
reader.read_exact(&mut payload).await?;
if tx.send(payload).is_err() {
break;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
#[async_std::test]
async fn test() {}
}