use ferriskey::{Cmd, ConnectionAddr, Value};
use futures_intrusive::sync::ManualResetEvent;
use std::collections::HashMap;
use std::io;
use std::io::Read;
use std::io::Write;
use std::net::TcpListener;
use std::net::TcpStream as StdTcpStream;
use std::str::from_utf8;
use std::sync::{
Arc,
atomic::{AtomicU16, Ordering},
};
use tokio::sync::mpsc::UnboundedSender;
pub struct MockedRequest {
pub expected_message: String,
pub response: String,
}
pub struct ServerMock {
request_sender: UnboundedSender<MockedRequest>,
address: ConnectionAddr,
received_commands: Arc<AtomicU16>,
runtime: Option<tokio::runtime::Runtime>, closing_signal: Arc<ManualResetEvent>,
closing_completed_signal: Arc<ManualResetEvent>,
}
fn read_from_socket(
buffer: &mut [u8],
socket: &mut StdTcpStream,
closing_signal: &Arc<ManualResetEvent>,
) -> Option<usize> {
while !closing_signal.is_set() {
let read_res = socket.read(buffer); match read_res {
Ok(0) => {
return None;
}
Ok(size) => {
return Some(size);
}
Err(ref e)
if e.kind() == io::ErrorKind::WouldBlock
|| e.kind() == io::ErrorKind::Interrupted =>
{
continue;
}
Err(_) => {
return None;
}
}
}
None
}
fn log_resp_message(msg: &str) {
tracing::info!(
"Test - {:?} {}",
std::thread::current().id(),
msg.replace('\r', "\\r").replace('\n', "\\n")
);
}
fn receive_and_respond_to_next_message(
receiver: &mut tokio::sync::mpsc::UnboundedReceiver<MockedRequest>,
socket: &mut StdTcpStream,
received_commands: &Arc<AtomicU16>,
constant_responses: &HashMap<String, Value>,
closing_signal: &Arc<ManualResetEvent>,
) -> bool {
let mut buffer = vec![0; 1024];
let size = match read_from_socket(&mut buffer, socket, closing_signal) {
Some(size) => size,
None => {
return false;
}
};
let message = from_utf8(&buffer[..size]).unwrap().to_string();
log_resp_message(&message);
let setinfo_count = message.matches("SETINFO").count();
if setinfo_count > 0 {
let mut buffer = Vec::new();
for _ in 0..setinfo_count {
super::encode_value(&Value::Okay, &mut buffer).unwrap();
}
socket.write_all(&buffer).unwrap();
return true;
}
if message.contains("HELLO") {
let mut buffer = Vec::new();
let response = Value::Map(vec![
(Value::BulkString(b"proto".to_vec().into()), Value::Int(3)),
(
Value::BulkString(b"role".to_vec().into()),
Value::BulkString(b"master".to_vec().into()),
),
]);
super::encode_value(&response, &mut buffer).unwrap();
socket.write_all(&buffer).unwrap();
return true;
}
if let Some(response) = constant_responses.get(&message) {
let mut buffer = Vec::new();
super::encode_value(response, &mut buffer).unwrap();
socket.write_all(&buffer).unwrap();
return true;
}
let Ok(request) = receiver.try_recv() else {
panic!("Received unexpected message: {message}");
};
received_commands.fetch_add(1, Ordering::AcqRel);
assert_eq!(message, request.expected_message);
socket.write_all(request.response.as_bytes()).unwrap();
true
}
pub trait Mock {
fn get_addresses(&self) -> Vec<ConnectionAddr>;
fn add_response(&self, request: &Cmd, response: String);
fn get_number_of_received_commands(&self) -> u16;
}
impl ServerMock {
pub fn new(constant_responses: HashMap<String, Value>) -> Self {
let listener = super::get_listener_on_available_port();
Self::new_with_listener(constant_responses, listener)
}
pub fn new_with_listener(
constant_responses: HashMap<String, Value>,
listener: TcpListener,
) -> Self {
let (request_sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
let received_commands = Arc::new(AtomicU16::new(0));
let received_commands_clone = received_commands.clone();
let address = ConnectionAddr::Tcp(
"localhost".to_string(),
listener.local_addr().unwrap().port(),
);
let closing_signal = Arc::new(ManualResetEvent::new(false));
let closing_signal_clone = closing_signal.clone();
let closing_completed_signal = Arc::new(ManualResetEvent::new(false));
let closing_completed_signal_clone = closing_completed_signal.clone();
let address_clone = address.clone();
std::thread::spawn(move || {
tracing::info!("Test - ServerMock started on: {address_clone}");
let mut socket: StdTcpStream = listener.accept().unwrap().0;
let _ = socket.set_read_timeout(Some(std::time::Duration::from_millis(10)));
while receive_and_respond_to_next_message(
&mut receiver,
&mut socket,
&received_commands_clone,
&constant_responses,
&closing_signal_clone,
) {}
let _ = socket.shutdown(std::net::Shutdown::Both);
closing_completed_signal_clone.set();
tracing::info!(
"Test - {:?} ServerMock exited", std::thread::current().id()
);
});
Self {
request_sender,
address,
received_commands,
runtime: None,
closing_signal,
closing_completed_signal,
}
}
pub async fn close(self) {
self.closing_signal.set();
self.closing_completed_signal.wait().await;
}
}
impl Mock for ServerMock {
fn get_addresses(&self) -> Vec<ConnectionAddr> {
vec![self.address.clone()]
}
fn add_response(&self, request: &Cmd, response: String) {
let expected_message = String::from_utf8(request.get_packed_command().to_vec()).unwrap();
let _ = self.request_sender.send(MockedRequest {
expected_message,
response,
});
}
fn get_number_of_received_commands(&self) -> u16 {
self.received_commands.load(Ordering::Acquire)
}
}
impl Drop for ServerMock {
fn drop(&mut self) {
self.closing_signal.set();
}
}