extern crate rusqlite;
extern crate aes_gcm;
extern crate sha3;
extern crate rand;
use rusqlite::{Connection, params};
use std::time::{SystemTime, UNIX_EPOCH};
use aes_gcm::{
aead::{AeadCore, AeadInPlace, KeyInit},
Aes256Gcm, Nonce, Key };
use aes_gcm::aead::Aead;
use rand::Rng;
use sha3::{Digest, Sha3_256};
const KEY_SIZE: usize = 256;
fn now_milli() -> u64 {
let duration_since_epoch = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time went backwards");
let milliseconds = duration_since_epoch.as_millis();
return milliseconds as u64;
}
fn rand_bytes(size: usize) -> Vec<u8> {
let mut rng = rand::thread_rng();
(0..size).map(|_| rng.gen()).collect()
}
fn hash(plain_text: &[u8]) -> Vec<u8> {
let mut hasher = Sha3_256::new();
hasher.update(plain_text);
let result = hasher.finalize();
return result.to_vec();
}
fn decrypt(cipher_text: &[u8], key: &[u8]) -> Result<Vec<u8>, String> {
let aes_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(aes_key);
let plain_text = cipher.decrypt(Nonce::from_slice(&[0u8; 12]), cipher_text).unwrap();
return Ok(plain_text)
}
fn encrypt(plain_text: &[u8], key: &[u8]) -> Result<Vec<u8>, String> {
let aes_key = Key::<Aes256Gcm>::from_slice(key);
let cipher = Aes256Gcm::new(aes_key);
let cipher_text = cipher.encrypt(Nonce::from_slice(&[0u8; 12]), plain_text).unwrap();
return Ok(cipher_text)
}
pub struct SecureLayer {
conn: Connection
}
impl SecureLayer {
pub fn new(conn: Connection) -> SecureLayer {
conn.execute(r#"
CREATE TABLE IF NOT EXISTS passcodes (
id INTEGER PRIMARY KEY,
hash BLOB
)
"#, ()).unwrap();
conn.execute(r#"
CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY,
hash BLOB,
passcode,
modified INTEGER,
FOREIGN KEY(passcode) REFERENCES passcodes(id)
)
"#, ()).unwrap();
return SecureLayer {
conn
}
}
pub fn register_passcode(&self, passcode: &str) -> Result<(), String> {
let passcode_hash = hash(passcode.as_bytes());
self.conn.execute(r#"
INSERT INTO passcodes (hash)
VALUES (?1)
"#, params![passcode_hash]).unwrap();
return Ok(())
}
pub fn delete_passcode(&self, passcode: &str) -> Result<(), String> {
let passcode_hash = hash(passcode.as_bytes());
self.conn.execute(r#"
DELETE FROM passcodes
WHERE hash=?1
"#, params![passcode_hash]).unwrap();
return Ok(())
}
pub fn end_session(&self, session_id: u64) -> Result<(), String> {
self.conn.execute(r#"
DELETE FROM sessions
WHERE id=?1
"#, params![session_id]).unwrap();
return Ok(())
}
pub fn start_session(&self, passcode_hash: &[u8]) -> Result<(u64, Vec<u8>), String> {
let passcode_id: u64 = match self.conn.query_row(
"SELECT * FROM passcodes WHERE hash=?",
params![passcode_hash],
|row| row.get(0),
) {
Ok(id) => id,
Err(e) => return Err(format!("sql exception: {:?}", e))
};
let mut init_key_source = rand_bytes(KEY_SIZE);
init_key_source.iter_mut()
.zip(passcode_hash.iter())
.for_each(|(x1, x2)| *x1 ^= *x2);
let init_key_hash = hash(init_key_source.as_slice());
self.conn.execute(r#"
INSERT INTO sessions (hash, passcode, modified)
VALUES (?1, ?2, ?3)
"#, params![init_key_hash.as_slice(), passcode_id, now_milli()]).unwrap();
let session_id = self.conn.last_insert_rowid() as u64;
Ok((session_id, init_key_hash))
}
pub fn authenticate_request(&self, session_id: u64, cipher_text: &[u8]) -> Result<Vec<u8>, String> {
let mut session_key: Vec<u8> = match self.conn.query_row(
"SELECT * FROM sessions WHERE id=?",
params![session_id],
|row| row.get(1),
) {
Ok(hash) => hash,
Err(e) => return Err(format!("sql exception: {:?}", e))
};
let plain_text = decrypt(cipher_text, session_key.as_slice()).unwrap();
session_key.extend(plain_text.clone());
let new_hash = hash(session_key.as_slice());
self.conn.execute(r#"
UPDATE sessions
SET hash=?1, modified=?2
WHERE id=?3
"#, params![new_hash, now_milli(), session_id]).unwrap();
Ok(plain_text)
}
}
#[cfg(test)]
mod tests {
use super::*;
const PASSCODE: &str = "abc123";
#[test]
fn test_client_server_communications() {
let conn = Connection::open_in_memory().unwrap();
let mut sl = SecureLayer::new(conn);
sl.register_passcode(PASSCODE).unwrap();
let passcode_hash = hash(PASSCODE.as_bytes());
let (sid, init_hash) = sl.start_session(passcode_hash.as_slice()).unwrap();
let mut key = init_hash;
let plain_text1 = "hello world".as_bytes();
let payload1 = encrypt(plain_text1, key.as_slice()).unwrap();
let req1 = sl.authenticate_request(sid, payload1.as_slice()).unwrap();
println!( "payload1 received as {:?}", String::from_utf8(req1.to_vec()) );
key.extend(plain_text1);
key = hash(key.as_slice());
let plain_text2 = "hai".as_bytes();
let payload2 = encrypt(plain_text2, key.as_slice()).unwrap();
let req2 = sl.authenticate_request(sid, payload2.as_slice()).unwrap();
println!( "payload2 received as {:?}", String::from_utf8(req2.to_vec()) );
key.extend(plain_text2);
key = hash(key.as_slice());
let plain_text3 = "it works!".as_bytes();
let payload3 = encrypt(plain_text3, key.as_slice()).unwrap();
let req3 = sl.authenticate_request(sid, payload3.as_slice()).unwrap();
println!( "payload3 received as {:?}", String::from_utf8(req3.to_vec()) );
key.extend(plain_text3);
key = hash(key.as_slice());
sl.end_session(sid);
let plain_text4 = "This should'nt work".as_bytes();
let payload4 = encrypt(plain_text4, key.as_slice()).unwrap();
let req4 = sl.authenticate_request(sid, payload4.as_slice());
assert!( req4.is_err() );
}
}