use std::{collections::HashMap, env, sync::Arc, time::Instant};
use indicatif::{ProgressBar, ProgressStyle};
use itertools::Itertools;
use miette::{Context, IntoDiagnostic};
use rattler_conda_types::{Channel, ChannelConfig, MatchSpec, ParseMatchSpecOptions, Platform};
use rattler_networking::AuthenticationMiddleware;
#[cfg(feature = "s3")]
use rattler_networking::AuthenticationStorage;
use rattler_repodata_gateway::{Gateway, RepoData, SourceConfig};
use reqwest::Client;
#[derive(Debug, clap::Parser)]
#[clap(after_help = r#"Examples:
rattler search 'python*' # glob pattern
rattler search '^numpy-.*$' # regex pattern
rattler search openssl -c bioconda # search in specific channel"#)]
pub struct Opt {
#[clap(required = true)]
matchspec: String,
#[clap(short, long, default_value = "conda-forge")]
channels: Vec<String>,
#[clap(short, long, default_value_t = Platform::current())]
platform: Platform,
#[clap(long, default_value = "3")]
limit_packages: usize,
#[clap(long, default_value = "5")]
limit: usize,
#[clap(long)]
all: bool,
#[clap(long, default_value = "true", action = clap::ArgAction::Set)]
sharded: bool,
}
pub async fn search(opt: Opt) -> miette::Result<()> {
let channel_config =
ChannelConfig::default_with_root_dir(env::current_dir().into_diagnostic()?);
println!("Searching for '{}' on {}", opt.matchspec, opt.platform);
let matchspec = MatchSpec::from_str(
&opt.matchspec,
ParseMatchSpecOptions::strict()
.with_exact_names_only(false)
.with_experimental_extras(true),
)
.into_diagnostic()
.context("failed to parse pattern as matchspec")?;
let channels = opt
.channels
.into_iter()
.map(|channel_str| Channel::from_str(channel_str, &channel_config))
.collect::<Result<Vec<_>, _>>()
.into_diagnostic()?;
println!(
"Channels: {}",
channels.iter().map(Channel::canonical_name).join(", ")
);
let download_client = Client::builder()
.no_gzip()
.build()
.expect("failed to create client");
let download_client = reqwest_middleware::ClientBuilder::new(download_client.clone())
.with_arc(Arc::new(
AuthenticationMiddleware::from_env_and_defaults().into_diagnostic()?,
))
.with(rattler_networking::OciMiddleware::new(download_client));
#[cfg(feature = "s3")]
let download_client = download_client.with(rattler_networking::S3Middleware::new(
HashMap::new(),
AuthenticationStorage::from_env_and_defaults().into_diagnostic()?,
));
#[cfg(feature = "gcs")]
let download_client = download_client.with(rattler_networking::GCSMiddleware::default());
let download_client = download_client.build();
let gateway = Gateway::builder()
.with_client(download_client)
.with_channel_config(rattler_repodata_gateway::ChannelConfig {
default: SourceConfig {
sharded_enabled: opt.sharded,
..SourceConfig::default()
},
per_channel: HashMap::new(),
})
.finish();
let pb = ProgressBar::new_spinner();
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb.set_style(ProgressStyle::with_template("{spinner:.green} {msg}").unwrap());
pb.set_message("Loading repodata...");
let start = Instant::now();
let repo_data = gateway
.query(
channels,
[opt.platform, Platform::NoArch],
vec![matchspec.clone()],
)
.recursive(false) .await
.into_diagnostic()
.context("failed to query repodata")?;
pb.finish_and_clear();
let total_records: usize = repo_data.iter().map(RepoData::len).sum();
println!(
"Found {} matching records in {:?}\n",
total_records,
start.elapsed()
);
if total_records == 0 {
println!("No packages found matching '{}'", opt.matchspec);
return Ok(());
}
let mut packages: HashMap<String, Vec<_>> = HashMap::new();
for record in repo_data.iter().flat_map(RepoData::iter) {
packages
.entry(record.package_record.name.as_normalized().to_string())
.or_default()
.push(record);
}
let mut package_names: Vec<_> = packages.keys().cloned().collect();
package_names.sort();
let limit_versions = if opt.all { usize::MAX } else { opt.limit };
let limit_packages = if opt.all {
usize::MAX
} else {
opt.limit_packages
};
let total_packages = package_names.len();
let shown_packages = total_packages.min(limit_packages);
for name in package_names.into_iter().take(limit_packages) {
let mut records = packages.remove(&name).unwrap();
records.sort_unstable();
records.reverse();
let total = records.len();
let shown = records.len().min(limit_versions);
println!(
"{} ({} version{})",
console::style(&name).bold().green(),
total,
if total == 1 { "" } else { "s" }
);
for record in records.iter().take(limit_versions) {
let channel = record.channel.as_deref().unwrap_or("unknown");
println!(
" {} {} [{}] {}",
console::style(&record.package_record.version).cyan(),
record.package_record.build,
record.package_record.subdir,
console::style(channel).dim()
);
}
if shown < total {
println!(
" {} ... and {} more version{}",
console::style("").dim(),
total - shown,
if total - shown == 1 { "" } else { "s" }
);
}
println!();
}
if shown_packages < total_packages {
println!(
"... and {} more package{}",
total_packages - shown_packages,
if total_packages - shown_packages == 1 {
""
} else {
"s"
}
);
}
Ok(())
}