use crate::error::{Result, RypeError};
pub fn parse_byte_suffix(s: &str) -> Result<Option<usize>> {
let s = s.trim();
if s.eq_ignore_ascii_case("auto") {
return Ok(None);
}
let numeric_end = s
.find(|c: char| !c.is_ascii_digit() && c != '.')
.unwrap_or(s.len());
if numeric_end == 0 {
return Err(RypeError::validation(format!(
"Invalid byte size: '{}' (no numeric value)",
s
)));
}
let numeric_part = &s[..numeric_end];
let suffix_part = s[numeric_end..].trim();
let value: f64 = numeric_part
.parse()
.map_err(|_| RypeError::validation(format!("Invalid numeric value: '{}'", numeric_part)))?;
if value < 0.0 {
return Err(RypeError::validation(format!(
"Byte size cannot be negative: {}",
value
)));
}
let multiplier: u64 = match suffix_part.to_ascii_uppercase().as_str() {
"" | "B" => 1,
"K" | "KB" => 1024,
"M" | "MB" => 1024 * 1024,
"G" | "GB" => 1024 * 1024 * 1024,
"T" | "TB" => 1024 * 1024 * 1024 * 1024,
_ => {
return Err(RypeError::validation(format!(
"Unknown byte suffix: '{}' (use B, K, M, G, or T)",
suffix_part
)))
}
};
let result = value * multiplier as f64;
if !result.is_finite() || result < 0.0 || result > usize::MAX as f64 {
return Err(RypeError::validation(format!(
"Byte size overflow: '{}' exceeds maximum representable value",
s
)));
}
let bytes = result.round() as usize;
Ok(Some(bytes))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemorySource {
CgroupsV2,
CgroupsV1,
Slurm,
ProcMeminfo,
#[allow(dead_code)]
MacOsSysctl,
Fallback,
}
#[derive(Debug, Clone)]
pub struct AvailableMemory {
pub bytes: usize,
pub source: MemorySource,
}
pub const FALLBACK_MEMORY_BYTES: usize = 8 * 1024 * 1024 * 1024;
pub fn detect_available_memory() -> AvailableMemory {
#[cfg(target_os = "linux")]
{
if let Some(bytes) = read_cgroups_v2_limit() {
return AvailableMemory {
bytes,
source: MemorySource::CgroupsV2,
};
}
if let Some(bytes) = read_cgroups_v1_limit() {
return AvailableMemory {
bytes,
source: MemorySource::CgroupsV1,
};
}
if let Some(bytes) = read_slurm_job_memory() {
return AvailableMemory {
bytes,
source: MemorySource::Slurm,
};
}
if let Some(bytes) = read_proc_meminfo_available() {
return AvailableMemory {
bytes,
source: MemorySource::ProcMeminfo,
};
}
}
#[cfg(target_os = "macos")]
{
if let Some(bytes) = read_macos_memsize() {
return AvailableMemory {
bytes,
source: MemorySource::MacOsSysctl,
};
}
}
AvailableMemory {
bytes: FALLBACK_MEMORY_BYTES,
source: MemorySource::Fallback,
}
}
#[cfg(target_os = "linux")]
fn read_cgroups_v2_limit() -> Option<usize> {
let cgroup_content = std::fs::read_to_string("/proc/self/cgroup").ok()?;
let mut cgroup_path = None;
for line in cgroup_content.lines() {
let parts: Vec<&str> = line.splitn(3, ':').collect();
if parts.len() == 3 && parts[0] == "0" && parts[1].is_empty() {
let path = parts[2];
if !path.is_empty() && path != "/" {
cgroup_path = Some(path.to_string());
}
break;
}
}
let path = cgroup_path?;
let memory_max_path = format!("/sys/fs/cgroup{}/memory.max", path);
let content = std::fs::read_to_string(&memory_max_path).ok()?;
let trimmed = content.trim();
if trimmed == "max" {
return None;
}
trimmed.parse().ok()
}
#[cfg(target_os = "linux")]
fn read_cgroups_v1_limit() -> Option<usize> {
let cgroup_content = std::fs::read_to_string("/proc/self/cgroup").ok()?;
let mut memory_path = None;
for line in cgroup_content.lines() {
let parts: Vec<&str> = line.splitn(3, ':').collect();
if parts.len() == 3 && parts[1] == "memory" {
memory_path = Some(parts[2].to_string());
break;
}
}
let path = memory_path?;
let limit_path = format!("/sys/fs/cgroup/memory{}/memory.limit_in_bytes", path);
let content = std::fs::read_to_string(&limit_path).ok()?;
let value: usize = content.trim().parse().ok()?;
const ONE_TB: usize = 1024 * 1024 * 1024 * 1024;
if value > ONE_TB {
return None;
}
Some(value)
}
#[cfg(target_os = "linux")]
fn read_slurm_job_memory() -> Option<usize> {
use std::process::Command;
let job_id = std::env::var("SLURM_JOB_ID").ok()?;
let output = Command::new("scontrol")
.args(["show", "job", &job_id])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8_lossy(&output.stdout);
for line in stdout.lines() {
for field in line.split_whitespace() {
if field.starts_with("MinMemoryNode=") {
let value = field.strip_prefix("MinMemoryNode=")?;
return parse_slurm_memory(value);
}
}
}
None
}
#[cfg(target_os = "linux")]
fn parse_slurm_memory(s: &str) -> Option<usize> {
let s = s.trim();
if s.is_empty() {
return None;
}
let numeric_end = s.find(|c: char| !c.is_ascii_digit()).unwrap_or(s.len());
if numeric_end == 0 {
return None;
}
let numeric: usize = s[..numeric_end].parse().ok()?;
let suffix = &s[numeric_end..];
let multiplier: usize = match suffix.to_ascii_uppercase().as_str() {
"" | "M" => 1024 * 1024, "G" => 1024 * 1024 * 1024,
"T" => 1024 * 1024 * 1024 * 1024,
"K" => 1024,
_ => return None,
};
numeric.checked_mul(multiplier)
}
#[cfg(target_os = "linux")]
fn read_proc_meminfo_available() -> Option<usize> {
let content = std::fs::read_to_string("/proc/meminfo").ok()?;
for line in content.lines() {
if line.starts_with("MemAvailable:") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 2 {
let kb: usize = parts[1].parse().ok()?;
return Some(kb * 1024); }
}
}
None
}
#[cfg(target_os = "macos")]
fn read_macos_memsize() -> Option<usize> {
use std::process::Command;
let output = Command::new("sysctl")
.args(["-n", "hw.memsize"])
.output()
.ok()?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
stdout.trim().parse().ok()
} else {
None
}
}
#[derive(Debug, Clone)]
pub struct ReadMemoryProfile {
pub avg_read_length: usize,
pub avg_query_length: usize,
pub minimizers_per_query: usize,
pub is_paired: bool,
}
impl ReadMemoryProfile {
pub fn new(avg_read_length: usize, is_paired: bool, k: usize, w: usize) -> Self {
let avg_query_length = if is_paired {
avg_read_length * 2
} else {
avg_read_length
};
let minimizers_per_query = if avg_query_length > k {
((avg_query_length - k + 1) / w).max(1) * 2
} else {
0
};
ReadMemoryProfile {
avg_read_length,
avg_query_length,
minimizers_per_query,
is_paired,
}
}
pub fn default_profile(is_paired: bool, k: usize, w: usize) -> Self {
Self::new(5000, is_paired, k, w)
}
pub fn from_files(
r1_path: &std::path::Path,
r2_path: Option<&std::path::Path>,
sample_size: usize,
k: usize,
w: usize,
is_parquet: bool,
trim_to: Option<usize>,
) -> Option<Self> {
if is_parquet {
let (total_length, count, is_paired) = sample_parquet_lengths(r1_path, sample_size)?;
let avg_query_length = total_length / count;
let avg_read_length = if is_paired {
total_length / (count * 2)
} else {
avg_query_length
};
let (avg_read_length, avg_query_length) =
apply_trim_to_limit(avg_read_length, avg_query_length, is_paired, trim_to);
let minimizers_per_query = if avg_query_length > k {
((avg_query_length - k + 1) / w).max(1) * 2
} else {
0
};
return Some(ReadMemoryProfile {
avg_read_length,
avg_query_length,
minimizers_per_query,
is_paired,
});
}
let (r1_total, r1_count) = sample_fastx_lengths(r1_path, sample_size)?;
if r1_count == 0 {
return None;
}
let avg_r1_length = r1_total / r1_count;
let (avg_query_length, avg_read_length, is_paired) = if let Some(r2) = r2_path {
let (r2_total, r2_count) = sample_fastx_lengths(r2, sample_size)?;
if r2_count == 0 {
return None;
}
let avg_r2_length = r2_total / r2_count;
let avg_read = (r1_total + r2_total) / (r1_count + r2_count);
(avg_r1_length + avg_r2_length, avg_read, true)
} else {
(avg_r1_length, avg_r1_length, false)
};
let (avg_read_length, avg_query_length) =
apply_trim_to_limit(avg_read_length, avg_query_length, is_paired, trim_to);
let minimizers_per_query = if avg_query_length > k {
((avg_query_length - k + 1) / w).max(1) * 2
} else {
0
};
Some(ReadMemoryProfile {
avg_read_length,
avg_query_length,
minimizers_per_query,
is_paired,
})
}
pub fn estimate_arrow_bytes_per_row(&self, is_paired: bool) -> usize {
let num_string_cols = if is_paired { 3 } else { 2 };
let read_id_bytes = 50;
let data_bytes = read_id_bytes + self.avg_query_length;
let offset_overhead = 4 * num_string_cols;
let validity_overhead = (num_string_cols + 7) / 8;
let array_overhead = 1;
data_bytes + offset_overhead + validity_overhead + array_overhead
}
pub fn estimate_owned_record_bytes(&self, is_paired: bool) -> usize {
let id_bytes = 8;
let vec_overhead = 24;
let seq1_bytes = vec_overhead + self.avg_read_length;
let seq2_bytes = if is_paired {
vec_overhead + self.avg_read_length
} else {
0 };
id_bytes + seq1_bytes + seq2_bytes
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InputFormat {
Fastx { is_paired: bool },
Parquet {
is_paired: bool,
trimmed_in_reader: bool,
},
}
impl InputFormat {
pub fn prefetch_slots(&self) -> usize {
match self {
InputFormat::Fastx { .. } => FASTX_PREFETCH_BUFFER_SLOTS,
InputFormat::Parquet { .. } => PARQUET_PREFETCH_BUFFER_SLOTS,
}
}
pub fn estimate_buffer_bytes_per_row(&self, profile: &ReadMemoryProfile) -> usize {
match self {
InputFormat::Fastx { is_paired } => profile.estimate_owned_record_bytes(*is_paired),
InputFormat::Parquet {
is_paired,
trimmed_in_reader: true,
} => profile.estimate_owned_record_bytes(*is_paired),
InputFormat::Parquet {
is_paired,
trimmed_in_reader: false,
} => profile.estimate_arrow_bytes_per_row(*is_paired),
}
}
pub fn is_paired(&self) -> bool {
match self {
InputFormat::Fastx { is_paired }
| InputFormat::Parquet {
is_paired,
trimmed_in_reader: _,
} => *is_paired,
}
}
}
fn apply_trim_to_limit(
avg_read_length: usize,
avg_query_length: usize,
is_paired: bool,
trim_to: Option<usize>,
) -> (usize, usize) {
match trim_to {
Some(limit) if limit > 0 => {
let capped_read_length = avg_read_length.min(limit);
let capped_query_length = if is_paired {
capped_read_length * 2
} else {
capped_read_length
};
(capped_read_length, capped_query_length)
}
_ => (avg_read_length, avg_query_length),
}
}
fn sample_fastx_lengths(path: &std::path::Path, sample_size: usize) -> Option<(usize, usize)> {
use needletail::parse_fastx_file;
let mut total_length: usize = 0;
let mut count: usize = 0;
let mut reader = parse_fastx_file(path).ok()?;
while let Some(record) = reader.next() {
let record = record.ok()?;
total_length += record.seq().len();
count += 1;
if count >= sample_size {
break;
}
}
Some((total_length, count))
}
fn sample_parquet_lengths(
path: &std::path::Path,
sample_size: usize,
) -> Option<(usize, usize, bool)> {
use arrow::array::{Array, LargeStringArray, StringArray};
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
use std::fs::File;
let file = File::open(path).ok()?;
let builder = ParquetRecordBatchReaderBuilder::try_new(file).ok()?;
let schema = builder.schema();
let seq1_idx = schema
.fields()
.iter()
.position(|f| f.name() == "sequence1")?;
let seq2_idx = schema.fields().iter().position(|f| f.name() == "sequence2");
let mut col_indices = vec![seq1_idx];
if let Some(idx) = seq2_idx {
col_indices.push(idx);
}
let projection =
parquet::arrow::ProjectionMask::roots(builder.parquet_schema(), col_indices.clone());
let reader = builder.with_projection(projection).build().ok()?;
let mut total_length: usize = 0;
let mut count: usize = 0;
let mut is_paired = false;
let mut checked_paired = false;
for batch_result in reader {
let batch = batch_result.ok()?;
let seq1_col = batch.column_by_name("sequence1")?;
let seq2_col = batch.column_by_name("sequence2");
if !checked_paired {
if let Some(col) = &seq2_col {
is_paired = col.null_count() < col.len();
}
checked_paired = true;
}
fn get_string_len(col: &dyn Array, idx: usize) -> Option<usize> {
if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
if !arr.is_null(idx) {
return Some(arr.value(idx).len());
}
} else if let Some(arr) = col.as_any().downcast_ref::<LargeStringArray>() {
if !arr.is_null(idx) {
return Some(arr.value(idx).len());
}
}
None
}
for i in 0..batch.num_rows() {
if count >= sample_size {
break;
}
if let Some(len) = get_string_len(seq1_col.as_ref(), i) {
total_length += len;
count += 1;
if is_paired {
if let Some(col) = &seq2_col {
if let Some(len2) = get_string_len(col.as_ref(), i) {
total_length += len2;
}
}
}
}
}
if count >= sample_size {
break;
}
}
if count == 0 {
return None;
}
Some((total_length, count, is_paired))
}
pub const FASTX_PREFETCH_BUFFER_SLOTS: usize = 2;
pub const PARQUET_PREFETCH_BUFFER_SLOTS: usize = 4;
pub const DEFAULT_PREFETCH_BUFFER_SLOTS: usize = PARQUET_PREFETCH_BUFFER_SLOTS;
#[derive(Debug, Clone)]
pub struct MemoryConfig {
pub max_memory: usize,
pub num_threads: usize,
pub index_memory: usize,
pub shard_reservation: usize,
pub read_profile: ReadMemoryProfile,
pub num_buckets: usize,
pub input_format: InputFormat,
pub is_log_ratio: bool,
}
impl MemoryConfig {
#[allow(clippy::too_many_arguments)]
pub fn new(
max_memory: usize,
num_threads: usize,
index_memory: usize,
shard_reservation: usize,
read_profile: ReadMemoryProfile,
num_buckets: usize,
input_format: InputFormat,
is_log_ratio: bool,
) -> Result<Self> {
if max_memory == 0 {
return Err(RypeError::validation("max_memory must be > 0"));
}
if num_threads == 0 {
return Err(RypeError::validation("num_threads must be > 0"));
}
if num_buckets == 0 {
return Err(RypeError::validation("num_buckets must be > 0"));
}
Ok(Self {
max_memory,
num_threads,
index_memory,
shard_reservation,
read_profile,
num_buckets,
input_format,
is_log_ratio,
})
}
pub fn prefetch_buffer_slots(&self) -> usize {
self.input_format.prefetch_slots()
}
pub fn buffer_bytes_per_row(&self) -> usize {
self.input_format
.estimate_buffer_bytes_per_row(&self.read_profile)
}
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub batch_size: usize,
pub batch_count: usize,
pub per_batch_memory: usize,
pub peak_memory: usize,
}
pub const MIN_BATCH_SIZE: usize = 1000;
pub const MAX_BATCH_SIZE: usize = 100_000_000;
const SAFETY_MARGIN_PERCENT: f64 = 0.10;
const SAFETY_MARGIN_MIN_BYTES: usize = 256 * 1024 * 1024;
const MEMORY_FUDGE_FACTOR: f64 = 1.8;
const ARROW_BUILDER_OVERHEAD: f64 = 1.5;
pub fn estimate_batch_memory(
batch_size: usize,
profile: &ReadMemoryProfile,
num_buckets: usize,
is_log_ratio: bool,
) -> Option<usize> {
let record_overhead: usize = 72;
let input_records =
batch_size.checked_mul(record_overhead.checked_add(profile.avg_query_length)?)?;
let minimizer_vecs = batch_size
.checked_mul(profile.minimizers_per_query)?
.checked_mul(16)?;
let query_index = batch_size
.checked_mul(profile.minimizers_per_query)?
.checked_mul(12)?;
let estimated_buckets_per_read = 4.min(num_buckets);
let accumulators = batch_size
.checked_mul(estimated_buckets_per_read)?
.checked_mul(24)?;
let mut base_estimate = input_records
.checked_add(minimizer_vecs)?
.checked_add(query_index)?
.checked_add(accumulators)?;
if is_log_ratio {
let deferred_reads = batch_size;
let meta_bytes: usize = 48;
let estimated_header_bytes: usize = 60;
let per_read_overhead = meta_bytes.checked_add(estimated_header_bytes)?;
let minimizer_cost = profile.minimizers_per_query.checked_mul(12)?;
let query_mins_cost = profile.minimizers_per_query.checked_mul(8)?;
let deferred_memory = deferred_reads.checked_mul(
per_read_overhead
.checked_add(minimizer_cost)?
.checked_add(query_mins_cost)?,
)?;
base_estimate = base_estimate.checked_add(deferred_memory)?;
}
let result = (base_estimate as f64 * MEMORY_FUDGE_FACTOR).round() as usize;
Some(result)
}
fn estimate_io_buffer_memory(batch_size: usize, config: &MemoryConfig) -> Option<usize> {
let prefetch_slots = config.prefetch_buffer_slots();
let bytes_per_row = config.buffer_bytes_per_row();
let base_memory = prefetch_slots
.checked_mul(batch_size)?
.checked_mul(bytes_per_row)?;
let result = match config.input_format {
InputFormat::Parquet {
trimmed_in_reader: false,
..
} => (base_memory as f64 * ARROW_BUILDER_OVERHEAD).round() as usize,
InputFormat::Parquet {
trimmed_in_reader: true,
..
}
| InputFormat::Fastx { .. } => base_memory,
};
Some(result)
}
fn estimate_total_batch_memory(
batch_size: usize,
batch_count: usize,
config: &MemoryConfig,
) -> Option<usize> {
let per_batch = estimate_batch_memory(
batch_size,
&config.read_profile,
config.num_buckets,
config.is_log_ratio,
)?;
let io_buffers = estimate_io_buffer_memory(batch_size, config)?;
per_batch.checked_mul(batch_count)?.checked_add(io_buffers)
}
pub fn calculate_batch_config(config: &MemoryConfig) -> BatchConfig {
let safety_margin = (config.max_memory as f64 * SAFETY_MARGIN_PERCENT).round() as usize;
let safety_margin = safety_margin.max(SAFETY_MARGIN_MIN_BYTES);
let base_reserved = config
.index_memory
.saturating_add(config.shard_reservation)
.saturating_add(safety_margin);
let available = config.max_memory.saturating_sub(base_reserved);
let make_min_config = || {
let per_batch_memory = estimate_batch_memory(
MIN_BATCH_SIZE,
&config.read_profile,
config.num_buckets,
config.is_log_ratio,
)
.unwrap_or(usize::MAX);
let io_buffer_memory =
estimate_io_buffer_memory(MIN_BATCH_SIZE, config).unwrap_or(usize::MAX);
BatchConfig {
batch_size: MIN_BATCH_SIZE,
batch_count: 1,
per_batch_memory,
peak_memory: base_reserved
.saturating_add(per_batch_memory)
.saturating_add(io_buffer_memory),
}
};
let min_total = estimate_total_batch_memory(MIN_BATCH_SIZE, 1, config);
if min_total.map_or(true, |m| available < m) {
return make_min_config();
}
let batch_count = 1;
let batch_size = binary_search_batch_size_with_io(available, batch_count, config);
if batch_size >= MIN_BATCH_SIZE {
let total = estimate_total_batch_memory(batch_size, batch_count, config);
if total.is_some_and(|t| t <= available) {
let per_batch_memory = estimate_batch_memory(
batch_size,
&config.read_profile,
config.num_buckets,
config.is_log_ratio,
)
.unwrap_or(usize::MAX);
let io_buffer_memory =
estimate_io_buffer_memory(batch_size, config).unwrap_or(usize::MAX);
let peak_memory = base_reserved
.saturating_add(per_batch_memory)
.saturating_add(io_buffer_memory);
return BatchConfig {
batch_size,
batch_count,
per_batch_memory,
peak_memory,
};
}
}
make_min_config()
}
fn binary_search_batch_size_with_io(
memory_budget: usize,
batch_count: usize,
config: &MemoryConfig,
) -> usize {
let mut low = MIN_BATCH_SIZE;
let mut high = MAX_BATCH_SIZE;
let mut best = MIN_BATCH_SIZE;
while low <= high {
let mid = low + (high - low) / 2;
let total_memory = estimate_total_batch_memory(mid, batch_count, config);
let fits = total_memory.is_some_and(|m| m <= memory_budget);
if fits {
best = mid;
low = mid + 1;
} else {
if mid == 0 {
break;
}
high = mid - 1;
}
}
best
}
pub fn estimate_shard_reservation(largest_shard_entries: u64, num_threads: usize) -> usize {
const SHARD_SELECTIVITY_ESTIMATE: f64 = 0.10;
const CSR_CONCAT_MULTIPLIER: f64 = 2.0;
const DECODE_OVERHEAD_MULTIPLIER: usize = 3;
if largest_shard_entries == 0 {
return 0;
}
let bytes_per_entry: usize = 12; let rg_size = crate::constants::DEFAULT_ROW_GROUP_SIZE;
let decode_buffers = num_threads
.saturating_mul(rg_size)
.saturating_mul(bytes_per_entry)
.saturating_mul(DECODE_OVERHEAD_MULTIPLIER);
let filtered_bytes = largest_shard_entries as f64
* bytes_per_entry as f64
* SHARD_SELECTIVITY_ESTIMATE
* CSR_CONCAT_MULTIPLIER;
let filtered_csr = filtered_bytes.min(usize::MAX as f64).round() as usize;
decode_buffers.saturating_add(filtered_csr)
}
pub fn format_bytes(bytes: usize) -> String {
const KB: usize = 1024;
const MB: usize = KB * 1024;
const GB: usize = MB * 1024;
const TB: usize = GB * 1024;
if bytes >= TB {
format!("{:.2} TB", bytes as f64 / TB as f64)
} else if bytes >= GB {
format!("{:.2} GB", bytes as f64 / GB as f64)
} else if bytes >= MB {
format!("{:.2} MB", bytes as f64 / MB as f64)
} else if bytes >= KB {
format!("{:.2} KB", bytes as f64 / KB as f64)
} else {
format!("{} B", bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_byte_suffix_gigabytes() {
assert_eq!(
parse_byte_suffix("4G").unwrap(),
Some(4 * 1024 * 1024 * 1024)
);
assert_eq!(
parse_byte_suffix("4GB").unwrap(),
Some(4 * 1024 * 1024 * 1024)
);
assert_eq!(
parse_byte_suffix("4g").unwrap(),
Some(4 * 1024 * 1024 * 1024)
);
assert_eq!(
parse_byte_suffix("4gb").unwrap(),
Some(4 * 1024 * 1024 * 1024)
);
}
#[test]
fn test_parse_byte_suffix_megabytes() {
assert_eq!(parse_byte_suffix("512M").unwrap(), Some(512 * 1024 * 1024));
assert_eq!(parse_byte_suffix("512MB").unwrap(), Some(512 * 1024 * 1024));
assert_eq!(parse_byte_suffix("512m").unwrap(), Some(512 * 1024 * 1024));
}
#[test]
fn test_parse_byte_suffix_kilobytes() {
assert_eq!(parse_byte_suffix("1024K").unwrap(), Some(1024 * 1024));
assert_eq!(parse_byte_suffix("1024KB").unwrap(), Some(1024 * 1024));
}
#[test]
fn test_parse_byte_suffix_bytes() {
assert_eq!(parse_byte_suffix("1024").unwrap(), Some(1024));
assert_eq!(parse_byte_suffix("1024B").unwrap(), Some(1024));
}
#[test]
fn test_parse_byte_suffix_terabytes() {
assert_eq!(
parse_byte_suffix("1T").unwrap(),
Some(1024 * 1024 * 1024 * 1024)
);
assert_eq!(
parse_byte_suffix("1TB").unwrap(),
Some(1024 * 1024 * 1024 * 1024)
);
}
#[test]
fn test_parse_byte_suffix_decimal() {
assert_eq!(
parse_byte_suffix("1.5G").unwrap(),
Some((1.5 * 1024.0 * 1024.0 * 1024.0) as usize)
);
assert_eq!(
parse_byte_suffix("2.5M").unwrap(),
Some((2.5 * 1024.0 * 1024.0) as usize)
);
}
#[test]
fn test_parse_byte_suffix_auto() {
assert_eq!(parse_byte_suffix("auto").unwrap(), None);
assert_eq!(parse_byte_suffix("AUTO").unwrap(), None);
assert_eq!(parse_byte_suffix("Auto").unwrap(), None);
}
#[test]
fn test_parse_byte_suffix_whitespace() {
assert_eq!(
parse_byte_suffix(" 4G ").unwrap(),
Some(4 * 1024 * 1024 * 1024)
);
assert_eq!(
parse_byte_suffix("4 G").unwrap(),
Some(4 * 1024 * 1024 * 1024)
);
}
#[test]
fn test_parse_byte_suffix_invalid() {
assert!(parse_byte_suffix("").is_err());
assert!(parse_byte_suffix("G").is_err());
assert!(parse_byte_suffix("abc").is_err());
assert!(parse_byte_suffix("4X").is_err());
assert!(parse_byte_suffix("-4G").is_err());
}
#[test]
fn test_parse_byte_suffix_overflow() {
assert!(parse_byte_suffix("99999999999G").is_err());
assert!(parse_byte_suffix("99999999999T").is_err());
assert!(parse_byte_suffix("1e400G").is_err());
}
#[test]
fn test_detect_available_memory_returns_nonzero() {
let result = detect_available_memory();
assert!(result.bytes > 0);
}
#[test]
fn test_fallback_memory_is_8gb() {
assert_eq!(FALLBACK_MEMORY_BYTES, 8 * 1024 * 1024 * 1024);
}
#[test]
fn test_estimate_batch_memory_scales_linearly() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let mem_1k = estimate_batch_memory(1000, &profile, 100, false).unwrap();
let mem_2k = estimate_batch_memory(2000, &profile, 100, false).unwrap();
let ratio = mem_2k as f64 / mem_1k as f64;
assert!(
ratio > 1.8 && ratio < 2.2,
"Expected ~2x scaling, got {}",
ratio
);
}
#[test]
fn test_estimate_batch_memory_increases_with_read_length() {
let profile_short = ReadMemoryProfile::new(150, false, 64, 50);
let profile_long = ReadMemoryProfile::new(10000, false, 64, 50);
let mem_short = estimate_batch_memory(10000, &profile_short, 100, false).unwrap();
let mem_long = estimate_batch_memory(10000, &profile_long, 100, false).unwrap();
assert!(mem_long > mem_short);
}
#[test]
fn test_estimate_batch_memory_overflow_protection() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let result = estimate_batch_memory(usize::MAX, &profile, 100, false);
assert!(result.is_none(), "Should return None on overflow");
let large_profile = ReadMemoryProfile {
avg_read_length: 150,
avg_query_length: 150,
minimizers_per_query: usize::MAX / 2,
is_paired: false,
};
let result = estimate_batch_memory(1000000, &large_profile, 100, false);
assert!(result.is_none(), "Should return None on overflow");
}
#[test]
fn test_estimate_batch_memory_with_log_ratio_larger() {
let profile = ReadMemoryProfile::new(1000, false, 64, 50);
let mem_normal = estimate_batch_memory(10000, &profile, 100, false).unwrap();
let mem_log_ratio = estimate_batch_memory(10000, &profile, 100, true).unwrap();
assert!(
mem_log_ratio > mem_normal,
"Log-ratio memory {} should exceed normal memory {}",
mem_log_ratio,
mem_normal
);
}
#[test]
fn test_estimate_batch_memory_log_ratio_scales() {
let profile = ReadMemoryProfile::new(1000, false, 64, 50);
let mem_10k = estimate_batch_memory(10000, &profile, 100, true).unwrap();
let mem_20k = estimate_batch_memory(20000, &profile, 100, true).unwrap();
let ratio = mem_20k as f64 / mem_10k as f64;
assert!(
ratio > 1.8 && ratio < 2.2,
"Expected ~2x scaling for log-ratio memory, got {}",
ratio
);
}
#[test]
fn test_calculate_batch_config_shrinks_for_log_ratio() {
let profile = ReadMemoryProfile::new(1000, false, 64, 50);
let config_normal = MemoryConfig {
max_memory: 4 * 1024 * 1024 * 1024, num_threads: 4,
index_memory: 100 * 1024 * 1024,
shard_reservation: 50 * 1024 * 1024,
read_profile: profile.clone(),
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let config_log_ratio = MemoryConfig {
is_log_ratio: true,
..config_normal.clone()
};
let batch_normal = calculate_batch_config(&config_normal);
let batch_log_ratio = calculate_batch_config(&config_log_ratio);
assert!(
batch_log_ratio.batch_size < batch_normal.batch_size,
"Log-ratio batch_size {} should be < normal batch_size {}",
batch_log_ratio.batch_size,
batch_normal.batch_size
);
}
#[test]
fn test_log_ratio_deferred_covers_full_batch() {
let profile = ReadMemoryProfile::new(1000, false, 64, 50);
let batch_size: usize = 10_000;
let num_buckets: usize = 100;
let mem_normal = estimate_batch_memory(batch_size, &profile, num_buckets, false).unwrap();
let mem_lr = estimate_batch_memory(batch_size, &profile, num_buckets, true).unwrap();
let deferred_component = mem_lr - mem_normal;
let per_read = 48 + 60 + profile.minimizers_per_query * 20;
let expected_min =
(batch_size as f64 * per_read as f64 * MEMORY_FUDGE_FACTOR * 0.95) as usize;
assert!(
deferred_component >= expected_min,
"Deferred component {} should be >= {} (full batch_size, not batch_size/2)",
deferred_component,
expected_min
);
}
#[test]
fn test_calculate_batch_config_respects_limit() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config = MemoryConfig {
max_memory: 1024 * 1024 * 1024, num_threads: 4,
index_memory: 100 * 1024 * 1024, shard_reservation: 50 * 1024 * 1024, read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let batch_config = calculate_batch_config(&config);
assert!(
batch_config.peak_memory <= config.max_memory,
"Peak memory {} exceeds max {}",
batch_config.peak_memory,
config.max_memory
);
}
#[test]
fn test_calculate_batch_config_accounts_for_index() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config_small_index = MemoryConfig {
max_memory: 1024 * 1024 * 1024,
num_threads: 4,
index_memory: 100 * 1024 * 1024,
shard_reservation: 0,
read_profile: profile.clone(),
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let config_large_index = MemoryConfig {
max_memory: 1024 * 1024 * 1024,
num_threads: 4,
index_memory: 500 * 1024 * 1024,
shard_reservation: 0,
read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let batch_small = calculate_batch_config(&config_small_index);
let batch_large = calculate_batch_config(&config_large_index);
assert!(
batch_small.batch_size >= batch_large.batch_size,
"Small index batch {} should be >= large index batch {}",
batch_small.batch_size,
batch_large.batch_size
);
}
#[test]
fn test_calculate_batch_config_minimum_values() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config = MemoryConfig {
max_memory: 50 * 1024 * 1024, num_threads: 4,
index_memory: 10 * 1024 * 1024,
shard_reservation: 5 * 1024 * 1024,
read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let batch_config = calculate_batch_config(&config);
assert!(batch_config.batch_size >= MIN_BATCH_SIZE);
assert!(batch_config.batch_count >= 1);
}
#[test]
fn test_calculate_batch_config_uses_threads() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config = MemoryConfig {
max_memory: 16 * 1024 * 1024 * 1024, num_threads: 8,
index_memory: 100 * 1024 * 1024,
shard_reservation: 0,
read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let batch_config = calculate_batch_config(&config);
assert!(batch_config.batch_count >= 1);
}
#[test]
fn test_read_memory_profile_paired() {
let profile_single = ReadMemoryProfile::new(150, false, 64, 50);
let profile_paired = ReadMemoryProfile::new(150, true, 64, 50);
assert_eq!(profile_single.avg_query_length, 150);
assert_eq!(profile_paired.avg_query_length, 300);
}
#[test]
fn test_read_memory_profile_minimizers() {
let profile = ReadMemoryProfile::new(1000, false, 64, 50);
assert!(profile.minimizers_per_query > 0);
}
#[test]
fn test_format_bytes() {
assert_eq!(format_bytes(500), "500 B");
assert_eq!(format_bytes(1024), "1.00 KB");
assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
assert_eq!(format_bytes(1024 * 1024 * 1024 * 1024), "1.00 TB");
}
#[test]
fn test_read_memory_profile_from_files() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut file = NamedTempFile::new().unwrap();
for i in 0..3 {
writeln!(file, "@read{}", i).unwrap();
writeln!(file, "{}", "A".repeat(100)).unwrap();
writeln!(file, "+").unwrap();
writeln!(file, "{}", "I".repeat(100)).unwrap();
}
file.flush().unwrap();
let profile = ReadMemoryProfile::from_files(
file.path(),
None,
10, 64, 50, false, None, );
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.avg_read_length, 100);
assert_eq!(profile.avg_query_length, 100); }
#[test]
fn test_read_memory_profile_from_files_paired() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut r1 = NamedTempFile::new().unwrap();
for i in 0..3 {
writeln!(r1, "@read{}", i).unwrap();
writeln!(r1, "{}", "A".repeat(100)).unwrap();
writeln!(r1, "+").unwrap();
writeln!(r1, "{}", "I".repeat(100)).unwrap();
}
r1.flush().unwrap();
let mut r2 = NamedTempFile::new().unwrap();
for i in 0..3 {
writeln!(r2, "@read{}", i).unwrap();
writeln!(r2, "{}", "T".repeat(150)).unwrap();
writeln!(r2, "+").unwrap();
writeln!(r2, "{}", "I".repeat(150)).unwrap();
}
r2.flush().unwrap();
let profile =
ReadMemoryProfile::from_files(r1.path(), Some(r2.path()), 10, 64, 50, false, None);
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.avg_read_length, 125);
assert_eq!(profile.avg_query_length, 250);
}
#[test]
fn test_read_memory_profile_from_files_nonexistent() {
let profile = ReadMemoryProfile::from_files(
std::path::Path::new("/nonexistent/file.fq"),
None,
10,
64,
50,
false,
None,
);
assert!(profile.is_none());
}
#[test]
fn test_read_memory_profile_from_parquet_single_end() {
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::NamedTempFile;
let file = NamedTempFile::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("read_id", DataType::Utf8, false),
Field::new("sequence1", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["read0", "read1", "read2"])) as ArrayRef,
Arc::new(StringArray::from(vec![
"A".repeat(100),
"A".repeat(100),
"A".repeat(100),
])) as ArrayRef,
],
)
.unwrap();
let writer_file = std::fs::File::create(file.path()).unwrap();
let mut writer = ArrowWriter::try_new(writer_file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let profile = ReadMemoryProfile::from_files(
file.path(),
None,
10, 64, 50, true, None, );
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.avg_read_length, 100);
assert_eq!(profile.avg_query_length, 100); }
#[test]
fn test_read_memory_profile_from_parquet_paired_end() {
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::NamedTempFile;
let file = NamedTempFile::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("read_id", DataType::Utf8, false),
Field::new("sequence1", DataType::Utf8, false),
Field::new("sequence2", DataType::Utf8, true), ]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["read0", "read1", "read2"])) as ArrayRef,
Arc::new(StringArray::from(vec![
"A".repeat(100),
"A".repeat(100),
"A".repeat(100),
])) as ArrayRef,
Arc::new(StringArray::from(vec![
"T".repeat(150),
"T".repeat(150),
"T".repeat(150),
])) as ArrayRef,
],
)
.unwrap();
let writer_file = std::fs::File::create(file.path()).unwrap();
let mut writer = ArrowWriter::try_new(writer_file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let profile = ReadMemoryProfile::from_files(
file.path(),
None,
10, 64, 50, true, None, );
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.avg_read_length, 125);
assert_eq!(profile.avg_query_length, 250);
}
#[test]
fn test_read_memory_profile_from_parquet_nonexistent() {
let profile = ReadMemoryProfile::from_files(
std::path::Path::new("/nonexistent/file.parquet"),
None,
10,
64,
50,
true, None, );
assert!(profile.is_none());
}
#[test]
fn test_sample_parquet_lengths_with_large_utf8() {
use arrow::array::{ArrayRef, LargeStringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::NamedTempFile;
let file = NamedTempFile::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("read_id", DataType::LargeUtf8, false),
Field::new("sequence1", DataType::LargeUtf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(LargeStringArray::from(vec!["read0", "read1"])) as ArrayRef,
Arc::new(LargeStringArray::from(vec![
"A".repeat(200),
"A".repeat(200),
])) as ArrayRef,
],
)
.unwrap();
let writer_file = std::fs::File::create(file.path()).unwrap();
let mut writer = ArrowWriter::try_new(writer_file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let profile = ReadMemoryProfile::from_files(
file.path(),
None,
10, 64, 50, true, None, );
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.avg_read_length, 200);
assert_eq!(profile.avg_query_length, 200);
}
#[test]
fn test_apply_trim_to_limit_single_end() {
let (read_len, query_len) = apply_trim_to_limit(1000, 1000, false, Some(100));
assert_eq!(read_len, 100);
assert_eq!(query_len, 100);
let (read_len, query_len) = apply_trim_to_limit(1000, 1000, false, None);
assert_eq!(read_len, 1000);
assert_eq!(query_len, 1000);
let (read_len, query_len) = apply_trim_to_limit(100, 100, false, Some(1000));
assert_eq!(read_len, 100);
assert_eq!(query_len, 100);
let (read_len, query_len) = apply_trim_to_limit(1000, 1000, false, Some(0));
assert_eq!(read_len, 1000);
assert_eq!(query_len, 1000);
}
#[test]
fn test_apply_trim_to_limit_paired_end() {
let (read_len, query_len) = apply_trim_to_limit(1000, 2000, true, Some(100));
assert_eq!(read_len, 100);
assert_eq!(query_len, 200);
let (read_len, query_len) = apply_trim_to_limit(1000, 2000, true, None);
assert_eq!(read_len, 1000);
assert_eq!(query_len, 2000);
let (read_len, query_len) = apply_trim_to_limit(1000, 2000, true, Some(0));
assert_eq!(read_len, 1000);
assert_eq!(query_len, 2000);
}
#[test]
fn test_read_memory_profile_from_fastx_with_trim_to() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut file = NamedTempFile::new().unwrap();
for i in 0..3 {
writeln!(file, "@read{}", i).unwrap();
writeln!(file, "{}", "A".repeat(1000)).unwrap();
writeln!(file, "+").unwrap();
writeln!(file, "{}", "I".repeat(1000)).unwrap();
}
file.flush().unwrap();
let profile =
ReadMemoryProfile::from_files(file.path(), None, 10, 64, 50, false, None).unwrap();
assert_eq!(profile.avg_read_length, 1000);
assert_eq!(profile.avg_query_length, 1000);
let profile =
ReadMemoryProfile::from_files(file.path(), None, 10, 64, 50, false, Some(100)).unwrap();
assert_eq!(profile.avg_read_length, 100);
assert_eq!(profile.avg_query_length, 100);
}
#[test]
fn test_read_memory_profile_from_parquet_with_trim_to() {
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::NamedTempFile;
let file = NamedTempFile::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("read_id", DataType::Utf8, false),
Field::new("sequence1", DataType::Utf8, false),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["read0", "read1", "read2"])) as ArrayRef,
Arc::new(StringArray::from(vec![
"A".repeat(1000),
"A".repeat(1000),
"A".repeat(1000),
])) as ArrayRef,
],
)
.unwrap();
let writer_file = std::fs::File::create(file.path()).unwrap();
let mut writer = ArrowWriter::try_new(writer_file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let profile =
ReadMemoryProfile::from_files(file.path(), None, 10, 64, 50, true, None).unwrap();
assert_eq!(profile.avg_read_length, 1000);
assert_eq!(profile.avg_query_length, 1000);
let profile =
ReadMemoryProfile::from_files(file.path(), None, 10, 64, 50, true, Some(100)).unwrap();
assert_eq!(profile.avg_read_length, 100);
assert_eq!(profile.avg_query_length, 100);
}
#[test]
fn test_read_memory_profile_paired_fastx_with_trim_to() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut r1 = NamedTempFile::new().unwrap();
for i in 0..3 {
writeln!(r1, "@read{}", i).unwrap();
writeln!(r1, "{}", "A".repeat(1000)).unwrap();
writeln!(r1, "+").unwrap();
writeln!(r1, "{}", "I".repeat(1000)).unwrap();
}
r1.flush().unwrap();
let mut r2 = NamedTempFile::new().unwrap();
for i in 0..3 {
writeln!(r2, "@read{}", i).unwrap();
writeln!(r2, "{}", "T".repeat(1000)).unwrap();
writeln!(r2, "+").unwrap();
writeln!(r2, "{}", "I".repeat(1000)).unwrap();
}
r2.flush().unwrap();
let profile =
ReadMemoryProfile::from_files(r1.path(), Some(r2.path()), 10, 64, 50, false, None)
.unwrap();
assert_eq!(profile.avg_read_length, 1000);
assert_eq!(profile.avg_query_length, 2000);
let profile =
ReadMemoryProfile::from_files(r1.path(), Some(r2.path()), 10, 64, 50, false, Some(100))
.unwrap();
assert_eq!(profile.avg_read_length, 100);
assert_eq!(profile.avg_query_length, 200);
}
#[test]
fn test_read_memory_profile_paired_parquet_with_trim_to() {
use arrow::array::{ArrayRef, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::NamedTempFile;
let file = NamedTempFile::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("read_id", DataType::Utf8, false),
Field::new("sequence1", DataType::Utf8, false),
Field::new("sequence2", DataType::Utf8, true),
]));
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(StringArray::from(vec!["read0", "read1", "read2"])) as ArrayRef,
Arc::new(StringArray::from(vec![
"A".repeat(1000),
"A".repeat(1000),
"A".repeat(1000),
])) as ArrayRef,
Arc::new(StringArray::from(vec![
"T".repeat(1000),
"T".repeat(1000),
"T".repeat(1000),
])) as ArrayRef,
],
)
.unwrap();
let writer_file = std::fs::File::create(file.path()).unwrap();
let mut writer = ArrowWriter::try_new(writer_file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let profile =
ReadMemoryProfile::from_files(file.path(), None, 10, 64, 50, true, None).unwrap();
assert_eq!(profile.avg_read_length, 1000);
assert_eq!(profile.avg_query_length, 2000);
let profile =
ReadMemoryProfile::from_files(file.path(), None, 10, 64, 50, true, Some(100)).unwrap();
assert_eq!(profile.avg_read_length, 100);
assert_eq!(profile.avg_query_length, 200);
}
#[test]
fn test_batch_config_accounts_for_prefetch_buffer() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config = MemoryConfig {
max_memory: 1024 * 1024 * 1024, num_threads: 4,
index_memory: 100 * 1024 * 1024,
shard_reservation: 0,
read_profile: profile.clone(),
num_buckets: 100,
input_format: InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: false,
},
is_log_ratio: false,
};
let batch_config = calculate_batch_config(&config);
let prefetch_overhead = config.prefetch_buffer_slots()
* batch_config.batch_size
* config.buffer_bytes_per_row();
assert!(
batch_config.peak_memory >= prefetch_overhead,
"Peak memory {} should include prefetch overhead {}",
batch_config.peak_memory,
prefetch_overhead
);
}
#[test]
fn test_fastx_vs_parquet_uses_different_prefetch_slots() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config_fastx = MemoryConfig {
max_memory: 1024 * 1024 * 1024,
num_threads: 4,
index_memory: 100 * 1024 * 1024,
shard_reservation: 0,
read_profile: profile.clone(),
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let config_parquet = MemoryConfig {
max_memory: 1024 * 1024 * 1024,
num_threads: 4,
index_memory: 100 * 1024 * 1024,
shard_reservation: 0,
read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: false,
},
is_log_ratio: false,
};
assert_eq!(
config_fastx.prefetch_buffer_slots(),
FASTX_PREFETCH_BUFFER_SLOTS
);
assert_eq!(
config_parquet.prefetch_buffer_slots(),
PARQUET_PREFETCH_BUFFER_SLOTS
);
assert_eq!(config_fastx.prefetch_buffer_slots(), 2);
assert_eq!(config_parquet.prefetch_buffer_slots(), 4);
let batch_fastx = calculate_batch_config(&config_fastx);
let batch_parquet = calculate_batch_config(&config_parquet);
assert!(
batch_fastx.batch_size >= batch_parquet.batch_size,
"FASTX batch {} should be >= Parquet batch {} (fewer prefetch slots)",
batch_fastx.batch_size,
batch_parquet.batch_size
);
}
#[test]
fn test_owned_fastx_record_vs_arrow_bytes_estimation() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let owned_bytes = profile.estimate_owned_record_bytes(false);
let arrow_bytes = profile.estimate_arrow_bytes_per_row(false);
assert!(
owned_bytes > 100 && owned_bytes < 500,
"OwnedFastxRecord bytes {} should be reasonable",
owned_bytes
);
assert!(
arrow_bytes > 100 && arrow_bytes < 500,
"Arrow bytes {} should be reasonable",
arrow_bytes
);
let owned_paired = profile.estimate_owned_record_bytes(true);
assert!(
owned_paired > owned_bytes,
"Paired OwnedFastxRecord {} should be > single {}",
owned_paired,
owned_bytes
);
}
#[test]
fn test_read_length_affects_buffer_bytes() {
let profile_short = ReadMemoryProfile::new(150, false, 64, 50);
let profile_long = ReadMemoryProfile::new(10000, false, 64, 50);
let short_owned = profile_short.estimate_owned_record_bytes(false);
let long_owned = profile_long.estimate_owned_record_bytes(false);
assert!(long_owned > short_owned);
let short_arrow = profile_short.estimate_arrow_bytes_per_row(false);
let long_arrow = profile_long.estimate_arrow_bytes_per_row(false);
assert!(long_arrow > short_arrow);
}
#[test]
fn test_total_memory_with_io_buffers_within_budget() {
let profile = ReadMemoryProfile::new(5000, true, 64, 50); let config = MemoryConfig {
max_memory: 8 * 1024 * 1024 * 1024, num_threads: 8,
index_memory: 500 * 1024 * 1024,
shard_reservation: 100 * 1024 * 1024,
read_profile: profile,
num_buckets: 1000,
input_format: InputFormat::Parquet {
is_paired: true,
trimmed_in_reader: false,
},
is_log_ratio: false,
};
let batch_config = calculate_batch_config(&config);
assert!(
batch_config.peak_memory <= config.max_memory,
"Peak {} exceeds max {}",
batch_config.peak_memory,
config.max_memory
);
}
#[test]
fn test_estimate_arrow_bytes_per_row() {
let profile_short = ReadMemoryProfile::new(150, false, 64, 50);
let profile_long = ReadMemoryProfile::new(10000, false, 64, 50);
let bytes_short = profile_short.estimate_arrow_bytes_per_row(false);
let bytes_long = profile_long.estimate_arrow_bytes_per_row(false);
assert!(
bytes_short > 40, "Short read arrow bytes {} should be > 40 (fixed overhead)",
bytes_short
);
assert!(
bytes_long > bytes_short,
"Long read arrow bytes {} should be > short read bytes {}",
bytes_long,
bytes_short
);
let expected_diff = 10000 - 150;
let actual_diff = bytes_long - bytes_short;
assert!(
actual_diff >= expected_diff - 100 && actual_diff <= expected_diff + 100,
"Arrow bytes difference {} should be close to sequence length difference {}",
actual_diff,
expected_diff
);
}
#[test]
fn test_binary_search_validates_result() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let config = MemoryConfig {
max_memory: 500 * 1024 * 1024, num_threads: 4,
index_memory: 100 * 1024 * 1024,
shard_reservation: 50 * 1024 * 1024,
read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let batch_config = calculate_batch_config(&config);
let total =
estimate_total_batch_memory(batch_config.batch_size, batch_config.batch_count, &config);
let safety_margin = (config.max_memory as f64 * 0.10).round() as usize;
let safety_margin = safety_margin.max(256 * 1024 * 1024);
let base_reserved = config.index_memory + config.shard_reservation + safety_margin;
let available = config.max_memory.saturating_sub(base_reserved);
assert!(
total.is_some_and(|t| t <= available),
"Binary search result should fit in available memory"
);
}
#[test]
fn test_memory_config_validation() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let valid = MemoryConfig::new(
1024 * 1024 * 1024,
4,
100 * 1024 * 1024,
0,
profile.clone(),
100,
InputFormat::Fastx { is_paired: false },
false,
);
assert!(valid.is_ok());
let invalid = MemoryConfig::new(
0,
4,
100 * 1024 * 1024,
0,
profile.clone(),
100,
InputFormat::Fastx { is_paired: false },
false,
);
assert!(invalid.is_err());
let invalid = MemoryConfig::new(
1024 * 1024 * 1024,
0,
100 * 1024 * 1024,
0,
profile.clone(),
100,
InputFormat::Fastx { is_paired: false },
false,
);
assert!(invalid.is_err());
let invalid = MemoryConfig::new(
1024 * 1024 * 1024,
4,
100 * 1024 * 1024,
0,
profile,
0,
InputFormat::Fastx { is_paired: false },
false,
);
assert!(invalid.is_err());
}
#[test]
fn test_estimate_shard_reservation_zero_entries() {
assert_eq!(estimate_shard_reservation(0, 8), 0);
}
#[test]
fn test_estimate_shard_reservation_realistic() {
let reservation = estimate_shard_reservation(62_443_845, 8);
let expected_mb = 170;
let actual_mb = reservation / (1024 * 1024);
assert!(
actual_mb >= expected_mb - 15 && actual_mb <= expected_mb + 15,
"Shard reservation should be ~{}MB for 62M entries, got {}MB",
expected_mb,
actual_mb
);
}
#[test]
fn test_estimate_shard_reservation_overflow_safety() {
let reservation = estimate_shard_reservation(u64::MAX, 64);
assert!(
reservation > 0,
"Should return a positive value for large inputs"
);
}
#[test]
fn test_estimate_shard_reservation_scales_with_entries() {
let small = estimate_shard_reservation(1_000_000, 8);
let large = estimate_shard_reservation(100_000_000, 8);
assert!(
large > small,
"Larger shards should require more reservation: small={}, large={}",
small,
large
);
}
#[test]
fn test_estimate_shard_reservation_scales_with_threads() {
let few = estimate_shard_reservation(62_000_000, 2);
let many = estimate_shard_reservation(62_000_000, 16);
assert!(
many > few,
"More threads should increase decode buffer reservation: 2t={}, 16t={}",
few,
many
);
}
#[test]
fn test_batch_count_is_one_for_sequential_processing() {
let profile = ReadMemoryProfile::new(5000, false, 64, 200);
let config = MemoryConfig {
max_memory: 64 * 1024 * 1024 * 1024, num_threads: 8,
index_memory: 0,
shard_reservation: 0,
read_profile: profile,
num_buckets: 160,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let result = calculate_batch_config(&config);
assert_eq!(
result.batch_count, 1,
"batch_count should be 1 (sequential processing), got {}",
result.batch_count
);
}
#[test]
fn test_batch_size_independent_of_thread_count() {
let profile = ReadMemoryProfile::new(5000, false, 64, 200);
let base = MemoryConfig {
max_memory: 64 * 1024 * 1024 * 1024,
num_threads: 1,
index_memory: 0,
shard_reservation: 0,
read_profile: profile,
num_buckets: 160,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let result_1t = calculate_batch_config(&base);
let result_8t = calculate_batch_config(&MemoryConfig {
num_threads: 8,
..base.clone()
});
let result_16t = calculate_batch_config(&MemoryConfig {
num_threads: 16,
..base.clone()
});
assert_eq!(
result_1t.batch_size, result_8t.batch_size,
"1-thread batch_size ({}) != 8-thread batch_size ({}); \
thread count should not affect batch_size when batches are sequential",
result_1t.batch_size, result_8t.batch_size
);
assert_eq!(
result_1t.batch_size, result_16t.batch_size,
"1-thread batch_size ({}) != 16-thread batch_size ({})",
result_1t.batch_size, result_16t.batch_size
);
}
#[test]
fn test_shard_reservation_reduces_batch_size() {
let profile = ReadMemoryProfile::new(5000, false, 64, 200);
let config_no_reservation = MemoryConfig {
max_memory: 8 * 1024 * 1024 * 1024, num_threads: 8,
index_memory: 0,
shard_reservation: 0,
read_profile: profile.clone(),
num_buckets: 160,
input_format: InputFormat::Fastx { is_paired: false },
is_log_ratio: false,
};
let shard_reservation = 158 * 1024 * 1024;
let config_with_reservation = MemoryConfig {
shard_reservation,
..config_no_reservation.clone()
};
let result_no = calculate_batch_config(&config_no_reservation);
let result_with = calculate_batch_config(&config_with_reservation);
assert!(
result_with.batch_size < result_no.batch_size,
"Shard reservation of {}MB should reduce batch_size: \
without={}, with={}",
shard_reservation / (1024 * 1024),
result_no.batch_size,
result_with.batch_size
);
}
#[test]
fn test_parquet_trimmed_uses_owned_format_for_io_buffer_estimate() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let format_trimmed = InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: true,
};
let format_fastx = InputFormat::Fastx { is_paired: false };
let trimmed_bytes = format_trimmed.estimate_buffer_bytes_per_row(&profile);
let fastx_bytes = format_fastx.estimate_buffer_bytes_per_row(&profile);
assert_eq!(
trimmed_bytes, fastx_bytes,
"Parquet trimmed_in_reader should use owned record estimate ({}), not Arrow estimate (got {})",
fastx_bytes, trimmed_bytes
);
let expected = profile.estimate_owned_record_bytes(false);
assert_eq!(trimmed_bytes, expected);
}
#[test]
fn test_parquet_untrimmed_uses_arrow_estimate() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let format_untrimmed = InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: false,
};
let bytes = format_untrimmed.estimate_buffer_bytes_per_row(&profile);
let expected = profile.estimate_arrow_bytes_per_row(false);
assert_eq!(
bytes, expected,
"Parquet untrimmed should use Arrow estimate ({}), got {}",
expected, bytes
);
}
#[test]
fn test_parquet_trimmed_prefetch_slots_still_four() {
let format_trimmed = InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: true,
};
let format_untrimmed = InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: false,
};
assert_eq!(
format_trimmed.prefetch_slots(),
PARQUET_PREFETCH_BUFFER_SLOTS
);
assert_eq!(
format_untrimmed.prefetch_slots(),
PARQUET_PREFETCH_BUFFER_SLOTS
);
assert_eq!(format_trimmed.prefetch_slots(), 4);
}
#[test]
fn test_parquet_trimmed_no_arrow_builder_overhead() {
let profile = ReadMemoryProfile::new(150, false, 64, 50);
let batch_size = 10_000;
let config_untrimmed = MemoryConfig {
max_memory: 1024 * 1024 * 1024,
num_threads: 4,
index_memory: 0,
shard_reservation: 0,
read_profile: profile.clone(),
num_buckets: 100,
input_format: InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: false,
},
is_log_ratio: false,
};
let io_untrimmed = estimate_io_buffer_memory(batch_size, &config_untrimmed).unwrap();
let config_trimmed = MemoryConfig {
max_memory: 1024 * 1024 * 1024,
num_threads: 4,
index_memory: 0,
shard_reservation: 0,
read_profile: profile,
num_buckets: 100,
input_format: InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: true,
},
is_log_ratio: false,
};
let io_trimmed = estimate_io_buffer_memory(batch_size, &config_trimmed).unwrap();
assert!(
io_untrimmed > io_trimmed,
"Untrimmed IO buffer ({}) should be > trimmed IO buffer ({}) \
due to Arrow builder overhead",
io_untrimmed,
io_trimmed
);
let expected_trimmed = config_trimmed.prefetch_buffer_slots()
* batch_size
* config_trimmed.buffer_bytes_per_row();
assert_eq!(
io_trimmed, expected_trimmed,
"Trimmed IO buffer should equal slots * batch * owned_bytes (no 1.5x overhead)"
);
}
#[test]
fn test_parquet_trimmed_paired_end_uses_paired_owned_estimate() {
let profile = ReadMemoryProfile::new(150, true, 64, 50);
let format_trimmed_paired = InputFormat::Parquet {
is_paired: true,
trimmed_in_reader: true,
};
let format_trimmed_single = InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: true,
};
let paired_bytes = format_trimmed_paired.estimate_buffer_bytes_per_row(&profile);
let single_bytes = format_trimmed_single.estimate_buffer_bytes_per_row(&profile);
assert!(
paired_bytes > single_bytes,
"Paired trimmed Parquet ({}) should use more memory than single ({})",
paired_bytes,
single_bytes
);
let expected_paired = profile.estimate_owned_record_bytes(true);
let expected_single = profile.estimate_owned_record_bytes(false);
assert_eq!(
paired_bytes, expected_paired,
"Paired trimmed Parquet should use paired owned estimate ({}), got {}",
expected_paired, paired_bytes
);
assert_eq!(
single_bytes, expected_single,
"Single trimmed Parquet should use single owned estimate ({}), got {}",
expected_single, single_bytes
);
}
}