1use crate::command::Command;
2use crate::connection::Connection;
3use bytes::Bytes;
4use crossbeam_utils::CachePadded;
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::net::SocketAddr;
8use std::str;
9use std::sync::{Arc, Mutex};
10use tokio::net::{TcpListener, TcpStream};
11
12type ShardedDb = Arc<Vec<CachePadded<Mutex<HashMap<String, Bytes>>>>>;
13
14pub struct Server {
15 addr: SocketAddr,
16 db: ShardedDb,
17}
18
19impl Server {
20 pub fn new(addr: SocketAddr, num_shards: usize) -> Self {
21 let db = new_sharded_db(num_shards);
22 Server { addr, db }
23 }
24
25 pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
26 let listener = TcpListener::bind(self.addr).await?;
27
28 loop {
29 let (socekt, _) = listener.accept().await?;
30 let db = self.db.clone();
31 tokio::spawn(async move {
32 process(socekt, db).await;
33 });
34 }
35 }
36}
37
38fn new_sharded_db(num_shards: usize) -> ShardedDb {
39 let mut shards = Vec::with_capacity(num_shards);
40 for _ in 0..num_shards {
41 shards.push(CachePadded::new(Mutex::new(HashMap::new())));
42 }
43 Arc::new(shards)
44}
45
46fn hash_key(key: &str) -> usize {
47 let mut hasher = std::collections::hash_map::DefaultHasher::new();
48 key.hash(&mut hasher);
49 hasher.finish() as usize
50}
51
52async fn process(socket: TcpStream, db: ShardedDb) {
53 let mut connection = match Connection::new(socket).await {
54 Ok(connection) => connection,
55 Err(_e) => {
56 return;
57 }
58 };
59
60 loop {
61 let buff = match connection.read_stream().await {
62 Ok(buff) => buff,
63 Err(_e) => {
64 return;
65 }
66 };
67
68 let response: Bytes = match Command::from_bytes(&buff) {
69 Command::Get(cmd) => {
70 if cmd.is_valid() {
71 let idx = hash_key(cmd.key()) % db.len();
72 let db = db[idx].lock().unwrap();
73
74 if let Some(value) = db.get(cmd.key()) {
75 let value_string = std::str::from_utf8(&value).unwrap();
76 Bytes::from(format!("{{\"{}\":\"{}\"}}", cmd.key(), value_string))
77 } else {
78 Bytes::copy_from_slice(b"{}")
79 }
80 } else {
81 Bytes::copy_from_slice(b"{\"GET\": \"Invalid \"}")
82 }
83 }
84 Command::Set(cmd) => {
85 if cmd.is_valid() {
86 let idx: usize = hash_key(cmd.key()) % db.len();
87 let mut db = db[idx].lock().unwrap();
88
89 db.insert(
90 cmd.key().to_string(),
91 Bytes::copy_from_slice(cmd.val().as_bytes()),
92 );
93
94 Bytes::copy_from_slice(b"{\"SET\": \"OK\"}")
95 } else {
96 Bytes::copy_from_slice(b"{\"SET\": \"Invalid \"}")
97 }
98 }
99 Command::MultipleSet(cmd) => {
100 if cmd.is_valid() {
101 for (key, val) in cmd.kv().iter() {
102 let idx: usize = hash_key(key) % db.len();
103 let mut db = db[idx].lock().unwrap();
104
105 db.insert(
106 key.to_string(),
107 Bytes::copy_from_slice(val.as_str().unwrap().to_string().as_bytes()),
108 );
109 }
110 Bytes::copy_from_slice(b"{\"SET\": \"OK\"}")
111 } else {
112 Bytes::copy_from_slice(b"{\"SET\": \"Invalid \"}")
113 }
114 }
115 Command::Invalid => Bytes::copy_from_slice(b"{}"),
116 };
117
118 if let Err(e) = connection.write_stream(&response).await {
119 println!("write stream error: {e}");
120 return;
121 }
122 }
123}