1use std::fmt::Write;
2use std::path::Path;
3
4use scythe_backend::manifest::{BackendManifest, load_manifest};
5use scythe_backend::naming::{
6 enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
7};
8
9use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
10use scythe_core::errors::{ErrorCode, ScytheError};
11use scythe_core::parser::QueryCommand;
12
13use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
14use crate::singularize;
15
16const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
18
19pub struct SqlxBackend {
21 manifest: BackendManifest,
22}
23
24impl SqlxBackend {
25 pub fn new(engine: &str) -> Result<Self, ScytheError> {
26 match engine {
29 "postgresql" | "postgres" | "pg" | "mysql" | "mariadb" | "sqlite" | "sqlite3" => {}
30 _ => {
31 return Err(ScytheError::new(
32 ErrorCode::InternalError,
33 format!("unsupported engine '{}' for rust-sqlx backend", engine),
34 ));
35 }
36 }
37 let manifest = load_sqlx_manifest()?;
38 Ok(Self { manifest })
39 }
40}
41
42fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
43 let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
44 if manifest_path.exists() {
45 load_manifest(manifest_path).map_err(|e| {
46 ScytheError::new(
47 ErrorCode::InternalError,
48 format!("failed to load manifest: {e}"),
49 )
50 })
51 } else {
52 toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
53 ScytheError::new(
54 ErrorCode::InternalError,
55 format!("failed to parse embedded manifest: {e}"),
56 )
57 })
58 }
59}
60
61impl CodegenBackend for SqlxBackend {
62 fn name(&self) -> &str {
63 "rust-sqlx"
64 }
65
66 fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
67 &self.manifest
68 }
69
70 fn supported_engines(&self) -> &[&str] {
71 &["postgresql", "mysql", "sqlite"]
72 }
73
74 fn file_header(&self) -> String {
75 "// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
76 .to_string()
77 }
78
79 fn generate_row_struct(
80 &self,
81 query_name: &str,
82 columns: &[ResolvedColumn],
83 ) -> Result<String, ScytheError> {
84 let struct_name = row_struct_name(query_name, &self.manifest.naming);
85 let mut out = String::new();
86
87 let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
88 let _ = writeln!(out, "pub struct {} {{", struct_name);
89
90 for col in columns {
91 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
92 }
93
94 let _ = write!(out, "}}");
95 Ok(out)
96 }
97
98 fn generate_model_struct(
99 &self,
100 table_name: &str,
101 columns: &[ResolvedColumn],
102 ) -> Result<String, ScytheError> {
103 let singular = singularize(table_name);
104 let struct_name = to_pascal_case(&singular).into_owned();
105 let mut out = String::new();
106
107 let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
108 let _ = writeln!(out, "pub struct {} {{", struct_name);
109
110 for col in columns {
111 let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
112 }
113
114 let _ = write!(out, "}}");
115 Ok(out)
116 }
117
118 fn generate_query_fn(
119 &self,
120 analyzed: &AnalyzedQuery,
121 struct_name: &str,
122 _columns: &[ResolvedColumn],
123 params: &[ResolvedParam],
124 ) -> Result<String, ScytheError> {
125 let func_name = fn_name(&analyzed.name, &self.manifest.naming);
126 let mut out = String::new();
127
128 if let Some(ref msg) = analyzed.deprecated {
130 let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
131 }
132
133 let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
135 for param in params {
136 param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
137 }
138
139 let return_type = match &analyzed.command {
141 QueryCommand::One => struct_name.to_string(),
142 QueryCommand::Many => format!("Vec<{}>", struct_name),
143 QueryCommand::Exec => "()".to_string(),
144 QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
145 QueryCommand::ExecRows => "u64".to_string(),
146 QueryCommand::Batch => format!("Vec<{}>", struct_name),
147 };
148
149 let _ = writeln!(
151 out,
152 "pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
153 func_name,
154 param_parts.join(", "),
155 return_type
156 );
157
158 let sql_raw = super::clean_sql(&analyzed.sql);
160 let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
161
162 let has_row_struct = matches!(
164 analyzed.command,
165 QueryCommand::One | QueryCommand::Many | QueryCommand::Batch
166 );
167
168 let bind_params: String = analyzed
170 .params
171 .iter()
172 .map(|p| {
173 let param_name = to_snake_case(&p.name);
174 if p.neutral_type.starts_with("enum::") {
175 let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
176 let rust_type = enum_type_name(enum_name, &self.manifest.naming);
177 format!(", {} as &{}", param_name, rust_type)
178 } else {
179 format!(", {}", param_name)
180 }
181 })
182 .collect();
183
184 let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
185
186 if is_exec_rows {
187 if has_row_struct && !analyzed.columns.is_empty() {
188 let _ = write!(
189 out,
190 " let result = sqlx::query_as!({}, \"{}\"{})",
191 struct_name, sql, bind_params
192 );
193 } else {
194 let _ = write!(
195 out,
196 " let result = sqlx::query!(\"{}\"{})",
197 sql, bind_params
198 );
199 }
200 } else if has_row_struct && !analyzed.columns.is_empty() {
201 let _ = write!(
202 out,
203 " sqlx::query_as!({}, \"{}\"{})",
204 struct_name, sql, bind_params
205 );
206 } else {
207 let _ = write!(out, " sqlx::query!(\"{}\"{})", sql, bind_params);
208 }
209
210 let _ = writeln!(out);
211
212 let fetch_method = match &analyzed.command {
214 QueryCommand::One => ".fetch_one(pool)",
215 QueryCommand::Many => ".fetch_all(pool)",
216 QueryCommand::Exec => ".execute(pool)",
217 QueryCommand::ExecResult => ".execute(pool)",
218 QueryCommand::ExecRows => ".execute(pool)",
219 QueryCommand::Batch => ".fetch_all(pool)",
220 };
221
222 let _ = write!(out, " {}", fetch_method);
223 let _ = writeln!(out);
224
225 match &analyzed.command {
227 QueryCommand::Exec => {
228 let _ = writeln!(out, " .await?;");
229 let _ = writeln!(out, " Ok(())");
230 }
231 QueryCommand::ExecRows => {
232 let _ = writeln!(out, " .await?;");
233 let _ = writeln!(out, " Ok(result.rows_affected())");
234 }
235 _ => {
236 let _ = writeln!(out, " .await");
237 }
238 }
239
240 let _ = write!(out, "}}");
241 Ok(out)
242 }
243
244 fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
245 let mut out = String::with_capacity(256);
246 let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
247
248 let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
249 let _ = writeln!(
250 out,
251 "#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
252 enum_info.sql_name
253 );
254 let _ = writeln!(out, "pub enum {type_name} {{");
255
256 for value in &enum_info.values {
257 let variant = enum_variant_name(value, &self.manifest.naming);
258 let _ = writeln!(out, " {variant},");
259 }
260
261 let _ = write!(out, "}}");
262 Ok(out)
263 }
264
265 fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
266 use scythe_backend::types::resolve_type;
267
268 let struct_name = to_pascal_case(&composite.sql_name).into_owned();
269 let mut out = String::new();
270
271 let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
272 let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
273 let _ = writeln!(out, "pub struct {} {{", struct_name);
274 for field in &composite.fields {
275 let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
276 .map(|t| t.into_owned())
277 .map_err(|e| {
278 ScytheError::new(
279 ErrorCode::InternalError,
280 format!("composite field type error: {}", e),
281 )
282 })?;
283 let _ = writeln!(
284 out,
285 " pub {}: {},",
286 to_snake_case(&field.name),
287 rust_type
288 );
289 }
290 let _ = write!(out, "}}");
291 Ok(out)
292 }
293}
294
295fn rewrite_sql_for_enums(
301 sql: &str,
302 columns: &[AnalyzedColumn],
303 manifest: &BackendManifest,
304) -> String {
305 let enum_cols: Vec<(&str, String)> = columns
306 .iter()
307 .filter_map(|col| {
308 if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
309 let rust_type = enum_type_name(enum_name, &manifest.naming);
310 let annotation = if col.nullable {
311 format!("Option<{}>", rust_type)
312 } else {
313 rust_type
314 };
315 Some((col.name.as_str(), annotation))
316 } else {
317 None
318 }
319 })
320 .collect();
321
322 if enum_cols.is_empty() {
323 return sql.to_string();
324 }
325
326 let mut result = sql.to_string();
327 for (col_name, annotation) in &enum_cols {
328 let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
329 if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
330 let select_part = &result[..from_pos];
331 let rest = &result[from_pos..];
332 let new_select = replace_column_in_select(select_part, col_name, &alias);
333 result = format!("{}{}", new_select, rest);
334 }
335 }
336 result
337}
338
339fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
340 let mut result = select.to_string();
341 let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
342 for pattern in &patterns {
343 if let Some(pos) = result.rfind(pattern.as_str()) {
344 let after = pos + pattern.len();
345 let next_char = result[after..].chars().next();
346 if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
347 let prefix = &result[..pos + pattern.len() - col_name.len()];
348 let suffix = &result[after..];
349 result = format!("{}{}{}", prefix, replacement, suffix);
350 break;
351 }
352 }
353 }
354 result
355}