1use std::collections::{BTreeSet, HashSet};
2use std::env;
3use std::ffi::{OsStr, OsString};
4use std::fs;
5use std::path::{Path, PathBuf};
6use std::process::Command;
7
8use anyhow::{bail, Context};
9use console::style;
10use sqlx::Connection;
11
12use crate::metadata::{manifest_dir, Metadata};
13use crate::opt::ConnectOpts;
14
15pub struct PrepareCtx {
16 pub workspace: bool,
17 pub all: bool,
18 pub cargo: OsString,
19 pub cargo_args: Vec<String>,
20 pub metadata: Metadata,
21 pub connect_opts: ConnectOpts,
22}
23
24impl PrepareCtx {
25 fn prepare_dir(&self) -> anyhow::Result<PathBuf> {
27 if self.workspace {
28 Ok(self.metadata.workspace_root().join(".sqlx"))
29 } else {
30 Ok(manifest_dir(&self.cargo)?.join(".sqlx"))
31 }
32 }
33}
34
35pub async fn run(
36 check: bool,
37 all: bool,
38 workspace: bool,
39 connect_opts: ConnectOpts,
40 cargo_args: Vec<String>,
41) -> anyhow::Result<()> {
42 let cargo = env::var_os("CARGO")
43 .context("failed to get value of `CARGO`; `prepare` subcommand may only be invoked as `cargo sqlx prepare`")?;
44
45 anyhow::ensure!(
46 Path::new("Cargo.toml").exists(),
47 r#"Failed to read `Cargo.toml`.
48hint: This command only works in the manifest directory of a Cargo package or workspace."#
49 );
50
51 let metadata: Metadata = Metadata::from_current_directory(&cargo)?;
52 let ctx = PrepareCtx {
53 workspace,
54 all,
55 cargo,
56 cargo_args,
57 metadata,
58 connect_opts,
59 };
60
61 if check {
62 prepare_check(&ctx).await
63 } else {
64 prepare(&ctx).await
65 }
66}
67
68async fn prepare(ctx: &PrepareCtx) -> anyhow::Result<()> {
69 if ctx.connect_opts.database_url.is_some() {
70 check_backend(&ctx.connect_opts).await?;
71 }
72
73 let prepare_dir = ctx.prepare_dir()?;
74 run_prepare_step(ctx, &prepare_dir)?;
75
76 if glob_query_files(prepare_dir)?.is_empty() {
78 println!("{} no queries found", style("warning:").yellow());
79 return Ok(());
80 }
81
82 if ctx.workspace {
83 println!(
84 "query data written to .sqlx in the workspace root; \
85 please check this into version control"
86 );
87 } else {
88 println!(
89 "query data written to .sqlx in the current directory; \
90 please check this into version control"
91 );
92 }
93 Ok(())
94}
95
96async fn prepare_check(ctx: &PrepareCtx) -> anyhow::Result<()> {
97 if ctx.connect_opts.database_url.is_some() {
98 check_backend(&ctx.connect_opts).await?;
99 }
100
101 let prepare_dir = ctx.prepare_dir()?;
104 let cache_dir = ctx.metadata.target_directory().join("sqlx-prepare-check");
105 run_prepare_step(ctx, &cache_dir)?;
106
107 let prepare_filenames: HashSet<String> = glob_query_files(&prepare_dir)?
109 .into_iter()
110 .filter_map(|path| path.file_name().map(|f| f.to_string_lossy().into_owned()))
111 .collect();
112 let cache_filenames: HashSet<String> = glob_query_files(&cache_dir)?
113 .into_iter()
114 .filter_map(|path| path.file_name().map(|f| f.to_string_lossy().into_owned()))
115 .collect();
116
117 if cache_filenames
119 .difference(&prepare_filenames)
120 .next()
121 .is_some()
122 {
123 bail!("prepare check failed: .sqlx is missing one or more queries; you should re-run sqlx prepare");
124 }
125 if prepare_filenames
127 .difference(&cache_filenames)
128 .next()
129 .is_some()
130 {
131 println!(
132 "{} potentially unused queries found in .sqlx; you may want to re-run sqlx prepare",
133 style("warning:").yellow()
134 );
135 }
136
137 for filename in cache_filenames {
140 let prepare_json = load_json_file(prepare_dir.join(&filename))?;
141 let cache_json = load_json_file(cache_dir.join(&filename))?;
142 if prepare_json != cache_json {
143 bail!("prepare check failed: one or more query files differ ({}); you should re-run sqlx prepare", filename);
144 }
145 }
146
147 Ok(())
148}
149
150fn run_prepare_step(ctx: &PrepareCtx, cache_dir: &Path) -> anyhow::Result<()> {
151 fs::create_dir_all(cache_dir).context(format!(
153 "Failed to create query cache directory: {:?}",
154 cache_dir
155 ))?;
156
157 let tmp_dir = ctx.metadata.target_directory().join("sqlx-tmp");
159 fs::create_dir_all(&tmp_dir).context(format!(
160 "Failed to create temporary query cache directory: {:?}",
161 cache_dir
162 ))?;
163
164 for query_file in glob_query_files(cache_dir).context("Failed to read query cache files")? {
166 fs::remove_file(&query_file)
167 .with_context(|| format!("Failed to delete query file: {}", query_file.display()))?;
168 }
169
170 setup_minimal_project_recompile(&ctx.cargo, &ctx.metadata, ctx.all, ctx.workspace)?;
173
174 let check_status = {
176 let mut check_command = Command::new(&ctx.cargo);
177 check_command
178 .arg("check")
179 .args(&ctx.cargo_args)
180 .env("SQLX_TMP", tmp_dir)
181 .env("SQLX_OFFLINE", "false")
182 .env("SQLX_OFFLINE_DIR", cache_dir);
183
184 if let Some(database_url) = &ctx.connect_opts.database_url {
185 check_command.env("DATABASE_URL", database_url);
186 }
187
188 if let Ok(rustflags) = env::var("RUSTFLAGS") {
192 check_command.env("RUSTFLAGS", rustflags);
193 }
194
195 check_command.status()?
196 };
197 if !check_status.success() {
198 bail!("`cargo check` failed with status: {}", check_status);
199 }
200
201 Ok(())
202}
203
204#[derive(Debug, PartialEq)]
205struct ProjectRecompileAction {
206 clean_packages: Vec<String>,
208 touch_paths: Vec<PathBuf>,
209}
210
211fn setup_minimal_project_recompile(
220 cargo: impl AsRef<OsStr>,
221 metadata: &Metadata,
222 all: bool,
223 workspace: bool,
224) -> anyhow::Result<()> {
225 let recompile_action: ProjectRecompileAction = if workspace {
226 minimal_project_recompile_action(metadata, all)
227 } else {
228 ProjectRecompileAction {
230 clean_packages: Vec::new(),
231 touch_paths: metadata.current_package()
232 .context("failed to get package in current working directory, pass `--workspace` if running from a workspace root")?
233 .src_paths()
234 .to_vec(),
235 }
236 };
237
238 if let Err(err) = minimal_project_clean(&cargo, recompile_action) {
239 println!(
240 "Failed minimal recompile setup. Cleaning entire project. Err: {}",
241 err
242 );
243 let clean_status = Command::new(&cargo).arg("clean").status()?;
244 if !clean_status.success() {
245 bail!("`cargo clean` failed with status: {}", clean_status);
246 }
247 }
248
249 Ok(())
250}
251
252fn minimal_project_clean(
253 cargo: impl AsRef<OsStr>,
254 action: ProjectRecompileAction,
255) -> anyhow::Result<()> {
256 let ProjectRecompileAction {
257 clean_packages,
258 touch_paths,
259 } = action;
260
261 for file in touch_paths {
263 let now = filetime::FileTime::now();
264 filetime::set_file_times(&file, now, now)
265 .with_context(|| format!("Failed to update mtime for {file:?}"))?;
266 }
267
268 for pkg_id in &clean_packages {
270 let clean_status = Command::new(&cargo)
271 .args(["clean", "-p", pkg_id])
272 .status()?;
273
274 if !clean_status.success() {
275 bail!("`cargo clean -p {}` failed", pkg_id);
276 }
277 }
278
279 Ok(())
280}
281
282fn minimal_project_recompile_action(metadata: &Metadata, all: bool) -> ProjectRecompileAction {
283 let mut sqlx_macros_dependents = BTreeSet::new();
285 let sqlx_macros_ids: BTreeSet<_> = metadata
286 .entries()
287 .filter(|(_, package)| package.name() == "sqlx-macros")
290 .map(|(id, _)| id)
291 .collect();
292 for sqlx_macros_id in sqlx_macros_ids {
293 sqlx_macros_dependents.extend(metadata.all_dependents_of(sqlx_macros_id));
294 }
295
296 let mut in_workspace_dependents = Vec::new();
298 let mut out_of_workspace_dependents = Vec::new();
299 for dependent in sqlx_macros_dependents {
300 if metadata.workspace_members().contains(dependent) {
301 in_workspace_dependents.push(dependent);
302 } else {
303 out_of_workspace_dependents.push(dependent);
304 }
305 }
306
307 let files_to_touch: Vec<_> = in_workspace_dependents
309 .iter()
310 .filter_map(|id| {
311 metadata
312 .package(id)
313 .map(|package| package.src_paths().to_owned())
314 })
315 .flatten()
316 .collect();
317
318 let packages_to_clean: Vec<_> = if all {
320 out_of_workspace_dependents
321 .iter()
322 .filter_map(|id| {
323 metadata
324 .package(id)
325 .map(|package| package.name().to_owned())
326 })
327 .filter(|name| name != "sqlx")
329 .collect()
330 } else {
331 Vec::new()
332 };
333
334 ProjectRecompileAction {
335 clean_packages: packages_to_clean,
336 touch_paths: files_to_touch,
337 }
338}
339
340fn glob_query_files(path: impl AsRef<Path>) -> anyhow::Result<Vec<PathBuf>> {
342 let path = path.as_ref();
343 let pattern = path.join("query-*.json");
344 glob::glob(
345 pattern
346 .to_str()
347 .context("query cache path is invalid UTF-8")?,
348 )
349 .with_context(|| format!("failed to read query cache path: {}", path.display()))?
350 .collect::<Result<Vec<_>, _>>()
351 .context("glob failed")
352}
353
354fn load_json_file(path: impl AsRef<Path>) -> anyhow::Result<serde_json::Value> {
356 let path = path.as_ref();
357 let file_bytes =
358 fs::read(path).with_context(|| format!("failed to load file: {}", path.display()))?;
359 Ok(serde_json::from_slice(&file_bytes)?)
360}
361
362async fn check_backend(opts: &ConnectOpts) -> anyhow::Result<()> {
363 crate::connect(opts).await?.close().await?;
364 Ok(())
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use std::assert_eq;
371
372 #[test]
373 fn minimal_project_recompile_action_works() -> anyhow::Result<()> {
374 let sample_metadata_path = Path::new("tests")
375 .join("assets")
376 .join("sample_metadata.json");
377 let sample_metadata = std::fs::read_to_string(sample_metadata_path)?;
378 let metadata: Metadata = sample_metadata.parse()?;
379
380 let action = minimal_project_recompile_action(&metadata, false);
381 assert_eq!(
382 action,
383 ProjectRecompileAction {
384 clean_packages: vec![],
385 touch_paths: vec![
386 "/home/user/problematic/workspace/b_in_workspace_lib/src/lib.rs".into(),
387 "/home/user/problematic/workspace/c_in_workspace_bin/src/main.rs".into(),
388 ],
389 }
390 );
391
392 Ok(())
393 }
394}