rudis_http/
server.rs

1use crate::command::Command;
2use crate::connection::Connection;
3use bytes::Bytes;
4use crossbeam_utils::CachePadded;
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::net::SocketAddr;
8use std::str;
9use std::sync::{Arc, Mutex};
10use tokio::net::{TcpListener, TcpStream};
11
12type ShardedDb = Arc<Vec<CachePadded<Mutex<HashMap<String, Bytes>>>>>;
13
14pub struct Server {
15    addr: SocketAddr,
16    db: ShardedDb,
17}
18
19impl Server {
20    pub fn new(addr: SocketAddr, num_shards: usize) -> Self {
21        let db = new_sharded_db(num_shards);
22        Server { addr, db }
23    }
24
25    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error>> {
26        let listener = TcpListener::bind(self.addr).await?;
27
28        loop {
29            let (socekt, _) = listener.accept().await?;
30            let db = self.db.clone();
31            tokio::spawn(async move {
32                process(socekt, db).await;
33            });
34        }
35    }
36}
37
38fn new_sharded_db(num_shards: usize) -> ShardedDb {
39    let mut shards = Vec::with_capacity(num_shards);
40    for _ in 0..num_shards {
41        shards.push(CachePadded::new(Mutex::new(HashMap::new())));
42    }
43    Arc::new(shards)
44}
45
46fn hash_key(key: &str) -> usize {
47    let mut hasher = std::collections::hash_map::DefaultHasher::new();
48    key.hash(&mut hasher);
49    hasher.finish() as usize
50}
51
52async fn process(socket: TcpStream, db: ShardedDb) {
53    let mut connection = match Connection::new(socket).await {
54        Ok(connection) => connection,
55        Err(_e) => {
56            return;
57        }
58    };
59
60    loop {
61        let buff = match connection.read_stream().await {
62            Ok(buff) => buff,
63            Err(_e) => {
64                return;
65            }
66        };
67
68        let response: Bytes = match Command::from_bytes(&buff) {
69            Command::Get(cmd) => {
70                if cmd.is_valid() {
71                    let idx = hash_key(cmd.key()) % db.len();
72                    let db = db[idx].lock().unwrap();
73
74                    if let Some(value) = db.get(cmd.key()) {
75                        let value_string = std::str::from_utf8(&value).unwrap();
76                        Bytes::from(format!("{{\"{}\":\"{}\"}}", cmd.key(), value_string))
77                    } else {
78                        Bytes::copy_from_slice(b"{}")
79                    }
80                } else {
81                    Bytes::copy_from_slice(b"{\"GET\": \"Invalid \"}")
82                }
83            }
84            Command::Set(cmd) => {
85                if cmd.is_valid() {
86                    let idx: usize = hash_key(cmd.key()) % db.len();
87                    let mut db = db[idx].lock().unwrap();
88
89                    db.insert(
90                        cmd.key().to_string(),
91                        Bytes::copy_from_slice(cmd.val().as_bytes()),
92                    );
93
94                    Bytes::copy_from_slice(b"{\"SET\": \"OK\"}")
95                } else {
96                    Bytes::copy_from_slice(b"{\"SET\": \"Invalid \"}")
97                }
98            }
99            Command::MultipleSet(cmd) => {
100                if cmd.is_valid() {
101                    for (key, val) in cmd.kv().iter() {
102                        let idx: usize = hash_key(key) % db.len();
103                        let mut db = db[idx].lock().unwrap();
104
105                        db.insert(
106                            key.to_string(),
107                            Bytes::copy_from_slice(val.as_str().unwrap().to_string().as_bytes()),
108                        );
109                    }
110                    Bytes::copy_from_slice(b"{\"SET\": \"OK\"}")
111                } else {
112                    Bytes::copy_from_slice(b"{\"SET\": \"Invalid \"}")
113                }
114            }
115            Command::Invalid => Bytes::copy_from_slice(b"{}"),
116        };
117
118        if let Err(e) = connection.write_stream(&response).await {
119            println!("write stream error: {e}");
120            return;
121        }
122    }
123}