use std::path::Path;
use anyhow::Result;
use rype::memory::{
calculate_batch_config, detect_available_memory, estimate_shard_reservation, format_bytes,
InputFormat, MemoryConfig, MemorySource, ReadMemoryProfile,
};
use super::load_index_metadata;
pub struct BatchSizeConfig<'a> {
pub batch_size_override: Option<usize>,
pub max_memory: usize,
pub r1_path: &'a Path,
pub r2_path: Option<&'a Path>,
pub is_parquet_input: bool,
pub index_path: &'a Path,
pub trim_to: Option<usize>,
pub minimum_length: Option<usize>,
pub is_log_ratio: bool,
pub denominator_index_path: Option<&'a Path>,
}
pub struct BatchSizeResult {
pub batch_size: usize,
pub peak_memory: usize,
pub input_format: InputFormat,
pub shard_reservation: usize,
}
fn determine_input_format(config: &BatchSizeConfig, is_paired: bool) -> InputFormat {
if config.is_parquet_input {
let trimmed_in_reader = config.trim_to.is_some() || config.minimum_length.is_some();
InputFormat::Parquet {
is_paired,
trimmed_in_reader,
}
} else {
InputFormat::Fastx { is_paired }
}
}
pub fn compute_effective_batch_size(config: &BatchSizeConfig) -> Result<BatchSizeResult> {
let is_paired_hint = config.r2_path.is_some();
if let Some(bs) = config.batch_size_override {
log::info!("Using user-specified batch size: {}", bs);
let input_format = determine_input_format(config, is_paired_hint);
return Ok(BatchSizeResult {
batch_size: bs,
peak_memory: 0, input_format,
shard_reservation: 0, });
}
let metadata = load_index_metadata(config.index_path)?;
let largest_shard_entries = if let Some(denom_path) = config.denominator_index_path {
let denom_meta = load_index_metadata(denom_path)?;
metadata
.largest_shard_entries
.max(denom_meta.largest_shard_entries)
} else {
metadata.largest_shard_entries
};
let mem_limit = if config.max_memory == 0 {
let detected = detect_available_memory();
if detected.source == MemorySource::Fallback {
log::warn!(
"Could not detect available memory, using 8GB fallback. \
Consider specifying --max-memory explicitly."
);
} else {
log::info!(
"Auto-detected available memory: {} (source: {:?})",
format_bytes(detected.bytes),
detected.source
);
}
detected.bytes
} else {
config.max_memory
};
let effective_trim_to = match config.trim_to {
Some(0) => {
log::warn!("--trim-to 0 specified, treating as no trimming");
None
}
Some(n) => {
log::info!("Read trimming enabled: --trim-to {}", n);
Some(n)
}
None => None,
};
let read_profile = ReadMemoryProfile::from_files(
config.r1_path,
config.r2_path,
1000, metadata.k,
metadata.w,
config.is_parquet_input,
effective_trim_to,
)
.unwrap_or_else(|| {
log::warn!("Could not sample read lengths, using default profile");
ReadMemoryProfile::default_profile(is_paired_hint, metadata.k, metadata.w)
});
let is_paired = read_profile.is_paired;
log::debug!(
"Read profile: avg_read_length={}, avg_query_length={}, minimizers_per_query={}, is_paired={}",
read_profile.avg_read_length,
read_profile.avg_query_length,
read_profile.minimizers_per_query,
is_paired
);
let estimated_index_mem = metadata.bucket_minimizer_counts.values().sum::<usize>() * 8;
let num_buckets = metadata.bucket_names.len();
let input_format = determine_input_format(config, is_paired);
let shard_reservation =
estimate_shard_reservation(largest_shard_entries, rayon::current_num_threads());
let mem_config = MemoryConfig {
max_memory: mem_limit,
num_threads: rayon::current_num_threads(),
index_memory: estimated_index_mem,
shard_reservation,
read_profile,
num_buckets,
input_format,
is_log_ratio: config.is_log_ratio,
};
let batch_config = calculate_batch_config(&mem_config);
Ok(BatchSizeResult {
batch_size: batch_config.batch_size,
peak_memory: batch_config.peak_memory,
input_format,
shard_reservation,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::{NamedTempFile, TempDir};
fn create_test_index() -> TempDir {
use std::fs;
let dir = TempDir::new().unwrap();
let index_path = dir.path().join("test.ryxdi");
fs::create_dir(&index_path).unwrap();
let manifest = r#"magic = "RYPE_PARQUET_V1"
format_version = 1
k = 64
w = 50
salt = "0x5555555555555555"
source_hash = "0xDEADBEEF"
num_buckets = 2
total_minimizers = 1000
[inverted]
num_shards = 1
total_entries = 3
has_overlapping_shards = false
[[inverted.shards]]
shard_id = 0
min_minimizer = "0x0000000000000001"
max_minimizer = "0x0000000000000003"
num_entries = 3
"#;
fs::write(index_path.join("manifest.toml"), manifest).unwrap();
use arrow::array::{
ArrayRef, LargeListBuilder, LargeStringArray, LargeStringBuilder, UInt32Array,
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("bucket_id", DataType::UInt32, false),
Field::new("bucket_name", DataType::LargeUtf8, false),
Field::new(
"sources",
DataType::LargeList(Arc::new(Field::new("item", DataType::LargeUtf8, true))),
false,
),
]));
let mut list_builder = LargeListBuilder::new(LargeStringBuilder::new());
list_builder.values().append_value("source0");
list_builder.append(true);
list_builder.values().append_value("source1");
list_builder.append(true);
let sources_array: ArrayRef = Arc::new(list_builder.finish());
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0, 1])) as ArrayRef,
Arc::new(LargeStringArray::from(vec!["bucket0", "bucket1"])) as ArrayRef,
sources_array,
],
)
.unwrap();
let file = fs::File::create(index_path.join("buckets.parquet")).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let inverted_path = index_path.join("inverted");
fs::create_dir(&inverted_path).unwrap();
let shard_schema = Arc::new(Schema::new(vec![
Field::new("minimizer", DataType::UInt64, false),
Field::new("bucket_id", DataType::UInt32, false),
]));
let shard_batch = RecordBatch::try_new(
shard_schema.clone(),
vec![
Arc::new(arrow::array::UInt64Array::from(vec![1u64, 2, 3])) as ArrayRef,
Arc::new(UInt32Array::from(vec![0, 0, 1])) as ArrayRef,
],
)
.unwrap();
let shard_file = fs::File::create(inverted_path.join("shard.0.parquet")).unwrap();
let mut shard_writer = ArrowWriter::try_new(shard_file, shard_schema, None).unwrap();
shard_writer.write(&shard_batch).unwrap();
shard_writer.close().unwrap();
dir
}
fn create_test_fastq(read_length: usize, num_reads: usize) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
for i in 0..num_reads {
writeln!(file, "@read{}", i).unwrap();
writeln!(file, "{}", "A".repeat(read_length)).unwrap();
writeln!(file, "+").unwrap();
writeln!(file, "{}", "I".repeat(read_length)).unwrap();
}
file.flush().unwrap();
file
}
#[test]
fn test_user_specified_batch_size_used_directly() {
let index_dir = create_test_index();
let index_path = index_dir.path().join("test.ryxdi");
let r1_file = create_test_fastq(150, 10);
let config = BatchSizeConfig {
batch_size_override: Some(5000),
max_memory: 0,
r1_path: r1_file.path(),
r2_path: None,
is_parquet_input: false,
index_path: &index_path,
trim_to: None,
minimum_length: None,
is_log_ratio: false,
denominator_index_path: None,
};
let result = compute_effective_batch_size(&config).unwrap();
assert_eq!(result.batch_size, 5000);
}
#[test]
fn test_auto_batch_size_returns_reasonable_value() {
let index_dir = create_test_index();
let index_path = index_dir.path().join("test.ryxdi");
let r1_file = create_test_fastq(150, 10);
let config = BatchSizeConfig {
batch_size_override: None,
max_memory: 4 * 1024 * 1024 * 1024, r1_path: r1_file.path(),
r2_path: None,
is_parquet_input: false,
index_path: &index_path,
trim_to: None,
minimum_length: None,
is_log_ratio: false,
denominator_index_path: None,
};
let result = compute_effective_batch_size(&config).unwrap();
assert!(
result.batch_size >= 1000,
"Batch size {} should be >= 1000",
result.batch_size
);
assert!(result.peak_memory > 0);
}
#[test]
fn test_paired_end_detected_correctly() {
let index_dir = create_test_index();
let index_path = index_dir.path().join("test.ryxdi");
let r1_file = create_test_fastq(150, 10);
let r2_file = create_test_fastq(150, 10);
let config = BatchSizeConfig {
batch_size_override: None,
max_memory: 4 * 1024 * 1024 * 1024,
r1_path: r1_file.path(),
r2_path: Some(r2_file.path()),
is_parquet_input: false,
index_path: &index_path,
trim_to: None,
minimum_length: None,
is_log_ratio: false,
denominator_index_path: None,
};
let result = compute_effective_batch_size(&config).unwrap();
assert!(matches!(
result.input_format,
InputFormat::Fastx { is_paired: true }
));
}
#[test]
fn test_parquet_format_detected() {
let index_dir = create_test_index();
let index_path = index_dir.path().join("test.ryxdi");
let r1_file = create_test_fastq(150, 10);
let config = BatchSizeConfig {
batch_size_override: None,
max_memory: 4 * 1024 * 1024 * 1024,
r1_path: r1_file.path(),
r2_path: None,
is_parquet_input: true,
index_path: &index_path,
trim_to: None,
minimum_length: None,
is_log_ratio: false,
denominator_index_path: None,
};
let result = compute_effective_batch_size(&config).unwrap();
assert!(matches!(
result.input_format,
InputFormat::Parquet {
is_paired: false,
trimmed_in_reader: false
}
));
}
#[test]
#[ignore]
fn test_real_sharded_index_batch_size_is_reasonable() {
let index_path = std::path::Path::new("perf-assessment/parquet-index/n100-w200.ryxdi");
let query_path = std::path::Path::new("perf-assessment/query-files/long_read.parquet");
if !index_path.exists() || !query_path.exists() {
eprintln!("Skipping: perf-assessment data not available");
return;
}
let config = BatchSizeConfig {
batch_size_override: None,
max_memory: 64 * 1024 * 1024 * 1024, r1_path: query_path,
r2_path: None,
is_parquet_input: true,
index_path,
trim_to: None,
minimum_length: None,
is_log_ratio: false,
denominator_index_path: None,
};
let result = compute_effective_batch_size(&config).unwrap();
eprintln!(
"Real index: batch_size={}, peak_memory={:.2}GB, shard_reservation={:.2}MB",
result.batch_size,
result.peak_memory as f64 / (1024.0 * 1024.0 * 1024.0),
result.shard_reservation as f64 / (1024.0 * 1024.0)
);
assert!(
result.batch_size > 1_000_000,
"batch_size should be > 1M for 64GB with sequential batches, got {}",
result.batch_size
);
}
#[test]
#[ignore]
fn test_real_index_shard_reservation_affects_batch_size() {
let sharded_index = std::path::Path::new("perf-assessment/parquet-index/n100-w200.ryxdi");
if !sharded_index.exists() {
eprintln!("Skipping: perf-assessment data not available");
return;
}
let metadata = load_index_metadata(sharded_index).unwrap();
eprintln!("largest_shard_entries: {}", metadata.largest_shard_entries);
assert!(
metadata.largest_shard_entries > 50_000_000,
"8-shard index should have largest shard > 50M entries, got {}",
metadata.largest_shard_entries
);
let reservation = estimate_shard_reservation(metadata.largest_shard_entries, 8);
let reservation_mb = reservation / (1024 * 1024);
eprintln!("shard_reservation: {}MB", reservation_mb);
assert!(
reservation_mb > 100,
"Shard reservation should be > 100MB for 62M-entry shard, got {}MB",
reservation_mb
);
}
#[test]
fn test_shard_info_flows_to_batch_sizing() {
let index_dir = create_test_index();
let index_path = index_dir.path().join("test.ryxdi");
let metadata = load_index_metadata(&index_path).unwrap();
assert_eq!(
metadata.largest_shard_entries, 3,
"load_index_metadata should populate largest_shard_entries from manifest"
);
}
fn create_test_index_with_entries(num_entries: usize) -> TempDir {
use std::fs;
let dir = TempDir::new().unwrap();
let index_path = dir.path().join("test.ryxdi");
fs::create_dir(&index_path).unwrap();
let manifest = format!(
r#"magic = "RYPE_PARQUET_V1"
format_version = 1
k = 64
w = 50
salt = "0x5555555555555555"
source_hash = "0xDEADBEEF"
num_buckets = 1
total_minimizers = {num_entries}
[inverted]
num_shards = 1
total_entries = {num_entries}
has_overlapping_shards = false
[[inverted.shards]]
shard_id = 0
min_minimizer = "0x0000000000000001"
max_minimizer = "0xFFFFFFFFFFFFFFFF"
num_entries = {num_entries}
"#
);
fs::write(index_path.join("manifest.toml"), manifest).unwrap();
use arrow::array::{
ArrayRef, LargeListBuilder, LargeStringArray, LargeStringBuilder, UInt32Array,
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
let schema = Arc::new(Schema::new(vec![
Field::new("bucket_id", DataType::UInt32, false),
Field::new("bucket_name", DataType::LargeUtf8, false),
Field::new(
"sources",
DataType::LargeList(Arc::new(Field::new("item", DataType::LargeUtf8, true))),
false,
),
]));
let mut list_builder = LargeListBuilder::new(LargeStringBuilder::new());
list_builder.values().append_value("source0");
list_builder.append(true);
let sources_array: ArrayRef = Arc::new(list_builder.finish());
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(UInt32Array::from(vec![0])) as ArrayRef,
Arc::new(LargeStringArray::from(vec!["bucket0"])) as ArrayRef,
sources_array,
],
)
.unwrap();
let file = fs::File::create(index_path.join("buckets.parquet")).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
let inverted_path = index_path.join("inverted");
fs::create_dir(&inverted_path).unwrap();
let shard_schema = Arc::new(Schema::new(vec![
Field::new("minimizer", DataType::UInt64, false),
Field::new("bucket_id", DataType::UInt32, false),
]));
let shard_batch = RecordBatch::try_new(
shard_schema.clone(),
vec![
Arc::new(arrow::array::UInt64Array::from(vec![1u64])) as ArrayRef,
Arc::new(UInt32Array::from(vec![0])) as ArrayRef,
],
)
.unwrap();
let shard_file = fs::File::create(inverted_path.join("shard.0.parquet")).unwrap();
let mut shard_writer = ArrowWriter::try_new(shard_file, shard_schema, None).unwrap();
shard_writer.write(&shard_batch).unwrap();
shard_writer.close().unwrap();
dir
}
#[test]
fn test_log_ratio_uses_larger_denominator_shard() {
let num_dir = create_test_index_with_entries(3); let denom_dir = create_test_index_with_entries(1_000_000); let num_path = num_dir.path().join("test.ryxdi");
let denom_path = denom_dir.path().join("test.ryxdi");
let r1_file = create_test_fastq(150, 10);
let config = BatchSizeConfig {
batch_size_override: None,
max_memory: 4 * 1024 * 1024 * 1024, r1_path: r1_file.path(),
r2_path: None,
is_parquet_input: false,
index_path: &num_path,
trim_to: None,
minimum_length: None,
is_log_ratio: true,
denominator_index_path: Some(&denom_path),
};
let result = compute_effective_batch_size(&config).unwrap();
let threads = rayon::current_num_threads();
let expected_reservation = estimate_shard_reservation(1_000_000, threads);
assert!(
result.shard_reservation >= expected_reservation,
"shard_reservation {} should be >= {} (from denominator's 1M-entry shard)",
result.shard_reservation,
expected_reservation
);
}
#[test]
fn test_log_ratio_batch_size_shrinks_for_large_denominator() {
let num_dir = create_test_index_with_entries(3);
let denom_dir = create_test_index_with_entries(1_000_000);
let num_path = num_dir.path().join("test.ryxdi");
let denom_path = denom_dir.path().join("test.ryxdi");
let r1_file = create_test_fastq(150, 10);
let config_no_denom = BatchSizeConfig {
batch_size_override: None,
max_memory: 4 * 1024 * 1024 * 1024,
r1_path: r1_file.path(),
r2_path: None,
is_parquet_input: false,
index_path: &num_path,
trim_to: None,
minimum_length: None,
is_log_ratio: true,
denominator_index_path: None,
};
let config_with_denom = BatchSizeConfig {
batch_size_override: None,
max_memory: 4 * 1024 * 1024 * 1024,
r1_path: r1_file.path(),
r2_path: None,
is_parquet_input: false,
index_path: &num_path,
trim_to: None,
minimum_length: None,
is_log_ratio: true,
denominator_index_path: Some(&denom_path),
};
let result_no_denom = compute_effective_batch_size(&config_no_denom).unwrap();
let result_with_denom = compute_effective_batch_size(&config_with_denom).unwrap();
assert!(
result_with_denom.batch_size < result_no_denom.batch_size,
"Batch size with large denom {} should be < without denom {} \
(more memory reserved for shard loading)",
result_with_denom.batch_size,
result_no_denom.batch_size
);
}
}