1use cfg_if::cfg_if;
2pub use error::{HtsGetError, Result};
3pub use htsget_config::config::Config;
4use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
5use htsget_config::types::{Format, Query, Request, Response};
6pub use http_core::{get, post};
7pub use post_request::{PostRequest, Region};
8use query_builder::QueryBuilder;
9pub use service_info::get_service_info_json;
10pub use service_info::{Htsget, ServiceInfo, Type};
11use std::collections::HashMap;
12use std::result;
13use std::str::FromStr;
14
15pub mod error;
16pub mod http_core;
17pub mod middleware;
18pub mod post_request;
19pub mod query_builder;
20pub mod service_info;
21
22#[derive(Debug, PartialEq, Eq)]
25pub enum Endpoint {
26 Reads,
27 Variants,
28}
29
30impl FromStr for Endpoint {
31 type Err = ();
32
33 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
34 match s {
35 "reads" => Ok(Self::Reads),
36 "variants" => Ok(Self::Variants),
37 _ => Err(()),
38 }
39 }
40}
41
42pub fn match_format_from_query(
44 endpoint: &Endpoint,
45 query: &HashMap<String, String>,
46) -> Result<Format> {
47 match_format(endpoint, query.get("format"))
48}
49
50pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
52 let format = format.map(Into::into).map(|format| format.to_lowercase());
53
54 match (endpoint, format) {
55 (Endpoint::Reads, None) => Ok(Bam),
56 (Endpoint::Variants, None) => Ok(Vcf),
57 (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
58 (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
59 (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
60 (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
61 (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
62 "{format} isn't a supported format for this endpoint"
63 ))),
64 }
65}
66
67fn convert_to_query(request: Request, format: Format) -> Result<Query> {
68 let query = request.query().clone();
69
70 set_query_builder(
71 QueryBuilder::new(request, format),
72 query.get("class"),
73 query.get("referenceName"),
74 query.get("fields"),
75 (query.get("tags"), query.get("notags")),
76 (query.get("start"), query.get("end")),
77 query.get("encryptionScheme"),
78 )
79}
80
81fn set_query_builder(
82 builder: QueryBuilder,
83 class: Option<impl Into<String>>,
84 reference_name: Option<impl Into<String>>,
85 fields: Option<impl Into<String>>,
86 (tags, no_tags): (Option<impl Into<String>>, Option<impl Into<String>>),
87 (start, end): (Option<impl Into<String>>, Option<impl Into<String>>),
88 _encryption_scheme: Option<impl Into<String>>,
89) -> Result<Query> {
90 let builder = builder
91 .with_class(class)?
92 .with_fields(fields)
93 .with_tags(tags, no_tags)?
94 .with_reference_name(reference_name)
95 .with_range(start, end)?;
96
97 cfg_if! {
98 if #[cfg(feature = "experimental")] {
99 Ok(builder.with_encryption_scheme(_encryption_scheme)?.build())
100 } else {
101 Ok(builder.build())
102 }
103 }
104}
105
106fn merge_responses(responses: Vec<Response>) -> Option<Response> {
107 responses.into_iter().reduce(|mut acc, mut response| {
108 acc.urls.append(&mut response.urls);
109 acc
110 })
111}
112
113#[cfg(test)]
114mod tests {
115 use std::collections::HashMap;
116 use std::path::PathBuf;
117
118 use htsget_config::storage;
119 use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
120 use htsget_search::FileStorage;
121 use htsget_search::HtsGet;
122 use htsget_search::Storage;
123 use htsget_search::from_storage::HtsGetFromStorage;
124 use http::uri::Authority;
125
126 use super::*;
127
128 #[test]
129 fn match_with_invalid_format() {
130 assert!(matches!(
131 match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
132 HtsGetError::UnsupportedFormat(_)
133 ));
134 }
135
136 #[test]
137 fn match_with_invalid_endpoint() {
138 assert!(matches!(
139 match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
140 HtsGetError::UnsupportedFormat(_)
141 ));
142 }
143
144 #[test]
145 fn match_with_valid_format() {
146 assert!(matches!(
147 match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
148 Bam,
149 ));
150 }
151
152 #[tokio::test]
153 async fn get_request() {
154 let request = HashMap::new();
155
156 let mut expected_response_headers = Headers::default();
157 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
158
159 let request = Request::new(
160 "bam/htsnexus_test_NA12878".to_string(),
161 request,
162 Default::default(),
163 );
164
165 assert_eq!(
166 get(get_searcher(), request, Endpoint::Reads, None, None).await,
167 Ok(expected_bam_json_response(expected_response_headers))
168 );
169 }
170
171 #[tokio::test]
172 async fn get_reads_request_with_variants_format() {
173 let mut request = HashMap::new();
174 request.insert("format".to_string(), "VCF".to_string());
175
176 let request = Request::new(
177 "bam/htsnexus_test_NA12878".to_string(),
178 request,
179 Default::default(),
180 );
181
182 assert!(matches!(
183 get(get_searcher(), request, Endpoint::Reads, None, None).await,
184 Err(HtsGetError::UnsupportedFormat(_))
185 ));
186 }
187
188 #[tokio::test]
189 async fn get_request_with_range() {
190 let mut request = HashMap::new();
191 request.insert("referenceName".to_string(), "chrM".to_string());
192 request.insert("start".to_string(), "149".to_string());
193 request.insert("end".to_string(), "200".to_string());
194
195 let mut expected_response_headers = Headers::default();
196 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
197
198 let request = Request::new(
199 "vcf/sample1-bcbio-cancer".to_string(),
200 request,
201 Default::default(),
202 );
203
204 assert_eq!(
205 get(get_searcher(), request, Endpoint::Variants, None, None).await,
206 Ok(expected_vcf_json_response(expected_response_headers))
207 );
208 }
209
210 #[tokio::test]
211 async fn post_request() {
212 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
213 let body = PostRequest {
214 format: None,
215 class: None,
216 fields: None,
217 tags: None,
218 notags: None,
219 regions: None,
220 encryption_scheme: None,
221 };
222
223 let mut expected_response_headers = Headers::default();
224 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
225
226 assert_eq!(
227 post(get_searcher(), body, request, Endpoint::Reads, None, None).await,
228 Ok(expected_bam_json_response(expected_response_headers))
229 );
230 }
231
232 #[tokio::test]
233 async fn post_variants_request_with_reads_format() {
234 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
235 let body = PostRequest {
236 format: Some("BAM".to_string()),
237 class: None,
238 fields: None,
239 tags: None,
240 notags: None,
241 regions: None,
242 encryption_scheme: None,
243 };
244
245 assert!(matches!(
246 post(
247 get_searcher(),
248 body,
249 request,
250 Endpoint::Variants,
251 None,
252 None
253 )
254 .await,
255 Err(HtsGetError::UnsupportedFormat(_))
256 ));
257 }
258
259 #[tokio::test]
260 async fn post_request_with_range() {
261 let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
262 let body = PostRequest {
263 format: Some("VCF".to_string()),
264 class: None,
265 fields: None,
266 tags: None,
267 notags: None,
268 regions: Some(vec![Region {
269 reference_name: "chrM".to_string(),
270 start: Some(149),
271 end: Some(200),
272 }]),
273 encryption_scheme: None,
274 };
275
276 let mut expected_response_headers = Headers::default();
277 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
278
279 assert_eq!(
280 post(
281 get_searcher(),
282 body,
283 request,
284 Endpoint::Variants,
285 None,
286 None
287 )
288 .await,
289 Ok(expected_vcf_json_response(expected_response_headers))
290 );
291 }
292
293 fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
294 JsonResponse::from(Response::new(
295 Vcf,
296 vec![
297 Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
298 .with_headers(headers),
299 ],
300 ))
301 }
302
303 fn expected_bam_json_response(headers: Headers) -> JsonResponse {
304 JsonResponse::from(Response::new(
305 Bam,
306 vec![
307 Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
308 .with_headers(headers),
309 ],
310 ))
311 }
312
313 fn get_base_path() -> PathBuf {
314 std::env::current_dir()
315 .unwrap()
316 .parent()
317 .unwrap()
318 .join("data")
319 }
320
321 fn get_searcher() -> impl HtsGet + Clone {
322 HtsGetFromStorage::new(Storage::new(
323 FileStorage::new(
324 get_base_path(),
325 storage::file::File::new(
326 Scheme::Http,
327 Authority::from_static("127.0.0.1:8081"),
328 "data".to_string(),
329 ),
330 vec![],
331 )
332 .unwrap(),
333 ))
334 }
335}