1use std::collections::{HashMap, HashSet};
18use std::fs;
19use std::io::{self, BufRead, Read, Write};
20use std::path::PathBuf;
21
22use regex::Regex;
23use serde::{Deserialize, Serialize};
24use serde_json;
25
26pub mod errors;
27pub mod taxa;
28
29use crate::errors::ASDBTaxonError;
30use crate::taxa::NcbiTaxEntry;
31
32#[derive(Debug, Clone, Deserialize, Serialize)]
33pub struct TaxonCache {
34 pub deprecated_ids: HashMap<i64, i64>,
35 pub mappings: HashMap<i64, NcbiTaxEntry>,
36}
37
38impl TaxonCache {
39 pub fn new() -> TaxonCache {
40 TaxonCache {
41 deprecated_ids: HashMap::new(),
42 mappings: HashMap::new(),
43 }
44 }
45
46 pub fn initialise(
47 &mut self,
48 taxdump: impl Read,
49 merged_id_dump: impl Read,
50 taxids: &mut HashSet<i64>,
51 ) -> Result<(), ASDBTaxonError> {
52 populate_merged_ids(merged_id_dump, taxids, &mut self.deprecated_ids)?;
53
54 populate_mappings(taxdump, taxids, &self.deprecated_ids, &mut self.mappings)?;
55
56 Ok(())
57 }
58
59 pub fn initialise_from_paths(
60 &mut self,
61 taxdump_path: PathBuf,
62 merged_id_dump_path: PathBuf,
63 datadir_path: PathBuf,
64 ) -> Result<(), ASDBTaxonError> {
65 let mut taxids = self.find_taxids(datadir_path)?;
66 let taxdump = fs::File::open(taxdump_path)?;
67 let mergeddump = fs::File::open(merged_id_dump_path)?;
68
69 self.initialise(taxdump, mergeddump, &mut taxids)?;
70
71 Ok(())
72 }
73
74 pub fn find_taxids(&self, datadir: PathBuf) -> Result<HashSet<i64>, ASDBTaxonError> {
75 let re = Regex::new(r#""taxon:(\d+)"#)?;
76 let mut taxids: HashSet<i64> = HashSet::new();
77 let mut entries = fs::read_dir(datadir)?
78 .map(|res| res.map(|e| e.path()))
79 .filter(|p| p.is_ok() && p.as_ref().unwrap().extension() == Some("json".as_ref()))
80 .collect::<Result<Vec<_>, io::Error>>()?;
81
82 entries.sort();
83
84 for path in entries {
85 let content = fs::read_to_string(&path)?;
86 let cap = re.captures(&content);
87 if cap.is_none() {
88 continue;
89 }
90 let taxid_match = cap.unwrap().get(1);
91 if taxid_match.is_none() {
92 continue;
93 }
94 if let Ok(taxid) = taxid_match.unwrap().as_str().parse::<i64>() {
95 taxids.insert(taxid);
96 }
97 }
98 Ok(taxids)
99 }
100
101 pub fn save(&self, mut output: impl Write) -> Result<usize, ASDBTaxonError> {
102 let json_data = serde_json::to_string(self)?;
103 output.write(json_data.as_bytes())?;
104
105 Ok(self.mappings.len())
106 }
107
108 pub fn save_path(&self, outfile: &PathBuf) -> Result<usize, ASDBTaxonError> {
109 let out = fs::File::create(outfile)?;
110 self.save(out)
111 }
112
113 pub fn load(&mut self, mut input: impl Read) -> Result<usize, ASDBTaxonError> {
114 let mut json_data = String::new();
115 input.read_to_string(&mut json_data)?;
116 let loaded_cache: TaxonCache = serde_json::from_str(&json_data)?;
117 self.mappings = loaded_cache.mappings;
118 self.deprecated_ids = loaded_cache.deprecated_ids;
119
120 Ok(self.mappings.len())
121 }
122
123 pub fn load_path(&mut self, infile: &PathBuf) -> Result<usize, ASDBTaxonError> {
124 let handle = fs::File::open(infile)?;
125 self.load(handle)
126 }
127}
128
129fn populate_merged_ids(
130 merged_id_dump: impl Read,
131 taxids: &mut HashSet<i64>,
132 deprecated_ids: &mut HashMap<i64, i64>,
133) -> Result<(), ASDBTaxonError> {
134 for line_option in io::BufReader::new(merged_id_dump).lines() {
135 if let Ok(line) = line_option {
136 let parts: Vec<String> = line
137 .trim()
138 .splitn(3, "|")
139 .map(|part| part.trim().to_string())
140 .collect();
141
142 let old_id: i64 = parts[0].parse()?;
143 if !taxids.contains(&old_id) {
144 continue;
145 }
146
147 let new_id: i64 = parts[1].parse()?;
148
149 deprecated_ids.insert(old_id, new_id);
150 taxids.remove(&old_id);
151 taxids.insert(new_id);
152 }
153 }
154 Ok(())
155}
156
157fn populate_mappings(
158 taxdump: impl Read,
159 taxids: &HashSet<i64>,
160 deprecated_ids: &HashMap<i64, i64>,
161 mappings: &mut HashMap<i64, NcbiTaxEntry>,
162) -> Result<(), ASDBTaxonError> {
163 for line_option in io::BufReader::new(taxdump).lines() {
164 if let Ok(line) = line_option {
165 let parts: Vec<String> = line
166 .trim()
167 .splitn(11, "|")
168 .map(|part| match part.trim() {
169 "" => "Unknown".to_string(),
170 part => part.to_string(),
171 })
172 .collect();
173
174 let mut tax_id: i64 = parts[0].parse()?;
175 if deprecated_ids.contains_key(&tax_id) {
176 tax_id = *deprecated_ids.get(&tax_id).unwrap();
177 }
178
179 if !taxids.contains(&tax_id) {
180 continue;
181 }
182
183 let mut entry = NcbiTaxEntry {
184 tax_id,
185 name: parts[1].to_owned(),
186 species: parts[2]
187 .split_whitespace()
188 .next_back()
189 .unwrap_or(parts[2].as_str())
190 .to_owned(),
191 genus: parts[3].to_owned(),
192 family: parts[4].to_owned(),
193 order: parts[5].to_owned(),
194 class: parts[6].to_owned(),
195 phylum: parts[7].to_owned(),
196 kingdom: parts[8].to_owned(),
197 superkingdom: parts[9].to_owned(),
198 };
199
200 if !entry.name.contains(" sp. ") && entry.species == "Unknown" {
201 entry.species = entry.name.split_whitespace().next_back().unwrap_or("Unknown").to_owned();
202 }
203
204 mappings.insert(tax_id, entry.to_owned());
205 }
206 }
207 Ok(())
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn test_initialise() {
216 let mut taxids: HashSet<i64> = HashSet::new();
217 taxids.insert(12345);
218
219 let merged_ids = "12345 | 23456 |".as_bytes();
220 let taxdump = "23456 | Streptomyces examplis NBC12345 | Streptomyces examplis | Streptomyces | Streptomycetaceae | Streptomycetales | Actinomycetia | Actinobacteria | | Bacteria |".as_bytes();
221
222 let mut taxon_cache = TaxonCache::new();
223
224 let res = taxon_cache.initialise(taxdump, merged_ids, &mut taxids);
225 assert_eq!(res.is_ok(), true);
226 assert_eq!(
227 taxon_cache.deprecated_ids.len(),
228 1,
229 "unexpected length of deprecated_ids: {}",
230 taxon_cache.deprecated_ids.len()
231 );
232 assert_eq!(
233 taxon_cache.mappings.len(),
234 1,
235 "unexpected length of mappings: {}",
236 taxon_cache.mappings.len()
237 );
238 assert_eq!(
239 taxon_cache.mappings.get(&23456).unwrap().name,
240 "Streptomyces examplis NBC12345"
241 );
242 }
243}