Skip to main content

reddb_file/
control_store.rs

1use std::fs::{self, File, OpenOptions};
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5use serde_json::Value as JsonValue;
6
7use crate::embedded::{RdbFileError, RdbFileResult};
8
9pub const DEFAULT_FILE_TERM: u64 = 1;
10pub const LAST_VOTE_TEMP_EXTENSION: &str = "lastvote.tmp";
11pub const TERM_TEMP_EXTENSION: &str = "term.tmp";
12
13#[derive(Debug, Clone, PartialEq, Eq, Default)]
14pub struct DurableLastVote {
15    pub term: u64,
16    pub voted_for: Option<String>,
17}
18
19impl DurableLastVote {
20    pub fn new(term: u64, voted_for: Option<String>) -> Self {
21        Self { term, voted_for }
22    }
23
24    pub fn encode(&self) -> RdbFileResult<Vec<u8>> {
25        let mut obj = serde_json::Map::new();
26        obj.insert("term".to_string(), JsonValue::Number(self.term.into()));
27        obj.insert(
28            "voted_for".to_string(),
29            match &self.voted_for {
30                Some(id) => JsonValue::String(id.clone()),
31                None => JsonValue::Null,
32            },
33        );
34        serde_json::to_vec(&JsonValue::Object(obj))
35            .map_err(|err| RdbFileError::InvalidOperation(format!("serialize last-vote: {err}")))
36    }
37
38    pub fn decode(bytes: &[u8]) -> RdbFileResult<Self> {
39        let value: JsonValue = serde_json::from_slice(bytes)
40            .map_err(|err| RdbFileError::InvalidOperation(format!("parse last-vote: {err}")))?;
41        let obj = value.as_object().ok_or_else(|| {
42            RdbFileError::InvalidOperation("last-vote json is not an object".into())
43        })?;
44        let term = obj
45            .get("term")
46            .and_then(JsonValue::as_u64)
47            .ok_or_else(|| RdbFileError::InvalidOperation("missing term".into()))?;
48        let voted_for = match obj.get("voted_for") {
49            None | Some(JsonValue::Null) => None,
50            Some(JsonValue::String(id)) => Some(id.clone()),
51            Some(_) => {
52                return Err(RdbFileError::InvalidOperation(
53                    "voted_for must be a string or null".into(),
54                ))
55            }
56        };
57        Ok(Self { term, voted_for })
58    }
59}
60
61pub struct FileLastVoteStore {
62    path: PathBuf,
63}
64
65impl FileLastVoteStore {
66    pub fn new(path: impl Into<PathBuf>) -> Self {
67        Self { path: path.into() }
68    }
69
70    pub fn load_file(&self) -> RdbFileResult<DurableLastVote> {
71        match fs::read(&self.path) {
72            Ok(bytes) => DurableLastVote::decode(&bytes),
73            Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
74                Ok(DurableLastVote::default())
75            }
76            Err(err) => Err(err.into()),
77        }
78    }
79
80    pub fn persist_file(&self, vote: &DurableLastVote) -> RdbFileResult<()> {
81        write_bytes_atomically(&self.path, LAST_VOTE_TEMP_EXTENSION, &vote.encode()?)
82    }
83}
84
85pub struct FileTermStore {
86    path: PathBuf,
87    default_term: u64,
88}
89
90impl FileTermStore {
91    pub fn new(path: impl Into<PathBuf>) -> Self {
92        Self {
93            path: path.into(),
94            default_term: DEFAULT_FILE_TERM,
95        }
96    }
97
98    pub fn with_default_term(path: impl Into<PathBuf>, default_term: u64) -> Self {
99        Self {
100            path: path.into(),
101            default_term,
102        }
103    }
104
105    pub fn load_file(&self) -> RdbFileResult<u64> {
106        match fs::read(&self.path) {
107            Ok(bytes) => decode_term(&bytes),
108            Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(self.default_term),
109            Err(err) => Err(err.into()),
110        }
111    }
112
113    pub fn persist_file(&self, term: u64) -> RdbFileResult<()> {
114        write_bytes_atomically(&self.path, TERM_TEMP_EXTENSION, &encode_term(term)?)
115    }
116}
117
118fn encode_term(term: u64) -> RdbFileResult<Vec<u8>> {
119    let mut obj = serde_json::Map::new();
120    obj.insert("term".to_string(), JsonValue::Number(term.into()));
121    serde_json::to_vec(&JsonValue::Object(obj))
122        .map_err(|err| RdbFileError::InvalidOperation(format!("serialize term: {err}")))
123}
124
125fn decode_term(bytes: &[u8]) -> RdbFileResult<u64> {
126    let value: JsonValue = serde_json::from_slice(bytes)
127        .map_err(|err| RdbFileError::InvalidOperation(format!("parse term: {err}")))?;
128    value
129        .get("term")
130        .and_then(JsonValue::as_u64)
131        .ok_or_else(|| RdbFileError::InvalidOperation("missing term".into()))
132}
133
134fn write_bytes_atomically(path: &Path, temp_extension: &str, bytes: &[u8]) -> RdbFileResult<()> {
135    if let Some(parent) = path.parent() {
136        fs::create_dir_all(parent)?;
137    }
138    let tmp = path.with_extension(temp_extension);
139    {
140        let mut file = OpenOptions::new()
141            .create(true)
142            .truncate(true)
143            .write(true)
144            .open(&tmp)?;
145        file.write_all(bytes)?;
146        file.sync_all()?;
147    }
148    fs::rename(&tmp, path)?;
149    if let Some(parent) = path.parent() {
150        if let Ok(dir) = File::open(parent) {
151            let _ = dir.sync_all();
152        }
153    }
154    Ok(())
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use std::time::{SystemTime, UNIX_EPOCH};
161
162    fn temp_path(name: &str) -> PathBuf {
163        let suffix = SystemTime::now()
164            .duration_since(UNIX_EPOCH)
165            .expect("clock")
166            .as_nanos();
167        std::env::temp_dir().join(format!("reddb-file-control-{name}-{suffix}.json"))
168    }
169
170    #[test]
171    fn last_vote_round_trips_and_defaults() {
172        let path = temp_path("lastvote");
173        let store = FileLastVoteStore::new(&path);
174        assert_eq!(
175            store.load_file().expect("default"),
176            DurableLastVote::default()
177        );
178
179        let vote = DurableLastVote::new(9, Some("replica-a".into()));
180        store.persist_file(&vote).expect("persist");
181        assert_eq!(
182            FileLastVoteStore::new(&path).load_file().expect("load"),
183            vote
184        );
185
186        let _ = fs::remove_file(path);
187    }
188
189    #[test]
190    fn term_round_trips_and_defaults() {
191        let path = temp_path("term");
192        let store = FileTermStore::with_default_term(&path, 3);
193        assert_eq!(store.load_file().expect("default"), 3);
194
195        store.persist_file(12).expect("persist");
196        assert_eq!(FileTermStore::new(&path).load_file().expect("load"), 12);
197
198        let _ = fs::remove_file(path);
199    }
200}