use std::collections::BTreeMap;
use ontolius::{TermId, TermIdParseError};
#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClusteringWorkflowResult {
affinity_matrix: Vec<f64>,
cluster_labels: BTreeMap<u16, Vec<u16>>,
sort_order: Option<Vec<usize>>,
gap_values: Option<GapValues>,
split_check: Option<SplitCheck>,
term_associations: BTreeMap<u16, Vec<TermAssociation>>,
metadata: Option<ClusteringWorkflowMetadata>,
}
impl TryFrom<crate::io::workflow::ClusteringWorkflowResult> for ClusteringWorkflowResult {
type Error = String;
fn try_from(value: crate::io::workflow::ClusteringWorkflowResult) -> Result<Self, Self::Error> {
let cluster_labels: BTreeMap<u16, Vec<u16>> = value
.cluster_labels
.into_iter()
.map(|cl| {
(
u16::try_from(cl.k).unwrap(),
cl.labels.into_iter().map(convert_cluster_id).collect(),
)
})
.collect();
let sort_order = if value.sort_order.is_empty() {
None
} else {
Some(
value
.sort_order
.into_iter()
.map(|v| usize::try_from(v).unwrap())
.collect(),
)
};
let gap_values = if let Some(gv) = value.gap_values {
Some(GapValues::try_from(gv)?)
} else {
None
};
let split_check = value.split_check.map(SplitCheck::from);
let mut term_associations = BTreeMap::new();
for ele in value.term_associations {
let mut tas = Vec::with_capacity(ele.associations.len());
for ass in ele.associations {
tas.push(TermAssociation::try_from(ass)?);
}
term_associations.insert(convert_cluster_id(ele.k), tas);
}
Ok(ClusteringWorkflowResult {
affinity_matrix: value.affinity_matrix,
cluster_labels,
sort_order,
gap_values,
split_check,
term_associations,
metadata: value.metadata.map(ClusteringWorkflowMetadata::from),
})
}
}
#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GapValues {
log_wk_data: BTreeMap<u16, f64>,
log_wk_rand: BTreeMap<u16, Vec<f64>>,
}
impl TryFrom<crate::io::workflow::GapValues> for GapValues {
type Error = String;
fn try_from(value: crate::io::workflow::GapValues) -> Result<Self, Self::Error> {
let log_wk_data = value
.log_wk_data
.into_iter()
.map(|(k, v)| {
(
u16::try_from(k).expect("cluster id should never exceed u16 range"),
v,
)
})
.collect();
let log_wk_rand = value
.log_wk_rand
.into_iter()
.map(|v| (convert_cluster_id(v.k), v.values))
.collect();
Ok(GapValues {
log_wk_data,
log_wk_rand,
})
}
}
#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SplitCheck {
should_split: bool,
split_proba: f64,
}
impl From<crate::io::workflow::SplitCheck> for SplitCheck {
fn from(value: crate::io::workflow::SplitCheck) -> Self {
SplitCheck {
should_split: value.should_split,
split_proba: value.split_proba,
}
}
}
#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TermAssociation {
#[serde(
serialize_with = "TermId::serialize_as_curie",
deserialize_with = "TermId::deserialize_from_curie"
)]
term_id: TermId,
counts: BTreeMap<u16, u64>,
pval: f64,
corrected_pval: Option<f64>,
}
impl TryFrom<crate::io::workflow::TermAssociation> for TermAssociation {
type Error = String;
fn try_from(value: crate::io::workflow::TermAssociation) -> Result<Self, Self::Error> {
let term_id: TermId = value
.term_id
.parse::<TermId>()
.map_err(|e: TermIdParseError| format!("{}: {}", value.term_id, e))?;
let counts: BTreeMap<u16, u64> = value
.counts
.into_iter()
.map(|(k, v)| (convert_cluster_id(k), u64::from(v)))
.collect();
Ok(TermAssociation {
term_id,
counts,
pval: value.pval,
corrected_pval: value.corrected_pval,
})
}
}
#[derive(Debug, PartialEq, Clone, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClusteringWorkflowMetadata {
stratiphy_version: String,
hpo_version: Option<String>,
}
impl From<crate::io::workflow::ClusteringWorkflowMetadata> for ClusteringWorkflowMetadata {
fn from(value: crate::io::workflow::ClusteringWorkflowMetadata) -> Self {
ClusteringWorkflowMetadata {
stratiphy_version: value.stratiphy_version,
hpo_version: Some(value.hpo_version),
}
}
}
fn convert_cluster_id(val: u32) -> u16 {
u16::try_from(val).expect("cluster id should never exceed the range of u16")
}