Skip to main content

rsomics_sc_cell_cycle/
lib.rs

1use std::fs::File;
2use std::io::{BufWriter, Write};
3use std::path::{Path, PathBuf};
4
5use rsomics_common::{Result, RsomicsError};
6use rsomics_sc_score_genes::{CountMatrix, ScoreParams, read_10x, read_gene_list, score};
7
8pub struct CellCycle {
9    pub barcodes: Vec<String>,
10    pub s_score: Vec<f64>,
11    pub g2m_score: Vec<f64>,
12    pub phase: Vec<&'static str>,
13}
14
15pub struct CellCycleParams {
16    pub n_bins: usize,
17    pub seed: u32,
18}
19
20/// scanpy precedence: default S, then G2M where it outscores S, then G1 where
21/// both scores are negative (the G1 rule is applied last and wins).
22fn assign_phase(s: f64, g2m: f64) -> &'static str {
23    if s < 0.0 && g2m < 0.0 {
24        "G1"
25    } else if g2m > s {
26        "G2M"
27    } else {
28        "S"
29    }
30}
31
32/// score_genes twice (ctrl_size = min over the two sets, as scanpy does), then
33/// the deterministic phase call.
34pub fn score_cell_cycle(
35    m: &CountMatrix,
36    s_genes: &[String],
37    g2m_genes: &[String],
38    params: &CellCycleParams,
39) -> Result<CellCycle> {
40    let ctrl_size = s_genes.len().min(g2m_genes.len());
41    let mk = |names: &[String]| -> Result<Vec<f64>> {
42        score(
43            m,
44            names,
45            &ScoreParams {
46                ctrl_size,
47                n_bins: params.n_bins,
48                seed: params.seed,
49            },
50        )
51    };
52    let s_score = mk(s_genes)?;
53    let g2m_score = mk(g2m_genes)?;
54
55    let phase = s_score
56        .iter()
57        .zip(&g2m_score)
58        .map(|(&s, &g)| assign_phase(s, g))
59        .collect();
60
61    Ok(CellCycle {
62        barcodes: m.barcodes.clone(),
63        s_score,
64        g2m_score,
65        phase,
66    })
67}
68
69pub fn write_tsv(cc: &CellCycle, out: impl Write) -> Result<()> {
70    let mut w = BufWriter::with_capacity(1 << 20, out);
71    w.write_all(b"cell_id\tS_score\tG2M_score\tphase\n")
72        .map_err(RsomicsError::Io)?;
73    let mut fmt = ryu::Buffer::new();
74    let mut line: Vec<u8> = Vec::with_capacity(96);
75    for i in 0..cc.barcodes.len() {
76        line.clear();
77        line.extend_from_slice(cc.barcodes[i].as_bytes());
78        line.push(b'\t');
79        line.extend_from_slice(fmt.format(cc.s_score[i]).as_bytes());
80        line.push(b'\t');
81        line.extend_from_slice(fmt.format(cc.g2m_score[i]).as_bytes());
82        line.push(b'\t');
83        line.extend_from_slice(cc.phase[i].as_bytes());
84        line.push(b'\n');
85        w.write_all(&line).map_err(RsomicsError::Io)?;
86    }
87    w.flush().map_err(RsomicsError::Io)?;
88    Ok(())
89}
90
91pub fn open_output(path: &str) -> Result<Box<dyn Write>> {
92    if path == "-" {
93        Ok(Box::new(std::io::stdout().lock()))
94    } else {
95        Ok(Box::new(
96            File::create(PathBuf::from(path)).map_err(RsomicsError::Io)?,
97        ))
98    }
99}
100
101pub fn run(
102    mtx_dir: &Path,
103    s_genes: &Path,
104    g2m_genes: &Path,
105    params: &CellCycleParams,
106    out: impl Write,
107) -> Result<usize> {
108    let m = read_10x(mtx_dir)?;
109    let s = read_gene_list(s_genes)?;
110    let g2m = read_gene_list(g2m_genes)?;
111    let cc = score_cell_cycle(&m, &s, &g2m, params)?;
112    write_tsv(&cc, out)?;
113    Ok(m.n_cells)
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use rsomics_sc_score_genes::{CountMatrix, Entry};
120
121    fn matrix(n_genes: usize, n_cells: usize, triples: &[(u32, u32, f64)]) -> CountMatrix {
122        CountMatrix {
123            n_genes,
124            n_cells,
125            gene_ids: (0..n_genes).map(|g| format!("G{g}")).collect(),
126            barcodes: (0..n_cells).map(|c| format!("C{c}")).collect(),
127            entries: triples
128                .iter()
129                .map(|&(gene, cell, value)| Entry { gene, cell, value })
130                .collect(),
131        }
132    }
133
134    #[test]
135    fn phase_precedence() {
136        assert_eq!(assign_phase(0.5, -0.2), "S"); // S positive, G2M loses
137        assert_eq!(assign_phase(0.1, 0.4), "G2M"); // G2M outscores S, both >= 0
138        assert_eq!(assign_phase(-0.3, -0.1), "G1"); // both negative
139        assert_eq!(assign_phase(0.0, 0.0), "S"); // tie at zero stays S
140        assert_eq!(assign_phase(-0.5, 0.2), "G2M"); // only S<0, so not G1; G2M>S -> G2M
141    }
142
143    #[test]
144    fn g2m_strictly_greater_for_g2m_call() {
145        // when G2M does not strictly exceed S, phase stays S.
146        assert_eq!(assign_phase(0.3, 0.3), "S");
147        assert_eq!(assign_phase(0.3, 0.3000001), "G2M");
148    }
149
150    #[test]
151    fn end_to_end_runs_and_calls_phases() {
152        // 60 genes / 8 cells so control bins are populated after excluding the
153        // 3+3 list genes; just checks shape + a deterministic phase per cell.
154        let mut triples = Vec::new();
155        for g in 0..60u32 {
156            for c in 0..8u32 {
157                let v = ((g * 7 + c * 3) % 11) as f64 + 1.0;
158                triples.push((g, c, v));
159            }
160        }
161        let m = matrix(60, 8, &triples);
162        let s = vec!["G0".to_string(), "G1".to_string(), "G2".to_string()];
163        let g2m = vec!["G30".to_string(), "G31".to_string(), "G32".to_string()];
164        let params = CellCycleParams {
165            n_bins: 25,
166            seed: 0,
167        };
168        let cc = score_cell_cycle(&m, &s, &g2m, &params).unwrap();
169        assert_eq!(cc.phase.len(), 8);
170        for p in &cc.phase {
171            assert!(matches!(*p, "S" | "G2M" | "G1"));
172        }
173    }
174}