Skip to main content

forest/db/
db_mode.rs

1// Copyright 2019-2026 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4use std::{
5    fs,
6    path::{Path, PathBuf},
7};
8
9use anyhow::Context as _;
10use semver::Version;
11
12use crate::utils::version::FOREST_VERSION;
13
14/// Environment variable used to set the development mode
15/// It is used for development purposes. Possible values:
16/// - `current`: use the database matching the current binary version. May result in migrations.
17/// - `latest`: use the latest database version.
18/// - other values: use the database matching the provided name.
19pub(super) const FOREST_DB_DEV_MODE: &str = "FOREST_DB_DEV_MODE";
20
21/// Lists all versioned databases in the chain data directory.
22/// Versioned databases are directories with a `SemVer` version as their name. The rest is discarded.
23fn list_versioned_databases(chain_data_path: &Path) -> anyhow::Result<Vec<Version>> {
24    let versions = fs::read_dir(chain_data_path)?
25        .filter_map(|entry| entry.ok())
26        .filter_map(|entry| {
27            let path = entry.path();
28            Version::parse(path.file_name()?.to_str()?).ok()
29        })
30        .collect();
31
32    Ok(versions)
33}
34
35/// Returns the latest versioned database in the chain data directory (if such one exists).
36pub(super) fn get_latest_versioned_database(
37    chain_data_path: &Path,
38) -> anyhow::Result<Option<Version>> {
39    let versions = list_versioned_databases(chain_data_path)?;
40    Ok(versions.iter().max().cloned())
41}
42
43/// Chooses the correct database directory to use based on the `[FOREST_DB_DEV_MODE]`
44/// environment variable (or the lack of it).
45pub fn choose_db(chain_data_path: &Path) -> anyhow::Result<PathBuf> {
46    let db = match DbMode::read() {
47        DbMode::Current => chain_data_path.join(FOREST_VERSION.to_string()),
48        DbMode::Latest => {
49            let versions = list_versioned_databases(chain_data_path)?;
50
51            if versions.is_empty() {
52                chain_data_path.join(FOREST_VERSION.to_string())
53            } else {
54                let latest = versions
55                    .iter()
56                    .max()
57                    .context("Failed to find latest versioned database")?; // This should never happen
58                chain_data_path.join(latest.to_string())
59            }
60        }
61        DbMode::Custom(custom) => chain_data_path.join(custom),
62    };
63
64    Ok(db)
65}
66
67/// Represents different modes of access to the database
68#[derive(Debug, PartialEq, Eq, Clone)]
69pub enum DbMode {
70    /// Using the database matching the binary version. This is the default, and is the only mode
71    /// in which migrations are run.
72    Current,
73    /// Using the latest versioned database if exists
74    Latest,
75    /// Using a custom database
76    Custom(String),
77}
78
79impl DbMode {
80    /// Returns the database mode based on the environment variable
81    pub fn read() -> Self {
82        match std::env::var(FOREST_DB_DEV_MODE)
83            .map(|s| s.to_lowercase())
84            .as_deref()
85        {
86            Ok("latest") => Self::Latest,
87            Ok("current") | Err(_) => Self::Current,
88            Ok(val) => Self::Custom(val.to_owned()),
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use itertools::Itertools;
96
97    use super::*;
98    use std::env;
99
100    #[test]
101    fn test_db_mode() {
102        unsafe {
103            env::set_var(FOREST_DB_DEV_MODE, "latest");
104            assert_eq!(DbMode::read(), DbMode::Latest);
105
106            env::set_var(FOREST_DB_DEV_MODE, "current");
107            assert_eq!(DbMode::read(), DbMode::Current);
108
109            env::set_var(FOREST_DB_DEV_MODE, "cthulhu");
110            assert_eq!(DbMode::read(), DbMode::Custom("cthulhu".to_owned()));
111
112            env::remove_var(FOREST_DB_DEV_MODE);
113            assert_eq!(DbMode::read(), DbMode::Current);
114        }
115    }
116
117    #[test]
118    fn test_list_versioned_databases() {
119        use tempfile::tempdir;
120
121        let dir = tempdir().unwrap();
122        let path = dir.path();
123
124        for dir in &["0.1.0", "0.2.0", "0.3.0", "Elder God", "my0.4.0"] {
125            std::fs::create_dir(path.join(dir)).unwrap();
126        }
127
128        let versions = list_versioned_databases(path)
129            .unwrap()
130            .iter()
131            .sorted()
132            .cloned()
133            .collect_vec();
134        assert_eq!(
135            versions,
136            vec![
137                Version::parse("0.1.0").unwrap(),
138                Version::parse("0.2.0").unwrap(),
139                Version::parse("0.3.0").unwrap()
140            ]
141        );
142    }
143
144    #[test]
145    fn test_choose_db() {
146        use tempfile::tempdir;
147
148        let dir = tempdir().unwrap();
149        let path = dir.path();
150
151        for dir in &["0.1.0", "0.2.0", "0.3.0", "Elder God", "my0.4.0"] {
152            std::fs::create_dir(path.join(dir)).unwrap();
153        }
154
155        let cases = [
156            ("latest", path.join("0.3.0")),
157            ("current", path.join(FOREST_VERSION.to_string())),
158            ("cthulhu", path.join("cthulhu")),
159        ];
160
161        for (mode, expected) in &cases {
162            unsafe { env::set_var(FOREST_DB_DEV_MODE, mode) };
163            let db = choose_db(path).unwrap();
164            assert_eq!(db, *expected);
165        }
166
167        unsafe { env::remove_var(FOREST_DB_DEV_MODE) };
168        let db = choose_db(path).unwrap();
169        assert_eq!(db, path.join(FOREST_VERSION.to_string()));
170    }
171}