1use std::collections::HashMap;
12
13use arrow::{
14 array::{Array, Float64Array, StringArray, UInt8Array},
15 datatypes::{DataType, Field, Schema},
16 record_batch::RecordBatch,
17};
18
19use crate::{Error, Result};
20
21pub type GeneValues = Vec<Vec<Option<f64>>>;
23
24#[derive(Debug, Clone)]
32pub struct ExpressionMatrix {
33 pub genes: Vec<String>,
35
36 pub samples: Vec<String>,
39
40 pub values: RecordBatch,
44}
45
46impl ExpressionMatrix {
47 #[must_use]
52 pub fn get(&self, gene: &str, sample: &str) -> Option<f64> {
53 let gene_idx = self.genes.iter().position(|g| g == gene)?;
54 let sample_idx = self.samples.iter().position(|s| s == sample)?;
55
56 let col = self.values.column(sample_idx);
57 let array = col.as_any().downcast_ref::<Float64Array>()?;
58
59 if array.is_null(gene_idx) {
60 None
61 } else {
62 Some(array.value(gene_idx))
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
69pub enum AggregationMethod {
70 #[default]
72 Mean,
73
74 Median,
76
77 Max,
79
80 Min,
82}
83
84#[derive(Debug, Clone)]
86pub struct MatrixConfig {
87 pub aggregation: AggregationMethod,
89
90 pub min_sample_presence: usize,
92}
93
94impl Default for MatrixConfig {
95 fn default() -> Self {
96 Self {
97 aggregation: AggregationMethod::Mean,
98 min_sample_presence: 1,
99 }
100 }
101}
102
103pub struct MatrixBuilder {
105 config: MatrixConfig,
106}
107
108impl Default for MatrixBuilder {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl MatrixBuilder {
115 #[must_use]
117 pub fn new() -> Self {
118 Self {
119 config: MatrixConfig::default(),
120 }
121 }
122
123 #[must_use]
125 pub fn with_config(config: MatrixConfig) -> Self {
126 Self { config }
127 }
128
129 pub fn from_soft<R>(&self, mut reader: geo_soft_rs::SoftReader<R>) -> Result<ExpressionMatrix>
142 where
143 R: std::io::BufRead,
144 {
145 let (samples, platform_opt) = Self::collect_records(&mut reader)?;
146 self.assemble_matrix(&samples, platform_opt.as_ref())
147 }
148
149 pub fn build_all<R>(
163 &self,
164 mut reader: geo_soft_rs::SoftReader<R>,
165 ) -> Result<(ExpressionMatrix, SampleMetadata, Option<PlatformAnnotation>)>
166 where
167 R: std::io::BufRead,
168 {
169 let (samples, platform_opt) = Self::collect_records(&mut reader)?;
170 let metadata = SampleMetadata::from_samples(&samples)?;
171 let annotation = platform_opt
172 .as_ref()
173 .map(PlatformAnnotation::from_platform)
174 .transpose()?
175 .flatten();
176 let matrix = self.assemble_matrix(&samples, platform_opt.as_ref())?;
177 Ok((matrix, metadata, annotation))
178 }
179
180 fn collect_records<R>(
183 reader: &mut geo_soft_rs::SoftReader<R>,
184 ) -> Result<(Vec<geo_soft_rs::GsmRecord>, Option<geo_soft_rs::GplRecord>)>
185 where
186 R: std::io::BufRead,
187 {
188 let mut samples: Vec<geo_soft_rs::GsmRecord> = Vec::new();
189 let mut platform_opt: Option<geo_soft_rs::GplRecord> = None;
190
191 while let Some(result) = reader.next_record() {
192 match result? {
193 geo_soft_rs::SoftRecord::Sample(s) if s.data_table.is_some() => {
194 samples.push(s);
195 }
196 geo_soft_rs::SoftRecord::Platform(p) if platform_opt.is_none() => {
197 platform_opt = Some(p);
198 }
199 _ => {}
200 }
201 }
202
203 if samples.is_empty() {
204 return Err(Error::Matrix(
205 "No samples with data tables found in SOFT file".to_string(),
206 ));
207 }
208
209 Ok((samples, platform_opt))
210 }
211
212 fn assemble_matrix(
215 &self,
216 samples: &[geo_soft_rs::GsmRecord],
217 platform_opt: Option<&geo_soft_rs::GplRecord>,
218 ) -> Result<ExpressionMatrix> {
219 let mut probe_data: HashMap<String, Vec<(usize, f64)>> = HashMap::new();
221 let mut sample_ids: Vec<String> = Vec::with_capacity(samples.len());
222
223 for (sample_idx, sample) in samples.iter().enumerate() {
224 let sample_id = sample
225 .geo_accession
226 .clone()
227 .unwrap_or_else(|| sample.local_id.clone());
228 sample_ids.push(sample_id);
229
230 if let Some(ref table) = sample.data_table {
231 let id_ref_idx = table
233 .columns
234 .iter()
235 .position(|c| c.name.eq_ignore_ascii_case("ID_REF"))
236 .ok_or_else(|| {
237 Error::Matrix(format!(
238 "Sample {} missing `ID_REF` column",
239 sample.local_id
240 ))
241 })?;
242
243 let value_idx = table
244 .columns
245 .iter()
246 .position(|c| c.name.eq_ignore_ascii_case("VALUE"))
247 .ok_or_else(|| {
248 Error::Matrix(format!("Sample {} missing `VALUE` column", sample.local_id))
249 })?;
250
251 for row in &table.rows {
253 if let Some(probe_id) = row.get(id_ref_idx) {
254 if let Some(value_str) = row.get(value_idx) {
255 if let Ok(value) = value_str.parse::<f64>() {
256 probe_data
257 .entry(probe_id.clone())
258 .or_default()
259 .push((sample_idx, value));
260 }
261 }
263 }
264 }
265 }
266 }
267
268 let probe_to_gene = Self::build_probe_to_gene_map(platform_opt);
270
271 let (genes, gene_values) =
273 self.aggregate_by_gene(&probe_data, &probe_to_gene, samples.len());
274
275 let values = Self::build_record_batch(&genes, &sample_ids, &gene_values)?;
277
278 Ok(ExpressionMatrix {
279 genes,
280 samples: sample_ids,
281 values,
282 })
283 }
284
285 fn build_probe_to_gene_map(
287 platform: Option<&geo_soft_rs::GplRecord>,
288 ) -> HashMap<String, String> {
289 let mut mapping = HashMap::new();
290
291 if let Some(p) = platform {
292 if let Some(ref table) = p.annotation_table {
293 let probe_idx = table.columns.iter().position(|c| {
295 c.name.eq_ignore_ascii_case("ID")
296 || c.name.eq_ignore_ascii_case("PROBE_ID")
297 || c.name.eq_ignore_ascii_case("ID_REF")
298 });
299
300 let gene_idx = table.columns.iter().position(|c| {
301 c.name.eq_ignore_ascii_case("GENE_SYMBOL")
302 || c.name.eq_ignore_ascii_case("SYMBOL")
303 || c.name.eq_ignore_ascii_case("GENE")
304 });
305
306 if let (Some(p_idx), Some(g_idx)) = (probe_idx, gene_idx) {
307 for row in &table.rows {
308 if let (Some(probe), Some(gene)) = (row.get(p_idx), row.get(g_idx)) {
309 if !gene.is_empty() && gene != "---" {
310 mapping.insert(probe.clone(), gene.clone());
311 }
312 }
313 }
314 }
315 }
316 }
317
318 mapping
319 }
320
321 #[allow(clippy::cast_precision_loss)]
323 fn aggregate_by_gene(
324 &self,
325 probe_data: &HashMap<String, Vec<(usize, f64)>>,
326 probe_to_gene: &HashMap<String, String>,
327 num_samples: usize,
328 ) -> (Vec<String>, GeneValues) {
329 let mut gene_probes: HashMap<String, Vec<String>> = HashMap::new();
331
332 for (probe_id, sample_entries) in probe_data {
333 let distinct_samples = sample_entries
334 .iter()
335 .map(|(s_idx, _)| s_idx)
336 .collect::<std::collections::HashSet<_>>()
337 .len();
338 if distinct_samples < self.config.min_sample_presence {
339 continue;
340 }
341 let gene = probe_to_gene
342 .get(probe_id)
343 .cloned()
344 .unwrap_or_else(|| probe_id.clone());
345 gene_probes.entry(gene).or_default().push(probe_id.clone());
346 }
347
348 let mut genes: Vec<String> = gene_probes.keys().cloned().collect();
350 genes.sort();
351
352 let mut gene_values: Vec<Vec<Option<f64>>> = Vec::with_capacity(genes.len());
354
355 for gene in &genes {
356 let probes = gene_probes.get(gene).unwrap();
357 let mut sample_values: Vec<Vec<f64>> = vec![Vec::new(); num_samples];
358
359 for probe_id in probes {
361 if let Some(values) = probe_data.get(probe_id) {
362 for (sample_idx, value) in values {
363 sample_values[*sample_idx].push(*value);
364 }
365 }
366 }
367
368 let mut aggregated: Vec<Option<f64>> = Vec::with_capacity(num_samples);
370 for values in sample_values {
371 if values.is_empty() {
372 aggregated.push(None);
373 } else {
374 let agg = match self.config.aggregation {
375 AggregationMethod::Mean => values.iter().sum::<f64>() / values.len() as f64,
376 AggregationMethod::Median => {
377 let mut sorted = values;
378 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
379 let mid = sorted.len() / 2;
380 if sorted.len() % 2 == 0 {
381 f64::midpoint(sorted[mid - 1], sorted[mid])
382 } else {
383 sorted[mid]
384 }
385 }
386 AggregationMethod::Max => values
387 .iter()
388 .max_by(|a, b| a.partial_cmp(b).unwrap())
389 .copied()
390 .expect("non-empty guaranteed by is_empty check above"),
391 AggregationMethod::Min => values
392 .iter()
393 .min_by(|a, b| a.partial_cmp(b).unwrap())
394 .copied()
395 .expect("non-empty guaranteed by is_empty check above"),
396 };
397 aggregated.push(Some(agg));
398 }
399 }
400
401 gene_values.push(aggregated);
402 }
403
404 (genes, gene_values)
405 }
406
407 fn build_record_batch(
409 genes: &[String],
410 sample_ids: &[String],
411 gene_values: &[Vec<Option<f64>>],
412 ) -> Result<RecordBatch> {
413 let fields: Vec<Field> = sample_ids
415 .iter()
416 .map(|id| Field::new(id.clone(), DataType::Float64, true))
417 .collect();
418 let schema = Schema::new(fields);
419
420 let mut columns: Vec<arrow::array::ArrayRef> = Vec::with_capacity(sample_ids.len());
422
423 for sample_idx in 0..sample_ids.len() {
424 let mut values: Vec<Option<f64>> = Vec::with_capacity(genes.len());
425 for gene_values_row in gene_values {
426 debug_assert!(
427 sample_idx < gene_values_row.len(),
428 "gene_values row length ({}) must equal num_samples ({})",
429 gene_values_row.len(),
430 sample_ids.len()
431 );
432 values.push(gene_values_row[sample_idx]);
433 }
434 let array = Float64Array::from(values);
435 columns.push(std::sync::Arc::new(array));
436 }
437
438 let batch = RecordBatch::try_new(std::sync::Arc::new(schema), columns)?;
439 Ok(batch)
440 }
441}
442
443#[derive(Debug, Clone)]
448pub struct SampleMetadata {
449 pub data: RecordBatch,
451}
452
453impl SampleMetadata {
454 pub fn from_soft<R>(mut reader: geo_soft_rs::SoftReader<R>) -> Result<Self>
468 where
469 R: std::io::BufRead,
470 {
471 let mut records: Vec<(String, String, String, String)> = Vec::new();
472
473 while let Some(result) = reader.next_sample() {
474 let sample = result?;
475 let gsm_accession = sample
476 .geo_accession
477 .clone()
478 .unwrap_or_else(|| sample.local_id.clone());
479
480 for char_map in &sample.characteristics {
482 for (key, value) in char_map {
483 records.push((
484 gsm_accession.clone(),
485 sample.title.clone(),
486 key.clone(),
487 value.clone(),
488 ));
489 }
490 }
491
492 if sample.characteristics.is_empty() {
494 records.push((gsm_accession, sample.title, String::new(), String::new()));
495 }
496 }
497
498 let schema = Schema::new(vec![
500 Field::new("gsm_accession", DataType::Utf8, false),
501 Field::new("title", DataType::Utf8, false),
502 Field::new("characteristic_key", DataType::Utf8, false),
503 Field::new("characteristic_value", DataType::Utf8, false),
504 ]);
505
506 let gsm_accessions: Vec<&str> = records.iter().map(|r| r.0.as_str()).collect();
507 let titles: Vec<&str> = records.iter().map(|r| r.1.as_str()).collect();
508 let keys: Vec<&str> = records.iter().map(|r| r.2.as_str()).collect();
509 let values: Vec<&str> = records.iter().map(|r| r.3.as_str()).collect();
510
511 let batch = RecordBatch::try_new(
512 std::sync::Arc::new(schema),
513 vec![
514 std::sync::Arc::new(StringArray::from(gsm_accessions)),
515 std::sync::Arc::new(StringArray::from(titles)),
516 std::sync::Arc::new(StringArray::from(keys)),
517 std::sync::Arc::new(StringArray::from(values)),
518 ],
519 )?;
520
521 Ok(Self { data: batch })
522 }
523
524 pub fn from_samples(samples: &[geo_soft_rs::GsmRecord]) -> Result<Self> {
533 let mut records: Vec<(String, String, u8, String, String)> = Vec::new();
535
536 for sample in samples {
537 let gsm_accession = sample
538 .geo_accession
539 .clone()
540 .unwrap_or_else(|| sample.local_id.clone());
541
542 for (channel_idx, char_map) in sample.characteristics.iter().enumerate() {
543 #[allow(clippy::cast_possible_truncation)]
544 let ch = channel_idx as u8;
545 for (key, value) in char_map {
546 records.push((
547 gsm_accession.clone(),
548 sample.title.clone(),
549 ch,
550 key.clone(),
551 value.clone(),
552 ));
553 }
554 }
555
556 if sample.characteristics.is_empty() {
557 records.push((
558 gsm_accession,
559 sample.title.clone(),
560 0,
561 String::new(),
562 String::new(),
563 ));
564 }
565 }
566
567 let schema = Schema::new(vec![
568 Field::new("gsm_accession", DataType::Utf8, false),
569 Field::new("title", DataType::Utf8, false),
570 Field::new("channel_index", DataType::UInt8, false),
571 Field::new("characteristic_key", DataType::Utf8, false),
572 Field::new("characteristic_value", DataType::Utf8, false),
573 ]);
574
575 let gsm_accessions: Vec<&str> = records.iter().map(|r| r.0.as_str()).collect();
576 let titles: Vec<&str> = records.iter().map(|r| r.1.as_str()).collect();
577 let channels: Vec<u8> = records.iter().map(|r| r.2).collect();
578 let keys: Vec<&str> = records.iter().map(|r| r.3.as_str()).collect();
579 let values: Vec<&str> = records.iter().map(|r| r.4.as_str()).collect();
580
581 let batch = RecordBatch::try_new(
582 std::sync::Arc::new(schema),
583 vec![
584 std::sync::Arc::new(StringArray::from(gsm_accessions)),
585 std::sync::Arc::new(StringArray::from(titles)),
586 std::sync::Arc::new(UInt8Array::from(channels)),
587 std::sync::Arc::new(StringArray::from(keys)),
588 std::sync::Arc::new(StringArray::from(values)),
589 ],
590 )?;
591
592 Ok(Self { data: batch })
593 }
594}
595
596#[derive(Debug, Clone)]
600pub struct PlatformAnnotation {
601 pub data: RecordBatch,
603}
604
605impl PlatformAnnotation {
606 #[allow(clippy::similar_names)]
615 pub fn from_platform(platform: &geo_soft_rs::GplRecord) -> Result<Option<Self>> {
616 let Some(ref table) = platform.annotation_table else {
617 return Ok(None);
618 };
619
620 let probe_idx = table
621 .columns
622 .iter()
623 .position(|c| {
624 c.name.eq_ignore_ascii_case("ID")
625 || c.name.eq_ignore_ascii_case("PROBE_ID")
626 || c.name.eq_ignore_ascii_case("ID_REF")
627 })
628 .ok_or_else(|| {
629 Error::Matrix("Platform annotation missing probe ID column".to_string())
630 })?;
631
632 let gene_idx = table.columns.iter().position(|c| {
633 c.name.eq_ignore_ascii_case("GENE_SYMBOL")
634 || c.name.eq_ignore_ascii_case("SYMBOL")
635 || c.name.eq_ignore_ascii_case("GENE")
636 });
637
638 let entrez_idx = table.columns.iter().position(|c| {
639 c.name.eq_ignore_ascii_case("ENTREZ_ID")
640 || c.name.eq_ignore_ascii_case("ENTREZ")
641 || c.name.eq_ignore_ascii_case("GENE_ID")
642 });
643
644 let desc_idx = table.columns.iter().position(|c| {
645 c.name.eq_ignore_ascii_case("DESCRIPTION")
646 || c.name.eq_ignore_ascii_case("DESC")
647 || c.name.eq_ignore_ascii_case("GENE_TITLE")
648 });
649
650 let mut probe_ids: Vec<&str> = Vec::new();
651 let mut gene_symbols: Vec<Option<&str>> = Vec::new();
652 let mut gene_entrez_ids: Vec<Option<&str>> = Vec::new();
653 let mut descriptions: Vec<Option<&str>> = Vec::new();
654
655 for row in &table.rows {
656 if let Some(probe) = row.get(probe_idx) {
657 probe_ids.push(probe);
658 gene_symbols.push(gene_idx.and_then(|i| row.get(i).map(String::as_str)));
659 gene_entrez_ids.push(entrez_idx.and_then(|i| row.get(i).map(String::as_str)));
660 descriptions.push(desc_idx.and_then(|i| row.get(i).map(String::as_str)));
661 }
662 }
663
664 let schema = Schema::new(vec![
665 Field::new("probe_id", DataType::Utf8, false),
666 Field::new("gene_symbol", DataType::Utf8, true),
667 Field::new("entrez_id", DataType::Utf8, true),
668 Field::new("description", DataType::Utf8, true),
669 ]);
670
671 let batch = RecordBatch::try_new(
672 std::sync::Arc::new(schema),
673 vec![
674 std::sync::Arc::new(StringArray::from(probe_ids)),
675 std::sync::Arc::new(StringArray::from(gene_symbols)),
676 std::sync::Arc::new(StringArray::from(gene_entrez_ids)),
677 std::sync::Arc::new(StringArray::from(descriptions)),
678 ],
679 )?;
680
681 Ok(Some(Self { data: batch }))
682 }
683
684 pub fn from_soft<R>(mut reader: geo_soft_rs::SoftReader<R>) -> Result<Option<Self>>
693 where
694 R: std::io::BufRead,
695 {
696 while let Some(result) = reader.next_platform() {
697 let platform = result?;
698 if let Some(annotation) = Self::from_platform(&platform)? {
699 return Ok(Some(annotation));
700 }
701 }
702 Ok(None)
703 }
704}