Skip to main content

ferrule_core/
params.rs

1use ferrule_sql::render_value;
2use ferrule_sql::value::Value;
3use ferrule_sql::{Backend, SqlError};
4use indexmap::IndexMap;
5use std::str::FromStr;
6
7/// A named set of runtime parameters.
8#[derive(Debug, Clone, Default, PartialEq)]
9pub struct ParameterSet {
10    pub map: IndexMap<String, Value>,
11}
12
13impl ParameterSet {
14    /// Insert or overwrite a parameter.
15    pub fn set(&mut self, name: String, value: Value) {
16        self.map.insert(name, value);
17    }
18
19    /// Remove a parameter.
20    pub fn clear(&mut self) {
21        self.map.clear();
22    }
23}
24
25/// Substitute `${name}` placeholders in SQL with values from `params`.
26///
27/// One pass only — no recursive substitution.
28/// Missing parameters return `SqlError::QueryFailed`.
29pub fn substitute(sql: &str, params: &ParameterSet, backend: Backend) -> Result<String, SqlError> {
30    let mut result = String::with_capacity(sql.len());
31    let mut chars = sql.chars().peekable();
32
33    while let Some(ch) = chars.next() {
34        if ch == '$' && chars.next_if_eq(&'{').is_some() {
35            let name: String = chars.by_ref().take_while(|c| *c != '}').collect();
36            if name.is_empty() {
37                result.push_str("${}");
38                continue;
39            }
40            match params.map.get(&name) {
41                Some(value) => result.push_str(&render_value(value, backend)),
42                None => {
43                    return Err(SqlError::QueryFailed(format!(
44                        "Missing parameter: {}",
45                        name
46                    )));
47                }
48            }
49        } else {
50            result.push(ch);
51        }
52    }
53
54    Ok(result)
55}
56
57/// Parse a `NAME=VALUE` string, splitting at the first `=`.
58pub fn parse_param(s: &str) -> Result<(String, String), SqlError> {
59    let pos = s.find('=').ok_or_else(|| {
60        SqlError::QueryFailed(format!(
61            "Invalid parameter format '{}', expected NAME=VALUE",
62            s
63        ))
64    })?;
65    let (name, value) = s.split_at(pos);
66    Ok((name.to_string(), value[1..].to_string()))
67}
68
69/// Infer a `Value` type from a raw string.
70///
71/// * `"true"` / `"false"` (case‑insensitive) → `Value::Bool`
72/// * `"-42"` → `Value::Int64`
73/// * `"3.14"` → `Value::Float64`
74/// * anything else → `Value::String`
75pub fn infer_type(v: &str) -> Value {
76    let trimmed = v.trim();
77    if trimmed.eq_ignore_ascii_case("true") {
78        return Value::Bool(true);
79    }
80    if trimmed.eq_ignore_ascii_case("false") {
81        return Value::Bool(false);
82    }
83    if trimmed.bytes().all(|b| b.is_ascii_digit() || b == b'-') && trimmed.len() > 1
84        || trimmed.bytes().all(|b| b.is_ascii_digit()) && !trimmed.is_empty()
85    {
86        if let Ok(i) = i64::from_str(trimmed) {
87            return Value::Int64(i);
88        }
89    }
90    if trimmed.bytes().filter(|b| *b == b'.').count() == 1 {
91        if let Ok(f) = f64::from_str(trimmed) {
92            return Value::Float64(f);
93        }
94    }
95    Value::String(trimmed.to_string())
96}
97
98/// Load parameters from a JSON file (object mapping name → raw string value).
99pub fn load_from_json(path: &std::path::Path) -> Result<ParameterSet, SqlError> {
100    let content = std::fs::read_to_string(path).map_err(|e| {
101        SqlError::QueryFailed(format!(
102            "Cannot read parameter file '{}': {}",
103            path.display(),
104            e
105        ))
106    })?;
107    let obj: serde_json::Map<String, serde_json::Value> =
108        serde_json::from_str(&content).map_err(|e| {
109            SqlError::QueryFailed(format!(
110                "Invalid JSON in parameter file '{}': {}",
111                path.display(),
112                e
113            ))
114        })?;
115
116    let mut set = ParameterSet::default();
117    for (key, val) in obj {
118        let value = match val {
119            serde_json::Value::Bool(b) => Value::Bool(b),
120            serde_json::Value::Number(n) => {
121                if let Some(i) = n.as_i64() {
122                    Value::Int64(i)
123                } else if let Some(f) = n.as_f64() {
124                    if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
125                        Value::Int64(f as i64)
126                    } else {
127                        Value::Float64(f)
128                    }
129                } else {
130                    Value::String(n.to_string())
131                }
132            }
133            serde_json::Value::String(s) => infer_type(&s),
134            other => Value::String(other.to_string()),
135        };
136        set.set(key, value);
137    }
138    Ok(set)
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_infer_string() {
147        assert_eq!(infer_type("Alice"), Value::String("Alice".into()));
148    }
149
150    #[test]
151    fn test_infer_int() {
152        assert_eq!(infer_type("-42"), Value::Int64(-42));
153        assert_eq!(infer_type("0"), Value::Int64(0));
154    }
155
156    #[test]
157    fn test_infer_float() {
158        assert_eq!(infer_type("2.5"), Value::Float64(2.5));
159    }
160
161    #[test]
162    fn test_infer_bool() {
163        assert_eq!(infer_type("false"), Value::Bool(false));
164        assert_eq!(infer_type("true"), Value::Bool(true));
165        assert_eq!(infer_type("TRUE"), Value::Bool(true));
166        assert_eq!(infer_type("FALSE"), Value::Bool(false));
167    }
168
169    #[test]
170    fn test_parse_param() {
171        assert_eq!(
172            parse_param("name=Alice").unwrap(),
173            ("name".into(), "Alice".into())
174        );
175        assert_eq!(
176            parse_param("host=localhost:5432").unwrap(),
177            ("host".into(), "localhost:5432".into())
178        );
179        assert!(parse_param("no_equals").is_err());
180    }
181
182    #[cfg(feature = "postgres")]
183    #[test]
184    fn test_substitute_completes() {
185        let mut params = ParameterSet::default();
186        params.set("name".into(), Value::String("Alice".into()));
187        params.set("age".into(), Value::Int64(30));
188        let sql = substitute(
189            "SELECT * FROM t WHERE n = ${name} AND a = ${age}",
190            &params,
191            Backend::Postgres,
192        )
193        .unwrap();
194        assert_eq!(sql, "SELECT * FROM t WHERE n = 'Alice' AND a = 30");
195    }
196
197    #[cfg(feature = "postgres")]
198    #[test]
199    fn test_substitute_missing_errors() {
200        let params = ParameterSet::default();
201        let result = substitute("SELECT ${x}", &params, Backend::Postgres);
202        assert!(result.is_err());
203        let err = result.unwrap_err().to_string();
204        assert!(err.contains("Missing parameter: x"));
205    }
206
207    #[cfg(feature = "oracle")]
208    #[test]
209    fn test_substitute_oracle_bool() {
210        let mut params = ParameterSet::default();
211        params.set("active".into(), Value::Bool(true));
212        let sql = substitute("SELECT ${active}", &params, Backend::Oracle).unwrap();
213        assert_eq!(sql, "SELECT 1");
214    }
215
216    #[cfg(feature = "postgres")]
217    #[test]
218    fn test_substitute_postgres_bool() {
219        let mut params = ParameterSet::default();
220        params.set("active".into(), Value::Bool(false));
221        let sql = substitute("SELECT ${active}", &params, Backend::Postgres).unwrap();
222        assert_eq!(sql, "SELECT FALSE");
223    }
224
225    #[cfg(feature = "postgres")]
226    #[test]
227    fn test_substitute_no_recursive() {
228        let mut params = ParameterSet::default();
229        params.set("foo".into(), Value::String("${bar}".into()));
230        let sql = substitute("SELECT ${foo}", &params, Backend::Postgres).unwrap();
231        assert_eq!(sql, "SELECT '${bar}'");
232    }
233
234    #[test]
235    fn test_load_from_json() {
236        let path = std::env::temp_dir().join("ferrule_test_params.json");
237        std::fs::write(
238            &path,
239            r#"{"name":"Alice","age":30,"active":true,"score":99.5}"#,
240        )
241        .unwrap();
242        let set = load_from_json(&path).unwrap();
243        assert_eq!(set.map.get("name"), Some(&Value::String("Alice".into())));
244        assert_eq!(set.map.get("age"), Some(&Value::Int64(30)));
245        assert_eq!(set.map.get("active"), Some(&Value::Bool(true)));
246        assert_eq!(set.map.get("score"), Some(&Value::Float64(99.5)));
247        std::fs::remove_file(&path).ok();
248    }
249}