alpine/lib/
distmat.rs

1use anyhow::anyhow;
2use anyhow::Result;
3use clap::ValueEnum;
4use derive_new::new;
5use distmat::SquareMatrix;
6use noodles::{bgzf, fasta};
7use polars::{lazy::dsl::col, prelude::*};
8use std::fmt;
9use std::fs::File;
10use std::io::BufReader;
11use std::io::ErrorKind;
12use std::ops::Mul;
13use std::rc::Rc;
14use textdistance::{
15    nstr::{lcsseq, lcsstr},
16    str::{damerau_levenshtein, jaro_winkler, levenshtein, ratcliff_obershelp, smith_waterman},
17    str::{entropy_ncd, hamming, jaccard},
18};
19
20#[derive(ValueEnum, Debug, Clone, PartialEq)]
21pub enum DistanceMethods {
22    /// Hamming edit distance
23    Hamming,
24
25    /// Levenshtein edit distance
26    Levenshtein,
27
28    /// Damerau-Levenshtein edit distance
29    DamerauLevenshtein,
30
31    /// Jaro-Winkler edit distance
32    JaroWinkler,
33
34    /// Smith-Waterman edit distance
35    SmithWaterman,
36
37    /// Ratcliff-Obershelp/Gestalt pattern matching sequence-based distance
38    RatcliffObershelp,
39
40    /// Longest Common SubSequence distance
41    LCSSeq,
42
43    /// Longest Common SubString
44    LCSStr,
45
46    /// Jaccard token/kmer-based distance
47    Jaccard,
48
49    /// Entropy normalized compression distance
50    Entropy,
51}
52
53impl fmt::Display for DistanceMethods {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(
56            f,
57            "{}",
58            match self {
59                DistanceMethods::Hamming => "hamming",
60                DistanceMethods::Levenshtein => "levenshtein",
61                DistanceMethods::DamerauLevenshtein => "damerau-levenshtein",
62                DistanceMethods::JaroWinkler => "jaro-winkler",
63                DistanceMethods::SmithWaterman => "smith-waterman",
64                DistanceMethods::RatcliffObershelp => "ratcliff-obershelp",
65                DistanceMethods::LCSSeq => "lcs-seq",
66                DistanceMethods::LCSStr => "lcs-str",
67                DistanceMethods::Jaccard => "jaccard",
68                DistanceMethods::Entropy => "entropy",
69            }
70        )
71    }
72}
73
74#[derive(ValueEnum, Debug, Clone, PartialEq)]
75pub enum Stringency {
76    Lenient,
77    Intermediate,
78    Strict,
79    Extreme,
80}
81
82impl fmt::Display for Stringency {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        write!(
85            f,
86            "{}",
87            match self {
88                Stringency::Lenient => "lenient",
89                Stringency::Intermediate => "intermediate",
90                Stringency::Strict => "strict",
91                Stringency::Extreme => "extreme",
92            }
93        )
94    }
95}
96
97trait DistanceCalculator {
98    fn calculate_distance(&self, s1: &str, s2: &str) -> f64;
99}
100
101impl DistanceCalculator for DistanceMethods {
102    fn calculate_distance(&self, s1: &str, s2: &str) -> f64 {
103        match self {
104            DistanceMethods::Hamming => hamming(s1, s2) as f64,
105            DistanceMethods::Levenshtein => levenshtein(s1, s2) as f64,
106            DistanceMethods::DamerauLevenshtein => damerau_levenshtein(s1, s2) as f64,
107            DistanceMethods::JaroWinkler => jaro_winkler(s1, s2),
108            DistanceMethods::SmithWaterman => smith_waterman(s1, s2) as f64,
109            DistanceMethods::RatcliffObershelp => ratcliff_obershelp(s1, s2),
110            DistanceMethods::LCSSeq => lcsseq(s1, s2),
111            DistanceMethods::LCSStr => lcsstr(s1, s2),
112            DistanceMethods::Jaccard => jaccard(s1, s2),
113            DistanceMethods::Entropy => entropy_ncd(s1, s2),
114        }
115    }
116}
117
118fn collect_fa_data(fasta: &str) -> Result<(Vec<String>, Vec<Rc<str>>)> {
119    // pull out record IDs and sequences into their own string vectors, while
120    // handling potential bgzip compression
121    let parsed_fasta: std::io::Result<Vec<(String, Rc<str>)>> = if fasta.ends_with(".gz") {
122        File::open(fasta)
123            .map(bgzf::Reader::new)
124            .map(fasta::Reader::new)?
125            .records()
126            .map(|result| {
127                result.and_then(|record| {
128                    let id = record.name().to_owned();
129                    unpack_sequence(&record).map(|sequence_string| (id, Rc::from(sequence_string)))
130                })
131            })
132            .collect()
133    } else {
134        File::open(fasta)
135            .map(BufReader::new)
136            .map(fasta::Reader::new)?
137            .records()
138            .map(|result| {
139                result.and_then(|record| {
140                    let id = record.name().to_owned();
141                    unpack_sequence(&record).map(|sequence_string| (id, Rc::from(sequence_string)))
142                })
143            })
144            .collect()
145    };
146
147    let (ids, sequences) = match parsed_fasta {
148        Ok(pairs) => pairs.into_iter().unzip(),
149        Err(e) => return Err(e.into()),
150    };
151
152    Ok((ids, sequences))
153}
154
155/// Cluster columns contains the column names where the information ALPINE needs is stored
156#[derive(new, Debug, Clone)]
157struct ClusterColumns {
158    type_col: Rc<str>,
159    index_col: Rc<str>,
160    id_col: Rc<str>,
161    size_col: Rc<str>,
162}
163
164fn get_cluster_cols(cluster_table: &LazyFrame) -> Result<ClusterColumns> {
165    // separate out the columns of information we need
166    let cluster_query = cluster_table.clone().collect()?;
167    let col_names = cluster_query.get_column_names();
168
169    let type_col: Rc<str> = match col_names.first() {
170        Some(col_name) => Rc::from(col_name.to_string()),
171        None => {
172            eprintln!(
173                "Please double check that the column of VSEARCH cluster types is the first column."
174            );
175            return Err(anyhow!(
176                "Member types could not be parsed from provided cluster table,"
177            ));
178        }
179    };
180
181    let index_col: Rc<str> = match col_names.get(1) {
182        Some(col_name) => Rc::from(col_name.to_string()),
183        None => {
184            eprintln!(
185                "Please double check that the column of VSEARCH cluster index is the second column."
186            );
187            return Err(anyhow!(
188                "Column indices could not be parsed from provided cluster table,"
189            ));
190        }
191    };
192
193    let name_col: Rc<str> = match col_names.get(8) {
194        Some(col_name) => Rc::from(col_name.to_string()),
195        None => {
196            eprintln!("Please double check that the column of sequence names is the ninth column.");
197            return Err(anyhow!(
198                "Sequence names could not be parsed from provided cluster table,"
199            ));
200        }
201    };
202
203    let size_col: Rc<str> = match col_names.get(2) {
204        Some(col_name) => Rc::from(col_name.to_string()),
205        None => {
206            eprintln!(
207                "Please double check that the column of VSEARCH cluster sizes is the third column."
208            );
209            return Err(anyhow!(
210                "Cluster sizes could not be parsed from provided cluster table,"
211            ));
212        }
213    };
214
215    Ok(ClusterColumns::new(type_col, index_col, name_col, size_col))
216}
217
218fn get_size_per_member(
219    cluster_table: &LazyFrame,
220    centroids_only: &LazyFrame,
221    clust_cols: &ClusterColumns,
222) -> Result<(f64, DataFrame)> {
223    // pull out the sizes for each centroid by index
224    let centroid_sizes = centroids_only
225        .clone()
226        .select(&[col(&clust_cols.index_col), col(&clust_cols.size_col)])
227        .collect()?;
228
229    // Filter down to hits only and use to get a total number of sequences for the current month
230    let member_count = cluster_table
231        .clone()
232        .filter(
233            col(&clust_cols.type_col)
234                .eq(lit("H"))
235                .or(col(&clust_cols.type_col).eq(lit("S"))),
236        )
237        .select(&[col(&clust_cols.index_col)])
238        .collect()?
239        .shape()
240        .0;
241
242    // count rows of members to get the month total
243    let month_total: f64 = if member_count == 0 {
244        1.0
245    } else {
246        member_count as f64
247    };
248
249    Ok((month_total, centroid_sizes))
250}
251
252fn get_cluster_index(
253    cluster_table: &LazyFrame,
254    clust_cols: &ClusterColumns,
255    seq_name: &str,
256) -> Result<i64> {
257    let filtered = cluster_table
258        .clone()
259        .filter(col(&clust_cols.id_col).eq(lit(seq_name)))
260        .select(&[col(&clust_cols.index_col)])
261        .collect()?;
262
263    let index = filtered
264        .column(&clust_cols.index_col)?
265        .get(0)?
266        .try_extract::<i64>()?;
267
268    Ok(index)
269}
270
271fn compute_weighting_freq(
272    centroid_lf: LazyFrame,
273    clust_index: i64,
274    month_total: f64,
275    clust_cols: &ClusterColumns,
276) -> Result<f64> {
277    let collected_df = centroid_lf
278        .clone()
279        .filter(col(&clust_cols.index_col).eq(clust_index))
280        .select([col(&clust_cols.size_col)])
281        .collect()?;
282
283    let attempt = match collected_df
284        .column(&clust_cols.size_col)?
285        .iter()
286        .next() {
287            Some(value) => value,
288            None => return Err(anyhow!("Could not parse centroid data to compute a weight. Please double check the input cluster table."))
289        };
290
291    let cluster_freq = attempt.try_extract::<f64>()? / month_total;
292
293    Ok(cluster_freq)
294}
295
296fn weight_by_cluster_size(
297    seq_name: &str,
298    stringency: &Stringency,
299    cluster_table: &LazyFrame,
300) -> Result<(String, Series)> {
301    let clust_cols = get_cluster_cols(cluster_table)?;
302
303    // Filter down the df so that only rows representing centroids are present
304    let centroids_only = cluster_table
305        .clone()
306        .filter(col(&clust_cols.type_col).eq(lit("C")));
307
308    // Filter down to hits only and use to get a total number of sequences for the current month
309    let (month_total, all_size_df) =
310        get_size_per_member(cluster_table, &centroids_only, &clust_cols)?;
311
312    // determine the cluster index for the current cluster member
313    let index: i64 = get_cluster_index(cluster_table, &clust_cols, seq_name)?;
314
315    // find the frequency of the cluster for the member accession in question
316    let weighting_freq = compute_weighting_freq(centroids_only, index, month_total, &clust_cols)?;
317
318    // compute Polars series of weights along with the series name
319    let weights_header = format!("{}_weights", seq_name);
320    let weights_lf = match *stringency {
321        Stringency::Strict | Stringency::Extreme => all_size_df
322            .lazy()
323            .with_column(lit(-1.0).alias("negative"))
324            .with_column(lit(weighting_freq.ln()).alias("log_freq"))
325            .with_column(lit(month_total).alias("total"))
326            .with_column(
327                col(&clust_cols.size_col) * ((col("negative") * col("log_freq")) / col("total")),
328            )
329            .rename([&clust_cols.size_col], [&weights_header]),
330        _ => all_size_df
331            .lazy()
332            .with_column(lit(1.0).alias("tmp_int"))
333            .with_column(lit(-1.0).alias("negative"))
334            .with_column(lit(weighting_freq).alias("freq"))
335            .with_column(lit(month_total).alias("total"))
336            .with_column(
337                col(&clust_cols.size_col)
338                    * ((col("tmp_int") + col("negative") * col("freq")) / col("total")),
339            )
340            .rename([&clust_cols.size_col], [&weights_header]),
341    };
342
343    let weights = weights_lf
344        .select(&[col(&weights_header)])
345        .collect()?
346        .column(&weights_header)?
347        .to_owned();
348
349    Ok((weights_header, weights))
350}
351
352fn unpack_sequence(record: &fasta::Record) -> std::io::Result<String> {
353    let seq_attempt =
354        match record.sequence().get(..) {
355            Some(seq) => seq.to_vec(),
356            None => return Err(std::io::Error::new(
357                ErrorKind::InvalidData,
358                "No sequence was found for the provided record. Double check FASTA completeness.",
359            )),
360        };
361
362    let seq_as_string = String::from_utf8(seq_attempt).unwrap();
363
364    Ok(seq_as_string)
365}
366
367fn process_cluster_info(
368    cluster_table: Option<&str>,
369    dist_col_vec: Vec<Series>,
370    ids: &Vec<String>,
371    stringency: &Stringency,
372) -> Result<DataFrame> {
373    let mut dist_df = DataFrame::new(dist_col_vec)?;
374    dist_df = match cluster_table {
375        Some(table) => {
376            // read the cluster table into a lazyframe to query for cluster-size-based distance weights
377            let cluster_df = CsvReader::from_path(table)?
378                .has_header(false)
379                .with_delimiter(b'\t')
380                .finish()?
381                .lazy();
382
383            // multiply weights onto each column based on the sequence it represents
384            for id in ids {
385                let (weights_header, weights) =
386                    weight_by_cluster_size(id, stringency, &cluster_df)?;
387                dist_df = dist_df
388                    .hstack(&[weights])?
389                    .lazy()
390                    .with_columns(&[col(id).mul(col(&weights_header)).alias(id)])
391                    .collect()?
392                    .drop(&weights_header)?
393            }
394            let col_series = Series::new("Sequence_Name", &ids);
395            dist_df.hstack(&[col_series])?
396        }
397        None => {
398            let col_series = Series::new("Sequence_Name", &ids);
399            dist_df.hstack(&[col_series])?
400        }
401    };
402
403    Ok(dist_df)
404}
405
406pub fn compute_distance_matrix(
407    fasta: &str,
408    cluster_table: Option<&str>,
409    label: &str,
410    stringency: &Stringency,
411    distance_method: &DistanceMethods,
412) -> Result<()> {
413    let (ids, sequences) = collect_fa_data(fasta)?;
414
415    // double check that there are as many ids as there are sequences
416    assert!(
417        ids.len() == sequences.len(),
418        "Unable to identify an ID for each sequence from the FASTA {}.",
419        &fasta
420    );
421
422    // call a distance matrix with the chosen distance metric (defaulting to Levenshtein)
423    let mut pw_distmat = SquareMatrix::from_pw_distances_with(&sequences, |seq1, seq2| {
424        distance_method.calculate_distance(seq1, seq2)
425    });
426    pw_distmat.set_labels(ids.clone());
427
428    // pull distmat information out of the SquareMatrix struct and convert to dataframe
429    let mut dist_col_vec: Vec<Series> = vec![Default::default(); pw_distmat.size()];
430    for (i, (column, label)) in pw_distmat
431        .iter_cols()
432        .zip(pw_distmat.iter_labels())
433        .enumerate()
434    {
435        let series = Series::new(label, column.collect::<Vec<f64>>());
436        dist_col_vec[i] = series;
437    }
438
439    let mut dist_df = process_cluster_info(cluster_table, dist_col_vec, &ids, stringency)?;
440
441    // write out the weighted distance matrix
442    let out_name = format!("{}-dist-matrix.csv", label);
443    let out_handle = File::create(out_name).expect(
444        "File could not be created to write the distance matrix to. Please check file-write permissions."
445    );
446    CsvWriter::new(out_handle)
447        .has_header(true)
448        .finish(&mut dist_df)
449        .expect("Weighted distance matrix could not be written.");
450
451    Ok(())
452}