Skip to main content

fraiseql_wire/operators/
field.rs

1//! Field and value type definitions for operators
2//!
3//! Provides type-safe representations of database fields and values
4//! to prevent SQL injection and improve API ergonomics.
5
6use std::fmt;
7
8/// Represents a field reference in a WHERE clause or ORDER BY
9///
10/// Supports both JSONB payload fields and direct database columns,
11/// with automatic type casting and proper SQL generation.
12///
13/// # Examples
14///
15/// ```ignore
16/// // JSONB field: (data->>'name')
17/// Field::JsonbField("name".to_string())
18///
19/// // Direct column: created_at
20/// Field::DirectColumn("created_at".to_string())
21///
22/// // Nested JSONB: (data->'user'->>'name')
23/// Field::JsonbPath(vec!["user".to_string(), "name".to_string()])
24/// ```
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub enum Field {
27    /// A field extracted from the JSONB `data` column with text extraction (->>)
28    ///
29    /// The value is extracted as text and wrapped in parentheses.
30    ///
31    /// Generated SQL: `(data->>'field_name')`
32    JsonbField(String),
33
34    /// A direct database column (not from JSONB)
35    ///
36    /// Uses the native type stored in the database.
37    ///
38    /// Generated SQL: `column_name`
39    DirectColumn(String),
40
41    /// A nested path within the JSONB `data` column
42    ///
43    /// The path is traversed left-to-right, with intermediate steps using `->` (JSON navigation)
44    /// and the final step using `->>` (text extraction).
45    ///
46    /// All extracted values are text and wrapped in parentheses.
47    ///
48    /// Generated SQL: `(data->'path[0]'->...->>'path[n]')`
49    JsonbPath(Vec<String>),
50}
51
52impl Field {
53    /// Validate field name to prevent SQL injection
54    ///
55    /// Allows: alphanumeric, underscore
56    /// Disallows: quotes, brackets, dashes, special characters
57    pub fn validate(&self) -> Result<(), String> {
58        let name = match self {
59            Field::JsonbField(n) => n,
60            Field::DirectColumn(n) => n,
61            Field::JsonbPath(path) => {
62                for segment in path {
63                    if !is_valid_field_name(segment) {
64                        return Err(format!("Invalid field name in path: {}", segment));
65                    }
66                }
67                return Ok(());
68            }
69        };
70
71        if !is_valid_field_name(name) {
72            return Err(format!("Invalid field name: {}", name));
73        }
74
75        Ok(())
76    }
77
78    /// Generate SQL for this field
79    pub fn to_sql(&self) -> String {
80        match self {
81            Field::JsonbField(name) => format!("(data->'{}')", name),
82            Field::DirectColumn(name) => name.clone(),
83            Field::JsonbPath(path) => {
84                if path.is_empty() {
85                    return "data".to_string();
86                }
87
88                let mut sql = String::from("(data");
89                for (i, segment) in path.iter().enumerate() {
90                    if i == path.len() - 1 {
91                        // Last segment: use ->> for text extraction
92                        sql.push_str(&format!("->>'{}\'", segment));
93                    } else {
94                        // Intermediate segments: use -> for JSON objects
95                        sql.push_str(&format!("->'{}\'", segment));
96                    }
97                }
98                sql.push(')');
99                sql
100            }
101        }
102    }
103}
104
105impl fmt::Display for Field {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            Field::JsonbField(name) => write!(f, "data->'{}'", name),
109            Field::DirectColumn(name) => write!(f, "{}", name),
110            Field::JsonbPath(path) => {
111                write!(f, "data")?;
112                for (i, segment) in path.iter().enumerate() {
113                    if i == path.len() - 1 {
114                        write!(f, "->>{}", segment)?;
115                    } else {
116                        write!(f, "->{}", segment)?;
117                    }
118                }
119                Ok(())
120            }
121        }
122    }
123}
124
125/// Represents a value to bind in a WHERE clause
126///
127/// # Examples
128///
129/// ```ignore
130/// Value::String("John".to_string())
131/// Value::Number(42.0)
132/// Value::Bool(true)
133/// Value::Null
134/// Value::Array(vec![Value::String("a".to_string()), Value::String("b".to_string())])
135/// ```
136#[derive(Debug, Clone)]
137pub enum Value {
138    /// String value
139    String(String),
140
141    /// Numeric value (f64 can represent i64, u64, f32 with precision)
142    Number(f64),
143
144    /// Boolean value
145    Bool(bool),
146
147    /// NULL
148    Null,
149
150    /// Array of values (for IN operators)
151    Array(Vec<Value>),
152
153    /// Vector of floats (for pgvector distance operators)
154    FloatArray(Vec<f32>),
155
156    /// Raw SQL expression (use with caution!)
157    ///
158    /// This should only be used for trusted SQL fragments,
159    /// never for user input.
160    RawSql(String),
161}
162
163impl Value {
164    /// Check if value is NULL
165    pub fn is_null(&self) -> bool {
166        matches!(self, Value::Null)
167    }
168
169    /// Convert value to SQL literal
170    ///
171    /// For parameterized queries, prefer using parameter placeholders ($1, $2, etc.)
172    /// This is primarily for documentation and debugging.
173    pub fn to_sql_literal(&self) -> String {
174        match self {
175            Value::String(s) => format!("'{}'", s.replace('\'', "''")),
176            Value::Number(n) => n.to_string(),
177            Value::Bool(b) => b.to_string(),
178            Value::Null => "NULL".to_string(),
179            Value::Array(arr) => {
180                let items: Vec<String> = arr.iter().map(|v| v.to_sql_literal()).collect();
181                format!("ARRAY[{}]", items.join(", "))
182            }
183            Value::FloatArray(arr) => {
184                let items: Vec<String> = arr.iter().map(|f| f.to_string()).collect();
185                format!("[{}]", items.join(", "))
186            }
187            Value::RawSql(sql) => sql.clone(),
188        }
189    }
190}
191
192impl fmt::Display for Value {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        write!(f, "{}", self.to_sql_literal())
195    }
196}
197
198/// Check if a field name is valid (alphanumeric + underscore)
199fn is_valid_field_name(name: &str) -> bool {
200    if name.is_empty() {
201        return false;
202    }
203
204    // First character must be alphabetic or underscore
205    let first = name.chars().next().unwrap();
206    if !first.is_alphabetic() && first != '_' {
207        return false;
208    }
209
210    // Remaining characters must be alphanumeric or underscore
211    name.chars().all(|c| c.is_alphanumeric() || c == '_')
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_valid_field_names() {
220        assert!(is_valid_field_name("name"));
221        assert!(is_valid_field_name("_private"));
222        assert!(is_valid_field_name("field_123"));
223        assert!(is_valid_field_name("a"));
224    }
225
226    #[test]
227    fn test_invalid_field_names() {
228        assert!(!is_valid_field_name(""));
229        assert!(!is_valid_field_name("123field")); // starts with digit
230        assert!(!is_valid_field_name("field-name")); // contains dash
231        assert!(!is_valid_field_name("field.name")); // contains dot
232        assert!(!is_valid_field_name("field'name")); // contains quote
233    }
234
235    #[test]
236    fn test_field_validation() {
237        assert!(Field::JsonbField("name".to_string()).validate().is_ok());
238        assert!(Field::JsonbField("name-invalid".to_string())
239            .validate()
240            .is_err());
241        assert!(
242            Field::JsonbPath(vec!["user".to_string(), "name".to_string()])
243                .validate()
244                .is_ok()
245        );
246    }
247
248    #[test]
249    fn test_field_to_sql_jsonb() {
250        let field = Field::JsonbField("name".to_string());
251        assert_eq!(field.to_sql(), "(data->'name')");
252    }
253
254    #[test]
255    fn test_field_to_sql_direct() {
256        let field = Field::DirectColumn("created_at".to_string());
257        assert_eq!(field.to_sql(), "created_at");
258    }
259
260    #[test]
261    fn test_field_to_sql_path() {
262        let field = Field::JsonbPath(vec!["user".to_string(), "name".to_string()]);
263        assert_eq!(field.to_sql(), "(data->'user'->>'name')");
264    }
265
266    #[test]
267    fn test_value_to_sql_literal() {
268        assert_eq!(Value::String("test".to_string()).to_sql_literal(), "'test'");
269        assert_eq!(Value::Number(42.0).to_sql_literal(), "42");
270        assert_eq!(Value::Bool(true).to_sql_literal(), "true");
271        assert_eq!(Value::Null.to_sql_literal(), "NULL");
272    }
273
274    #[test]
275    fn test_value_string_escaping() {
276        let val = Value::String("O'Brien".to_string());
277        assert_eq!(val.to_sql_literal(), "'O''Brien'");
278    }
279}