use crate::error::MotifError;
use crate::fasta::reverse_complement;
use crate::types::*;
use polars::lazy::dsl::*;
use polars::prelude::*;
use std::collections::HashMap;
use std::fmt::format;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::iter::Peekable;
const PSEUDOCOUNT: f64 = 0.0001;
const RT: f64 = 2.5;
fn skip_until_motif<I>(lines: &mut Peekable<I>)
where
I: Iterator<Item = Result<String, std::io::Error>>,
{
while let Some(Ok(line)) = lines.peek() {
if line.starts_with("MOTIF") {
break;
}
lines.next();
}
}
fn parse_pwm<I>(lines: &mut I) -> Result<Option<(String, PWM)>, MotifError>
where
I: Iterator<Item = Result<String, std::io::Error>>,
{
let motif_line = match lines.next() {
Some(Ok(line)) if line.starts_with("MOTIF") => line,
_ => return Ok(None),
};
let motif_id = motif_line
.split_whitespace()
.nth(1)
.ok_or_else(|| MotifError::InvalidFileFormat("Missing motif ID".into()))?
.to_string();
for _ in 0..2 {
lines.next();
}
let pwm_rows: Vec<Vec<f64>> = lines
.take_while(|line| {
line.as_ref()
.map(|l| l.starts_with(|c: char| c.is_whitespace() || c == '0' || c == '1'))
.unwrap_or(false)
})
.map(|line| {
let line = line.map_err(|e| MotifError::Io(e))?;
let values: Vec<f64> = line
.split_whitespace()
.map(|s| s.parse::<f64>())
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MotifError::InvalidFileFormat(format!("Invalid PWM value: {}", e)))?;
Ok(values)
})
.collect::<Result<Vec<_>, MotifError>>()?;
if pwm_rows.is_empty() {
return Err(MotifError::InvalidFileFormat("Empty PWM".into()));
}
let pwm = DataFrame::new(vec![
Column::new(
"A".into(),
pwm_rows.iter().map(|row| row[0]).collect::<Vec<f64>>(),
),
Column::new(
"C".into(),
pwm_rows.iter().map(|row| row[1]).collect::<Vec<f64>>(),
),
Column::new(
"G".into(),
pwm_rows.iter().map(|row| row[2]).collect::<Vec<f64>>(),
),
Column::new(
"T".into(),
pwm_rows.iter().map(|row| row[3]).collect::<Vec<f64>>(),
),
])
.map_err(|e| MotifError::DataError(e.to_string()))?;
Ok(Some((motif_id, pwm)))
}
pub fn read_pwm_files(filename: &str) -> Result<PWMCollection, MotifError> {
let file = File::open(filename)?;
let reader = BufReader::new(file);
let mut lines = reader.lines().peekable();
let mut pwms = HashMap::new();
skip_until_motif(&mut lines);
while let Some((id, pwm)) = parse_pwm(&mut lines)? {
pwms.insert(id, pwm);
skip_until_motif(&mut lines);
}
if pwms.is_empty() {
return Err(MotifError::InvalidFileFormat("No PWMs found".into()));
}
Ok(pwms)
}
pub fn read_pwm_to_ewm(filename: &str) -> Result<EWMCollection, MotifError> {
let pwms = read_pwm_files(filename)?;
let ewms: EWMCollection = pwms
.into_iter()
.map(|(id, pwm)| {
let normalized = pwm
.clone()
.lazy()
.select([
(col("A") + lit(PSEUDOCOUNT)).alias("A_pseudo"),
(col("C") + lit(PSEUDOCOUNT)).alias("C_pseudo"),
(col("G") + lit(PSEUDOCOUNT)).alias("G_pseudo"),
(col("T") + lit(PSEUDOCOUNT)).alias("T_pseudo"),
])
.with_column(
max_horizontal([
col("A_pseudo"),
col("C_pseudo"),
col("G_pseudo"),
col("T_pseudo"),
])
.unwrap()
.alias("max_val"),
)
.select([
(col("A_pseudo") / col("max_val")).alias("A_norm"),
(col("C_pseudo") / col("max_val")).alias("C_norm"),
(col("G_pseudo") / col("max_val")).alias("G_norm"),
(col("T_pseudo") / col("max_val")).alias("T_norm"),
])
.select([
(-lit(RT) * col("A_norm").log(std::f64::consts::E)).alias("A"),
(-lit(RT) * col("C_norm").log(std::f64::consts::E)).alias("C"),
(-lit(RT) * col("G_norm").log(std::f64::consts::E)).alias("G"),
(-lit(RT) * col("T_norm").log(std::f64::consts::E)).alias("T"),
])
.collect()
.map_err(|e| MotifError::DataError(e.to_string()))?;
Ok((id, normalized))
})
.collect::<Result<HashMap<_, _>, MotifError>>()?;
Ok(ewms)
}
pub fn energy_landscape(seq: &str, ewm: &EWM) -> Result<(Vec<f64>, Vec<f64>), MotifError> {
let motif_len = ewm.height();
let n_scores = seq.len() - motif_len + 1;
let r_seq = reverse_complement(seq)?;
let mut fscores = vec![0.0; n_scores];
let mut rscores = vec![0.0; n_scores];
for (pos, (fscore, rscore)) in fscores.iter_mut().zip(rscores.iter_mut()).enumerate() {
let f_kmer = &seq[pos..pos + motif_len];
let r_kmer = &r_seq[pos..pos + motif_len];
*fscore = (0..motif_len)
.map(|i| {
ewm.column(&f_kmer[i..i + 1])
.unwrap()
.get(i)
.unwrap()
.try_extract::<f64>()
.map_err(|e| MotifError::DataError(e.to_string()))
})
.sum::<Result<f64, MotifError>>()?;
*rscore = (0..motif_len)
.map(|i| {
ewm.column(&r_kmer[i..i + 1])
.unwrap()
.get(i)
.unwrap()
.try_extract::<f64>()
.map_err(|e| MotifError::DataError(e.to_string()))
})
.sum::<Result<f64, MotifError>>()?;
}
rscores.reverse();
Ok((fscores, rscores))
}
pub fn occupancy_landscape(
seq: &str,
ewm: &EWM,
mu: f64,
) -> Result<(Vec<f64>, Vec<f64>), MotifError> {
let (fscores, rscores) = energy_landscape(seq, ewm)?;
let foccupancies: Vec<f64> = fscores
.into_iter()
.map(|s| 1.0 / (1.0 + (s - mu).exp()))
.collect();
let roccupancies: Vec<f64> = rscores
.into_iter()
.map(|s| 1.0 / (1.0 + (s - mu).exp()))
.collect();
Ok((foccupancies, roccupancies))
}
pub fn total_landscape(seq: &str, ewms: &EWMCollection, mu: f64) -> Result<DataFrame, MotifError> {
let seq_len = seq.len();
let mut columns: Vec<Column> = Vec::new();
let mut names: Vec<String> = Vec::new();
for (name, ewm) in ewms {
let (fscores, rscores) = occupancy_landscape(seq, ewm, mu)?;
let amount_to_add = seq_len - fscores.len();
let mut fscores_padded = fscores.clone();
let mut rscores_padded = rscores.clone();
fscores_padded.extend(vec![0.0; amount_to_add]);
rscores_padded.extend(vec![0.0; amount_to_add]);
columns.push(Column::new(format!("{}_F", name).into(), fscores_padded));
columns.push(Column::new(format!("{}_R", name).into(), rscores_padded));
names.push(name.to_string());
}
DataFrame::new(columns).map_err(|e| MotifError::DataError(e.to_string()))
}