seqtkrs 0.1.1

A Rust reimplementation of seqtk, a fast and lightweight tool for processing biological sequences in FASTA/FASTQ format
Documentation
//! 蓄水池采样算法实现
//!
//! 提供两种采样策略:
//! - `ReservoirSampler`: 标准蓄水池采样,在内存中保存选中的序列
//! - `TwoPassSampler`: 两遍采样,只保存索引,适合大文件但需要可seek的输入

use crate::core::seq_record::SeqRecord;
use crate::utils::random::RandomGenerator;

/// 蓄水池采样器(标准模式)
///
/// 使用蓄水池采样算法从序列流中随机选择固定数量的序列。
/// 对于前k个序列直接加入蓄水池,之后的每个序列以k/n的概率替换蓄水池中的随机序列。
///
/// # 示例
///
/// ```
/// use seqtkrs::algorithms::sampler::ReservoirSampler;
/// use seqtkrs::core::SeqRecord;
///
/// let mut sampler = ReservoirSampler::new(10, Some(42));  // 采样10条序列,种子42
///
/// // 添加序列
/// for i in 0..100 {
///     let record = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
///     sampler.add(record);
/// }
///
/// // 获取采样结果
/// let samples = sampler.into_samples();
/// assert_eq!(samples.len(), 10);
/// ```
pub struct ReservoirSampler {
    /// 采样数量
    capacity: usize,
    /// 已处理的序列数
    count: usize,
    /// 随机数生成器
    rng: RandomGenerator,
    /// 蓄水池
    reservoir: Vec<SeqRecord>,
}

impl ReservoirSampler {
    /// 创建新的蓄水池采样器
    ///
    /// # 参数
    ///
    /// * `capacity` - 采样数量
    /// * `seed` - 随机种子(None表示使用系统时间)
    pub fn new(capacity: usize, seed: Option<u64>) -> Self {
        let rng = match seed {
            Some(s) => RandomGenerator::with_seed(s),
            None => RandomGenerator::new(),
        };

        Self {
            capacity,
            count: 0,
            rng,
            reservoir: Vec::with_capacity(capacity),
        }
    }

    /// 添加一条序列到采样器
    ///
    /// # 算法
    ///
    /// 1. 如果蓄水池未满(count < capacity),直接添加
    /// 2. 如果蓄水池已满,以概率 capacity/count 随机替换蓄水池中的一个元素
    pub fn add(&mut self, record: SeqRecord) {
        self.count += 1;

        if self.reservoir.len() < self.capacity {
            // 蓄水池未满,直接添加
            self.reservoir.push(record);
        } else {
            // 蓄水池已满,以概率 capacity/count 替换
            let j = self.rng.random_range(self.count as u64) as usize;
            if j < self.capacity {
                self.reservoir[j] = record;
            }
        }
    }

    /// 获取采样结果
    ///
    /// 消费采样器并返回蓄水池中的所有序列
    pub fn into_samples(self) -> Vec<SeqRecord> {
        self.reservoir
    }

    /// 获取已处理的序列总数
    pub fn count(&self) -> usize {
        self.count
    }
}

/// 两遍蓄水池采样器(只保存索引)
///
/// 适用于需要保持序列顺序或内存受限的情况。
/// 第一遍遍历所有序列,只记录被选中的索引。
/// 第二遍根据索引输出对应的序列。
///
/// 注意:要求输入文件可随机访问(不支持stdin)。
///
/// # 示例
///
/// ```
/// use seqtkrs::algorithms::sampler::TwoPassSampler;
///
/// // 第一遍:收集索引
/// let mut sampler = TwoPassSampler::new(10, Some(42));
/// for _ in 0..100 {
///     sampler.add_index();
/// }
///
/// // 获取选中的索引(已排序)
/// let selected = sampler.get_selected_indices();
/// assert_eq!(selected.len(), 10);
/// ```
pub struct TwoPassSampler {
    /// 采样数量
    capacity: usize,
    /// 已处理的序列数
    count: usize,
    /// 随机数生成器
    rng: RandomGenerator,
    /// 选中的索引列表
    selected_indices: Vec<usize>,
}

