use serde::{Deserialize, Serialize};
use crate::collections::{Feature, FeatureList, FeatureSetProvider};
#[derive(Debug, PartialEq)]
pub struct RankedFeatureListItem<'a> {
feature: &'a Feature,
rank: u32,
index: usize,
}
impl RankedFeatureListItem<'_> {
pub fn feature(&self) -> &Feature {
self.feature
}
pub fn rank(&self) -> u32 {
self.rank
}
pub fn index(&self) -> usize {
self.index
}
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum ThresholdState {
Unthresholded,
Thresholded,
}
pub enum RemoveIndices {
Single(usize),
Multiple(Vec<usize>),
}
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub struct RankedFeatureList {
genes: FeatureList,
ranks: Vec<u32>,
thresholds: Vec<u32>,
threshold_state: ThresholdState,
}
impl Default for RankedFeatureList {
fn default() -> Self {
Self::new()
}
}
impl RankedFeatureList {
pub fn new() -> Self {
Self {
genes: FeatureList::new(),
ranks: Vec::new(),
thresholds: Vec::new(),
threshold_state: ThresholdState::Unthresholded,
}
}
pub fn from(genes: FeatureList, ranks: Vec<u32>) -> Result<Self, String> {
Self::check_lengths(&genes, &ranks)?;
let mut ranked_list = Self {
genes,
ranks,
thresholds: Vec::new(),
threshold_state: ThresholdState::Unthresholded,
};
ranked_list.sort_genes_and_ranks();
ranked_list.generate_thresholds();
ranked_list.threshold_state = ThresholdState::Thresholded;
Ok(ranked_list)
}
pub fn genes(&self) -> &FeatureList {
&self.genes
}
pub fn ranks(&self) -> &[u32] {
&self.ranks
}
pub fn thresholds(&self) -> &[u32] {
&self.thresholds
}
pub fn get(&self, index: usize) -> Option<RankedFeatureListItem> {
if index >= self.genes.len() {
return None;
}
Some(RankedFeatureListItem {
feature: &self.genes[index],
rank: self.ranks[index],
index,
})
}
pub fn len(&self) -> usize {
self.genes.len()
}
pub fn is_empty(&self) -> bool {
self.genes.is_empty()
}
pub fn generate_thresholds(&mut self) {
let mut thresholds = Vec::new();
let mut current = 1;
let max_rank = *self.ranks.last().unwrap_or(&0);
while current <= max_rank {
thresholds.push(current);
current = ((current as f64) * 1.01 + 1.0).floor() as u32;
}
if let Some(last) = self.thresholds.last_mut() {
*last = max_rank;
}
self.thresholds = thresholds;
}
pub fn remove<T>(&mut self, indices: T) -> Result<(), String>
where
T: Into<RemoveIndices>,
{
match indices.into() {
RemoveIndices::Single(index) => self.remove_index(index),
RemoveIndices::Multiple(mut indices) => {
indices.sort_unstable_by(|a, b| b.cmp(a)); for index in indices {
self.remove_index(index)?; }
Ok(())
}
}
}
pub fn filter_and_remove<F>(&mut self, filter_fn: F)
where
F: Fn(&RankedFeatureListItem) -> bool,
{
let indices_to_remove: Vec<usize> = self
.iter()
.enumerate()
.filter_map(|(idx, item)| if filter_fn(&item) { Some(idx) } else { None })
.collect();
for idx in indices_to_remove.into_iter().rev() {
self.remove(idx).unwrap();
}
}
fn remove_index(&mut self, index: usize) -> Result<(), String> {
if index >= self.genes.len() {
return Err(format!("Index {} is out of bounds", index));
}
self.genes
.remove(index)
.map_err(|_| "Feature removal failed".to_string())?;
self.ranks.remove(index);
if self.threshold_state == ThresholdState::Thresholded {
self.threshold_state = ThresholdState::Unthresholded;
}
Ok(())
}
fn check_lengths(genes: &FeatureList, ranks: &[u32]) -> Result<(), String> {
if genes.len() != ranks.len() {
return Err(format!(
"Feature list length ({}) does not match ranks length ({})",
genes.len(),
ranks.len()
));
}
Ok(())
}
fn sort_genes_and_ranks(&mut self) {
let mut combined: Vec<(u32, &Feature)> = self
.ranks
.iter()
.zip(self.genes.genes())
.map(|(rank, feature)| (*rank, feature))
.collect();
combined.sort_by_key(|(rank, _)| *rank);
self.ranks = combined.iter().map(|(rank, _)| *rank).collect();
self.genes = FeatureList::from(
combined
.into_iter()
.map(|(_, feature)| feature.clone())
.collect::<Vec<_>>(),
);
}
}
impl FeatureSetProvider for RankedFeatureList {
fn get_feature_set_by_threshold(&self, threshold: u32) -> Vec<Feature> {
self.ranks
.iter()
.enumerate()
.filter(|(_, &rank)| rank <= threshold)
.map(|(idx, _)| self.genes[idx].clone())
.collect()
}
}
impl From<usize> for RemoveIndices {
fn from(index: usize) -> Self {
RemoveIndices::Single(index)
}
}
impl From<&[usize]> for RemoveIndices {
fn from(indices: &[usize]) -> Self {
RemoveIndices::Multiple(indices.to_vec())
}
}
impl From<Vec<usize>> for RemoveIndices {
fn from(indices: Vec<usize>) -> Self {
RemoveIndices::Multiple(indices)
}
}
pub struct RankedFeatureListIterator<'a> {
ranked_feature_list: &'a RankedFeatureList,
index: usize,
}
impl RankedFeatureList {
pub fn iter(&self) -> RankedFeatureListIterator {
RankedFeatureListIterator {
ranked_feature_list: self,
index: 0,
}
}
}
impl<'a> Iterator for RankedFeatureListIterator<'a> {
type Item = RankedFeatureListItem<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.ranked_feature_list.genes().len() {
let item = RankedFeatureListItem {
feature: &self.ranked_feature_list.genes()[self.index],
rank: self.ranked_feature_list.ranks()[self.index],
index: self.index,
};
self.index += 1;
Some(item)
} else {
None
}
}
}
impl<'a> IntoIterator for &'a RankedFeatureList {
type Item = RankedFeatureListItem<'a>;
type IntoIter = RankedFeatureListIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
RankedFeatureListIterator {
ranked_feature_list: self,
index: 0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ranked_feature_list_iterator() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
]);
let ranks = vec![3, 2, 1];
let ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
let mut iter = ranked_list.into_iter();
let item = iter.next().unwrap();
assert_eq!(item.feature().id(), "gene3");
assert_eq!(item.rank(), 1);
let item = iter.next().unwrap();
assert_eq!(item.feature().id(), "gene2");
assert_eq!(item.rank(), 2);
let item = iter.next().unwrap();
assert_eq!(item.feature().id(), "gene1");
assert_eq!(item.rank(), 3);
assert_eq!(iter.next(), None); }
#[test]
fn test_ranked_feature_list_iterator_empty() {
let ranked_list = RankedFeatureList::new();
let mut iter = ranked_list.iter();
assert_eq!(iter.next(), None);
}
#[test]
fn test_remove_single_index() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
]);
let ranks = vec![3, 2, 1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
assert!(ranked_list.remove(1).is_ok()); assert_eq!(ranked_list.genes().len(), 2);
assert_eq!(ranked_list.genes()[0].id(), "gene3");
assert_eq!(ranked_list.genes()[1].id(), "gene1");
assert_eq!(ranked_list.ranks(), &[1, 3]);
}
#[test]
fn test_remove_multiple_indices() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
]);
let ranks = vec![3, 2, 1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
let _ = ranked_list.remove(vec![0, 2]); assert_eq!(ranked_list.genes().len(), 1);
assert_eq!(ranked_list.genes()[0].id(), "gene2");
assert_eq!(ranked_list.ranks(), &[2]);
}
#[test]
fn test_remove_invalid_index() {
let genes = FeatureList::from(vec![Feature::from("gene1")]);
let ranks = vec![1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
let res = ranked_list.remove(5);
match res {
Err(e) => assert_eq!(e, "Index 5 is out of bounds"),
_ => panic!("Expected an error"),
}
}
#[test]
fn test_remove_all_indices() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
]);
let ranks = vec![3, 2, 1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
let _ = ranked_list.remove(vec![0, 1, 2]); assert!(ranked_list.genes().is_empty());
assert!(ranked_list.ranks().is_empty());
}
#[test]
fn test_remove_indices_with_duplicate_ranks() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
]);
let ranks = vec![2, 2, 1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
let _ = ranked_list.remove(vec![0, 2]); assert_eq!(ranked_list.genes().len(), 1);
assert_eq!(ranked_list.genes()[0].id(), "gene1");
assert_eq!(ranked_list.ranks(), &[2]);
}
#[test]
fn test_iterator_after_removals() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
]);
let ranks = vec![3, 2, 1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
let _ = ranked_list.remove(vec![1]);
let mut iter = ranked_list.into_iter();
let item = iter.next().unwrap();
assert_eq!(item.feature().id(), "gene3");
assert_eq!(item.rank(), 1);
let item = iter.next().unwrap();
assert_eq!(item.feature().id(), "gene1");
assert_eq!(item.rank(), 3);
assert_eq!(iter.next(), None);
}
#[test]
fn test_filter_and_remove() {
let genes = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
Feature::from("gene4"),
]);
let ranks = vec![4, 3, 2, 1];
let mut ranked_list = RankedFeatureList::from(genes, ranks).unwrap();
ranked_list.filter_and_remove(|item| item.rank() > 2);
assert_eq!(ranked_list.len(), 2);
let remaining_items: Vec<_> = ranked_list.iter().collect();
assert_eq!(remaining_items[0].feature().id(), "gene4");
assert_eq!(remaining_items[0].rank(), 1);
assert_eq!(remaining_items[1].feature().id(), "gene3");
assert_eq!(remaining_items[1].rank(), 2);
}
#[test]
fn test_filter_by_another_ranked_feature_list() {
let genes1 = FeatureList::from(vec![
Feature::from("gene1"),
Feature::from("gene2"),
Feature::from("gene3"),
Feature::from("gene4"),
]);
let ranks1 = vec![4, 3, 2, 1];
let mut ranked_list1 = RankedFeatureList::from(genes1, ranks1).unwrap();
let genes2 = FeatureList::from(vec![Feature::from("gene2"), Feature::from("gene4")]);
let ranks2 = vec![1, 2];
let ranked_list2 = RankedFeatureList::from(genes2, ranks2).unwrap();
ranked_list1.filter_and_remove(|item| {
!ranked_list2
.iter()
.any(|other_item| other_item.feature().id() == item.feature().id())
});
assert_eq!(ranked_list1.len(), 2);
let remaining_items: Vec<_> = ranked_list1.iter().collect();
assert_eq!(remaining_items[0].feature().id(), "gene4");
assert_eq!(remaining_items[1].feature().id(), "gene2");
assert_eq!(remaining_items[0].rank(), 1);
assert_eq!(remaining_items[1].rank(), 3);
}
}