use chorba::{decode, encode};
use engine::KVEngine;
use protocol::{
CLEAR, CLEAR_OK, DELETE, DELETE_OK, DeleteRequest, ERROR, GET, GET_OK, GetRequest, GetResponse,
PACKET_BYTE_LIMIT, PACKET_INVALID, PAYLOAD_CHUNK_SIZE, PAYLOAD_FIRST_MAX_VALUE_SIZE, PING,
PONG, SET, SET_OK, SetRequest, generate_packet, parse_start_packet,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
mod engine;
pub mod protocol;
#[tokio::main]
async fn main() {
let engine = KVEngine::new();
let address = "0.0.0.0:13535";
log::debug!("Listening on {}", address);
let listener = tokio::net::TcpListener::bind(address).await.unwrap();
loop {
if let Ok((tcp_stream, socket_address)) = listener.accept().await {
log::debug!("Accepted connection from {}", socket_address);
let engine = engine.clone();
tokio::spawn(async move {
handle_stream(tcp_stream, engine).await;
});
} else {
log::error!("Failed to accept connection");
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum StreamStatus {
#[default]
NONE,
SET(u32),
GET(u32),
DELETE(u32),
}
async fn handle_stream(mut tcp_stream: TcpStream, mut engine: KVEngine) {
let mut read_buffer: [u8; PAYLOAD_CHUNK_SIZE as usize] = [0; PAYLOAD_CHUNK_SIZE as usize];
let mut state = StreamStatus::default();
let mut context_buffer: Vec<u8> = Vec::new();
loop {
let stream_result = tcp_stream.read(&mut read_buffer).await;
let size = match stream_result {
Ok(size) => size,
Err(error) => {
log::error!("Failed to read from socket: {}", error);
break;
}
};
let read_buffer = &read_buffer[..size];
match state {
StreamStatus::NONE => {
if size == 0 {
log::debug!("No data received");
return;
}
log::debug!("Received {:?} bytes", read_buffer);
let first_byte = read_buffer[0];
match first_byte {
PING => {
log::debug!("Received PING");
if let Err(error) = tcp_stream.write_all(&[PONG]).await {
log::error!("Failed to send PONG: {}", error);
continue;
}
}
SET => {
log::debug!("Received SET");
let start_packet = parse_start_packet(read_buffer);
log::debug!("start packet: {:?}", start_packet);
match start_packet {
Some(packet) => {
if packet.length > PACKET_BYTE_LIMIT {
log::error!("packet size exceeds limit");
let _ = tcp_stream.write_all(&[PACKET_INVALID]).await;
continue;
}
if packet.length > PAYLOAD_FIRST_MAX_VALUE_SIZE {
context_buffer.extend_from_slice(packet.value);
state = StreamStatus::SET(packet.length);
continue;
}
process_set(&mut tcp_stream, &mut engine, packet.value).await;
}
None => {
let _ = tcp_stream.write_all(&[PACKET_INVALID]).await;
}
}
}
GET => {
log::debug!("Received GET");
let start_packet = parse_start_packet(read_buffer);
match start_packet {
Some(packet) => {
if packet.length > PACKET_BYTE_LIMIT {
log::error!("packet size exceeds limit");
let _ = tcp_stream.write_all(&[PACKET_INVALID]).await;
continue;
}
if packet.length > PAYLOAD_FIRST_MAX_VALUE_SIZE {
context_buffer.extend_from_slice(packet.value);
state = StreamStatus::GET(packet.length);
continue;
}
process_get(&mut tcp_stream, &mut engine, packet.value).await;
}
None => {
let _ = tcp_stream.write_all(&[PACKET_INVALID]).await;
}
}
}
DELETE => {
log::debug!("Received DELETE");
let start_packet = parse_start_packet(read_buffer);
match start_packet {
Some(packet) => {
if packet.length > PACKET_BYTE_LIMIT {
log::error!("packet size exceeds limit");
let _ = tcp_stream.write_all(&[PACKET_INVALID]).await;
continue;
}
if packet.length > PAYLOAD_FIRST_MAX_VALUE_SIZE {
context_buffer.extend_from_slice(packet.value);
state = StreamStatus::DELETE(packet.length);
continue;
}
process_delete(&mut tcp_stream, &mut engine, packet.value).await;
}
None => {
let _ = tcp_stream.write_all(&[PACKET_INVALID]).await;
}
}
}
CLEAR => {
log::debug!("Received CLEAR");
if let Err(error) = engine.clear_all() {
log::error!("Failed to clear all key-value pairs: {}", error);
let _ = tcp_stream.write_all(&[ERROR]).await;
continue;
}
let _ = tcp_stream.write_all(&[CLEAR_OK]).await;
}
_ => {
log::error!("Unknown command: {}", first_byte);
}
}
}
StreamStatus::SET(length) => {
log::debug!("Stream status: SET");
if read_buffer.is_empty() {
log::debug!("No data received");
state = StreamStatus::NONE;
continue;
}
context_buffer.extend_from_slice(read_buffer);
if context_buffer.len() as u32 >= length {
process_set(&mut tcp_stream, &mut engine, &context_buffer).await;
state = StreamStatus::NONE;
} else {
log::debug!("Waiting for more data...");
}
}
StreamStatus::GET(_) => {
log::debug!("Stream status: GET");
if read_buffer.is_empty() {
log::debug!("No data received");
state = StreamStatus::NONE;
continue;
}
context_buffer.extend_from_slice(read_buffer);
if context_buffer.len() as u32 >= PAYLOAD_FIRST_MAX_VALUE_SIZE {
process_get(&mut tcp_stream, &mut engine, &context_buffer).await;
state = StreamStatus::NONE;
} else {
log::debug!("Waiting for more data...");
}
}
StreamStatus::DELETE(_) => {
log::debug!("Stream status: DELETE");
if read_buffer.is_empty() {
log::debug!("No data received");
state = StreamStatus::NONE;
continue;
}
context_buffer.extend_from_slice(read_buffer);
if context_buffer.len() as u32 >= PAYLOAD_FIRST_MAX_VALUE_SIZE {
process_delete(&mut tcp_stream, &mut engine, &context_buffer).await;
state = StreamStatus::NONE;
} else {
log::debug!("Waiting for more data...");
}
}
}
}
}
pub async fn process_set(stream: &mut TcpStream, engine: &mut KVEngine, bytes: &[u8]) {
let decode_result = decode::<SetRequest>(bytes);
let set_request = match decode_result {
Ok(set_request) => set_request,
Err(error) => {
log::error!("Failed to decode SetRequest: {}", error);
let _ = stream.write_all(&[PACKET_INVALID]).await;
return;
}
};
let key = set_request.key;
let value = set_request.value;
if let Err(error) = engine.set_key_value(key, value) {
log::error!("Failed to set key-value pair: {}", error);
let _ = stream.write_all(&[ERROR]).await;
}
let _ = stream.write_all(&[SET_OK]).await;
}
pub async fn process_get(stream: &mut TcpStream, engine: &mut KVEngine, bytes: &[u8]) {
let decode_result = decode::<GetRequest>(bytes);
let get_request = match decode_result {
Ok(get_request) => get_request,
Err(error) => {
log::error!("Failed to decode GetRequest: {}", error);
let _ = stream.write_all(&[PACKET_INVALID]).await;
return;
}
};
let key = get_request.key;
match engine.get_key_value(&key) {
Ok(value) => {
let get_response = GetResponse { value };
let response_bytes = encode(&get_response);
let response = generate_packet(GET_OK, &response_bytes);
let _ = stream.write_all(&response).await;
}
Err(error) => {
log::error!("Failed to get key-value pair: {}", error);
let _ = stream.write_all(&[ERROR]).await;
}
}
}
pub async fn process_delete(stream: &mut TcpStream, engine: &mut KVEngine, bytes: &[u8]) {
let decode_result = decode::<DeleteRequest>(bytes);
let get_request = match decode_result {
Ok(get_request) => get_request,
Err(error) => {
log::error!("Failed to decode GetRequest: {}", error);
let _ = stream.write_all(&[PACKET_INVALID]).await;
return;
}
};
let key = get_request.key;
if let Err(error) = engine.delete_key_value(&key) {
log::error!("Failed to delete key-value pair: {}", error);
let _ = stream.write_all(&[ERROR]).await;
}
let _ = stream.write_all(&[DELETE_OK]).await;
}