use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkMotifData {
pub num_nodes: usize,
pub node_labels: Vec<String>,
pub adjacency_matrices: Vec<AdjacencyMatrix>,
pub delay_values: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdjacencyMatrix {
pub index: usize,
pub delay: f64,
pub matrix: Vec<f64>,
pub edges: Vec<NetworkEdge>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkEdge {
pub from: usize,
pub to: usize,
pub weight: f64,
}
pub fn transform_cd_to_network_motifs(
q_matrix: &[Vec<f64>],
channel_pairs: &[[usize; 2]],
channel_names: &[String],
delays: &[f64],
threshold: Option<f64>,
) -> Result<NetworkMotifData, String> {
let threshold = threshold.unwrap_or(0.25);
let num_timepoints = if q_matrix.is_empty() {
0
} else {
q_matrix[0].len()
};
if num_timepoints == 0 {
return Err("Q-matrix has no timepoints".to_string());
}
let mut unique_nodes: Vec<usize> = channel_pairs
.iter()
.flat_map(|pair| vec![pair[0], pair[1]])
.collect();
unique_nodes.sort();
unique_nodes.dedup();
let num_nodes = unique_nodes.len();
if num_nodes == 0 {
return Err("No nodes found in channel pairs".to_string());
}
let node_labels: Vec<String> = unique_nodes
.iter()
.map(|&idx| {
channel_names
.get(idx)
.cloned()
.unwrap_or_else(|| format!("Ch{}", idx + 1))
})
.collect();
let node_index_map: std::collections::HashMap<usize, usize> = unique_nodes
.iter()
.enumerate()
.map(|(i, &orig)| (orig, i))
.collect();
let indices = if num_timepoints >= 3 {
vec![
num_timepoints / 4,
num_timepoints / 2,
(num_timepoints * 3) / 4,
]
} else if num_timepoints == 2 {
vec![0, 1, 1]
} else {
vec![0, 0, 0]
};
let delay_values: Vec<f64> = indices
.iter()
.map(|&i| delays.get(i).copied().unwrap_or(i as f64))
.collect();
let mut adjacency_matrices = Vec::with_capacity(3);
for (matrix_idx, &time_idx) in indices.iter().enumerate() {
let mut values: Vec<f64> = channel_pairs
.iter()
.enumerate()
.map(|(pair_idx, _)| {
q_matrix
.get(pair_idx)
.and_then(|row| row.get(time_idx))
.copied()
.unwrap_or(f64::NAN)
})
.collect();
let valid_values: Vec<f64> = values.iter().filter(|v| !v.is_nan()).copied().collect();
if !valid_values.is_empty() {
let min_val = valid_values.iter().copied().fold(f64::INFINITY, f64::min);
let max_val = valid_values
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
let range = max_val - min_val;
if range > 1e-10 {
for v in &mut values {
if !v.is_nan() {
*v = (*v - min_val) / range;
}
}
} else {
for v in &mut values {
if !v.is_nan() {
*v = 1.0;
}
}
}
}
let mut matrix = vec![0.0; num_nodes * num_nodes];
let mut edges = Vec::new();
for (pair_idx, pair) in channel_pairs.iter().enumerate() {
let from = *node_index_map.get(&pair[0]).unwrap_or(&0);
let to = *node_index_map.get(&pair[1]).unwrap_or(&0);
let mut weight = values.get(pair_idx).copied().unwrap_or(0.0);
if weight < threshold || weight.is_nan() {
weight = 0.0;
}
matrix[from * num_nodes + to] = weight;
if weight > 0.0 {
edges.push(NetworkEdge { from, to, weight });
}
}
adjacency_matrices.push(AdjacencyMatrix {
index: matrix_idx,
delay: delay_values[matrix_idx],
matrix,
edges,
});
}
Ok(NetworkMotifData {
num_nodes,
node_labels,
adjacency_matrices,
delay_values,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transform_cd_to_network_motifs() {
let channel_pairs = vec![[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]];
let q_matrix = vec![
vec![0.1, 0.2, 0.3, 0.4, 0.5],
vec![0.2, 0.3, 0.4, 0.5, 0.6],
vec![0.3, 0.4, 0.5, 0.6, 0.7],
vec![0.4, 0.5, 0.6, 0.7, 0.8],
vec![0.5, 0.6, 0.7, 0.8, 0.9],
vec![0.6, 0.7, 0.8, 0.9, 1.0],
];
let channel_names = vec!["A".to_string(), "B".to_string(), "C".to_string()];
let scales = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let result = transform_cd_to_network_motifs(
&q_matrix,
&channel_pairs,
&channel_names,
&scales,
Some(0.25),
)
.unwrap();
assert_eq!(result.num_nodes, 3);
assert_eq!(result.node_labels, vec!["A", "B", "C"]);
assert_eq!(result.adjacency_matrices.len(), 3);
for matrix in &result.adjacency_matrices {
assert_eq!(matrix.matrix.len(), 9);
}
}
#[test]
fn test_normalization_and_threshold() {
let channel_pairs = vec![[0, 1], [1, 0]];
let q_matrix = vec![vec![0.0, 0.5, 1.0], vec![0.0, 0.25, 0.5]];
let channel_names = vec!["X".to_string(), "Y".to_string()];
let scales = vec![0.1, 0.2, 0.3];
let result = transform_cd_to_network_motifs(
&q_matrix,
&channel_pairs,
&channel_names,
&scales,
Some(0.25),
)
.unwrap();
let middle_matrix = &result.adjacency_matrices[1];
assert!(middle_matrix.edges.len() >= 1);
}
}