use crate::network::client::CommandErrors;
use crate::network::future::Identity;
use crate::network::protocol::Protocol;
use crate::network::response::{MemoryParameters, ResponseBuffer};
use alloc::vec;
use alloc::vec::Vec;
use bytes::BytesMut;
use core::cell::RefCell;
use core::fmt::{Debug, Formatter};
use core::ops::{Deref, DerefMut};
use embedded_nal::TcpClientStack;
use redis_protocol::error::RedisProtocolErrorKind::BufferTooSmall;
pub(crate) struct Network<'a, N: TcpClientStack, P: Protocol> {
protocol: P,
stack: RefCell<&'a mut N>,
socket: RefCell<&'a mut N::TcpSocket>,
buffer: RefCell<ResponseBuffer<P>>,
current_series: RefCell<usize>,
next_index: RefCell<usize>,
clear_buffer: RefCell<bool>,
dropped_futures: RefCell<Vec<Identity>>,
}
impl<'a, N: TcpClientStack, P: Protocol> Network<'a, N, P> {
pub(crate) fn new(
stack: RefCell<&'a mut N>,
socket: RefCell<&'a mut N::TcpSocket>,
protocol: P,
memory: MemoryParameters,
) -> Self {
Network {
protocol: protocol.clone(),
stack,
socket,
buffer: RefCell::new(ResponseBuffer::new(protocol, memory)),
current_series: RefCell::new(0),
next_index: RefCell::new(0),
clear_buffer: RefCell::new(false),
dropped_futures: RefCell::new(vec![]),
}
}
pub(crate) fn receive_chunk(&self) -> nb::Result<(), N::Error> {
let mut local_buffer: [u8; 32] = [0; 32];
let mut stack = self.stack.borrow_mut();
let mut socket = self.socket.borrow_mut();
match stack.receive(socket.deref_mut(), &mut local_buffer) {
Ok(byte_count) => {
self.buffer.borrow_mut().append(&local_buffer[0..byte_count]);
Ok(())
}
Err(error) => nb::Result::Err(error),
}
}
pub(crate) fn is_buffer_full(&self) -> bool {
self.buffer.borrow().is_full()
}
pub(crate) fn send(&self, frame: P::FrameType) -> Result<Identity, CommandErrors> {
if *self.clear_buffer.borrow().deref() {
self.clear_socket();
*self.clear_buffer.borrow_mut() = false;
}
self.handle_dropped_futures();
self.send_frame(frame)?;
let identity = Identity {
series: *self.current_series.borrow(),
index: *self.next_index.borrow(),
};
*self.next_index.borrow_mut() += 1;
Ok(identity)
}
pub(crate) fn send_frame(&self, frame: P::FrameType) -> Result<(), CommandErrors> {
let mut buffer = BytesMut::new();
while let Err(error) = self.protocol.encode_bytes(&mut buffer, &frame) {
if let BufferTooSmall(size) = error.kind() {
buffer.resize(buffer.len() + *size, 0x0);
} else {
return Err(CommandErrors::EncodingCommandFailed);
}
}
let mut stack = self.stack.borrow_mut();
let mut socket = self.socket.borrow_mut();
if stack.send(socket.deref_mut(), buffer.as_ref()).is_err() {
return Err(CommandErrors::TcpError);
};
Ok(())
}
pub(crate) fn is_complete(&self, id: &Identity) -> Result<bool, CommandErrors> {
if self.current_series.borrow().deref() != &id.series {
return Err(CommandErrors::InvalidFuture);
}
if self.buffer.borrow().is_complete(id.index) {
return Ok(true);
}
if self.buffer.borrow().is_faulty() {
self.invalidate_futures();
return Err(CommandErrors::ProtocolViolation);
}
Ok(false)
}
pub(crate) fn take_frame(&self, id: &Identity) -> Option<P::FrameType> {
if self.current_series.borrow().deref() != &id.series {
return None;
}
self.buffer.borrow_mut().take_frame(id.index)
}
pub(crate) fn take_next_frame(&self) -> Option<P::FrameType> {
self.buffer.borrow_mut().take_next_frame()
}
pub(crate) fn invalidate_futures(&self) {
*self.current_series.borrow_mut() += 1;
*self.next_index.borrow_mut() = 0;
*self.clear_buffer.borrow_mut() = true;
}
pub(crate) fn drop_future(&self, id: Identity) {
self.dropped_futures.borrow_mut().push(id);
}
pub fn handle_dropped_futures(&self) {
if self.dropped_futures.borrow().is_empty() {
return;
}
self.receive_all();
let mut buffer = self.buffer.borrow_mut();
self.dropped_futures.borrow_mut().retain(|id| {
if &id.series != self.current_series.borrow().deref() {
return false;
}
if buffer.is_complete(id.index) {
buffer.take_frame(id.index);
return false;
}
true
})
}
pub fn remaining_dropped_futures(&self) -> bool {
!self.dropped_futures.borrow().is_empty()
}
pub fn receive_all(&self) {
let mut result = Ok(());
while result.is_ok() {
result = self.receive_chunk();
}
}
fn clear_socket(&self) {
let mut stack = self.stack.borrow_mut();
let mut socket = self.socket.borrow_mut();
loop {
let mut local_buffer: [u8; 32] = [0; 32];
match stack.receive(socket.deref_mut(), &mut local_buffer) {
Ok(_) => {}
Err(_) => {
break;
}
}
}
self.buffer.borrow_mut().clear();
}
pub fn get_protocol(&self) -> P {
self.protocol.clone()
}
#[cfg(test)]
pub fn get_dropped_future_count(&self) -> usize {
self.dropped_futures.borrow().len()
}
#[cfg(test)]
pub fn get_pending_frame_count(&self) -> usize {
self.buffer.borrow().pending_frame_count()
}
}
impl<N: TcpClientStack, P: Protocol> Debug for Network<'_, N, P> {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Network").finish()
}
}