use std::{
borrow::Borrow,
collections::{HashMap, HashSet},
fmt::{Display, Formatter},
path::Path,
};
use rattler_conda_types::{PackageName, PackageRecord, Platform, PrefixRecord};
use rattler_shell::shell::{Bash, CmdExe, ShellEnum};
use thiserror::Error;
use super::{InstallDriver, Transaction, installer::Reporter};
#[derive(Debug, thiserror::Error)]
pub enum LinkScriptError {
#[error("{0}")]
IoError(String, #[source] std::io::Error),
}
pub enum LinkScriptType {
PreUnlink,
PostLink,
}
impl LinkScriptType {
pub fn get_path(&self, package_record: &PackageRecord, platform: &Platform) -> String {
let name = &package_record.name.as_normalized();
if platform.is_windows() {
match self {
LinkScriptType::PreUnlink => {
format!("Scripts/.{name}-pre-unlink.bat")
}
LinkScriptType::PostLink => {
format!("Scripts/.{name}-post-link.bat")
}
}
} else {
match self {
LinkScriptType::PreUnlink => {
format!("bin/.{name}-pre-unlink.sh")
}
LinkScriptType::PostLink => {
format!("bin/.{name}-post-link.sh")
}
}
}
}
}
impl Display for LinkScriptType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
LinkScriptType::PreUnlink => write!(f, "pre-unlink"),
LinkScriptType::PostLink => write!(f, "post-link"),
}
}
}
#[derive(Debug, Clone)]
pub struct PrePostLinkResult {
pub messages: HashMap<PackageName, String>,
pub failed_packages: Vec<PackageName>,
}
#[derive(Debug, Error)]
pub enum PrePostLinkError {
#[error("failed to determine the installed packages")]
FailedToDetectInstalledPackages(#[source] std::io::Error),
}
pub fn run_link_scripts<'a>(
link_script_type: LinkScriptType,
prefix_records: impl Iterator<Item = &'a PrefixRecord>,
target_prefix: &Path,
platform: &Platform,
reporter: Option<&dyn Reporter>,
) -> Result<PrePostLinkResult, LinkScriptError> {
let mut env = HashMap::new();
env.insert(
"PREFIX".to_string(),
target_prefix.to_string_lossy().to_string(),
);
let mut failed_packages = Vec::new();
let mut messages = HashMap::<PackageName, String>::new();
for record in prefix_records {
let prec = &record.repodata_record.package_record;
let link_file = target_prefix.join(link_script_type.get_path(prec, platform));
if link_file.exists() {
env.insert(
"PKG_NAME".to_string(),
prec.name.as_normalized().to_string(),
);
env.insert("PKG_VERSION".to_string(), prec.version.to_string());
env.insert("PKG_BUILDNUM".to_string(), prec.build_number.to_string());
let shell = if platform.is_windows() {
ShellEnum::CmdExe(CmdExe)
} else {
ShellEnum::Bash(Bash::default())
};
tracing::info!(
"Running {} script for {}",
link_script_type.to_string(),
prec.name.as_normalized()
);
let script_path = link_script_type.get_path(prec, platform);
let reporter_idx = match (&reporter, &link_script_type) {
(Some(r), LinkScriptType::PostLink) => {
Some(r.on_post_link_start(prec.name.as_normalized(), &script_path))
}
(Some(r), LinkScriptType::PreUnlink) => {
Some(r.on_pre_unlink_start(prec.name.as_normalized(), &script_path))
}
_ => None,
};
let success =
match rattler_shell::run_in_environment(target_prefix, &link_file, shell, &env) {
Ok(o) if o.status.success() => true,
Ok(o) => {
failed_packages.push(prec.name.clone());
tracing::warn!("Error running post-link script. Status: {:?}", o.status);
tracing::warn!(" stdout: {}", String::from_utf8_lossy(&o.stdout));
tracing::warn!(" stderr: {}", String::from_utf8_lossy(&o.stderr));
false
}
Err(e) => {
failed_packages.push(prec.name.clone());
tracing::error!("Error running post-link script: {:?}", e);
false
}
};
if let (Some(r), Some(idx)) = (reporter, reporter_idx) {
match link_script_type {
LinkScriptType::PostLink => r.on_post_link_complete(idx, success),
LinkScriptType::PreUnlink => r.on_pre_unlink_complete(idx, success),
}
}
let message_file = target_prefix.join(".messages.txt");
if message_file.exists() {
let message = std::fs::read_to_string(&message_file).map_err(|err| {
LinkScriptError::IoError(
format!(
"error reading message file from {0}",
message_file.display()
),
err,
)
})?;
tracing::info!(
"Message from {} for {}: {}",
link_script_type.to_string(),
prec.name.as_normalized(),
message
);
messages.insert(prec.name.clone(), message);
std::fs::remove_file(&message_file).map_err(|err| {
LinkScriptError::IoError(
format!(
"error removing message file from {0}",
message_file.display()
),
err,
)
})?;
} else {
messages.insert(prec.name.clone(), "".to_string());
}
}
}
Ok(PrePostLinkResult {
messages,
failed_packages,
})
}
impl InstallDriver {
pub fn run_post_link_scripts<Old, New>(
&self,
transaction: &Transaction<Old, New>,
prefix_records: &[&PrefixRecord],
target_prefix: &Path,
reporter: Option<&dyn Reporter>,
) -> Result<PrePostLinkResult, LinkScriptError>
where
Old: AsRef<New>,
New: AsRef<PackageRecord>,
{
let to_install = transaction
.installed_packages()
.map(|r| &r.as_ref().name)
.collect::<HashSet<_>>();
let filter_iter = prefix_records
.iter()
.filter(|r| to_install.contains(&r.repodata_record.package_record.name))
.cloned();
run_link_scripts(
LinkScriptType::PostLink,
filter_iter,
target_prefix,
&transaction.platform,
reporter,
)
}
pub fn run_pre_unlink_scripts<Old, New>(
&self,
transaction: &Transaction<Old, New>,
target_prefix: &Path,
reporter: Option<&dyn Reporter>,
) -> Result<PrePostLinkResult, LinkScriptError>
where
Old: Borrow<PrefixRecord>,
{
run_link_scripts(
LinkScriptType::PreUnlink,
transaction.removed_packages().map(Borrow::borrow),
target_prefix,
&transaction.platform,
reporter,
)
}
}
#[cfg(test)]
mod tests {
use crate::{
get_repodata_record, get_test_data_dir,
install::{
InstallDriver, InstallOptions, TransactionOperation, test_utils::execute_transaction,
transaction,
},
package_cache::PackageCache,
};
use rattler_conda_types::{Platform, PrefixRecord, RepoDataRecord, prefix::Prefix};
use rattler_networking::LazyClient;
fn test_operations() -> Vec<TransactionOperation<PrefixRecord, RepoDataRecord>> {
let repodata_record_1 = get_repodata_record(
get_test_data_dir().join("link-scripts/link-scripts-0.1.0-h4616a5c_0.conda"),
);
vec![TransactionOperation::Install(repodata_record_1)]
}
#[tokio::test]
async fn test_run_link_scripts() {
let target_prefix = tempfile::tempdir().unwrap();
let target_prefix = Prefix::create(target_prefix.path()).unwrap();
let operations = test_operations();
let transaction = transaction::Transaction::<PrefixRecord, RepoDataRecord> {
operations,
python_info: None,
current_python_info: None,
platform: Platform::current(),
unchanged: Vec::new(),
};
let packages_dir = tempfile::tempdir().unwrap();
let cache = PackageCache::new(packages_dir.path());
let driver = InstallDriver::builder().execute_link_scripts(true).finish();
execute_transaction(
transaction,
&target_prefix,
&LazyClient::default(),
&cache,
&driver,
&InstallOptions::default(),
)
.await;
assert!(target_prefix.path().join("i-was-post-linked").exists());
let prefix_records: Vec<PrefixRecord> =
PrefixRecord::collect_from_prefix(&target_prefix).unwrap();
let transaction = transaction::Transaction::<PrefixRecord, RepoDataRecord> {
operations: vec![TransactionOperation::Remove(prefix_records[0].clone())],
python_info: None,
current_python_info: None,
platform: Platform::current(),
unchanged: Vec::new(),
};
execute_transaction(
transaction,
&target_prefix,
&LazyClient::default(),
&cache,
&driver,
&InstallOptions::default(),
)
.await;
assert!(!target_prefix.path().join("i-was-post-linked").exists());
}
}