1use ferrule_sql::render_value;
2use ferrule_sql::value::Value;
3use ferrule_sql::{Backend, SqlError};
4use indexmap::IndexMap;
5use std::str::FromStr;
6
7#[derive(Debug, Clone, Default, PartialEq)]
9pub struct ParameterSet {
10 pub map: IndexMap<String, Value>,
11}
12
13impl ParameterSet {
14 pub fn set(&mut self, name: String, value: Value) {
16 self.map.insert(name, value);
17 }
18
19 pub fn clear(&mut self) {
21 self.map.clear();
22 }
23}
24
25pub fn substitute(sql: &str, params: &ParameterSet, backend: Backend) -> Result<String, SqlError> {
30 let mut result = String::with_capacity(sql.len());
31 let mut chars = sql.chars().peekable();
32
33 while let Some(ch) = chars.next() {
34 if ch == '$' && chars.next_if_eq(&'{').is_some() {
35 let name: String = chars.by_ref().take_while(|c| *c != '}').collect();
36 if name.is_empty() {
37 result.push_str("${}");
38 continue;
39 }
40 match params.map.get(&name) {
41 Some(value) => result.push_str(&render_value(value, backend)),
42 None => {
43 return Err(SqlError::QueryFailed(format!(
44 "Missing parameter: {}",
45 name
46 )));
47 }
48 }
49 } else {
50 result.push(ch);
51 }
52 }
53
54 Ok(result)
55}
56
57pub fn parse_param(s: &str) -> Result<(String, String), SqlError> {
59 let pos = s.find('=').ok_or_else(|| {
60 SqlError::QueryFailed(format!(
61 "Invalid parameter format '{}', expected NAME=VALUE",
62 s
63 ))
64 })?;
65 let (name, value) = s.split_at(pos);
66 Ok((name.to_string(), value[1..].to_string()))
67}
68
69pub fn infer_type(v: &str) -> Value {
76 let trimmed = v.trim();
77 if trimmed.eq_ignore_ascii_case("true") {
78 return Value::Bool(true);
79 }
80 if trimmed.eq_ignore_ascii_case("false") {
81 return Value::Bool(false);
82 }
83 if trimmed.bytes().all(|b| b.is_ascii_digit() || b == b'-') && trimmed.len() > 1
84 || trimmed.bytes().all(|b| b.is_ascii_digit()) && !trimmed.is_empty()
85 {
86 if let Ok(i) = i64::from_str(trimmed) {
87 return Value::Int64(i);
88 }
89 }
90 if trimmed.bytes().filter(|b| *b == b'.').count() == 1 {
91 if let Ok(f) = f64::from_str(trimmed) {
92 return Value::Float64(f);
93 }
94 }
95 Value::String(trimmed.to_string())
96}
97
98pub fn load_from_json(path: &std::path::Path) -> Result<ParameterSet, SqlError> {
100 let content = std::fs::read_to_string(path).map_err(|e| {
101 SqlError::QueryFailed(format!(
102 "Cannot read parameter file '{}': {}",
103 path.display(),
104 e
105 ))
106 })?;
107 let obj: serde_json::Map<String, serde_json::Value> =
108 serde_json::from_str(&content).map_err(|e| {
109 SqlError::QueryFailed(format!(
110 "Invalid JSON in parameter file '{}': {}",
111 path.display(),
112 e
113 ))
114 })?;
115
116 let mut set = ParameterSet::default();
117 for (key, val) in obj {
118 let value = match val {
119 serde_json::Value::Bool(b) => Value::Bool(b),
120 serde_json::Value::Number(n) => {
121 if let Some(i) = n.as_i64() {
122 Value::Int64(i)
123 } else if let Some(f) = n.as_f64() {
124 if f.fract() == 0.0 && f >= i64::MIN as f64 && f <= i64::MAX as f64 {
125 Value::Int64(f as i64)
126 } else {
127 Value::Float64(f)
128 }
129 } else {
130 Value::String(n.to_string())
131 }
132 }
133 serde_json::Value::String(s) => infer_type(&s),
134 other => Value::String(other.to_string()),
135 };
136 set.set(key, value);
137 }
138 Ok(set)
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_infer_string() {
147 assert_eq!(infer_type("Alice"), Value::String("Alice".into()));
148 }
149
150 #[test]
151 fn test_infer_int() {
152 assert_eq!(infer_type("-42"), Value::Int64(-42));
153 assert_eq!(infer_type("0"), Value::Int64(0));
154 }
155
156 #[test]
157 fn test_infer_float() {
158 assert_eq!(infer_type("2.5"), Value::Float64(2.5));
159 }
160
161 #[test]
162 fn test_infer_bool() {
163 assert_eq!(infer_type("false"), Value::Bool(false));
164 assert_eq!(infer_type("true"), Value::Bool(true));
165 assert_eq!(infer_type("TRUE"), Value::Bool(true));
166 assert_eq!(infer_type("FALSE"), Value::Bool(false));
167 }
168
169 #[test]
170 fn test_parse_param() {
171 assert_eq!(
172 parse_param("name=Alice").unwrap(),
173 ("name".into(), "Alice".into())
174 );
175 assert_eq!(
176 parse_param("host=localhost:5432").unwrap(),
177 ("host".into(), "localhost:5432".into())
178 );
179 assert!(parse_param("no_equals").is_err());
180 }
181
182 #[cfg(feature = "postgres")]
183 #[test]
184 fn test_substitute_completes() {
185 let mut params = ParameterSet::default();
186 params.set("name".into(), Value::String("Alice".into()));
187 params.set("age".into(), Value::Int64(30));
188 let sql = substitute(
189 "SELECT * FROM t WHERE n = ${name} AND a = ${age}",
190 ¶ms,
191 Backend::Postgres,
192 )
193 .unwrap();
194 assert_eq!(sql, "SELECT * FROM t WHERE n = 'Alice' AND a = 30");
195 }
196
197 #[cfg(feature = "postgres")]
198 #[test]
199 fn test_substitute_missing_errors() {
200 let params = ParameterSet::default();
201 let result = substitute("SELECT ${x}", ¶ms, Backend::Postgres);
202 assert!(result.is_err());
203 let err = result.unwrap_err().to_string();
204 assert!(err.contains("Missing parameter: x"));
205 }
206
207 #[cfg(feature = "oracle")]
208 #[test]
209 fn test_substitute_oracle_bool() {
210 let mut params = ParameterSet::default();
211 params.set("active".into(), Value::Bool(true));
212 let sql = substitute("SELECT ${active}", ¶ms, Backend::Oracle).unwrap();
213 assert_eq!(sql, "SELECT 1");
214 }
215
216 #[cfg(feature = "postgres")]
217 #[test]
218 fn test_substitute_postgres_bool() {
219 let mut params = ParameterSet::default();
220 params.set("active".into(), Value::Bool(false));
221 let sql = substitute("SELECT ${active}", ¶ms, Backend::Postgres).unwrap();
222 assert_eq!(sql, "SELECT FALSE");
223 }
224
225 #[cfg(feature = "postgres")]
226 #[test]
227 fn test_substitute_no_recursive() {
228 let mut params = ParameterSet::default();
229 params.set("foo".into(), Value::String("${bar}".into()));
230 let sql = substitute("SELECT ${foo}", ¶ms, Backend::Postgres).unwrap();
231 assert_eq!(sql, "SELECT '${bar}'");
232 }
233
234 #[test]
235 fn test_load_from_json() {
236 let path = std::env::temp_dir().join("ferrule_test_params.json");
237 std::fs::write(
238 &path,
239 r#"{"name":"Alice","age":30,"active":true,"score":99.5}"#,
240 )
241 .unwrap();
242 let set = load_from_json(&path).unwrap();
243 assert_eq!(set.map.get("name"), Some(&Value::String("Alice".into())));
244 assert_eq!(set.map.get("age"), Some(&Value::Int64(30)));
245 assert_eq!(set.map.get("active"), Some(&Value::Bool(true)));
246 assert_eq!(set.map.get("score"), Some(&Value::Float64(99.5)));
247 std::fs::remove_file(&path).ok();
248 }
249}