use std::{
fs::File,
io::{self, BufRead, BufReader, BufWriter, Read},
path::{Path, PathBuf},
sync::Arc,
};
use bstr::ByteSlice;
use lib_tsalign::{
a_star_aligner::{
alignment_result::{AlignmentResult, AlignmentStatistics, a_star_sequences},
template_switch_distance::AlignmentType,
},
costs::U64Cost,
};
use lib_tsshow::svg::SvgConfig;
use serde::{Deserialize, Serialize};
use tracing::{error, info};
use crate::{
RunnableCommand,
common::{
ImmutableSequence, MutableSequence,
alignment::{ForwardAlignment, cigar_to_alignment},
coords::GenomeRegion,
reference::{CliReferenceArg, ReferenceReader},
},
counter,
};
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, clap::Args)]
pub struct Command {
input: Option<PathBuf>,
#[arg(short = 'o', long = "output-dir")]
output_dir: PathBuf,
#[arg(short = 'r', long)]
reference: PathBuf,
#[arg(short = 'a', long = "arrows")]
arrows: bool,
#[arg(short = 'c', long = "more-complement")]
more_complement: bool,
#[arg(short = 'z', long = "zoom")]
context: Option<usize>,
#[arg(short = 'e', long = "equal-cost-ranges")]
visualise_equal_cost_ranges: bool,
#[arg(short = 'f', long = "overwrite")]
overwrite: bool,
#[arg(short = 'i', long)]
ids: Vec<String>,
#[arg(short = 'I', long)]
id_file: Option<PathBuf>,
}
#[derive(Deserialize, Serialize, Default)]
struct VizCSVRecord {
id: String,
ref_ctx_region: String,
read_id: Option<String>,
fw_cigar: String,
fw_cigar_ctx: String,
fw_mi_ctx: String,
#[serde(deserialize_with = "deser_int_or_float_str")]
fw_cost: u64,
ts_cigar: String,
#[serde(deserialize_with = "deser_int_or_float_str")]
ts_cost: u64,
#[serde(deserialize_with = "deser_int_or_float_str")]
ref_cluster_offset: usize,
#[serde(deserialize_with = "deser_int_or_float_str")]
alt_cluster_offset: usize,
}
fn deser_int_or_float_str<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: serde::Deserializer<'de>,
T: std::str::FromStr + TryFrom<u64>,
<T as TryFrom<u64>>::Error: std::fmt::Display,
{
let s = String::deserialize(deserializer)?;
if let Ok(v) = s.parse::<T>() {
return Ok(v);
}
let f: f64 = s
.parse()
.map_err(|_| serde::de::Error::custom(format!("expected integer, got `{s}`")))?;
if f.fract() != 0.0 {
return Err(serde::de::Error::custom(format!(
"expected integer value, got `{s}`"
)));
}
if f < 0.0 {
return Err(serde::de::Error::custom(format!(
"expected non-negative value, got `{s}`"
)));
}
T::try_from(f as u64).map_err(serde::de::Error::custom)
}
impl RunnableCommand for Command {
async fn run(self) -> anyhow::Result<()> {
let cli_ref_args = CliReferenceArg::from(self.reference.as_os_str().to_str().unwrap());
let reference_reader = ReferenceReader::try_from(&cli_ref_args)?;
if !self.output_dir.exists() {
info!("Creating output folder {}", self.output_dir.display());
tokio::fs::create_dir(&self.output_dir).await?;
}
let input = if let Some(path) = &self.input {
Box::new(File::open(path)?) as Box<dyn Read + Send>
} else {
Box::new(io::stdin())
};
let ids = parse_ids(&self.ids, self.id_file.as_ref())?;
let reader = csv::Reader::from_reader(input);
for record in reader.into_deserialize::<VizCSVRecord>() {
match record {
Ok(record) => {
if ids
.as_ref()
.is_some_and(|ids| ids.binary_search(&record.id).is_err())
{
counter!("viz.filtered").inc(1);
continue;
}
if let Err(e) = self.process_record(&record, &reference_reader).await {
counter!("viz.error").inc(1);
error!("{e}");
}
}
Err(e) => {
error!("Unable to read record: {e}");
if let csv::ErrorKind::Deserialize { err, .. } = e.into_kind()
&& let csv::DeserializeErrorKind::Message(msg) = err.kind()
&& msg.contains("missing field")
{
let first_line = {
let mut buf = Vec::new();
let mut w = csv::Writer::from_writer(&mut buf);
w.serialize(VizCSVRecord::default())?;
w.flush()?;
std::mem::drop(w);
let newline = buf.find_byteset(b"\r\n").unwrap();
let _ = buf.split_off(newline);
buf
};
error!("Required fields are: {}", first_line.as_bstr());
return Ok(());
}
}
}
}
Ok(())
}
}
impl Command {
async fn process_record(
&self,
record: &VizCSVRecord,
reference: &ReferenceReader,
) -> anyhow::Result<()> {
let path = self.output_dir.join(&record.id).with_extension("svg");
if path.exists() && !self.overwrite {
info!("Skipping existing file {}", path.display());
counter!("viz.skipped").inc(1);
return Ok(());
}
let sequences = record.get_sequences(reference).await?;
let ts_alignment = record.get_ts_alignment_result(sequences.clone())?;
let fw_alignment = record.get_forward_alignment_result(sequences)?;
let out = BufWriter::new(File::create(&path)?);
lib_tsshow::svg::create_ts_svg(
out,
&ts_alignment,
&Some(fw_alignment),
&SvgConfig {
render_arrows: self.arrows,
render_more_complement: self.more_complement,
restrict_context: self.context,
visualise_equal_cost_ranges: self.visualise_equal_cost_ranges,
},
)?;
counter!("viz.success").inc(1);
Ok(())
}
}
impl VizCSVRecord {
async fn get_sequences(
&self,
reference: &ReferenceReader,
) -> anyhow::Result<a_star_sequences::SequencePair> {
let ref_region = GenomeRegion::parse(self.ref_ctx_region.as_bytes())?;
let ref_seq = reference.get_seq_exact_unmasked(ref_region).await?;
let forward_alignment = ForwardAlignment(cigar_to_alignment(&self.fw_cigar_ctx)?);
let query_seq = apply_cigar_and_mi(&ref_seq, &forward_alignment, &self.fw_mi_ctx);
Ok(a_star_sequences::SequencePair {
reference_name: reference.get_name().to_string(),
reference: String::from_utf8_lossy(&ref_seq).into_owned(),
reference_rc: String::from_utf8_lossy(&reverse_complement(&ref_seq)).into_owned(),
query_name: self.read_id.clone().unwrap_or_default(),
query: String::from_utf8_lossy(&query_seq).into_owned(),
query_rc: String::from_utf8_lossy(&reverse_complement(&query_seq)).into_owned(),
})
}
fn ts_num(&self) -> anyhow::Result<usize> {
let alignment = cigar_to_alignment(&self.ts_cigar)?;
Ok(alignment
.iter_flat()
.filter(|op| matches!(*op, AlignmentType::TemplateSwitchEntrance { .. }))
.count())
}
fn get_ts_alignment_result(
&self,
sequences: a_star_sequences::SequencePair,
) -> anyhow::Result<AlignmentResult<AlignmentType, U64Cost>> {
let alignment = cigar_to_alignment(&self.ts_cigar)?;
let statistics: AlignmentStatistics<U64Cost> = AlignmentStatistics {
result: generic_a_star::AStarResult::FoundTarget {
identifier: (),
cost: self.ts_cost.into(),
},
sequences,
reference_offset: self.ref_cluster_offset,
query_offset: self.alt_cluster_offset,
cost: (self.ts_cost as f64).try_into().unwrap(),
cost_per_base: 0f64.try_into().unwrap(),
duration_seconds: 0f64.try_into().unwrap(),
opened_nodes: 0f64.try_into().unwrap(),
closed_nodes: 0f64.try_into().unwrap(),
suboptimal_opened_nodes: 0f64.try_into().unwrap(),
suboptimal_opened_nodes_ratio: 0f64.try_into().unwrap(),
template_switch_amount: (self.ts_num()? as f64).try_into().unwrap(),
runtime: 0f64.try_into().unwrap(),
memory: 0f64.try_into().unwrap(),
};
Ok(AlignmentResult::WithTarget {
alignment,
statistics,
})
}
fn get_forward_alignment_result(
&self,
sequences: a_star_sequences::SequencePair,
) -> anyhow::Result<AlignmentResult<AlignmentType, U64Cost>> {
let alignment = cigar_to_alignment(&self.fw_cigar)?;
let statistics: AlignmentStatistics<U64Cost> = AlignmentStatistics {
result: generic_a_star::AStarResult::FoundTarget {
identifier: (),
cost: self.fw_cost.into(),
},
sequences,
reference_offset: self.ref_cluster_offset,
query_offset: self.alt_cluster_offset,
cost: (self.fw_cost as f64).try_into().unwrap(),
cost_per_base: 0f64.try_into().unwrap(),
duration_seconds: 0f64.try_into().unwrap(),
opened_nodes: 0f64.try_into().unwrap(),
closed_nodes: 0f64.try_into().unwrap(),
suboptimal_opened_nodes: 0f64.try_into().unwrap(),
suboptimal_opened_nodes_ratio: 0f64.try_into().unwrap(),
template_switch_amount: (self.ts_num()? as f64).try_into().unwrap(),
runtime: 0f64.try_into().unwrap(),
memory: 0f64.try_into().unwrap(),
};
Ok(AlignmentResult::WithTarget {
alignment,
statistics,
})
}
}
fn reverse_complement(seq: &ImmutableSequence) -> ImmutableSequence {
let rc: MutableSequence = seq
.iter()
.rev()
.map(|&b| match b {
b'A' => b'T',
b'T' => b'A',
b'C' => b'G',
b'G' => b'C',
other => other,
})
.collect();
Arc::from(rc.as_slice())
}
fn apply_cigar_and_mi(
reference: &ImmutableSequence,
forward_alignment: &ForwardAlignment,
mi_string: &str,
) -> ImmutableSequence {
let mut mi_bases = mi_string.bytes().filter(u8::is_ascii_alphabetic);
let mut query = MutableSequence::new();
let mut ref_idx = 0usize;
for op in forward_alignment.iter_flat_cloned() {
match op {
AlignmentType::PrimaryMatch | AlignmentType::PrimaryFlankMatch => {
query.push(reference[ref_idx]);
ref_idx += 1;
}
AlignmentType::PrimarySubstitution | AlignmentType::PrimaryFlankSubstitution => {
if let Some(base) = mi_bases.next() {
query.push(base);
}
ref_idx += 1;
}
AlignmentType::PrimaryInsertion | AlignmentType::PrimaryFlankInsertion => {
if let Some(base) = mi_bases.next() {
query.push(base);
}
}
AlignmentType::PrimaryDeletion | AlignmentType::PrimaryFlankDeletion => {
ref_idx += 1;
}
_ => {}
}
}
Arc::from(query.as_slice())
}
fn parse_ids<P: AsRef<Path>>(
args: &[String],
file: Option<P>,
) -> anyhow::Result<Option<Vec<String>>> {
let mut ids: Vec<_> = args
.iter()
.flat_map(|arg| arg.split(','))
.map(str::trim)
.filter(|token| !token.is_empty())
.map(str::to_string)
.collect();
if let Some(file) = file {
let path = file.as_ref();
for line in BufReader::new(File::open(path)?).lines() {
let line = line?.trim().to_string();
ids.push(line);
}
}
if ids.is_empty() {
return Ok(None);
}
ids.sort();
Ok(Some(ids))
}
#[cfg(test)]
mod tests {
use super::VizCSVRecord;
use crate::common::csv::CSVRecord;
#[test]
fn viz_csv_record_round_trip() {
let full = CSVRecord {
id: "abc123".to_string(),
cluster_id: "cluster_1".to_string(),
ref_ctx_region: "chr1:100-200".to_string(),
alt_ctx_region: "chr1:100-200".to_string(),
cluster_region: "chr1:110-190".to_string(),
ref_cluster_offset: 10,
alt_cluster_offset: 20,
ts_1_4_region: "chr1:115-185".to_string(),
read_id: Some("read_1".to_string()),
fw_cigar: "50M".to_string(),
fw_cigar_ctx: "100M".to_string(),
fw_mi_ctx: String::new(),
fw_cost: 42,
fw_cost_ctx: 50,
ts_cigar: "25M5I20M".to_string(),
ts_cigar_ctx: "50M5I45M".to_string(),
ts_cost: 30,
ts_cost_ctx: 40,
ts_num: 0,
ts_1_2: "5".to_string(),
ts_2_3: "3".to_string(),
ts_1_4: "8".to_string(),
ts_start_left_shift: "0".to_string(),
ts_start_right_shift: "2".to_string(),
ts_end_left_shift: "1".to_string(),
ts_end_right_shift: "3".to_string(),
ts_inner_alignment_cigar: "5M".to_string(),
};
let mut buf = Vec::new();
{
let mut writer = csv::Writer::from_writer(&mut buf);
writer.serialize(&full).unwrap();
writer.flush().unwrap();
}
let mut reader = csv::Reader::from_reader(buf.as_slice());
let viz: VizCSVRecord = reader.deserialize().next().unwrap().unwrap();
assert_eq!(viz.id, full.id);
assert_eq!(viz.ref_ctx_region, full.ref_ctx_region);
assert_eq!(viz.read_id, full.read_id);
assert_eq!(viz.fw_cigar, full.fw_cigar);
assert_eq!(viz.fw_cigar_ctx, full.fw_cigar_ctx);
assert_eq!(viz.fw_mi_ctx, full.fw_mi_ctx);
assert_eq!(viz.fw_cost, full.fw_cost);
assert_eq!(viz.ts_cigar, full.ts_cigar);
assert_eq!(viz.ts_cost, full.ts_cost);
assert_eq!(viz.ref_cluster_offset, full.ref_cluster_offset);
assert_eq!(viz.alt_cluster_offset, full.alt_cluster_offset);
}
fn deser_viz_from_csv(csv: &str) -> Result<VizCSVRecord, csv::Error> {
csv::Reader::from_reader(csv.as_bytes())
.deserialize()
.next()
.unwrap()
}
const HEADER: &str =
"id,ref_ctx_region,read_id,fw_cigar,fw_cigar_ctx,fw_mi_ctx,fw_cost,ts_cigar,ts_cost,ref_cluster_offset,alt_cluster_offset\n";
#[test]
fn deser_int_or_float_str_integer_strings() {
let rec = deser_viz_from_csv(&format!(
"{HEADER}x,chr1:1-2,,10M,10M,,42,10M,30,10,20"
))
.unwrap();
assert_eq!(rec.fw_cost, 42);
assert_eq!(rec.ts_cost, 30);
assert_eq!(rec.ref_cluster_offset, 10);
assert_eq!(rec.alt_cluster_offset, 20);
}
#[test]
fn deser_int_or_float_str_float_formatted_integers() {
let rec = deser_viz_from_csv(&format!(
"{HEADER}x,chr1:1-2,,10M,10M,,42.0,10M,30.0,10.0,20.0"
))
.unwrap();
assert_eq!(rec.fw_cost, 42);
assert_eq!(rec.ts_cost, 30);
assert_eq!(rec.ref_cluster_offset, 10);
assert_eq!(rec.alt_cluster_offset, 20);
}
#[test]
fn deser_int_or_float_str_rejects_fractional() {
assert!(deser_viz_from_csv(&format!(
"{HEADER}x,chr1:1-2,,10M,10M,,42.5,10M,30,10,20"
))
.is_err());
}
#[test]
fn deser_int_or_float_str_rejects_negative() {
assert!(deser_viz_from_csv(&format!(
"{HEADER}x,chr1:1-2,,10M,10M,,42,10M,30,-1.0,20"
))
.is_err());
}
}