watchso/
framework_utils.rs

1//! Utilities for framework implementations.
2
3use std::{
4    collections::HashMap,
5    path::{Path, PathBuf},
6    sync::Arc,
7};
8
9use lazy_static::lazy_static;
10use miette::IntoDiagnostic;
11use regex::{Match, Regex, RegexBuilder};
12use tokio::{fs, sync::RwLock, time};
13use watchexec_filterer_globset::GlobsetFilterer;
14
15use crate::{
16    command::WCommand,
17    constants::{dirname, extension, filename},
18    error::WatchError,
19    glob::glob,
20    toml::read_cargo_toml,
21};
22
23/// A mapping of program names and their paths. Using `RwLock` because the process is read heavy.
24#[derive(Default)]
25pub struct ProjectMap(Arc<RwLock<HashMap<String, PathBuf>>>);
26
27impl ProjectMap {
28    /// Get the program's path from the given path. Mainly used for getting the program path from
29    /// program keypair or ELF path.
30    pub async fn get_program_path<P: AsRef<Path>>(&self, path: P) -> Option<PathBuf> {
31        let program_name = match path.as_ref().extension().map(|ext| ext.to_str()) {
32            Some(Some(ext)) => match ext {
33                extension::JSON => ProgramName::from_keypair_path(path),
34                extension::SO => ProgramName::from_elf_path(path),
35                _ => None,
36            },
37            _ => None,
38        };
39
40        match program_name {
41            Some(program_name) => match self
42                .get_program_path_from_name(program_name.original())
43                .await
44            {
45                Some(program_path) => Some(program_path),
46                None => {
47                    self.get_program_path_from_name(program_name.kebab_case())
48                        .await
49                }
50            },
51            None => None,
52        }
53    }
54
55    /// Set the program path based on program name.
56    pub async fn set_program_path<S, P>(&self, name: S, path: P)
57    where
58        S: Into<String>,
59        P: Into<PathBuf>,
60    {
61        let mut program_hm = self.0.write().await;
62        program_hm.insert(name.into(), path.into());
63    }
64
65    /// Get the program path from the program name.
66    async fn get_program_path_from_name<S: AsRef<str>>(&self, name: S) -> Option<PathBuf> {
67        self.0
68            .read()
69            .await
70            .get(name.as_ref())
71            .map(|path| path.to_owned())
72    }
73}
74
75/// Utility struct to get the program name.
76///
77/// Solana build tools generate the keypair name as `<program_name>-keypair.json` and ELF name as
78/// `<program_name>.so`. Since `<program_name>` is always in snake case, we are not able to get the
79/// program name. That's  because the programs named "hello-world" and "hello_world" will have the
80/// exact same output files.
81#[derive(Debug)]
82pub struct ProgramName(String);
83
84impl ProgramName {
85    /// Create a new [`ProgramName`] from the given `name`.
86    pub fn new<S: Into<String>>(name: S) -> Self {
87        Self(name.into())
88    }
89
90    /// Get the program's name from the program keypair path.
91    ///
92    /// This function utilizes the fact that the program keypair names are in the format of
93    /// `<program_name>-keypair.json`.
94    pub fn from_keypair_path<P: AsRef<Path>>(program_keypair_path: P) -> Option<Self> {
95        Self::from_path(program_keypair_path, "-keypair.json")
96    }
97
98    /// Get the program's name from the program ELF path.
99    ///
100    /// This function utilizes the fact that the program ELF names are in the format of
101    /// `<program_name>.so`.
102    pub fn from_elf_path<P: AsRef<Path>>(program_elf_path: P) -> Option<Self> {
103        Self::from_path(program_elf_path, ".so")
104    }
105
106    /// Get the program's name by getting the file name from the `path` and stripping the `suffix`.
107    fn from_path<P, S>(path: P, suffix: S) -> Option<Self>
108    where
109        P: AsRef<Path>,
110        S: AsRef<str>,
111    {
112        path.as_ref()
113            .file_name()
114            .and_then(|name| name.to_str())
115            .filter(|name| name.ends_with(suffix.as_ref()))
116            .map(|name| Self::new(name.trim_end_matches(suffix.as_ref())))
117    }
118
119    /// Reference to the original program name.
120    pub fn original(&self) -> &str {
121        &self.0
122    }
123
124    /// Convert the original program name to kebab-case.
125    pub fn kebab_case(&self) -> String {
126        self.0.replace('_', "-")
127    }
128}
129
130/// Start a new test validator by running `solana-test-validator` command.
131///
132/// This won't have any effect if there is already a running test validator.
133///
134/// NOTE: This function will spawn a tokio task because `solana-test-validator` command never
135/// resolves. It will then sleep for a small duration to give time for the initialization. This
136/// means it will not confirm that the test validator has started.
137pub async fn start_test_validator<P: Into<PathBuf>>(origin: P) -> miette::Result<()> {
138    let origin = origin.into();
139    tokio::spawn(async {
140        let _ = WCommand::new("solana-test-validator")
141            .current_dir(origin)
142            .output()
143            .await;
144    });
145
146    // Wait 2 seconds for the test validator to start
147    time::sleep(time::Duration::from_secs(2)).await;
148
149    Ok(())
150}
151
152/// Get all the directory paths that will be watched by default.
153///
154/// If the `origin` is a workspace, the paths will be filtered by `workspace.members` and
155/// `workspace.exclude`. Otherwise it's the `src` dir by default.
156///
157/// Paths always include `target/deploy`.
158pub async fn get_watch_pathset<P: AsRef<Path>>(origin: P) -> miette::Result<Vec<PathBuf>> {
159    let mut paths = vec![Path::new(dirname::TARGET).join(dirname::DEPLOY)];
160    match filter_workspace_programs(origin).await? {
161        Some(filtered_paths) => paths.extend(filtered_paths),
162        None => paths.push(PathBuf::from(dirname::SRC)),
163    }
164
165    Ok(paths)
166}
167
168/// Filter workspace programs based on the manifest file at `origin`.
169///
170/// Returns `Ok(None)` if the `origin` is not a workspace but has a manifest file.
171async fn filter_workspace_programs<P: AsRef<Path>>(
172    origin: P,
173) -> miette::Result<Option<Vec<PathBuf>>> {
174    let manifest = read_cargo_toml(&origin).await?;
175    match manifest.workspace {
176        Some(workspace) => {
177            let paths = glob(origin.as_ref(), workspace.members, workspace.exclude, true).await?;
178            Ok(Some(paths))
179        }
180        None => Ok(None),
181    }
182}
183
184/// Get a mapping of program names and paths based on the manifest file at `origin`.
185pub async fn get_program_name_path_hashmap<P: AsRef<Path>>(
186    origin: P,
187) -> miette::Result<HashMap<String, PathBuf>> {
188    let mut program_name_path_hm = HashMap::new();
189    let program_paths = filter_workspace_programs(&origin)
190        .await?
191        .unwrap_or(vec![origin.as_ref().to_path_buf()]);
192    for program_path in program_paths {
193        if let Ok(manifest) = read_cargo_toml(&program_path).await {
194            if let Some(package) = manifest.package {
195                program_name_path_hm.insert(package.name, program_path);
196            }
197        }
198    }
199
200    Ok(program_name_path_hm)
201}
202
203/// Get program's root path by running `cargo locate-project` command.
204pub async fn get_program_path<P: AsRef<Path>>(modified_file_path: P) -> miette::Result<PathBuf> {
205    let output = WCommand::new("cargo locate-project --message-format plain")
206        .current_dir(modified_file_path.as_ref().parent().unwrap())
207        .output()
208        .await?;
209    if output.status().success() {
210        Ok(Path::new(output.stdout().trim_end_matches('\n'))
211            .parent()
212            .unwrap()
213            .to_path_buf())
214    } else {
215        Err(WatchError::CommandNotFound("cargo locate-project"))?
216    }
217}
218
219/// Get the keypair's address by running `solana address` command.
220pub async fn get_pubkey_from_keypair_path<P: AsRef<Path>>(
221    keypair_path: P,
222) -> miette::Result<String> {
223    let keypair_output = WCommand::new(format!(
224        "solana address -k {}",
225        keypair_path.as_ref().display()
226    ))
227    .output()
228    .await?;
229
230    if !keypair_output.status().success() {
231        return Err(WatchError::CouldNotGetKeypair(
232            keypair_output.stderr().into(),
233        ))?;
234    }
235
236    let program_id = keypair_output.stdout().trim_end_matches('\n');
237
238    Ok(program_id.to_owned())
239}
240
241/// Find the file that includes `declare_id!` macro and update the program id if it has changed.
242///
243/// This function will check `lib.rs` first and **only** if it doesn't find the declaration it will
244/// then check all the remaining source files.
245pub async fn find_and_update_program_id<P1, P2>(
246    program_path: P1,
247    program_keypair_path: P2,
248) -> miette::Result<()>
249where
250    P1: AsRef<Path>,
251    P2: AsRef<Path>,
252{
253    // Get the keypair program id
254    let program_id = get_pubkey_from_keypair_path(program_keypair_path).await?;
255
256    // Check lib.rs first for the program id
257    let src_path = program_path.as_ref().join(dirname::SRC);
258    let lib_rs_path = src_path.join(filename::LIB_RS);
259
260    if update_rust_program_id(lib_rs_path, &program_id).await? {
261        return Ok(());
262    }
263
264    // Check all the other files if the program_id doesn't exist in lib.rs
265    let rust_src_paths = glob(src_path, [format!("*.{}", extension::RS)], [], false).await?;
266    for path in rust_src_paths {
267        if update_rust_program_id(path, &program_id).await? {
268            // Not necessary to continue the loop after program id update
269            break;
270        }
271    }
272
273    Ok(())
274}
275
276/// Update the file at the given path's `declare_id!` macro with the given program id.
277///
278/// Returns whether the program id was updated successfully.
279async fn update_rust_program_id<P, S>(path: P, program_id: S) -> miette::Result<bool>
280where
281    P: AsRef<Path>,
282    S: AsRef<str>,
283{
284    lazy_static! {
285        static ref REGEX: Regex = RegexBuilder::new(r#"^(([\w]+::)*)declare_id!\("(\w*)"\)"#)
286            .multi_line(true)
287            .build()
288            .unwrap();
289    };
290
291    update_file_program_id_with(path, &program_id, |content| {
292        REGEX.captures(content).and_then(|captures| captures.get(3))
293    })
294    .await
295}
296
297/// Update the file's `declare_id!` macro with the program id based on the given callback.
298///
299/// Returns whether the program id was updated successfully.
300pub async fn update_file_program_id_with<P, S, F>(
301    path: P,
302    program_id: S,
303    cb: F,
304) -> miette::Result<bool>
305where
306    P: AsRef<Path>,
307    S: AsRef<str>,
308    F: Fn(&str) -> Option<Match<'_>>,
309{
310    let mut content = fs::read_to_string(&path).await.into_diagnostic()?;
311    if let Some(program_id_match) =
312        cb(&content).filter(|program_id_match| program_id_match.as_str() != program_id.as_ref())
313    {
314        // Update the program id
315        content.replace_range(program_id_match.range(), program_id.as_ref());
316
317        // Save the file
318        fs::write(&path, content).await.into_diagnostic()?;
319
320        return Ok(true);
321    }
322
323    Ok(false)
324}
325
326/// Get Solana build tool.
327///
328/// Checks for `cargo build-sbf` and `cargo build-bpf` in order.
329///
330/// Returns an error if the Solana build tools are not installed.
331pub async fn get_bpf_or_sbf() -> miette::Result<&'static str> {
332    const BUILD_SBF: &str = "cargo build-sbf";
333    const BUILD_BPF: &str = "cargo build-bpf";
334
335    let build_cmd = if WCommand::exists(BUILD_SBF).await {
336        BUILD_SBF
337    } else if WCommand::exists(BUILD_BPF).await {
338        BUILD_BPF
339    } else {
340        return Err(WatchError::CommandNotFound("solana"))?;
341    };
342
343    Ok(build_cmd)
344}
345
346/// Create a globset filterer that will be used to filter the watched files.
347///
348/// The filterer will always ignore `target`, `test-ledger` and `node_modules` paths.
349pub async fn create_globset_filterer<P: AsRef<Path>>(
350    origin: P,
351    filters: &[&str],
352    ignores: &[&str],
353    extensions: &[&str],
354) -> Arc<GlobsetFilterer> {
355    let filters = filters
356        .iter()
357        .map(|glob| (glob.to_string(), None))
358        .collect::<Vec<(String, Option<PathBuf>)>>();
359    let ignores = [
360        &[
361            "**/*/target/**/*",
362            "**/*/test-ledger/**/*",
363            "**/*/node_modules/**/*",
364        ],
365        ignores,
366    ]
367    .concat()
368    .iter()
369    .map(|glob| (glob.to_string(), None))
370    .collect::<Vec<(String, Option<PathBuf>)>>();
371    let ignore_files = [];
372    let extensions = extensions.iter().map(|ext| ext.into());
373
374    Arc::new(
375        GlobsetFilterer::new(origin, filters, ignores, ignore_files, extensions)
376            .await
377            .unwrap(),
378    )
379}