1use mlua::prelude::*;
15use std::collections::HashMap;
16
17use crate::util::{check_env_get, check_env_set};
18
19struct EnvOverrides(HashMap<String, String>);
26
27pub fn module(lua: &Lua) -> LuaResult<LuaTable> {
28 if lua.app_data_ref::<EnvOverrides>().is_none() {
29 lua.set_app_data(EnvOverrides(HashMap::new()));
30 }
31
32 let t = lua.create_table()?;
33
34 t.set(
35 "get",
36 lua.create_function(|lua, key: String| {
37 check_env_get(lua, &key)?;
38 if let Some(ov) = lua.app_data_ref::<EnvOverrides>() {
39 if let Some(val) = ov.0.get(&key) {
40 return Ok(Some(val.clone()));
41 }
42 }
43 Ok(std::env::var(&key).ok())
44 })?,
45 )?;
46
47 t.set(
48 "get_or",
49 lua.create_function(|lua, (key, default): (String, String)| {
50 check_env_get(lua, &key)?;
51 if let Some(ov) = lua.app_data_ref::<EnvOverrides>() {
52 if let Some(val) = ov.0.get(&key) {
53 return Ok(val.clone());
54 }
55 }
56 Ok(std::env::var(&key).unwrap_or(default))
57 })?,
58 )?;
59
60 t.set(
61 "set",
62 lua.create_function(|lua, (key, value): (String, String)| {
63 check_env_set(lua, &key)?;
64 let mut ov = lua
65 .app_data_mut::<EnvOverrides>()
66 .ok_or_else(|| LuaError::external("env overlay not initialized"))?;
67 ov.0.insert(key, value);
68 Ok(())
69 })?,
70 )?;
71
72 t.set(
73 "home",
74 lua.create_function(|lua, _: ()| {
75 let home_allowed = check_env_get(lua, "HOME").is_ok();
77
78 let userprofile_allowed = check_env_get(lua, "USERPROFILE").is_ok();
83
84 if !home_allowed && !userprofile_allowed {
85 check_env_get(lua, "HOME")?;
88 }
89
90 if let Some(ov) = lua.app_data_ref::<EnvOverrides>() {
92 if home_allowed {
93 if let Some(val) = ov.0.get("HOME") {
94 return Ok(Some(val.clone()));
95 }
96 }
97 if userprofile_allowed {
98 if let Some(val) = ov.0.get("USERPROFILE") {
99 return Ok(Some(val.clone()));
100 }
101 }
102 }
103
104 if home_allowed {
106 if let Ok(val) = std::env::var("HOME") {
107 return Ok(Some(val));
108 }
109 }
110 if userprofile_allowed {
111 if let Ok(val) = std::env::var("USERPROFILE") {
112 return Ok(Some(val));
113 }
114 }
115
116 Ok(None)
117 })?,
118 )?;
119
120 Ok(t)
121}
122
123#[cfg(test)]
124mod tests {
125 use crate::util::test_eval as eval;
126
127 #[test]
128 fn get_existing_var() {
129 let s: String = eval(
130 r#"
131 return type(std.env.get("PATH"))
132 "#,
133 );
134 assert_eq!(s, "string");
135 }
136
137 #[test]
138 fn get_missing_var_returns_nil() {
139 let b: bool = eval(
140 r#"
141 return std.env.get("__MLUA_STD_DOES_NOT_EXIST__") == nil
142 "#,
143 );
144 assert!(b);
145 }
146
147 #[test]
148 fn get_or_returns_default() {
149 let s: String = eval(
150 r#"
151 return std.env.get_or("__MLUA_STD_MISSING__", "fallback")
152 "#,
153 );
154 assert_eq!(s, "fallback");
155 }
156
157 #[test]
158 fn home_returns_string() {
159 let s: String = eval(
160 r#"
161 local h = std.env.home()
162 return h ~= nil and "ok" or "nil"
163 "#,
164 );
165 assert_eq!(s, "ok");
166 }
167
168 #[test]
169 fn set_and_get_roundtrip() {
170 let s: String = eval(
171 r#"
172 std.env.set("__MLUA_STD_TEST__", "test_value")
173 return std.env.get("__MLUA_STD_TEST__")
174 "#,
175 );
176 assert_eq!(s, "test_value");
177 }
178
179 #[test]
180 fn set_overrides_os_var() {
181 let s: String = eval(
182 r#"
183 std.env.set("PATH", "overridden")
184 return std.env.get("PATH")
185 "#,
186 );
187 assert_eq!(s, "overridden");
188 }
189
190 #[test]
191 fn home_reflects_overlay() {
192 let s: String = eval(
193 r#"
194 std.env.set("HOME", "/custom/home")
195 return std.env.home()
196 "#,
197 );
198 assert_eq!(s, "/custom/home");
199 }
200
201 #[test]
202 fn set_then_get_or_returns_overlay() {
203 let s: String = eval(
204 r#"
205 std.env.set("__MLUA_STD_OVERLAY__", "from_overlay")
206 return std.env.get_or("__MLUA_STD_OVERLAY__", "default")
207 "#,
208 );
209 assert_eq!(s, "from_overlay");
210 }
211
212 #[test]
215 fn home_blocked_when_both_vars_denied() {
216 use crate::policy::EnvAllowList;
217
218 let lua = mlua::Lua::new();
219 let config = crate::config::Config::builder()
221 .env_policy(EnvAllowList::new(["PATH"]))
222 .build()
223 .unwrap();
224 crate::register_all_with(&lua, "std", config).unwrap();
225
226 let result: mlua::Result<mlua::Value> = lua.load(r#"return std.env.home()"#).eval();
227 assert!(
228 result.is_err(),
229 "home() should fail when both HOME and USERPROFILE are denied"
230 );
231 }
232
233 #[test]
234 fn home_allowed_when_home_permitted() {
235 use crate::policy::EnvAllowList;
236
237 let lua = mlua::Lua::new();
238 let config = crate::config::Config::builder()
239 .env_policy(EnvAllowList::new(["HOME", "USERPROFILE"]))
240 .build()
241 .unwrap();
242 crate::register_all_with(&lua, "std", config).unwrap();
243
244 let result: mlua::Result<mlua::Value> = lua.load(r#"return std.env.home()"#).eval();
246 assert!(result.is_ok(), "home() should succeed when HOME is allowed");
247 }
248
249 #[test]
250 fn home_works_with_only_home_allowed() {
251 use crate::policy::EnvAllowList;
252
253 let lua = mlua::Lua::new();
254 let config = crate::config::Config::builder()
258 .env_policy(EnvAllowList::new(["HOME"]))
259 .build()
260 .unwrap();
261 crate::register_all_with(&lua, "std", config).unwrap();
262
263 let result: mlua::Result<mlua::Value> = lua.load(r#"return std.env.home()"#).eval();
264 assert!(
265 result.is_ok(),
266 "home() should succeed when HOME is allowed even if USERPROFILE is denied"
267 );
268 }
269}