1use std::{
5 fs,
6 path::{Path, PathBuf},
7};
8
9use anyhow::Context as _;
10use semver::Version;
11
12use crate::utils::version::FOREST_VERSION;
13
14pub(super) const FOREST_DB_DEV_MODE: &str = "FOREST_DB_DEV_MODE";
20
21fn list_versioned_databases(chain_data_path: &Path) -> anyhow::Result<Vec<Version>> {
24 let versions = fs::read_dir(chain_data_path)?
25 .filter_map(|entry| entry.ok())
26 .filter_map(|entry| {
27 let path = entry.path();
28 Version::parse(path.file_name()?.to_str()?).ok()
29 })
30 .collect();
31
32 Ok(versions)
33}
34
35pub(super) fn get_latest_versioned_database(
37 chain_data_path: &Path,
38) -> anyhow::Result<Option<Version>> {
39 let versions = list_versioned_databases(chain_data_path)?;
40 Ok(versions.iter().max().cloned())
41}
42
43pub fn choose_db(chain_data_path: &Path) -> anyhow::Result<PathBuf> {
46 let db = match DbMode::read() {
47 DbMode::Current => chain_data_path.join(FOREST_VERSION.to_string()),
48 DbMode::Latest => {
49 let versions = list_versioned_databases(chain_data_path)?;
50
51 if versions.is_empty() {
52 chain_data_path.join(FOREST_VERSION.to_string())
53 } else {
54 let latest = versions
55 .iter()
56 .max()
57 .context("Failed to find latest versioned database")?; chain_data_path.join(latest.to_string())
59 }
60 }
61 DbMode::Custom(custom) => chain_data_path.join(custom),
62 };
63
64 Ok(db)
65}
66
67#[derive(Debug, PartialEq, Eq, Clone)]
69pub enum DbMode {
70 Current,
73 Latest,
75 Custom(String),
77}
78
79impl DbMode {
80 pub fn read() -> Self {
82 match std::env::var(FOREST_DB_DEV_MODE)
83 .map(|s| s.to_lowercase())
84 .as_deref()
85 {
86 Ok("latest") => Self::Latest,
87 Ok("current") | Err(_) => Self::Current,
88 Ok(val) => Self::Custom(val.to_owned()),
89 }
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use itertools::Itertools;
96
97 use super::*;
98 use std::env;
99
100 #[test]
101 fn test_db_mode() {
102 unsafe {
103 env::set_var(FOREST_DB_DEV_MODE, "latest");
104 assert_eq!(DbMode::read(), DbMode::Latest);
105
106 env::set_var(FOREST_DB_DEV_MODE, "current");
107 assert_eq!(DbMode::read(), DbMode::Current);
108
109 env::set_var(FOREST_DB_DEV_MODE, "cthulhu");
110 assert_eq!(DbMode::read(), DbMode::Custom("cthulhu".to_owned()));
111
112 env::remove_var(FOREST_DB_DEV_MODE);
113 assert_eq!(DbMode::read(), DbMode::Current);
114 }
115 }
116
117 #[test]
118 fn test_list_versioned_databases() {
119 use tempfile::tempdir;
120
121 let dir = tempdir().unwrap();
122 let path = dir.path();
123
124 for dir in &["0.1.0", "0.2.0", "0.3.0", "Elder God", "my0.4.0"] {
125 std::fs::create_dir(path.join(dir)).unwrap();
126 }
127
128 let versions = list_versioned_databases(path)
129 .unwrap()
130 .iter()
131 .sorted()
132 .cloned()
133 .collect_vec();
134 assert_eq!(
135 versions,
136 vec![
137 Version::parse("0.1.0").unwrap(),
138 Version::parse("0.2.0").unwrap(),
139 Version::parse("0.3.0").unwrap()
140 ]
141 );
142 }
143
144 #[test]
145 fn test_choose_db() {
146 use tempfile::tempdir;
147
148 let dir = tempdir().unwrap();
149 let path = dir.path();
150
151 for dir in &["0.1.0", "0.2.0", "0.3.0", "Elder God", "my0.4.0"] {
152 std::fs::create_dir(path.join(dir)).unwrap();
153 }
154
155 let cases = [
156 ("latest", path.join("0.3.0")),
157 ("current", path.join(FOREST_VERSION.to_string())),
158 ("cthulhu", path.join("cthulhu")),
159 ];
160
161 for (mode, expected) in &cases {
162 unsafe { env::set_var(FOREST_DB_DEV_MODE, mode) };
163 let db = choose_db(path).unwrap();
164 assert_eq!(db, *expected);
165 }
166
167 unsafe { env::remove_var(FOREST_DB_DEV_MODE) };
168 let db = choose_db(path).unwrap();
169 assert_eq!(db, path.join(FOREST_VERSION.to_string()));
170 }
171}