rustolio-db 0.1.0

An DB extention for the rustolio HTTP-Server
Documentation
//
// SPDX-License-Identifier: MPL-2.0
//
// Copyright (c) 2026 Tobias Binnewies. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//

use std::{fs, io, path::PathBuf};

use rustolio_utils::{
    bytes::encoding::{decode_from_std, encode_to_std},
    crypto::signature::PublicKey,
};
use tokio::sync::{mpsc, oneshot};

use crate::{Error, Key, KeyType, Result, Value};

pub enum Action {
    Get(
        Key,
        Option<PublicKey>,
        oneshot::Sender<Result<Option<Value>>>,
    ),
    Set(Key, Value),
    Shutdown(oneshot::Sender<()>),
}

#[derive(Debug)]
pub struct Store {
    path: PathBuf,
}

impl Store {
    pub fn init(dir: PathBuf, mut disk_channel: mpsc::Receiver<Action>) {
        tokio::spawn(async move {
            let mut shutdown_tx = None;
            let mut disk_store = Self::new(dir);
            loop {
                // TODO: Skip multiple entries with same key in channel queue
                match disk_channel.recv().await {
                    Some(action) => match action {
                        Action::Shutdown(tx) => {
                            shutdown_tx = Some(tx);
                            disk_channel.close();
                            // No `break` here to consume any buffured messages
                        }
                        Action::Get(k, s, tx) => {
                            let value = disk_store.get(k, s).await;
                            tx.send(value).unwrap();
                        }
                        Action::Set(k, v) => {
                            disk_store.set(k, v).await.unwrap();
                        }
                    },
                    None => {
                        shutdown_tx.unwrap().send(()).unwrap();
                        break;
                    }
                }
            }
        });
    }

    fn new(path: PathBuf) -> Self {
        fs::create_dir_all(&path).unwrap();
        Self { path }
    }

    async fn get(&self, key: Key, signer: Option<PublicKey>) -> Result<Option<Value>> {
        let file = self.file(key);
        tokio::task::spawn_blocking(move || read(&file, key.ty(), signer))
            .await
            .unwrap()
    }

    async fn set(&mut self, key: Key, value: Value) -> Result<()> {
        let file = self.file(key);
        tokio::task::spawn_blocking(move || write(&file, &value));
        Ok(())
    }

    // async fn getset(&mut self, key: Key, value: Value) -> Result<Option<Value>> {
    //     let file = self.file(key);

    //     tokio::task::spawn_blocking(move || {
    //         let old_value = read(&file);
    //         write(&file, &value)?;
    //         old_value
    //     })
    //     .await
    //     .unwrap()
    // }

    fn file(&self, key: Key) -> PathBuf {
        self.path.join(key.hash().to_string())
    }
}

fn read(dir: &PathBuf, ty: KeyType, signer: Option<PublicKey>) -> Result<Option<Value>> {
    let mut file = match fs::File::open(dir) {
        Ok(f) => f,
        Err(e) => {
            if e.kind() == io::ErrorKind::NotFound {
                return Ok(None);
            }
            return Err(Error::FileError(e));
        }
    };
    let value: Value = decode_from_std(&mut file).map_err(Error::EncodingError)?;

    match ty {
        KeyType::ReadWrite | KeyType::ReadSecureWrite => {}
        KeyType::SecureReadWrite => {
            if value.signer() != signer {
                return Err(crate::Error::NotAllowed);
            }
        }
    }

    Ok(Some(value))
}

fn write(dir: &PathBuf, value: &crate::Value) -> Result<()> {
    if let Ok(mut file) = fs::File::open(dir) {
        let current_signer = decode_from_std(&mut file).map_err(Error::EncodingError)?;
        if value.signer() != current_signer {
            return Err(crate::Error::NotAllowed);
        }
    }

    let mut file = fs::File::create(dir).map_err(Error::FileError)?;
    encode_to_std(value, &mut file).map_err(Error::EncodingError)?;
    Ok(())
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use super::*;

    use crate::{test_utils::file::generate_dir, KeyType};

    #[tokio::test]
    async fn test_read_write() {
        let (dir, _guard) = generate_dir("read_write");
        let file = dir.join("filename");

        let res = read(&file, KeyType::ReadWrite, None).unwrap();
        assert!(res.is_none());

        let value: Value = Value::from_value(&"foo").unwrap();
        write(&file, &value).unwrap();
        let res = read(&file, KeyType::ReadWrite, None).unwrap();
        assert_eq!(res.unwrap(), value);
    }

    #[tokio::test]
    async fn test_disk_store() {
        let (dir, _guard) = generate_dir("disk_store");
        let mut disk_store = Store::new(dir.clone());

        let key = Key::from_value(&0, KeyType::ReadWrite);
        let value = Value::from_value(&"foo").unwrap();

        let res = disk_store.get(key, None).await.unwrap();
        assert!(res.is_none());

        disk_store.set(key, value.clone()).await.unwrap();
        // Wait for the value to be actually set -> returns instandy
        tokio::time::sleep(Duration::from_millis(50)).await;

        let res = disk_store.get(key, None).await.unwrap();
        assert_eq!(res.unwrap(), value);

        // let new_value = Value::from("bar".as_bytes());
        // let res = disk_store.getset(key, new_value.clone()).await.unwrap();
        // assert_eq!(res.unwrap(), value);

        // let res = disk_store.get(key).await.unwrap();
        // assert_eq!(res.unwrap(), new_value);
    }

    #[tokio::test]
    async fn test_disk_store_channel() {
        let (dir, _guard) = generate_dir("disk_store_channel");

        let (disk_channel, rx) = mpsc::channel(5);
        Store::init(dir.clone(), rx);

        let key = Key::from_value(&0, KeyType::ReadWrite);
        let value = Value::from_value(&"foo").unwrap();

        let (tx, rx) = oneshot::channel();
        disk_channel.send(Action::Get(key, None, tx)).await.unwrap();
        let res = rx.await.unwrap().unwrap();
        assert!(res.is_none());

        disk_channel
            .send(Action::Set(key, value.clone()))
            .await
            .unwrap();
        // Wait for the value to be actually set -> returns instandy
        tokio::time::sleep(Duration::from_millis(50)).await;

        let (tx, rx) = oneshot::channel();
        disk_channel.send(Action::Get(key, None, tx)).await.unwrap();
        let res = rx.await.unwrap().unwrap();
        assert_eq!(res.unwrap(), value);

        // let new_value = Value::from("bar".as_bytes());
        // let (tx, rx) = oneshot::channel();
        // disk_channel
        //     .send(Action::GetSet(key, new_value.clone(), tx))
        //     .await
        //     .unwrap();
        // let res = rx.await.unwrap().unwrap();
        // assert_eq!(res.unwrap(), value);

        // let (tx, rx) = oneshot::channel();
        // disk_channel.send(Action::Get(key, tx)).await.unwrap();
        // let res = rx.await.unwrap().unwrap();
        // assert_eq!(res.unwrap(), new_value);
    }
}