impl TwoPassSampler {
    /// 创建新的两遍采样器
    ///
    /// # 参数
    ///
    /// * `capacity` - 采样数量
    /// * `seed` - 随机种子(None表示使用系统时间)
    pub fn new(capacity: usize, seed: Option<u64>) -> Self {
        let rng = match seed {
            Some(s) => RandomGenerator::with_seed(s),
            None => RandomGenerator::new(),
        };

        Self {
            capacity,
            count: 0,
            rng,
            selected_indices: Vec::with_capacity(capacity),
        }
    }

    /// 第一遍:记录当前序列的索引
    ///
    /// 返回当前序列是否被选中(仅用于调试)
    pub fn add_index(&mut self) -> bool {
        let current_index = self.count;
        self.count += 1;

        if self.selected_indices.len() < self.capacity {
            self.selected_indices.push(current_index);
            true
        } else {
            let j = self.rng.random_range(self.count as u64) as usize;
            if j < self.capacity {
                self.selected_indices[j] = current_index;
            }
            false
        }
    }

    /// 获取选中的索引列表(排序后)
    ///
    /// 消费采样器并返回排序后的索引列表,用于第二遍遍历时顺序输出
    pub fn get_selected_indices(mut self) -> Vec<usize> {
        self.selected_indices.sort_unstable();
        self.selected_indices
    }

    /// 获取已处理的序列总数
    pub fn count(&self) -> usize {
        self.count
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_reservoir_sampler_basic() {
        let mut sampler = ReservoirSampler::new(5, Some(42));

        // 添加10条序列
        for i in 0..10 {
            let record = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
            sampler.add(record);
        }

        assert_eq!(sampler.count(), 10);

        let samples = sampler.into_samples();
        assert_eq!(samples.len(), 5); // 应该只保留5条
    }

    #[test]
    fn test_reservoir_sampler_less_than_capacity() {
        let mut sampler = ReservoirSampler::new(10, Some(42));

        // 只添加5条序列
        for i in 0..5 {
            let record = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
            sampler.add(record);
        }

        let samples = sampler.into_samples();
        assert_eq!(samples.len(), 5); // 应该保留全部5条
    }

    #[test]
    fn test_reservoir_sampler_deterministic() {
        // 使用相同的种子应该得到相同的结果
        let mut sampler1 = ReservoirSampler::new(3, Some(12345));
        let mut sampler2 = ReservoirSampler::new(3, Some(12345));

        for i in 0..10 {
            let record1 = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
            let record2 = SeqRecord::new(format!("seq{}", i).into_bytes(), b"ACGT".to_vec());
            sampler1.add(record1);
            sampler2.add(record2);
        }

        let samples1 = sampler1.into_samples();
        let samples2 = sampler2.into_samples();

        assert_eq!(samples1.len(), samples2.len());
        for (s1, s2) in samples1.iter().zip(samples2.iter()) {
            assert_eq!(s1.name, s2.name);
        }
    }

    #[test]
    fn test_two_pass_sampler_basic() {
        let mut sampler = TwoPassSampler::new(5, Some(42));

        // 第一遍:添加10个索引
        for _ in 0..10 {
            sampler.add_index();
        }

        assert_eq!(sampler.count(), 10);

        let selected = sampler.get_selected_indices();
        assert_eq!(selected.len(), 5); // 应该选中5个

        // 索引应该已排序
        for i in 1..selected.len() {
            assert!(selected[i] > selected[i - 1]);
        }

        // 所有索引应该在有效范围内
        for &idx in &selected {
            assert!(idx < 10);
        }
    }

    #[test]
    fn test_two_pass_sampler_less_than_capacity() {
        let mut sampler = TwoPassSampler::new(10, Some(42));

        // 只添加5个索引
        for _ in 0..5 {
            sampler.add_index();
        }

        let selected = sampler.get_selected_indices();
        assert_eq!(selected.len(), 5); // 应该全部选中
        assert_eq!(selected, vec![0, 1, 2, 3, 4]); // 应该是0-4
    }

    #[test]
    fn test_two_pass_sampler_deterministic() {
        // 使用相同的种子应该得到相同的索引
        let mut sampler1 = TwoPassSampler::new(3, Some(12345));
        let mut sampler2 = TwoPassSampler::new(3, Some(12345));

        for _ in 0..10 {
            sampler1.add_index();
            sampler2.add_index();
        }

        let selected1 = sampler1.get_selected_indices();
        let selected2 = sampler2.get_selected_indices();

        assert_eq!(selected1, selected2);
    }
}