stratiphy 0.1.1

Phenotype-driven identification of disease subgroups
Documentation
use std::collections::BTreeMap;

use ontolius::{TermId, TermIdParseError};

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClusteringWorkflowResult {
    affinity_matrix: Vec<f64>,
    cluster_labels: BTreeMap<u16, Vec<u16>>,
    sort_order: Option<Vec<usize>>,
    gap_values: Option<GapValues>,
    split_check: Option<SplitCheck>,
    term_associations: BTreeMap<u16, Vec<TermAssociation>>,
    metadata: Option<ClusteringWorkflowMetadata>,
}

impl TryFrom<crate::io::workflow::ClusteringWorkflowResult> for ClusteringWorkflowResult {
    type Error = String;

    fn try_from(value: crate::io::workflow::ClusteringWorkflowResult) -> Result<Self, Self::Error> {
        let cluster_labels: BTreeMap<u16, Vec<u16>> = value
            .cluster_labels
            .into_iter()
            .map(|cl| {
                (
                    u16::try_from(cl.k).unwrap(),
                    cl.labels.into_iter().map(convert_cluster_id).collect(),
                )
            })
            .collect();

        let sort_order = if value.sort_order.is_empty() {
            None
        } else {
            Some(
                value
                    .sort_order
                    .into_iter()
                    .map(|v| usize::try_from(v).unwrap())
                    .collect(),
            )
        };

        let gap_values = if let Some(gv) = value.gap_values {
            Some(GapValues::try_from(gv)?)
        } else {
            None
        };

        let split_check = value.split_check.map(SplitCheck::from);

        let mut term_associations = BTreeMap::new();
        for ele in value.term_associations {
            let mut tas = Vec::with_capacity(ele.associations.len());
            for ass in ele.associations {
                tas.push(TermAssociation::try_from(ass)?);
            }
            term_associations.insert(convert_cluster_id(ele.k), tas);
        }

        Ok(ClusteringWorkflowResult {
            affinity_matrix: value.affinity_matrix,
            cluster_labels,
            sort_order,
            gap_values,
            split_check,
            term_associations,
            metadata: value.metadata.map(ClusteringWorkflowMetadata::from),
        })
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GapValues {
    log_wk_data: BTreeMap<u16, f64>,
    log_wk_rand: BTreeMap<u16, Vec<f64>>,
}

impl TryFrom<crate::io::workflow::GapValues> for GapValues {
    type Error = String;

    fn try_from(value: crate::io::workflow::GapValues) -> Result<Self, Self::Error> {
        let log_wk_data = value
            .log_wk_data
            .into_iter()
            .map(|(k, v)| {
                (
                    u16::try_from(k).expect("cluster id should never exceed u16 range"),
                    v,
                )
            })
            .collect();
        let log_wk_rand = value
            .log_wk_rand
            .into_iter()
            .map(|v| (convert_cluster_id(v.k), v.values))
            .collect();
        Ok(GapValues {
            log_wk_data,
            log_wk_rand,
        })
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SplitCheck {
    should_split: bool,
    split_proba: f64,
}

impl From<crate::io::workflow::SplitCheck> for SplitCheck {
    fn from(value: crate::io::workflow::SplitCheck) -> Self {
        SplitCheck {
            should_split: value.should_split,
            split_proba: value.split_proba,
        }
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TermAssociation {
    #[serde(
        serialize_with = "TermId::serialize_as_curie",
        deserialize_with = "TermId::deserialize_from_curie"
    )]
    term_id: TermId,
    counts: BTreeMap<u16, u64>,
    pval: f64,
    corrected_pval: Option<f64>,
}

impl TryFrom<crate::io::workflow::TermAssociation> for TermAssociation {
    type Error = String;

    fn try_from(value: crate::io::workflow::TermAssociation) -> Result<Self, Self::Error> {
        let term_id: TermId = value
            .term_id
            .parse::<TermId>()
            .map_err(|e: TermIdParseError| format!("{}: {}", value.term_id, e))?;
        let counts: BTreeMap<u16, u64> = value
            .counts
            .into_iter()
            .map(|(k, v)| (convert_cluster_id(k), u64::from(v)))
            .collect();
        Ok(TermAssociation {
            term_id,
            counts,
            pval: value.pval,
            corrected_pval: value.corrected_pval,
        })
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClusteringWorkflowMetadata {
    stratiphy_version: String,
    hpo_version: Option<String>,
}

impl From<crate::io::workflow::ClusteringWorkflowMetadata> for ClusteringWorkflowMetadata {
    fn from(value: crate::io::workflow::ClusteringWorkflowMetadata) -> Self {
        ClusteringWorkflowMetadata {
            stratiphy_version: value.stratiphy_version,
            hpo_version: Some(value.hpo_version),
        }
    }
}

fn convert_cluster_id(val: u32) -> u16 {
    u16::try_from(val).expect("cluster id should never exceed the range of u16")
}