1pub fn resolve_vars(
21 input: &str,
22 params: Option<&std::collections::HashMap<String, String>>,
23) -> crate::error::Result<String> {
24 let mut result = input.to_string();
25 let mut search_from = 0;
26 while let Some(rel_start) = result[search_from..].find("${") {
27 let start = search_from + rel_start;
28 let Some(rel_end) = result[start..].find('}') else {
29 break;
30 };
31 let end = start + rel_end;
32 let var_name = &result[start + 2..end];
33
34 let value = if var_name.is_empty() {
35 String::new()
38 } else if let Some(v) = params.and_then(|p| p.get(var_name)) {
39 v.clone()
40 } else {
41 match std::env::var(var_name) {
42 Ok(v) => v,
43 Err(_) => anyhow::bail!(
44 "environment variable '{}' referenced in config is not set \
45 (a missing secret silently becomes an empty string — refusing)",
46 var_name
47 ),
48 }
49 };
50
51 if value.contains('\0') {
63 anyhow::bail!(
64 "value for '${{{var_name}}}' contains a NUL byte; refusing to substitute it \
65 (check the parameter/environment source)"
66 );
67 }
68
69 result = format!("{}{}{}", &result[..start], value, &result[end + 1..]);
70 search_from = start + value.len();
71 }
72 Ok(result)
73}
74
75pub fn resolve_env_vars(input: &str) -> crate::error::Result<String> {
77 resolve_vars(input, None)
78}
79
80pub fn find_unused_params(
86 haystack: &str,
87 params: Option<&std::collections::HashMap<String, String>>,
88) -> Vec<String> {
89 let Some(p) = params else {
90 return Vec::new();
91 };
92 let mut unused: Vec<String> = p
93 .keys()
94 .filter(|k| !haystack.contains(&format!("${{{k}}}")))
95 .cloned()
96 .collect();
97 unused.sort();
98 unused
99}
100
101pub fn warn_unused_params(
112 haystack: &str,
113 params: Option<&std::collections::HashMap<String, String>>,
114) {
115 for key in find_unused_params(haystack, params) {
116 log::warn!(
117 "--param '{}' was not referenced by any `${{{}}}` placeholder in the config — \
118 check the parameter name (case-sensitive) or remove the unused --param",
119 key,
120 key
121 );
122 }
123}
124
125pub fn parse_file_size(s: &str) -> crate::error::Result<u64> {
131 let s = s.trim().to_uppercase();
132 let (num, multiplier) = if let Some(n) = s.strip_suffix("GB") {
133 (n.trim(), 1024u64 * 1024 * 1024)
134 } else if let Some(n) = s.strip_suffix("MB") {
135 (n.trim(), 1024u64 * 1024)
136 } else if let Some(n) = s.strip_suffix("KB") {
137 (n.trim(), 1024u64)
138 } else if let Some(n) = s.strip_suffix('B') {
139 (n.trim(), 1u64)
140 } else {
141 (s.as_str(), 1u64)
142 };
143 let value: f64 = num.parse().map_err(|_| {
144 anyhow::anyhow!(
145 "invalid file size: '{}' — expected a number with an optional unit \
146 B/KB/MB/GB (e.g. '512MB', '1.5GB', or a bare byte count like '1048576'); \
147 a fractional value is allowed and units are binary (KB = 1024 bytes)",
148 s
149 )
150 })?;
151 Ok((value * multiplier as f64) as u64)
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use std::collections::HashMap;
158
159 #[test]
162 fn no_placeholders_returned_verbatim() {
163 assert_eq!(resolve_vars("SELECT 1", None).unwrap(), "SELECT 1");
164 }
165
166 #[test]
167 fn empty_string_returned_verbatim() {
168 assert_eq!(resolve_vars("", None).unwrap(), "");
169 }
170
171 #[test]
174 fn param_substitutes_placeholder() {
175 let mut p = HashMap::new();
176 p.insert("TABLE".into(), "orders".into());
177 let result = resolve_vars("SELECT * FROM ${TABLE}", Some(&p)).unwrap();
178 assert_eq!(result, "SELECT * FROM orders");
179 }
180
181 #[test]
182 fn param_takes_precedence_over_env() {
183 unsafe { std::env::set_var("RIVET_TEST_OVERRIDE_VAR", "from_env") };
185 let mut p = HashMap::new();
186 p.insert("RIVET_TEST_OVERRIDE_VAR".into(), "from_param".into());
187 let result = resolve_vars("${RIVET_TEST_OVERRIDE_VAR}", Some(&p)).unwrap();
188 unsafe { std::env::remove_var("RIVET_TEST_OVERRIDE_VAR") };
189 assert_eq!(result, "from_param");
190 }
191
192 #[test]
193 fn multiple_placeholders_all_substituted() {
194 let mut p = HashMap::new();
195 p.insert("A".into(), "hello".into());
196 p.insert("B".into(), "world".into());
197 let result = resolve_vars("${A} ${B}", Some(&p)).unwrap();
198 assert_eq!(result, "hello world");
199 }
200
201 #[test]
213 fn sec_param_value_with_nul_rejected() {
214 let mut p = HashMap::new();
215 p.insert("x".into(), "1\0injected".into());
216 let err = resolve_vars("${x}", Some(&p)).expect_err("a NUL value must be rejected");
217 assert!(err.to_string().contains("x"), "must name the param: {err}");
219 assert!(
220 !err.to_string().contains("injected"),
221 "must not echo the value: {err}"
222 );
223 }
224
225 #[test]
230 fn sec_param_value_newline_passes_through_guard() {
231 let mut p = HashMap::new();
235 p.insert("frag".into(), "a\nb".into());
236 let result = resolve_vars("X=${frag}", Some(&p)).unwrap();
237 assert_eq!(result, "X=a\nb");
238 }
239
240 #[test]
241 fn sec_normal_param_value_substitutes_fine_guard() {
242 let mut p = HashMap::new();
243 p.insert("id_min".into(), "100".into());
244 let result = resolve_vars("WHERE id >= ${id_min}", Some(&p)).unwrap();
245 assert_eq!(result, "WHERE id >= 100");
246 }
247
248 #[test]
249 fn sec_normal_param_value_with_spaces_and_quotes_substitutes_fine_guard() {
250 let mut p = HashMap::new();
251 p.insert("filter".into(), "name = 'o''brien'".into());
252 let result = resolve_vars("WHERE ${filter}", Some(&p)).unwrap();
253 assert_eq!(result, "WHERE name = 'o''brien'");
254 }
255
256 #[test]
259 fn env_var_substituted_when_set() {
260 unsafe { std::env::set_var("RIVET_TEST_RESOLVE_VAR", "secret123") };
261 let result = resolve_vars("pass=${RIVET_TEST_RESOLVE_VAR}", None).unwrap();
262 unsafe { std::env::remove_var("RIVET_TEST_RESOLVE_VAR") };
263 assert_eq!(result, "pass=secret123");
264 }
265
266 #[test]
267 fn missing_env_var_returns_error() {
268 unsafe { std::env::remove_var("RIVET_DEFINITELY_NOT_SET_VAR_XYZ") };
269 let err = resolve_vars("${RIVET_DEFINITELY_NOT_SET_VAR_XYZ}", None).unwrap_err();
270 let msg = err.to_string();
271 assert!(
272 msg.contains("RIVET_DEFINITELY_NOT_SET_VAR_XYZ"),
273 "got: {msg}"
274 );
275 }
276
277 #[test]
280 fn empty_placeholder_expands_to_empty_string() {
281 let result = resolve_vars("pre${}post", None).unwrap();
282 assert_eq!(result, "prepost");
283 }
284
285 #[test]
288 fn unclosed_placeholder_left_as_is() {
289 let result = resolve_vars("${UNCLOSED", None).unwrap();
290 assert_eq!(result, "${UNCLOSED");
291 }
292
293 #[test]
304 fn find_unused_params_returns_empty_when_no_params() {
305 assert!(find_unused_params("SELECT 1", None).is_empty());
306 }
307
308 #[test]
309 fn find_unused_params_used_key_not_flagged() {
310 let mut p = HashMap::new();
311 p.insert("max_id".into(), "20".into());
312 let unused = find_unused_params("SELECT * FROM t WHERE id <= ${max_id}", Some(&p));
313 assert!(unused.is_empty(), "got: {unused:?}");
314 }
315
316 #[test]
317 fn find_unused_params_unused_key_flagged_once() {
318 let mut p = HashMap::new();
319 p.insert("typo_id".into(), "20".into());
320 let unused = find_unused_params("SELECT * FROM t WHERE id <= ${max_id}", Some(&p));
321 assert_eq!(unused, vec!["typo_id".to_string()]);
322 }
323
324 #[test]
325 fn find_unused_params_mixed_used_and_unused() {
326 let mut p = HashMap::new();
327 p.insert("col".into(), "id".into());
328 p.insert("typo".into(), "x".into());
329 let unused = find_unused_params("SELECT ${col} FROM t", Some(&p));
330 assert_eq!(unused, vec!["typo".to_string()]);
331 }
332
333 #[test]
334 fn find_unused_params_partial_match_does_not_count() {
335 let mut p = HashMap::new();
338 p.insert("max".into(), "20".into());
339 let unused = find_unused_params("SELECT * FROM t WHERE id <= ${max_id}", Some(&p));
340 assert_eq!(unused, vec!["max".to_string()]);
341 }
342
343 #[test]
346 fn resolve_env_vars_reads_env() {
347 unsafe { std::env::set_var("RIVET_TEST_ENV_WRAPPER", "wrapped") };
348 let result = resolve_env_vars("v=${RIVET_TEST_ENV_WRAPPER}").unwrap();
349 unsafe { std::env::remove_var("RIVET_TEST_ENV_WRAPPER") };
350 assert_eq!(result, "v=wrapped");
351 }
352
353 #[test]
356 fn parse_1gb() {
357 assert_eq!(parse_file_size("1GB").unwrap(), 1024 * 1024 * 1024);
358 }
359
360 #[test]
361 fn parse_512mb() {
362 assert_eq!(parse_file_size("512MB").unwrap(), 512 * 1024 * 1024);
363 }
364
365 #[test]
366 fn parse_100kb() {
367 assert_eq!(parse_file_size("100KB").unwrap(), 100 * 1024);
368 }
369
370 #[test]
371 fn parse_bytes_suffix() {
372 assert_eq!(parse_file_size("2048B").unwrap(), 2048);
373 }
374
375 #[test]
376 fn parse_no_suffix_treated_as_bytes() {
377 assert_eq!(parse_file_size("4096").unwrap(), 4096);
378 }
379
380 #[test]
381 fn parse_whitespace_trimmed() {
382 assert_eq!(parse_file_size(" 256MB ").unwrap(), 256 * 1024 * 1024);
383 }
384
385 #[test]
386 fn parse_lowercase_accepted() {
387 assert_eq!(parse_file_size("1gb").unwrap(), 1024 * 1024 * 1024);
388 }
389
390 #[test]
391 fn parse_invalid_returns_error() {
392 assert!(parse_file_size("notanumber").is_err());
393 }
394
395 #[test]
396 fn parse_invalid_error_names_accepted_units() {
397 let err = parse_file_size("banana").unwrap_err();
400 let msg = err.to_string();
401 assert!(msg.contains("B/KB/MB/GB"), "got: {msg}");
402 assert!(msg.contains("fractional"), "got: {msg}");
403 assert!(msg.contains("1024"), "got: {msg}");
404 }
405}