1use anyhow::anyhow;
2use anyhow::Result;
3use clap::ValueEnum;
4use derive_new::new;
5use distmat::SquareMatrix;
6use noodles::{bgzf, fasta};
7use polars::{lazy::dsl::col, prelude::*};
8use std::fmt;
9use std::fs::File;
10use std::io::BufReader;
11use std::io::ErrorKind;
12use std::ops::Mul;
13use std::rc::Rc;
14use textdistance::{
15 nstr::{lcsseq, lcsstr},
16 str::{damerau_levenshtein, jaro_winkler, levenshtein, ratcliff_obershelp, smith_waterman},
17 str::{entropy_ncd, hamming, jaccard},
18};
19
20#[derive(ValueEnum, Debug, Clone, PartialEq)]
21pub enum DistanceMethods {
22 Hamming,
24
25 Levenshtein,
27
28 DamerauLevenshtein,
30
31 JaroWinkler,
33
34 SmithWaterman,
36
37 RatcliffObershelp,
39
40 LCSSeq,
42
43 LCSStr,
45
46 Jaccard,
48
49 Entropy,
51}
52
53impl fmt::Display for DistanceMethods {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 write!(
56 f,
57 "{}",
58 match self {
59 DistanceMethods::Hamming => "hamming",
60 DistanceMethods::Levenshtein => "levenshtein",
61 DistanceMethods::DamerauLevenshtein => "damerau-levenshtein",
62 DistanceMethods::JaroWinkler => "jaro-winkler",
63 DistanceMethods::SmithWaterman => "smith-waterman",
64 DistanceMethods::RatcliffObershelp => "ratcliff-obershelp",
65 DistanceMethods::LCSSeq => "lcs-seq",
66 DistanceMethods::LCSStr => "lcs-str",
67 DistanceMethods::Jaccard => "jaccard",
68 DistanceMethods::Entropy => "entropy",
69 }
70 )
71 }
72}
73
74#[derive(ValueEnum, Debug, Clone, PartialEq)]
75pub enum Stringency {
76 Lenient,
77 Intermediate,
78 Strict,
79 Extreme,
80}
81
82impl fmt::Display for Stringency {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 write!(
85 f,
86 "{}",
87 match self {
88 Stringency::Lenient => "lenient",
89 Stringency::Intermediate => "intermediate",
90 Stringency::Strict => "strict",
91 Stringency::Extreme => "extreme",
92 }
93 )
94 }
95}
96
97trait DistanceCalculator {
98 fn calculate_distance(&self, s1: &str, s2: &str) -> f64;
99}
100
101impl DistanceCalculator for DistanceMethods {
102 fn calculate_distance(&self, s1: &str, s2: &str) -> f64 {
103 match self {
104 DistanceMethods::Hamming => hamming(s1, s2) as f64,
105 DistanceMethods::Levenshtein => levenshtein(s1, s2) as f64,
106 DistanceMethods::DamerauLevenshtein => damerau_levenshtein(s1, s2) as f64,
107 DistanceMethods::JaroWinkler => jaro_winkler(s1, s2),
108 DistanceMethods::SmithWaterman => smith_waterman(s1, s2) as f64,
109 DistanceMethods::RatcliffObershelp => ratcliff_obershelp(s1, s2),
110 DistanceMethods::LCSSeq => lcsseq(s1, s2),
111 DistanceMethods::LCSStr => lcsstr(s1, s2),
112 DistanceMethods::Jaccard => jaccard(s1, s2),
113 DistanceMethods::Entropy => entropy_ncd(s1, s2),
114 }
115 }
116}
117
118fn collect_fa_data(fasta: &str) -> Result<(Vec<String>, Vec<Rc<str>>)> {
119 let parsed_fasta: std::io::Result<Vec<(String, Rc<str>)>> = if fasta.ends_with(".gz") {
122 File::open(fasta)
123 .map(bgzf::Reader::new)
124 .map(fasta::Reader::new)?
125 .records()
126 .map(|result| {
127 result.and_then(|record| {
128 let id = record.name().to_owned();
129 unpack_sequence(&record).map(|sequence_string| (id, Rc::from(sequence_string)))
130 })
131 })
132 .collect()
133 } else {
134 File::open(fasta)
135 .map(BufReader::new)
136 .map(fasta::Reader::new)?
137 .records()
138 .map(|result| {
139 result.and_then(|record| {
140 let id = record.name().to_owned();
141 unpack_sequence(&record).map(|sequence_string| (id, Rc::from(sequence_string)))
142 })
143 })
144 .collect()
145 };
146
147 let (ids, sequences) = match parsed_fasta {
148 Ok(pairs) => pairs.into_iter().unzip(),
149 Err(e) => return Err(e.into()),
150 };
151
152 Ok((ids, sequences))
153}
154
155#[derive(new, Debug, Clone)]
157struct ClusterColumns {
158 type_col: Rc<str>,
159 index_col: Rc<str>,
160 id_col: Rc<str>,
161 size_col: Rc<str>,
162}
163
164fn get_cluster_cols(cluster_table: &LazyFrame) -> Result<ClusterColumns> {
165 let cluster_query = cluster_table.clone().collect()?;
167 let col_names = cluster_query.get_column_names();
168
169 let type_col: Rc<str> = match col_names.first() {
170 Some(col_name) => Rc::from(col_name.to_string()),
171 None => {
172 eprintln!(
173 "Please double check that the column of VSEARCH cluster types is the first column."
174 );
175 return Err(anyhow!(
176 "Member types could not be parsed from provided cluster table,"
177 ));
178 }
179 };
180
181 let index_col: Rc<str> = match col_names.get(1) {
182 Some(col_name) => Rc::from(col_name.to_string()),
183 None => {
184 eprintln!(
185 "Please double check that the column of VSEARCH cluster index is the second column."
186 );
187 return Err(anyhow!(
188 "Column indices could not be parsed from provided cluster table,"
189 ));
190 }
191 };
192
193 let name_col: Rc<str> = match col_names.get(8) {
194 Some(col_name) => Rc::from(col_name.to_string()),
195 None => {
196 eprintln!("Please double check that the column of sequence names is the ninth column.");
197 return Err(anyhow!(
198 "Sequence names could not be parsed from provided cluster table,"
199 ));
200 }
201 };
202
203 let size_col: Rc<str> = match col_names.get(2) {
204 Some(col_name) => Rc::from(col_name.to_string()),
205 None => {
206 eprintln!(
207 "Please double check that the column of VSEARCH cluster sizes is the third column."
208 );
209 return Err(anyhow!(
210 "Cluster sizes could not be parsed from provided cluster table,"
211 ));
212 }
213 };
214
215 Ok(ClusterColumns::new(type_col, index_col, name_col, size_col))
216}
217
218fn get_size_per_member(
219 cluster_table: &LazyFrame,
220 centroids_only: &LazyFrame,
221 clust_cols: &ClusterColumns,
222) -> Result<(f64, DataFrame)> {
223 let centroid_sizes = centroids_only
225 .clone()
226 .select(&[col(&clust_cols.index_col), col(&clust_cols.size_col)])
227 .collect()?;
228
229 let member_count = cluster_table
231 .clone()
232 .filter(
233 col(&clust_cols.type_col)
234 .eq(lit("H"))
235 .or(col(&clust_cols.type_col).eq(lit("S"))),
236 )
237 .select(&[col(&clust_cols.index_col)])
238 .collect()?
239 .shape()
240 .0;
241
242 let month_total: f64 = if member_count == 0 {
244 1.0
245 } else {
246 member_count as f64
247 };
248
249 Ok((month_total, centroid_sizes))
250}
251
252fn get_cluster_index(
253 cluster_table: &LazyFrame,
254 clust_cols: &ClusterColumns,
255 seq_name: &str,
256) -> Result<i64> {
257 let filtered = cluster_table
258 .clone()
259 .filter(col(&clust_cols.id_col).eq(lit(seq_name)))
260 .select(&[col(&clust_cols.index_col)])
261 .collect()?;
262
263 let index = filtered
264 .column(&clust_cols.index_col)?
265 .get(0)?
266 .try_extract::<i64>()?;
267
268 Ok(index)
269}
270
271fn compute_weighting_freq(
272 centroid_lf: LazyFrame,
273 clust_index: i64,
274 month_total: f64,
275 clust_cols: &ClusterColumns,
276) -> Result<f64> {
277 let collected_df = centroid_lf
278 .clone()
279 .filter(col(&clust_cols.index_col).eq(clust_index))
280 .select([col(&clust_cols.size_col)])
281 .collect()?;
282
283 let attempt = match collected_df
284 .column(&clust_cols.size_col)?
285 .iter()
286 .next() {
287 Some(value) => value,
288 None => return Err(anyhow!("Could not parse centroid data to compute a weight. Please double check the input cluster table."))
289 };
290
291 let cluster_freq = attempt.try_extract::<f64>()? / month_total;
292
293 Ok(cluster_freq)
294}
295
296fn weight_by_cluster_size(
297 seq_name: &str,
298 stringency: &Stringency,
299 cluster_table: &LazyFrame,
300) -> Result<(String, Series)> {
301 let clust_cols = get_cluster_cols(cluster_table)?;
302
303 let centroids_only = cluster_table
305 .clone()
306 .filter(col(&clust_cols.type_col).eq(lit("C")));
307
308 let (month_total, all_size_df) =
310 get_size_per_member(cluster_table, ¢roids_only, &clust_cols)?;
311
312 let index: i64 = get_cluster_index(cluster_table, &clust_cols, seq_name)?;
314
315 let weighting_freq = compute_weighting_freq(centroids_only, index, month_total, &clust_cols)?;
317
318 let weights_header = format!("{}_weights", seq_name);
320 let weights_lf = match *stringency {
321 Stringency::Strict | Stringency::Extreme => all_size_df
322 .lazy()
323 .with_column(lit(-1.0).alias("negative"))
324 .with_column(lit(weighting_freq.ln()).alias("log_freq"))
325 .with_column(lit(month_total).alias("total"))
326 .with_column(
327 col(&clust_cols.size_col) * ((col("negative") * col("log_freq")) / col("total")),
328 )
329 .rename([&clust_cols.size_col], [&weights_header]),
330 _ => all_size_df
331 .lazy()
332 .with_column(lit(1.0).alias("tmp_int"))
333 .with_column(lit(-1.0).alias("negative"))
334 .with_column(lit(weighting_freq).alias("freq"))
335 .with_column(lit(month_total).alias("total"))
336 .with_column(
337 col(&clust_cols.size_col)
338 * ((col("tmp_int") + col("negative") * col("freq")) / col("total")),
339 )
340 .rename([&clust_cols.size_col], [&weights_header]),
341 };
342
343 let weights = weights_lf
344 .select(&[col(&weights_header)])
345 .collect()?
346 .column(&weights_header)?
347 .to_owned();
348
349 Ok((weights_header, weights))
350}
351
352fn unpack_sequence(record: &fasta::Record) -> std::io::Result<String> {
353 let seq_attempt =
354 match record.sequence().get(..) {
355 Some(seq) => seq.to_vec(),
356 None => return Err(std::io::Error::new(
357 ErrorKind::InvalidData,
358 "No sequence was found for the provided record. Double check FASTA completeness.",
359 )),
360 };
361
362 let seq_as_string = String::from_utf8(seq_attempt).unwrap();
363
364 Ok(seq_as_string)
365}
366
367fn process_cluster_info(
368 cluster_table: Option<&str>,
369 dist_col_vec: Vec<Series>,
370 ids: &Vec<String>,
371 stringency: &Stringency,
372) -> Result<DataFrame> {
373 let mut dist_df = DataFrame::new(dist_col_vec)?;
374 dist_df = match cluster_table {
375 Some(table) => {
376 let cluster_df = CsvReader::from_path(table)?
378 .has_header(false)
379 .with_delimiter(b'\t')
380 .finish()?
381 .lazy();
382
383 for id in ids {
385 let (weights_header, weights) =
386 weight_by_cluster_size(id, stringency, &cluster_df)?;
387 dist_df = dist_df
388 .hstack(&[weights])?
389 .lazy()
390 .with_columns(&[col(id).mul(col(&weights_header)).alias(id)])
391 .collect()?
392 .drop(&weights_header)?
393 }
394 let col_series = Series::new("Sequence_Name", &ids);
395 dist_df.hstack(&[col_series])?
396 }
397 None => {
398 let col_series = Series::new("Sequence_Name", &ids);
399 dist_df.hstack(&[col_series])?
400 }
401 };
402
403 Ok(dist_df)
404}
405
406pub fn compute_distance_matrix(
407 fasta: &str,
408 cluster_table: Option<&str>,
409 label: &str,
410 stringency: &Stringency,
411 distance_method: &DistanceMethods,
412) -> Result<()> {
413 let (ids, sequences) = collect_fa_data(fasta)?;
414
415 assert!(
417 ids.len() == sequences.len(),
418 "Unable to identify an ID for each sequence from the FASTA {}.",
419 &fasta
420 );
421
422 let mut pw_distmat = SquareMatrix::from_pw_distances_with(&sequences, |seq1, seq2| {
424 distance_method.calculate_distance(seq1, seq2)
425 });
426 pw_distmat.set_labels(ids.clone());
427
428 let mut dist_col_vec: Vec<Series> = vec![Default::default(); pw_distmat.size()];
430 for (i, (column, label)) in pw_distmat
431 .iter_cols()
432 .zip(pw_distmat.iter_labels())
433 .enumerate()
434 {
435 let series = Series::new(label, column.collect::<Vec<f64>>());
436 dist_col_vec[i] = series;
437 }
438
439 let mut dist_df = process_cluster_info(cluster_table, dist_col_vec, &ids, stringency)?;
440
441 let out_name = format!("{}-dist-matrix.csv", label);
443 let out_handle = File::create(out_name).expect(
444 "File could not be created to write the distance matrix to. Please check file-write permissions."
445 );
446 CsvWriter::new(out_handle)
447 .has_header(true)
448 .finish(&mut dist_df)
449 .expect("Weighted distance matrix could not be written.");
450
451 Ok(())
452}