1use std::collections::{BTreeSet, HashSet};
2use std::env;
3use std::ffi::{OsStr, OsString};
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::process::Command;
7
8use crate::metadata::{manifest_dir, Metadata};
9use crate::opt::ConnectOpts;
10use crate::Config;
11use anyhow::{bail, Context};
12use console::style;
13use sqlx::Connection;
14
15pub struct PrepareCtx<'a> {
16 pub config: &'a Config,
17 pub workspace: bool,
18 pub all: bool,
19 pub cargo: OsString,
20 pub cargo_args: Vec<String>,
21 pub metadata: Metadata,
22 pub connect_opts: ConnectOpts,
23}
24
25impl PrepareCtx<'_> {
26 fn prepare_dir(&self) -> anyhow::Result<PathBuf> {
28 if self.workspace {
29 Ok(self.metadata.workspace_root().join(".sqlx"))
30 } else {
31 Ok(manifest_dir(&self.cargo)?.join(".sqlx"))
32 }
33 }
34}
35
36pub async fn run(
37 config: &Config,
38 check: bool,
39 all: bool,
40 workspace: bool,
41 connect_opts: ConnectOpts,
42 cargo_args: Vec<String>,
43) -> anyhow::Result<()> {
44 let cargo = env::var_os("CARGO")
45 .context("failed to get value of `CARGO`; `prepare` subcommand may only be invoked as `cargo sqlx prepare`")?;
46
47 anyhow::ensure!(
48 Path::new("Cargo.toml").exists(),
49 r#"Failed to read `Cargo.toml`.
50hint: This command only works in the manifest directory of a Cargo package or workspace."#
51 );
52
53 let metadata: Metadata = Metadata::from_current_directory(&cargo)?;
54 let ctx = PrepareCtx {
55 config,
56 workspace,
57 all,
58 cargo,
59 cargo_args,
60 metadata,
61 connect_opts,
62 };
63
64 if check {
65 prepare_check(&ctx).await
66 } else {
67 prepare(&ctx).await
68 }
69}
70
71async fn prepare(ctx: &PrepareCtx<'_>) -> anyhow::Result<()> {
72 if ctx.connect_opts.database_url.is_some() {
73 check_backend(ctx.config, &ctx.connect_opts).await?;
74 }
75
76 let prepare_dir = ctx.prepare_dir()?;
77 run_prepare_step(ctx, &prepare_dir)?;
78
79 if glob_query_files(prepare_dir)?.is_empty() {
81 println!("{} no queries found", style("warning:").yellow());
82 return Ok(());
83 }
84
85 if ctx.workspace {
86 println!(
87 "query data written to .sqlx in the workspace root; \
88 please check this into version control"
89 );
90 } else {
91 println!(
92 "query data written to .sqlx in the current directory; \
93 please check this into version control"
94 );
95 }
96 Ok(())
97}
98
99async fn prepare_check(ctx: &PrepareCtx<'_>) -> anyhow::Result<()> {
100 if ctx.connect_opts.database_url.is_some() {
101 check_backend(ctx.config, &ctx.connect_opts).await?;
102 }
103
104 let prepare_dir = ctx.prepare_dir()?;
107 let cache_dir = ctx.metadata.target_directory().join("sqlx-prepare-check");
108 run_prepare_step(ctx, &cache_dir)?;
109
110 let prepare_filenames: HashSet<String> = glob_query_files(&prepare_dir)?
112 .into_iter()
113 .filter_map(|path| path.file_name().map(|f| f.to_string_lossy().into_owned()))
114 .collect();
115 let cache_filenames: HashSet<String> = glob_query_files(&cache_dir)?
116 .into_iter()
117 .filter_map(|path| path.file_name().map(|f| f.to_string_lossy().into_owned()))
118 .collect();
119
120 if cache_filenames
122 .difference(&prepare_filenames)
123 .next()
124 .is_some()
125 {
126 bail!("prepare check failed: .sqlx is missing one or more queries; you should re-run sqlx prepare");
127 }
128 if prepare_filenames
130 .difference(&cache_filenames)
131 .next()
132 .is_some()
133 {
134 println!(
135 "{} potentially unused queries found in .sqlx; you may want to re-run sqlx prepare",
136 style("warning:").yellow()
137 );
138 }
139
140 for filename in cache_filenames {
143 let prepare_json = load_json_file(prepare_dir.join(&filename))?;
144 let cache_json = load_json_file(cache_dir.join(&filename))?;
145 if prepare_json != cache_json {
146 bail!("prepare check failed: one or more query files differ ({}); you should re-run sqlx prepare", filename);
147 }
148 }
149
150 Ok(())
151}
152
153fn run_prepare_step(ctx: &PrepareCtx, cache_dir: &Path) -> anyhow::Result<()> {
154 fs::create_dir_all(cache_dir).context(format!(
156 "Failed to create query cache directory: {:?}",
157 cache_dir
158 ))?;
159
160 let tmp_dir = ctx.metadata.target_directory().join("sqlx-tmp");
162 fs::create_dir_all(&tmp_dir).context(format!(
163 "Failed to create temporary query cache directory: {:?}",
164 cache_dir
165 ))?;
166
167 for query_file in glob_query_files(cache_dir).context("Failed to read query cache files")? {
169 fs::remove_file(&query_file)
170 .with_context(|| format!("Failed to delete query file: {}", query_file.display()))?;
171 }
172
173 setup_minimal_project_recompile(&ctx.cargo, &ctx.metadata, ctx.all, ctx.workspace)?;
176
177 let check_status = {
179 let mut check_command = Command::new(&ctx.cargo);
180 check_command
181 .arg("check")
182 .args(&ctx.cargo_args)
183 .env("SQLX_TMP", tmp_dir)
184 .env("SQLX_OFFLINE", "false")
185 .env("SQLX_OFFLINE_DIR", cache_dir);
186
187 if let Some(database_url) = &ctx.connect_opts.database_url {
188 check_command.env("DATABASE_URL", database_url);
189 }
190
191 if let Ok(rustflags) = env::var("RUSTFLAGS") {
195 check_command.env("RUSTFLAGS", rustflags);
196 }
197
198 check_command.status()?
199 };
200 if !check_status.success() {
201 bail!("`cargo check` failed with status: {}", check_status);
202 }
203
204 Ok(())
205}
206
207#[derive(Debug, PartialEq)]
208struct ProjectRecompileAction {
209 clean_packages: Vec<String>,
211 touch_paths: Vec<PathBuf>,
212}
213
214fn setup_minimal_project_recompile(
223 cargo: impl AsRef<OsStr>,
224 metadata: &Metadata,
225 all: bool,
226 workspace: bool,
227) -> anyhow::Result<()> {
228 let recompile_action: ProjectRecompileAction = if workspace {
229 minimal_project_recompile_action(metadata, all)
230 } else {
231 ProjectRecompileAction {
233 clean_packages: Vec::new(),
234 touch_paths: metadata.current_package()
235 .context("failed to get package in current working directory, pass `--workspace` if running from a workspace root")?
236 .src_paths()
237 .to_vec(),
238 }
239 };
240
241 if let Err(err) = minimal_project_clean(&cargo, recompile_action) {
242 println!(
243 "Failed minimal recompile setup. Cleaning entire project. Err: {}",
244 err
245 );
246 let clean_status = Command::new(&cargo).arg("clean").status()?;
247 if !clean_status.success() {
248 bail!("`cargo clean` failed with status: {}", clean_status);
249 }
250 }
251
252 Ok(())
253}
254
255fn minimal_project_clean(
256 cargo: impl AsRef<OsStr>,
257 action: ProjectRecompileAction,
258) -> anyhow::Result<()> {
259 let ProjectRecompileAction {
260 clean_packages,
261 touch_paths,
262 } = action;
263
264 for file in touch_paths {
266 let now = filetime::FileTime::now();
267 filetime::set_file_times(&file, now, now)
268 .with_context(|| format!("Failed to update mtime for {file:?}"))?;
269 }
270
271 for pkg_id in &clean_packages {
273 let clean_status = Command::new(&cargo)
274 .args(["clean", "-p", pkg_id])
275 .status()?;
276
277 if !clean_status.success() {
278 bail!("`cargo clean -p {}` failed", pkg_id);
279 }
280 }
281
282 Ok(())
283}
284
285fn minimal_project_recompile_action(metadata: &Metadata, all: bool) -> ProjectRecompileAction {
286 let mut sqlx_macros_dependents = BTreeSet::new();
288 let sqlx_macros_ids: BTreeSet<_> = metadata
289 .entries()
290 .filter(|(_, package)| package.name() == "sqlx-macros")
293 .map(|(id, _)| id)
294 .collect();
295 for sqlx_macros_id in sqlx_macros_ids {
296 sqlx_macros_dependents.extend(metadata.all_dependents_of(sqlx_macros_id));
297 }
298
299 let mut in_workspace_dependents = Vec::new();
301 let mut out_of_workspace_dependents = Vec::new();
302 for dependent in sqlx_macros_dependents {
303 if metadata.workspace_members().contains(dependent) {
304 in_workspace_dependents.push(dependent);
305 } else {
306 out_of_workspace_dependents.push(dependent);
307 }
308 }
309
310 let files_to_touch: Vec<_> = in_workspace_dependents
312 .iter()
313 .filter_map(|id| {
314 metadata
315 .package(id)
316 .map(|package| package.src_paths().to_owned())
317 })
318 .flatten()
319 .collect();
320
321 let packages_to_clean: Vec<_> = if all {
323 out_of_workspace_dependents
324 .iter()
325 .filter_map(|id| {
326 metadata
327 .package(id)
328 .map(|package| package.name().to_owned())
329 })
330 .filter(|name| name != "sqlx")
332 .collect()
333 } else {
334 Vec::new()
335 };
336
337 ProjectRecompileAction {
338 clean_packages: packages_to_clean,
339 touch_paths: files_to_touch,
340 }
341}
342
343fn glob_query_files(path: impl AsRef<Path>) -> anyhow::Result<Vec<PathBuf>> {
345 let path = path.as_ref();
346 let pattern = path.join("query-*.json");
347 glob::glob(
348 pattern
349 .to_str()
350 .context("query cache path is invalid UTF-8")?,
351 )
352 .with_context(|| format!("failed to read query cache path: {}", path.display()))?
353 .collect::<Result<Vec<_>, _>>()
354 .context("glob failed")
355}
356
357fn load_json_file(path: impl AsRef<Path>) -> anyhow::Result<serde_json::Value> {
359 let path = path.as_ref();
360 let file_bytes =
361 fs::read(path).with_context(|| format!("failed to load file: {}", path.display()))?;
362 Ok(serde_json::from_slice(&file_bytes)?)
363}
364
365async fn check_backend(config: &Config, opts: &ConnectOpts) -> anyhow::Result<()> {
366 crate::connect(config, opts).await?.close().await?;
367 Ok(())
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use std::assert_eq;
374
375 #[test]
376 fn minimal_project_recompile_action_works() -> anyhow::Result<()> {
377 let sample_metadata_path = Path::new("tests")
378 .join("assets")
379 .join("sample_metadata.json");
380 let sample_metadata = std::fs::read_to_string(sample_metadata_path)?;
381 let metadata: Metadata = sample_metadata.parse()?;
382
383 let action = minimal_project_recompile_action(&metadata, false);
384 assert_eq!(
385 action,
386 ProjectRecompileAction {
387 clean_packages: vec![],
388 touch_paths: vec![
389 "/home/user/problematic/workspace/b_in_workspace_lib/src/lib.rs".into(),
390 "/home/user/problematic/workspace/c_in_workspace_bin/src/main.rs".into(),
391 ],
392 }
393 );
394
395 Ok(())
396 }
397}