use anyhow::Result;
use std::{collections::HashMap, time::Duration};
use crate::{messages, session, transport};
const RECEIVE_TIMEOUT: Duration = Duration::from_secs(3);
const MAX_RETRANSMIT_TIME: Duration = Duration::from_secs(10);
pub struct RetrContext<'a> {
received: HashMap<u32, bool>,
sent: HashMap<u32, Vec<u8>>,
subscribed_exchanges: HashMap<u16, bool>,
connection: &'a transport::Connection,
session: &'a session::Session,
}
impl<'b> RetrContext<'b> {
pub fn new<'a: 'b>(
connection: &'a transport::Connection,
session: &'a session::Session,
) -> Self {
Self {
received: HashMap::new(),
sent: HashMap::new(),
subscribed_exchanges: HashMap::new(),
connection,
session,
}
}
fn send_internal(&mut self, d: &[u8]) {
let h = messages::MessageHeader::decode(d).unwrap();
log::trace!("send msg counter:{}", h.0.message_counter);
self.sent.insert(h.0.message_counter, d.to_owned());
}
fn received_ack(&mut self, c: u32) {
log::trace!("received ack counter:{}", c);
self.sent.remove(&c);
}
fn received(&mut self, c: u32) -> bool {
if let std::collections::hash_map::Entry::Vacant(e) = self.received.entry(c) {
e.insert(true);
true
} else {
false
}
}
fn to_resend(&self) -> Option<Vec<u8>> {
if let Some((cnt, msg)) = self.sent.iter().next() {
log::trace!("retransmit counter = {}", cnt);
Some(msg.clone())
} else {
None
}
}
pub fn subscribe_exchange(&mut self, e: u16) {
self.subscribed_exchanges.insert(e, true);
}
pub async fn get_next_message(&mut self) -> Result<messages::Message> {
let start_time = tokio::time::Instant::now();
loop {
if start_time.elapsed() > MAX_RETRANSMIT_TIME {
anyhow::bail!("retransmit timeout exceeded");
}
let resp = self.connection.receive(RECEIVE_TIMEOUT).await;
let resp = match resp {
Ok(v) => v,
Err(_) => {
if let Some(r) = self.to_resend() {
self.connection.send(&r).await?;
}
continue;
}
};
let resp = match self.session.decode_message(&resp) {
Ok(resp) => resp,
Err(e) => {
log::debug!("can't decode incoming message {:?}", e);
continue;
}
};
let decoded = messages::Message::decode(&resp)?;
log::trace!("received message {:?}", decoded);
self.received_ack(decoded.protocol_header.ack_counter);
if !self.received(decoded.message_header.message_counter) {
let ack = messages::ack(
decoded.protocol_header.exchange_id,
decoded.message_header.message_counter as i64,
)?;
let out = self.session.encode_message(&ack)?;
self.connection.send(&out).await?;
log::trace!(
"sending ack for exchange:{} counter:{}",
decoded.protocol_header.exchange_id,
decoded.message_header.message_counter
);
log::trace!(
"dropping duplicit message exchange:{} counter:{}",
decoded.protocol_header.exchange_id,
decoded.message_header.message_counter
);
continue;
}
if decoded.protocol_header.protocol_id
== messages::ProtocolMessageHeader::PROTOCOL_ID_SECURE_CHANNEL
&& decoded.protocol_header.opcode == messages::ProtocolMessageHeader::OPCODE_ACK
{
log::trace!(
"standalone ack exchange:{} ack_counter:{}",
decoded.protocol_header.exchange_id,
decoded.protocol_header.ack_counter
);
continue;
}
let ack = messages::ack(
decoded.protocol_header.exchange_id,
decoded.message_header.message_counter as i64,
)?;
let out = self.session.encode_message(&ack)?;
self.connection.send(&out).await?;
log::trace!(
"sending ack for exchange:{} counter:{}",
decoded.protocol_header.exchange_id,
decoded.message_header.message_counter
);
if !self.subscribed_exchanges.is_empty()
&& !self
.subscribed_exchanges
.contains_key(&decoded.protocol_header.exchange_id)
{
continue;
}
return Ok(decoded);
}
}
pub async fn send(&mut self, data: &[u8]) -> Result<()> {
let out = self.session.encode_message(data)?;
self.send_internal(&out);
self.connection.send(&out).await?;
Ok(())
}
}