1use std::collections::HashMap;
36use std::ffi::OsString;
37use std::str::FromStr;
38use std::sync::{OnceLock, RwLock};
39
40static OVERRIDES: OnceLock<RwLock<HashMap<String, String>>> = OnceLock::new();
41
42fn map() -> &'static RwLock<HashMap<String, String>> {
43 OVERRIDES.get_or_init(|| RwLock::new(HashMap::new()))
44}
45
46pub fn normalize_key(key: &str) -> String {
48 if key.starts_with("RLX_") {
49 key.to_string()
50 } else {
51 format!("RLX_{key}")
52 }
53}
54
55pub fn set(key: impl AsRef<str>, value: impl Into<String>) {
58 let key = normalize_key(key.as_ref());
59 if let Ok(mut g) = map().write() {
60 g.insert(key, value.into());
61 }
62}
63
64pub fn unset(key: impl AsRef<str>) {
66 let key = normalize_key(key.as_ref());
67 if let Ok(mut g) = map().write() {
68 g.remove(&key);
69 }
70}
71
72pub fn clear_overrides() {
74 if let Ok(mut g) = map().write() {
75 g.clear();
76 }
77}
78
79pub fn var(key: &str) -> Option<String> {
81 let key = normalize_key(key);
82 if let Ok(g) = map().read() {
83 if let Some(v) = g.get(&key) {
84 return Some(v.clone());
85 }
86 }
87 std::env::var(&key).ok()
88}
89
90pub fn var_os(key: &str) -> Option<OsString> {
92 var(key).map(Into::into)
93}
94
95pub fn flag(key: &str) -> bool {
98 match var(key) {
99 Some(v) => truthy(&v),
100 None => false,
101 }
102}
103
104pub fn is_unset(key: &str) -> bool {
106 var(key).is_none()
107}
108
109pub fn parse_or<T: FromStr>(key: &str, default: T) -> T {
111 var(key).and_then(|s| s.parse().ok()).unwrap_or(default)
112}
113
114fn truthy(v: &str) -> bool {
115 let s = v.trim();
116 if s.is_empty() {
117 return false;
118 }
119 match s.to_ascii_lowercase().as_str() {
120 "0" | "false" | "off" | "no" => false,
121 "1" | "true" | "yes" | "on" => true,
122 _ if s.chars().all(|c| c.is_ascii_digit()) => s != "0",
123 _ => true, }
125}
126
127#[derive(Debug, Clone, Default)]
129pub struct RlxEnv {
130 pairs: Vec<(String, String)>,
131}
132
133impl RlxEnv {
134 pub fn new() -> Self {
135 Self::default()
136 }
137
138 pub fn set(mut self, key: impl AsRef<str>, value: impl Into<String>) -> Self {
139 self.pairs.push((normalize_key(key.as_ref()), value.into()));
140 self
141 }
142
143 pub fn flag(mut self, key: impl AsRef<str>, on: bool) -> Self {
144 self.pairs.push((
145 normalize_key(key.as_ref()),
146 if on { "1" } else { "0" }.into(),
147 ));
148 self
149 }
150
151 pub fn apply(self) {
153 for (k, v) in self.pairs {
154 set(&k, v);
155 }
156 }
157}
158
159pub struct RuntimeOverrides {
161 saved: Vec<(String, Option<String>)>,
162}
163
164impl RuntimeOverrides {
165 pub fn install(pairs: impl IntoIterator<Item = (impl AsRef<str>, impl Into<String>)>) -> Self {
167 let mut saved = Vec::new();
168 for (key, value) in pairs {
169 let key = normalize_key(key.as_ref());
170 let prev = map().read().ok().and_then(|g| g.get(&key).cloned());
171 saved.push((key.clone(), prev));
172 set(&key, value);
173 }
174 Self { saved }
175 }
176}
177
178impl Drop for RuntimeOverrides {
179 fn drop(&mut self) {
180 for (key, prev) in self.saved.drain(..) {
181 match prev {
182 Some(v) => set(&key, v),
183 None => unset(&key),
184 }
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use std::sync::Mutex;
193
194 static ENV_TEST_LOCK: Mutex<()> = Mutex::new(());
196
197 fn with_clean_overrides(f: impl FnOnce()) {
198 let _guard = ENV_TEST_LOCK.lock().expect("env test lock poisoned");
199 clear_overrides();
200 f();
201 clear_overrides();
202 }
203
204 #[test]
205 fn code_override_wins_over_process_env() {
206 with_clean_overrides(|| {
207 let _g = RuntimeOverrides::install([("VERBOSE", "2")]);
208 assert_eq!(var("RLX_VERBOSE"), Some("2".into()));
209 assert!(flag("RLX_VERBOSE"));
210 });
211 }
212
213 #[test]
214 fn flag_parses_falsy_override() {
215 with_clean_overrides(|| {
216 set("RLX_DISABLE_MPSGRAPH", "0");
217 assert!(!flag("RLX_DISABLE_MPSGRAPH"));
218 });
219 }
220
221 #[test]
222 fn rlx_env_bulk_apply() {
223 with_clean_overrides(|| {
224 RlxEnv::new()
225 .set("MPSGRAPH_MIN_FLOPS", "42")
226 .flag("USE_ICB", true)
227 .apply();
228 assert_eq!(parse_or("RLX_MPSGRAPH_MIN_FLOPS", 0u64), 42);
229 assert!(flag("RLX_USE_ICB"));
230 });
231 }
232}