1use cfg_if::cfg_if;
2pub use error::{HtsGetError, Result};
3pub use htsget_config::config::Config;
4use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
5use htsget_config::types::{Format, Query, Request, Response};
6pub use http_core::{get, post};
7pub use post_request::{PostRequest, Region};
8use query_builder::QueryBuilder;
9pub use service_info::get_service_info_json;
10pub use service_info::{Htsget, ServiceInfo, Type};
11use std::collections::HashMap;
12use std::fmt::{Display, Formatter};
13use std::str::FromStr;
14use std::{fmt, result};
15
16pub mod error;
17pub mod http_core;
18pub mod middleware;
19pub mod post_request;
20pub mod query_builder;
21pub mod service_info;
22
23#[derive(Debug, PartialEq, Eq)]
26pub enum Endpoint {
27 Reads,
28 Variants,
29}
30
31impl FromStr for Endpoint {
32 type Err = ();
33
34 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
35 match s {
36 "reads" => Ok(Self::Reads),
37 "variants" => Ok(Self::Variants),
38 _ => Err(()),
39 }
40 }
41}
42
43impl Display for Endpoint {
44 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
45 match self {
46 Self::Reads => write!(f, "reads"),
47 Self::Variants => write!(f, "variants"),
48 }
49 }
50}
51
52pub fn match_format_from_query(
54 endpoint: &Endpoint,
55 query: &HashMap<String, String>,
56) -> Result<Format> {
57 match_format(endpoint, query.get("format"))
58}
59
60pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
62 let format = format.map(Into::into).map(|format| format.to_lowercase());
63
64 match (endpoint, format) {
65 (Endpoint::Reads, None) => Ok(Bam),
66 (Endpoint::Variants, None) => Ok(Vcf),
67 (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
68 (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
69 (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
70 (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
71 (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
72 "{format} isn't a supported format for this endpoint"
73 ))),
74 }
75}
76
77fn convert_to_query(request: Request, format: Format) -> Result<Query> {
78 let query = request.query().clone();
79
80 set_query_builder(
81 QueryBuilder::new(request, format),
82 query.get("class"),
83 query.get("referenceName"),
84 query.get("fields"),
85 (query.get("tags"), query.get("notags")),
86 (query.get("start"), query.get("end")),
87 query.get("encryptionScheme"),
88 )
89}
90
91fn set_query_builder(
92 builder: QueryBuilder,
93 class: Option<impl Into<String>>,
94 reference_name: Option<impl Into<String>>,
95 fields: Option<impl Into<String>>,
96 (tags, no_tags): (Option<impl Into<String>>, Option<impl Into<String>>),
97 (start, end): (Option<impl Into<String>>, Option<impl Into<String>>),
98 _encryption_scheme: Option<impl Into<String>>,
99) -> Result<Query> {
100 let builder = builder
101 .with_class(class)?
102 .with_fields(fields)
103 .with_tags(tags, no_tags)?
104 .with_reference_name(reference_name)
105 .with_range(start, end)?;
106
107 cfg_if! {
108 if #[cfg(feature = "experimental")] {
109 Ok(builder.with_encryption_scheme(_encryption_scheme)?.build())
110 } else {
111 Ok(builder.build())
112 }
113 }
114}
115
116fn merge_responses(responses: Vec<Response>) -> Option<Response> {
117 responses.into_iter().reduce(|mut acc, mut response| {
118 acc.urls.append(&mut response.urls);
119 acc
120 })
121}
122
123#[cfg(test)]
124mod tests {
125 use std::collections::HashMap;
126 use std::path::PathBuf;
127
128 use htsget_config::storage;
129 use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
130 use htsget_search::FileStorage;
131 use htsget_search::HtsGet;
132 use htsget_search::Storage;
133 use htsget_search::from_storage::HtsGetFromStorage;
134 use http::uri::Authority;
135
136 use super::*;
137
138 #[test]
139 fn match_with_invalid_format() {
140 assert!(matches!(
141 match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
142 HtsGetError::UnsupportedFormat(_)
143 ));
144 }
145
146 #[test]
147 fn match_with_invalid_endpoint() {
148 assert!(matches!(
149 match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
150 HtsGetError::UnsupportedFormat(_)
151 ));
152 }
153
154 #[test]
155 fn match_with_valid_format() {
156 assert!(matches!(
157 match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
158 Bam,
159 ));
160 }
161
162 #[tokio::test]
163 async fn get_request() {
164 let request = HashMap::new();
165
166 let mut expected_response_headers = Headers::default();
167 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
168
169 let request = Request::new(
170 "bam/htsnexus_test_NA12878".to_string(),
171 request,
172 Default::default(),
173 );
174
175 assert_eq!(
176 get(get_searcher(), request, Endpoint::Reads, None, None, None).await,
177 Ok(expected_bam_json_response(expected_response_headers))
178 );
179 }
180
181 #[tokio::test]
182 async fn get_reads_request_with_variants_format() {
183 let mut request = HashMap::new();
184 request.insert("format".to_string(), "VCF".to_string());
185
186 let request = Request::new(
187 "bam/htsnexus_test_NA12878".to_string(),
188 request,
189 Default::default(),
190 );
191
192 assert!(matches!(
193 get(get_searcher(), request, Endpoint::Reads, None, None, None).await,
194 Err(HtsGetError::UnsupportedFormat(_))
195 ));
196 }
197
198 #[tokio::test]
199 async fn get_request_with_range() {
200 let mut request = HashMap::new();
201 request.insert("referenceName".to_string(), "chrM".to_string());
202 request.insert("start".to_string(), "149".to_string());
203 request.insert("end".to_string(), "200".to_string());
204
205 let mut expected_response_headers = Headers::default();
206 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
207
208 let request = Request::new(
209 "vcf/sample1-bcbio-cancer".to_string(),
210 request,
211 Default::default(),
212 );
213
214 assert_eq!(
215 get(
216 get_searcher(),
217 request,
218 Endpoint::Variants,
219 None,
220 None,
221 None
222 )
223 .await,
224 Ok(expected_vcf_json_response(expected_response_headers))
225 );
226 }
227
228 #[tokio::test]
229 async fn post_request() {
230 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
231 let body = PostRequest {
232 format: None,
233 class: None,
234 fields: None,
235 tags: None,
236 notags: None,
237 regions: None,
238 encryption_scheme: None,
239 };
240
241 let mut expected_response_headers = Headers::default();
242 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
243
244 assert_eq!(
245 post(
246 get_searcher(),
247 body,
248 request,
249 Endpoint::Reads,
250 None,
251 None,
252 None
253 )
254 .await,
255 Ok(expected_bam_json_response(expected_response_headers))
256 );
257 }
258
259 #[tokio::test]
260 async fn post_variants_request_with_reads_format() {
261 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
262 let body = PostRequest {
263 format: Some("BAM".to_string()),
264 class: None,
265 fields: None,
266 tags: None,
267 notags: None,
268 regions: None,
269 encryption_scheme: None,
270 };
271
272 assert!(matches!(
273 post(
274 get_searcher(),
275 body,
276 request,
277 Endpoint::Variants,
278 None,
279 None,
280 None
281 )
282 .await,
283 Err(HtsGetError::UnsupportedFormat(_))
284 ));
285 }
286
287 #[tokio::test]
288 async fn post_request_with_range() {
289 let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
290 let body = PostRequest {
291 format: Some("VCF".to_string()),
292 class: None,
293 fields: None,
294 tags: None,
295 notags: None,
296 regions: Some(vec![Region {
297 reference_name: "chrM".to_string(),
298 start: Some(149),
299 end: Some(200),
300 }]),
301 encryption_scheme: None,
302 };
303
304 let mut expected_response_headers = Headers::default();
305 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
306
307 assert_eq!(
308 post(
309 get_searcher(),
310 body,
311 request,
312 Endpoint::Variants,
313 None,
314 None,
315 None
316 )
317 .await,
318 Ok(expected_vcf_json_response(expected_response_headers))
319 );
320 }
321
322 fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
323 JsonResponse::from(Response::new(
324 Vcf,
325 vec![
326 Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
327 .with_headers(headers),
328 ],
329 ))
330 }
331
332 fn expected_bam_json_response(headers: Headers) -> JsonResponse {
333 JsonResponse::from(Response::new(
334 Bam,
335 vec![
336 Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
337 .with_headers(headers),
338 ],
339 ))
340 }
341
342 fn get_base_path() -> PathBuf {
343 std::env::current_dir()
344 .unwrap()
345 .parent()
346 .unwrap()
347 .join("data")
348 }
349
350 fn get_searcher() -> impl HtsGet + Clone {
351 HtsGetFromStorage::new(Storage::new(
352 FileStorage::new(
353 get_base_path(),
354 storage::file::File::new(
355 Scheme::Http,
356 Authority::from_static("127.0.0.1:8081"),
357 "data".to_string(),
358 ),
359 vec![],
360 )
361 .unwrap(),
362 ))
363 }
364}