1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/// Source of SQL content
#[derive(Clone, Debug)]
pub enum SqlSource {
Inline(String), // SQL string directly in the attribute
File(String), // File path to SQL file
}
impl SqlSource {
/// Get the SQL content as string (reads file if needed)
/// This always returns the actual SQL text, never file paths
pub fn resolve_content(self) -> syn::Result<String> {
match self {
SqlSource::Inline(sql) => Ok(sql),
SqlSource::File(file_path) => Self::read_file_content(&file_path),
}
}
/// Centralized file reading logic with security validation
fn read_file_content(file_path: &str) -> syn::Result<String> {
// Security validation: prevent path traversal attacks
Self::validate_file_path(file_path)?;
let content = if let Ok(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR") {
let full_path = std::path::Path::new(&manifest_dir).join(file_path);
std::fs::read_to_string(full_path).or_else(|_| std::fs::read_to_string(file_path)) // Fallback to relative path
} else {
std::fs::read_to_string(file_path) // No manifest dir available
}
.map_err(|e| {
syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to read SQL file '{}': {}", file_path, e),
)
})?;
Ok(content)
}
/// Validate file path for security - prevent path traversal attacks
fn validate_file_path(file_path: &str) -> syn::Result<()> {
// Check for path traversal attempts
if file_path.contains("..") {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"File path cannot contain '..' (path traversal prevented for security)",
));
}
// Check for absolute paths
if file_path.starts_with('/') || file_path.starts_with('\\') {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"File path must be relative (absolute paths not allowed for security)",
));
}
// Windows drive letters
if file_path.len() >= 2 && file_path.chars().nth(1) == Some(':') {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"File path cannot contain drive letters (absolute paths not allowed for security)",
));
}
// Restrict to .sql files only
if !file_path.ends_with(".sql") {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"Only .sql files are allowed for security reasons",
));
}
// Check for null bytes (security measure)
if file_path.contains('\0') {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"File path cannot contain null bytes",
));
}
Ok(())
}
}
/// Represents the different kinds of Rust types we can work with
#[derive(Debug, Clone, PartialEq)]
pub enum ReturnType {
/// Scalar types like i32, String, bool
Scalar { name: syn::Ident },
/// Struct types that can be mapped from database rows
Struct { name: syn::Ident },
/// Tuple types like (i32, String, bool)
Tuple { elements: Vec<ReturnType> },
/// Vec<T> types for multiple results
Vec { element_type: Box<ReturnType> },
/// Option<T> types for nullable results
Option { inner_type: Box<ReturnType> },
/// Result<T, E> types (standard return type)
Result {
ok_type: Box<ReturnType>,
err_type: Box<ReturnType>,
},
/// Stream<Item = T> types for streaming results
Stream { item_type: Box<ReturnType> },
/// Unit type ()
Unit,
/// Unknown or unsupported types
Unknown { name: String },
}
/// Represents the SQLx query strategy to use
#[derive(Debug, Clone, PartialEq)]
pub enum QueryType {
/// Use query_as! for struct mapping
QueryAs,
/// Use query_scalar! for scalar values
QueryScalar,
/// Use raw query! for complex cases
Query,
}
/// How to fetch results from the database
#[derive(Debug, Clone, PartialEq)]
pub enum FetchMethod {
Execute, // execute() - no results expected
FetchOne, // fetch_one() - expects exactly one result
FetchAll, // fetch_all() - returns Vec of results
FetchOptional, // fetch_optional() - returns Option, no error if empty
Fetch, // fetch() - returns Stream for streaming results
}
impl ReturnType {}