1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//! Database Batch management implementation for memdb.
use std::{
    collections::HashMap,
    io,
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc,
    },
};
use tokio::sync::{Mutex, RwLock};

use crate::subnet::rpc::{database::BoxedDatabase, errors::Error};

struct KeyValue {
    key: Vec<u8>,
    value: Vec<u8>,
    delete: bool,
}

/// Batch is a write-only database that commits changes to its host database
/// when Write is called. Although batch is currently async and thread safe it
/// should not be used concurrently.
#[derive(Clone)]
pub struct Batch {
    writes: Arc<Mutex<Vec<KeyValue>>>,
    size: usize,

    db_state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
    db_closed: Arc<AtomicBool>,
}

impl Batch {
    pub fn new(
        db_state: Arc<RwLock<HashMap<Vec<u8>, Vec<u8>>>>,
        db_closed: Arc<AtomicBool>,
    ) -> Self {
        Self {
            writes: Arc::new(Mutex::new(Vec::new())),
            size: 0,
            db_state,
            db_closed,
        }
    }
}

#[tonic::async_trait]
impl crate::subnet::rpc::database::batch::Batch for Batch {
    /// Implements the [`crate::subnet::rpc::database::batch::Batch`] trait.
    async fn put(&mut self, key: &[u8], value: &[u8]) -> io::Result<()> {
        let mut writes = self.writes.lock().await;
        writes.push(KeyValue {
            key: key.to_owned(),
            value: value.to_owned(),
            delete: false,
        });
        self.size += key.len() + value.len();
        Ok(())
    }

    /// Implements the [`crate::subnet::rpc::database::batch::Batch`] trait.
    async fn delete(&mut self, key: &[u8]) -> io::Result<()> {
        let mut writes = self.writes.lock().await;
        writes.push(KeyValue {
            key: key.to_owned(),
            value: vec![],
            delete: true,
        });
        self.size += key.len();
        Ok(())
    }

    /// Implements the [`crate::subnet::rpc::database::batch::Batch`] trait.
    async fn size(&self) -> io::Result<usize> {
        Ok(self.size)
    }

    /// Implements the [`crate::subnet::rpc::database::batch::Batch`] trait.
    async fn write(&self) -> io::Result<()> {
        if self.db_closed.load(Ordering::Relaxed) {
            return Err(Error::DatabaseClosed.to_err());
        }

        let writes = self.writes.lock().await;
        let mut db = self.db_state.write().await;
        for write in writes.iter() {
            if write.delete {
                db.remove(&write.key);
            } else {
                db.insert(write.key.clone(), write.value.clone());
            }
        }
        Ok(())
    }

    /// Implements the [`crate::subnet::rpc::database::batch::Batch`] trait.
    async fn reset(&mut self) {
        let mut writes = self.writes.lock().await;
        if writes.capacity()
            > writes.len() * crate::subnet::rpc::database::batch::MAX_EXCESS_CAPACITY_FACTOR
        {
            let kv: Vec<KeyValue> = Vec::with_capacity(
                writes.capacity() / crate::subnet::rpc::database::batch::CAPACITY_REDUCTION_FACTOR,
            );
            writes.clear();
            *writes = kv;
        } else {
            writes.clear()
        }
    }

    /// Implements the [`crate::subnet::rpc::database::batch::Batch`] trait.
    async fn replay(&self, k: Arc<Mutex<BoxedDatabase>>) -> io::Result<()> {
        let writes = self.writes.lock().await;
        let mut db = k.lock().await;
        for kv in writes.iter() {
            if kv.delete {
                db.delete(&kv.key).await.map_err(|e| {
                    io::Error::new(
                        io::ErrorKind::Other,
                        format!("replay delete failed: {:?}", e),
                    )
                })?;
            } else {
                db.put(&kv.key, &kv.value).await.map_err(|e| {
                    io::Error::new(io::ErrorKind::Other, format!("replay put failed: {:?}", e))
                })?;
            }
        }

        Ok(())
    }
}