use clap_derive::{Parser, Subcommand, ValueEnum};
use serde::{Deserialize, Deserializer};
use std::fmt::Display;
use std::path::PathBuf;
use std::str::FromStr;
use crate::graph::score::HaplotypeMetric;
use crate::translation::distance::DistanceMetric;
#[derive(Parser, Debug)]
#[clap(version, about)]
pub(crate) struct Predictosaurus {
#[clap(subcommand)]
pub(crate) command: Command,
#[clap(short, long, global = true)]
pub(crate) verbose: bool,
#[arg(short, long)]
pub(crate) threads: Option<usize>,
}
#[derive(Subcommand, Debug)]
pub(crate) enum Command {
Build {
#[clap(short, long)]
calls: PathBuf,
#[clap(short, long)]
observations: Vec<ObservationFile>,
#[clap(short, long, default_value = "0.8")]
min_prob_present: f64,
#[clap(long, default_value = "0.05")]
min_vaf: f32,
#[clap(long)]
output: PathBuf,
},
Process {
#[clap(short, long)]
features: PathBuf,
#[clap(short, long)]
reference: PathBuf,
#[clap(short, long)]
graph: PathBuf,
#[clap(short, long, default_value = "grantham")]
distance_metric: DistanceMetric,
#[clap(long, default_value = "minimum")]
haplotype_metric: HaplotypeMetric,
#[clap(short, long)]
output: PathBuf,
#[clap(long, default_value = "5000")]
max_cds_length: u64,
},
Peptides {
#[clap(short, long)]
features: PathBuf,
#[clap(short, long)]
reference: PathBuf,
#[clap(short, long)]
graph: PathBuf,
#[clap(short, long, default_value_t = Interval::default())]
interval: Interval,
#[clap(short, long)]
sample: String,
#[clap(short, long)]
events: Vec<String>,
#[clap(long)]
min_event_prob: f64,
#[clap(short, long)]
background_events: Vec<String>,
#[clap(long)]
max_background_event_prob: f64,
#[clap(short, long)]
output: PathBuf,
#[clap(long, default_value = "5000")]
max_cds_length: u64,
},
Plot {
#[clap(short, long)]
input: PathBuf,
#[clap(short, long)]
output: PathBuf,
},
}
#[derive(Debug, Clone)]
pub(crate) struct Interval {
pub(crate) start: u32,
pub(crate) end: u32,
}
impl Display for Interval {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}-{}", self.start, self.end)
}
}
impl Default for Interval {
fn default() -> Self {
Interval { start: 8, end: 11 }
}
}
impl Iterator for Interval {
type Item = u32;
fn next(&mut self) -> Option<Self::Item> {
if self.start <= self.end {
let current = self.start;
self.start += 1;
Some(current)
} else {
None
}
}
}
impl FromStr for Interval {
type Err = String;
fn from_str(string: &str) -> Result<Interval, Self::Err> {
let (start, end) = string
.split_once('-')
.expect("Invalid interval format. Make sure to use the format `start-end`");
let start = start.parse::<u32>().unwrap();
let end = end.parse::<u32>().unwrap();
if start > end {
panic!("Invalid interval format. Make sure to use the format `start-end`");
}
Ok(Interval { start, end })
}
}
#[derive(Debug, Clone, ValueEnum)]
pub(crate) enum Format {
Html,
Tsv,
Vega,
}
#[derive(Debug, Clone)]
pub(crate) struct ObservationFile {
pub(crate) path: PathBuf,
pub(crate) sample: String,
}
impl FromStr for ObservationFile {
type Err = String;
fn from_str(string: &str) -> Result<ObservationFile, Self::Err> {
let (sample, path) = string.split_once('=').expect("Invalid observation file parameter format. Make sure to use the format `--observations sample=observations.vcf`");
Ok(ObservationFile {
sample: sample.to_string(),
path: PathBuf::from(path),
})
}
}
impl<'de> Deserialize<'de> for ObservationFile {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
Ok(ObservationFile::from_str(&string).unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn from_str_parses_valid_observation_file() {
let input = "sample1=observations.vcf";
let observation_file = ObservationFile::from_str(input).unwrap();
assert_eq!(observation_file.sample, "sample1");
assert_eq!(observation_file.path, PathBuf::from("observations.vcf"));
}
#[test]
#[should_panic]
fn from_str_fails_on_invalid_format() {
let input = "invalid_format";
let result = ObservationFile::from_str(input);
}
#[test]
fn from_str_parses_valid_interval() {
let input = "10-20";
let interval = Interval::from_str(input).unwrap();
assert_eq!(interval.start, 10);
assert_eq!(interval.end, 20);
}
#[test]
#[should_panic]
fn from_str_fails_on_invalid_interval() {
let input = "10:20";
let result = Interval::from_str(input);
}
#[test]
#[should_panic]
fn from_str_fails_when_start_greater_than_end() {
let input = "20-10";
let result = Interval::from_str(input);
}
#[test]
fn iterator_yields_all_values_in_range() {
let mut interval = Interval { start: 1, end: 3 };
assert_eq!(interval.next(), Some(1));
assert_eq!(interval.next(), Some(2));
assert_eq!(interval.next(), Some(3));
assert_eq!(interval.next(), None);
}
#[test]
fn default_interval_has_correct_start_and_end() {
let interval = Interval::default();
assert_eq!(interval.start, 8);
assert_eq!(interval.end, 11);
}
#[test]
fn display_formats_interval_correctly() {
let interval = Interval { start: 5, end: 10 };
assert_eq!(format!("{}", interval), "5-10");
}
}