1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
//
// (c) 2019 Alexander Becker
// Released under the MIT license.
//

#[macro_use]
extern crate log;

pub mod args;
pub mod logo;
pub mod logger;
mod threadp;
mod error;

use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
use std::{process, result, str, io};
use std::sync::Arc;
use chashmap::{CHashMap};

type Result<T> = result::Result<T, Box<std::error::Error>>;
type Response = Result<Option<String>>;
type Command = Box<Fn(Arc<CHashMap<String, String>>) -> Response>;

const SEP: char = '\u{1f}';

fn parse_command(string: String) -> Result<Command> {
    let split: Vec<String> = string.split(SEP).map(|s| s.to_string()).collect();

    match split[0].as_ref() {

        "TEST" => {
            Ok(Box::new(|_| Ok(None)))
        },

        // Locates the given key inside the database and returns an Ok with the
        // corresponding value if existing or an None if not.
        "GET" => {
            if split.len() != 2 {
                Err(Box::new(error::ParseError))
            } else {
                Ok(Box::new(move |map| {
                    if let Some(rg) = map.get(&split[1]) {
                        Ok(Some(rg.to_string()))
                    } else {
                        Ok(None)
                    }
                }))
            }
        },

        // Inserts a specified value at a specified key. Return the old value if existing.
        "INSERT" => {
            if split.len() != 3 {
                Err(Box::new(error::ParseError))
            } else {
                Ok(Box::new(move |map| {
                    if let Some(old) = map.insert(split[1].clone(), split[2].clone()) {
                        Ok(Some(old))
                    } else {
                        Ok(None)
                    }
                }))
            }
        },

        // Removes the value corresponding to a key. Returns Err if key is not found.
        "REMOVE" => {
            if split.len() != 2 {
                Err(Box::new(error::ParseError))
            } else {
                Ok(Box::new(move |map| {
                    if let Some(old) = map.remove(&split[1]) {
                        Ok(Some(old))
                    } else {
                        Err(Box::new(error::StorageError(format!("Key not found: {}", split[1]))))
                    }
                }))
            }
        },

        // Removes all entries from the database.
        "CLEAR" => {
            if split.len() != 1 {
                Err(Box::new(error::ParseError))
            } else {
                Ok(Box::new(move |map| {
                    map.clear();
                    Ok(None)
                }))
            }
        },

        _ => Err(Box::new(error::ParseError))
    }
}

fn serialize(response: Response) -> String {
    match response {
        Ok(message) => {
            let mut string = "OK".to_string();
            if let Some(v) = message {
                string.push(SEP);
                string.push_str(&v);
            }
            string
        },

        Err(e) => {
            let mut string = "ERR".to_string();
            string.push(SEP);
            string.push_str(&format!("{}", e));
            string
        }
    }
}

fn handle_request(stream: &mut TcpStream, map: Arc<CHashMap<String, String>>) -> Response {
    let mut buffer = [0; 524288];
    stream.read(&mut buffer)?;
    let string = str::from_utf8(&buffer[..])?
        .trim_end_matches(char::from(0))
        .to_string();

    debug!("{}", string);

    let command: Command = parse_command(string)?;
    command(map)
}

fn write_response(stream: &mut TcpStream, response: Response) -> Result<()> {
    stream.write(serialize(response).as_bytes())?;
    stream.flush()?;
    Ok(())
}

/// Runs a new instance of yocto
///
/// # Arguments
///
/// * `config` - A config struct specifying the run parameters
///
pub fn run(config: args::Config) {
    let listener = match TcpListener::bind(&config.iface) {
        Ok(l) => {
            info!("Successfully bound to {}", config.iface);
            l
        },

        Err(e) => {
            error!("Failed to bind to {}: {}", config.iface, e);
            process::exit(1);
        }
    };

    let map: Arc<CHashMap<String, String>> = Arc::new(CHashMap::new());

    let pool = threadp::ThreadPool::new(config.threads);

    info!("Initialized thread pool with {} worker threads", config.threads);
    info!("Listening.");

    let iter: Box<dyn Iterator<Item=result::Result<TcpStream, io::Error>>> = if let Some(n) = config.exit_after {
        Box::new(listener.incoming().take(n))
    } else {
        Box::new(listener.incoming())
    };

    for stream in iter {
        match stream {
            Ok(mut stream) => {
                let map = Arc::clone(&map);
                pool.assign(move || {
                    let response = handle_request(&mut stream, map);

                    if let Err(e) = write_response(&mut stream, if let Err(e) = response {
                        error!("{}", e);
                        Err(e)
                    } else { response }) {
                        error!("{}", e);
                    }
                });
            },

            Err(e) => {
                error!("Unable to accept connection: {}", e);
                continue;
            }
        };
    }
}