stratiphy 0.1.3

Phenotype-driven identification of disease subgroups
Documentation
use std::collections::BTreeMap;

use ontolius::{TermId, TermIdParseError};

use crate::{io::generated::stratiphy_workflow::ObservationState, model::Cohort};

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClusteringWorkflowResult {
    pub affinity_matrix: Vec<f64>,
    pub cluster_labels: BTreeMap<u16, Vec<u16>>,
    pub sort_order: Option<Vec<usize>>,
    pub gap_values: Option<GapValues>,
    pub split_check: Option<SplitCheck>,
    pub term_associations: BTreeMap<u16, Vec<TermAssociation>>,
}

impl TryFrom<crate::io::generated::stratiphy_workflow::ClusteringWorkflowResult>
    for ClusteringWorkflowResult
{
    type Error = String;

    fn try_from(
        value: crate::io::generated::stratiphy_workflow::ClusteringWorkflowResult,
    ) -> Result<Self, Self::Error> {
        let cluster_labels: BTreeMap<u16, Vec<u16>> = value
            .cluster_labels
            .into_iter()
            .map(|cl| {
                (
                    u16::try_from(cl.k).unwrap(),
                    cl.labels.into_iter().map(convert_cluster_id).collect(),
                )
            })
            .collect();

        let sort_order = if value.sort_order.is_empty() {
            None
        } else {
            Some(
                value
                    .sort_order
                    .into_iter()
                    .map(|v| usize::try_from(v).unwrap())
                    .collect(),
            )
        };

        let gap_values = if let Some(gv) = value.gap_values {
            Some(GapValues::try_from(gv)?)
        } else {
            None
        };

        let split_check = value.split_check.map(SplitCheck::from);

        let mut term_associations = BTreeMap::new();
        for ele in value.term_associations {
            let mut tas = Vec::with_capacity(ele.associations.len());
            for ass in ele.associations {
                tas.push(TermAssociation::try_from(ass)?);
            }
            term_associations.insert(convert_cluster_id(ele.k), tas);
        }

        Ok(ClusteringWorkflowResult {
            affinity_matrix: value.affinity_matrix,
            cluster_labels,
            sort_order,
            gap_values,
            split_check,
            term_associations,
        })
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GapValues {
    pub log_wk_data: BTreeMap<u16, f64>,
    pub log_wk_rand: BTreeMap<u16, Vec<f64>>,
}

impl TryFrom<crate::io::generated::stratiphy_workflow::GapValues> for GapValues {
    type Error = String;

    fn try_from(
        value: crate::io::generated::stratiphy_workflow::GapValues,
    ) -> Result<Self, Self::Error> {
        let log_wk_data = value
            .log_wk_data
            .into_iter()
            .map(|(k, v)| {
                (
                    u16::try_from(k).expect("cluster id should never exceed u16 range"),
                    v,
                )
            })
            .collect();
        let log_wk_rand = value
            .log_wk_rand
            .into_iter()
            .map(|v| (convert_cluster_id(v.k), v.values))
            .collect();
        Ok(GapValues {
            log_wk_data,
            log_wk_rand,
        })
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SplitCheck {
    pub should_split: bool,
    pub split_proba: f64,
}

impl From<crate::io::generated::stratiphy_workflow::SplitCheck> for SplitCheck {
    fn from(value: crate::io::generated::stratiphy_workflow::SplitCheck) -> Self {
        SplitCheck {
            should_split: value.should_split,
            split_proba: value.split_proba,
        }
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TermCount {
    is_present: bool,
    count: u64,
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TermAssociation {
    #[serde(
        serialize_with = "TermId::serialize_as_curie",
        deserialize_with = "TermId::deserialize_from_curie"
    )]
    pub term_id: TermId,
    pub counts: BTreeMap<u16, Vec<TermCount>>,
    pub pval: f64,
    pub corrected_pval: Option<f64>,
}

impl TryFrom<crate::io::generated::stratiphy_workflow::TermAssociation> for TermAssociation {
    type Error = String;

    fn try_from(
        value: crate::io::generated::stratiphy_workflow::TermAssociation,
    ) -> Result<Self, Self::Error> {
        let term_id: TermId = value
            .term_id
            .parse::<TermId>()
            .map_err(|e: TermIdParseError| format!("{}: {}", value.term_id, e))?;
        let counts: BTreeMap<u16, Vec<TermCount>> = value
            .counts
            .into_iter()
            .map(|tc| (convert_cluster_id(tc.k), convert_term_counts(tc.counts)))
            .collect();
        Ok(TermAssociation {
            term_id,
            counts,
            pval: value.pval,
            corrected_pval: value.corrected_pval,
        })
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClusteringWorkflowMetadata {
    pub stratiphy_version: String,
    pub hpo_version: Option<String>,
}

impl From<crate::io::generated::stratiphy_workflow::ClusteringWorkflowMetadata>
    for ClusteringWorkflowMetadata
{
    fn from(value: crate::io::generated::stratiphy_workflow::ClusteringWorkflowMetadata) -> Self {
        ClusteringWorkflowMetadata {
            stratiphy_version: value.stratiphy_version,
            hpo_version: Some(value.hpo_version),
        }
    }
}

fn convert_cluster_id(val: u32) -> u16 {
    u16::try_from(val).expect("cluster id should never exceed the range of u16")
}

fn convert_term_counts(
    counts: Vec<crate::io::generated::stratiphy_workflow::TermCount>,
) -> Vec<TermCount> {
    counts
        .into_iter()
        .map(|tc| TermCount {
            is_present: convert_observation_state(tc.state()),
            count: tc.count as u64,
        })
        .collect()
}

fn convert_observation_state(state: ObservationState) -> bool {
    match state {
        ObservationState::Present => true,
        ObservationState::Excluded => false,
        ObservationState::Unspecified => panic!("state must be set!"),
    }
}

#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StratiphyResult {
    pub clustering_result: ClusteringWorkflowResult,
    pub cohort: Cohort,
    pub meta_data: ClusteringWorkflowMetadata,
}

impl TryFrom<crate::io::generated::stratiphy_workflow::StratiphyResult> for StratiphyResult {
    type Error = String;

    fn try_from(
        value: crate::io::generated::stratiphy_workflow::StratiphyResult,
    ) -> Result<Self, Self::Error> {
        let clustering_result = if let Some(cwr) = value.clustering_result {
            ClusteringWorkflowResult::try_from(cwr)?
        } else {
            return Err("Missing clustering result".to_string());
        };

        let cohort = if let Some(cohort) = value.cohort {
            Cohort::try_from(cohort)?
        } else {
            return Err("Missing cohort".to_string());
        };

        let meta_data = if let Some(meta_data) = value.meta_data {
            ClusteringWorkflowMetadata::from(meta_data)
        } else {
            return Err("Missing meta_data".to_string());
        };

        Ok(StratiphyResult {
            clustering_result,
            cohort,
            meta_data,
        })
    }
}