secure_layer 0.1.0

A mechanism to secure server client communication
Documentation
/**
 * Secure Layer
 * Author: Amit Hendin
 * Created: 22/1/2024
 * Description: This library is a simple mechanism a web server developer can employ in order to secure communications between client and server.
 * The main idea behind it is to have a "rolling shared key", meaning a key that constantly changes in an unpredictable way yet is securely shared between server and client.
 * To achieve this, the client and the server first share a key once, this is called in code the  "init_key_hash" and is generated by the function start_session, after that
 * the client encrypts his message using this key and sends the cipher text to the server and after that modifies his own key by concatenating the plain text to the original key
 * and hashing the result. The server received the cipher_text and decrypts with the original key, then modifies his own key in the same manner using the decrypted plain text.
 * The result is that both client and server have the same new key without exchanging the key over the network in any way. Notice that the server must successfully decrypt the cipher text in order
 * to acquire the new key, also the replication of the process by a third party get more difficult over time meaning to generate the same key one would have to follow the initial key and all of
 * the following plain text exchanged between the client and the server.
 */
extern crate rusqlite;
extern crate aes_gcm;
extern crate sha3;
extern crate rand;

use rusqlite::{Connection, params};
use std::time::{SystemTime, UNIX_EPOCH};
use aes_gcm::{
    aead::{AeadCore, AeadInPlace, KeyInit},
    Aes256Gcm, Nonce, Key // Or `Aes128Gcm`
};
use aes_gcm::aead::Aead;
use rand::Rng;
use sha3::{Digest, Sha3_256};

const KEY_SIZE: usize = 256;

/**
* Gets the current time in milliseconds. Reduced from 128 bit unsigned int to 64 but due to sqlite data type constraints.
*
* Output: the current time in milliseconds as 64 bit unsigned int
*/
fn now_milli() -> u64 {
    // Get the current time as a Duration since the Unix Epoch
    let duration_since_epoch = SystemTime::now().duration_since(UNIX_EPOCH).expect("Time went backwards");

    // Extract the milliseconds from the duration
    let milliseconds = duration_since_epoch.as_millis();
    return milliseconds as u64;
}

/**
* Generates a string of random bytes of a given length.
*
* Input: size - An unsigned integer
* Output: A random string of bytes of length <size>
*/
fn rand_bytes(size: usize) -> Vec<u8> {
    let mut rng = rand::thread_rng();
    (0..size).map(|_| rng.gen()).collect()
}

/**
* The secure hash function.
*
* Input: plain_text - The plain text to hash in bytes
* Output: hash_text - The hash text of the algorithm in use.
*/
fn hash(plain_text: &[u8]) -> Vec<u8> {
    // create a SHA3-256 object
    let mut hasher = Sha3_256::new();
    // write input message
    hasher.update(plain_text);
    // read hash digest
    let result = hasher.finalize();
    return result.to_vec();
}

/**
* A secure decryption function with integrated message authentication
*
* Input: cipher_text - The cipher text to decrypt
* key - The key to decrypt with
* Output: The plain text decrypted from the cipher text using the key, may return error is integrated MAC authentication fails
*/
fn decrypt(cipher_text: &[u8], key: &[u8]) -> Result<Vec<u8>, String> {
    let aes_key = Key::<Aes256Gcm>::from_slice(key);
    let cipher = Aes256Gcm::new(aes_key);
    let plain_text = cipher.decrypt(Nonce::from_slice(&[0u8; 12]), cipher_text).unwrap();

    return Ok(plain_text)
}

/**
 * A secure encryption function with integrated message authentication
 *
 * Input: plain_text - The plain text to encrypt
 * key - The key to encrypt with
 * Output: The cipher text encrypted from the given plain text using the key
 */
fn encrypt(plain_text: &[u8], key: &[u8]) -> Result<Vec<u8>, String> {
    let aes_key = Key::<Aes256Gcm>::from_slice(key);
    let cipher = Aes256Gcm::new(aes_key);
    let cipher_text = cipher.encrypt(Nonce::from_slice(&[0u8; 12]), plain_text).unwrap();

    return Ok(cipher_text)
}

/**
 * Just a struct to hold the sqlite connection
 */
pub struct SecureLayer {
    conn: Connection
}

