1use std::io::{BufRead, Write};
2
3use rayon::prelude::*;
4use rsomics_common::{Result, RsomicsError};
5
6pub mod dm;
7pub mod interaction;
8mod rng;
9
10pub use dm::DistanceMatrix;
11pub use interaction::Interaction;
12
13pub struct HommolaResult {
14 pub corr_coeff: f64,
15 pub p_value: f64,
16 pub num_hosts: usize,
17 pub num_pars: usize,
18 pub num_interactions: usize,
19 pub num_pairs: usize,
20 pub permutations: usize,
21 pub perm_stats: Vec<f64>,
22}
23
24pub fn hommola(
33 host: &DistanceMatrix,
34 par: &DistanceMatrix,
35 inter: &Interaction,
36 permutations: usize,
37 seed: u64,
38) -> HommolaResult {
39 let num_hosts = host.n;
40 let num_pars = par.n;
41 let edges = inter.edges();
42 let pairs = edges * (edges - 1) / 2;
43
44 let mut ei = Vec::with_capacity(pairs);
46 let mut ej = Vec::with_capacity(pairs);
47 for i in 0..edges {
48 for j in (i + 1)..edges {
49 ei.push(i);
50 ej.push(j);
51 }
52 }
53
54 let identity_h: Vec<usize> = (0..num_hosts).collect();
55 let identity_p: Vec<usize> = (0..num_pars).collect();
56 let corr_coeff = pearson_over_pairs(host, par, inter, &ei, &ej, &identity_h, &identity_p);
57
58 let (p_value, perm_stats) = if permutations == 0 || corr_coeff.is_nan() {
59 (f64::NAN, vec![f64::NAN; permutations])
60 } else {
61 let stats: Vec<f64> = (0..permutations)
62 .into_par_iter()
63 .map(|k| {
64 let (mh, mp) = rng::host_par_perms(num_hosts, num_pars, seed, k as u64);
65 pearson_over_pairs(host, par, inter, &ei, &ej, &mh, &mp)
66 })
67 .collect();
68 let extreme = stats.iter().filter(|&&s| s >= corr_coeff).count();
69 let p = (extreme + 1) as f64 / (permutations + 1) as f64;
70 (p, stats)
71 };
72
73 HommolaResult {
74 corr_coeff,
75 p_value,
76 num_hosts,
77 num_pars,
78 num_interactions: edges,
79 num_pairs: pairs,
80 permutations,
81 perm_stats,
82 }
83}
84
85fn pearson_over_pairs(
88 host: &DistanceMatrix,
89 par: &DistanceMatrix,
90 inter: &Interaction,
91 ei: &[usize],
92 ej: &[usize],
93 mh: &[usize],
94 mp: &[usize],
95) -> f64 {
96 let nh = host.n;
97 let np = par.n;
98 let hi = &inter.host_idx;
99 let pi = &inter.par_idx;
100
101 let mut sx = 0.0;
102 let mut sy = 0.0;
103 let mut sxx = 0.0;
104 let mut syy = 0.0;
105 let mut sxy = 0.0;
106 let count = ei.len() as f64;
107
108 for (&a, &b) in ei.iter().zip(ej) {
109 let hx = host.data[mh[hi[a]] * nh + mh[hi[b]]];
110 let py = par.data[mp[pi[a]] * np + mp[pi[b]]];
111 sx += hx;
112 sy += py;
113 sxx += hx * hx;
114 syy += py * py;
115 sxy += hx * py;
116 }
117
118 let cov = sxy - sx * sy / count;
119 let vx = sxx - sx * sx / count;
120 let vy = syy - sy * sy / count;
121 let denom = (vx * vy).sqrt();
122 if denom == 0.0 {
123 f64::NAN
124 } else {
125 (cov / denom).clamp(-1.0, 1.0)
126 }
127}
128
129pub fn read_matrix<R: BufRead>(reader: R, source: &str) -> Result<DistanceMatrix> {
130 DistanceMatrix::read(reader, source)
131}
132
133pub fn read_interaction<R: BufRead>(reader: R, source: &str) -> Result<Interaction> {
134 Interaction::read(reader, source)
135}
136
137pub fn write_result<W: Write>(out: &mut W, res: &HommolaResult) -> Result<()> {
138 writeln!(
139 out,
140 "statistic\tp_value\tnum_hosts\tnum_parasites\tnum_interactions\tnum_pairs\tpermutations"
141 )
142 .map_err(RsomicsError::Io)?;
143 writeln!(
144 out,
145 "{:.12}\t{}\t{}\t{}\t{}\t{}\t{}",
146 res.corr_coeff,
147 fmt_p(res.p_value),
148 res.num_hosts,
149 res.num_pars,
150 res.num_interactions,
151 res.num_pairs,
152 res.permutations,
153 )
154 .map_err(RsomicsError::Io)?;
155 Ok(())
156}
157
158fn fmt_p(p: f64) -> String {
159 if p.is_nan() {
160 "nan".to_string()
161 } else {
162 format!("{p:.12}")
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn dm(rows: &[&[f64]]) -> DistanceMatrix {
171 let n = rows.len();
172 let mut data = vec![0.0; n * n];
173 for (i, r) in rows.iter().enumerate() {
174 for (j, &v) in r.iter().enumerate() {
175 data[i * n + j] = v;
176 }
177 }
178 DistanceMatrix {
179 ids: (0..n).map(|i| i.to_string()).collect(),
180 data,
181 n,
182 }
183 }
184
185 fn inter(edges: &[(usize, usize)], num_hosts: usize, num_pars: usize) -> Interaction {
187 Interaction {
188 host_ids: (0..num_hosts).map(|i| i.to_string()).collect(),
189 par_ids: (0..num_pars).map(|i| i.to_string()).collect(),
190 par_idx: edges.iter().map(|&(p, _)| p).collect(),
191 host_idx: edges.iter().map(|&(_, h)| h).collect(),
192 }
193 }
194
195 #[test]
197 fn skbio_doc_example() {
198 let hdist = dm(&[
199 &[0.0, 3.0, 8.0, 8.0, 9.0],
200 &[3.0, 0.0, 7.0, 7.0, 8.0],
201 &[8.0, 7.0, 0.0, 6.0, 7.0],
202 &[8.0, 7.0, 6.0, 0.0, 3.0],
203 &[9.0, 8.0, 7.0, 3.0, 0.0],
204 ]);
205 let pdist = dm(&[
206 &[0.0, 5.0, 8.0, 8.0, 8.0],
207 &[5.0, 0.0, 7.0, 7.0, 7.0],
208 &[8.0, 7.0, 0.0, 4.0, 4.0],
209 &[8.0, 7.0, 4.0, 0.0, 2.0],
210 &[8.0, 7.0, 4.0, 2.0, 0.0],
211 ]);
212 let edges = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4)];
215 let it = inter(&edges, 5, 5);
216 let res = hommola(&hdist, &pdist, &it, 0, 42);
217 assert!(
218 (res.corr_coeff - 0.832).abs() < 5e-4,
219 "corr={}",
220 res.corr_coeff
221 );
222 assert_eq!(res.num_interactions, 6);
223 assert_eq!(res.num_pairs, 15);
224 assert!(res.p_value.is_nan());
225 }
226
227 #[test]
228 fn strong_signal_is_significant() {
229 let hdist = dm(&[
230 &[0.0, 3.0, 8.0, 8.0, 9.0],
231 &[3.0, 0.0, 7.0, 7.0, 8.0],
232 &[8.0, 7.0, 0.0, 6.0, 7.0],
233 &[8.0, 7.0, 6.0, 0.0, 3.0],
234 &[9.0, 8.0, 7.0, 3.0, 0.0],
235 ]);
236 let pdist = dm(&[
237 &[0.0, 5.0, 8.0, 8.0, 8.0],
238 &[5.0, 0.0, 7.0, 7.0, 7.0],
239 &[8.0, 7.0, 0.0, 4.0, 4.0],
240 &[8.0, 7.0, 4.0, 0.0, 2.0],
241 &[8.0, 7.0, 4.0, 2.0, 0.0],
242 ]);
243 let edges = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (4, 4)];
244 let it = inter(&edges, 5, 5);
245 let res = hommola(&hdist, &pdist, &it, 999, 42);
246 assert!(res.p_value <= 0.05, "p={}", res.p_value);
247 assert_eq!(res.perm_stats.len(), 999);
248 }
249}