impl SecureLayer {
    pub fn new(conn: Connection) -> SecureLayer {
        /* Used to store the hashed passcode the secure layer recognises */
        conn.execute(r#"
            CREATE TABLE IF NOT EXISTS passcodes (
                id INTEGER PRIMARY KEY,
                hash BLOB
            )
        "#, ()).unwrap();
        /* Used to store the open sessions for each passcode */
        conn.execute(r#"
            CREATE TABLE IF NOT EXISTS sessions (
                id INTEGER PRIMARY KEY,
                hash BLOB,
                passcode,
                modified INTEGER,
                FOREIGN KEY(passcode) REFERENCES passcodes(id)
            )
        "#, ()).unwrap();

        return SecureLayer {
            conn
        }
    }

    /**
    * Registers a passcode so the secure layer will recognize it
    *
    * Input: passcode - The passcode string to store
    * Output: Can fail on database error.
    */
    pub fn register_passcode(&self, passcode: &str) -> Result<(), String> {
        /* Hash the passcode */
        let passcode_hash = hash(passcode.as_bytes());

        /* Save the hashed passcode */
        self.conn.execute(r#"
            INSERT INTO passcodes (hash)
            VALUES (?1)
        "#, params![passcode_hash]).unwrap();

        return Ok(())
    }

    /**
    * Deletes passcode from the secure layer
    *
    * Input: passcode - The passcode string to store
    * Output: Can fail on database error.
    */
    pub fn delete_passcode(&self, passcode: &str) -> Result<(), String> {
        let passcode_hash = hash(passcode.as_bytes());

        /* Delete the passcode from the system which has the same hash as the given passcode */
        self.conn.execute(r#"
            DELETE FROM passcodes
            WHERE hash=?1
        "#, params![passcode_hash]).unwrap();

        return Ok(())
    }

    /**
    * Deletes session from the system
    *
    * Input: session_id - The id of the session to delete
    * Output: Can fail on database error.
    */
    pub fn end_session(&self, session_id: u64) -> Result<(), String> {
        self.conn.execute(r#"
            DELETE FROM sessions
            WHERE id=?1
        "#, params![session_id]).unwrap();

        return Ok(())
    }

    /**
    * Creates a session in the system for a given passcode
    *
    * Input: passcode - The passcode for which to start a session
    * Output: session_id - The id of the new session,
    * init_key_hash - The initial key after it wased hashed with the passcode and some random string of bytes.
    * Can fail on database error or if the given passcode is not present in the system.
    */
    pub fn start_session(&self, passcode_hash: &[u8]) -> Result<(u64, Vec<u8>), String> {
        /* Find passcode in the system */
        let passcode_id: u64 = match self.conn.query_row(
            "SELECT * FROM passcodes WHERE hash=?",
            params![passcode_hash],
            |row| row.get(0),
        ) {
            Ok(id) => id,
            Err(e) => return Err(format!("sql exception: {:?}", e))
        };

        /* Initialize init key to a random string of bytes */
        let mut init_key_source = rand_bytes(KEY_SIZE);
        /* Xor the passcode hash to the random string of bytes */
        init_key_source.iter_mut()
            .zip(passcode_hash.iter())
            .for_each(|(x1, x2)| *x1 ^= *x2);
        /* Hash the modified init key */
        let init_key_hash = hash(init_key_source.as_slice());

        /* Create a new session and store the hashed init key as the sessions current key */
        self.conn.execute(r#"
            INSERT INTO sessions (hash, passcode, modified)
            VALUES (?1, ?2, ?3)
        "#, params![init_key_hash.as_slice(), passcode_id, now_milli()]).unwrap();

        let session_id = self.conn.last_insert_rowid() as u64;

        Ok((session_id, init_key_hash)) /* return session id and hash init key to the user */
    }

    /**
    * Authenticates the request to the server and undated the shared key.
    *
    * Input: session_id - The id of the session in which the client has made the request
    * cipher_text - The cipher text encrypting the body of the client request
    * Output: Returns the decrypted plain text using the hashed key stored in the database for the given session key, also updated the hashed key for the session
    * with the plain text in order to be synchronised with the client. May fail on decryption error or database error.
    */
    pub fn authenticate_request(&self, session_id: u64, cipher_text: &[u8]) -> Result<Vec<u8>, String> {
        /* fetch the hashed key for the given session */
        let mut session_key: Vec<u8> = match self.conn.query_row(
            "SELECT * FROM sessions WHERE id=?",
            params![session_id],
            |row| row.get(1),
        ) {
            Ok(hash) => hash,
            Err(e) => return Err(format!("sql exception: {:?}", e))
        };

        /* decrypt the cipher text using the session key from the database */
        let plain_text = decrypt(cipher_text, session_key.as_slice()).unwrap();

        /* update the session's hashed key using the decrypted plain text */
        session_key.extend(plain_text.clone());
        let new_hash = hash(session_key.as_slice());
        self.conn.execute(r#"
            UPDATE sessions
            SET hash=?1, modified=?2
            WHERE id=?3
        "#, params![new_hash, now_milli(), session_id]).unwrap();

        /* return the request's plain text */
        Ok(plain_text)
    }
}


#[cfg(test)]
mod tests {
    use super::*;
    const PASSCODE: &str = "abc123";

    #[test]
    fn test_client_server_communications() {
        let conn = Connection::open_in_memory().unwrap();
        let mut sl = SecureLayer::new(conn);

        sl.register_passcode(PASSCODE).unwrap();
        let passcode_hash = hash(PASSCODE.as_bytes());


        let (sid, init_hash) = sl.start_session(passcode_hash.as_slice()).unwrap();
        let mut key = init_hash;


        let plain_text1 = "hello world".as_bytes();
        let payload1 = encrypt(plain_text1, key.as_slice()).unwrap();
        let req1 = sl.authenticate_request(sid, payload1.as_slice()).unwrap();

        println!( "payload1 received as {:?}", String::from_utf8(req1.to_vec()) );

        key.extend(plain_text1);
        key = hash(key.as_slice());

        let plain_text2 = "hai".as_bytes();
        let payload2 = encrypt(plain_text2, key.as_slice()).unwrap();
        let req2 = sl.authenticate_request(sid, payload2.as_slice()).unwrap();

        println!( "payload2 received as {:?}", String::from_utf8(req2.to_vec()) );

        key.extend(plain_text2);
        key = hash(key.as_slice());

        let plain_text3 = "it works!".as_bytes();
        let payload3 = encrypt(plain_text3, key.as_slice()).unwrap();
        let req3 = sl.authenticate_request(sid, payload3.as_slice()).unwrap();

        println!( "payload3 received as {:?}", String::from_utf8(req3.to_vec()) );

        key.extend(plain_text3);
        key = hash(key.as_slice());

        sl.end_session(sid);

        let plain_text4 = "This should'nt work".as_bytes();
        let payload4 = encrypt(plain_text4, key.as_slice()).unwrap();
        let req4 = sl.authenticate_request(sid, payload4.as_slice());

        assert!( req4.is_err() );
    }
}