1use cfg_if::cfg_if;
2pub use error::{HtsGetError, Result};
3pub use htsget_config::config::Config;
4use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
5use htsget_config::types::{Format, Query, Request, Response};
6pub use http_core::{get, post};
7pub use post_request::{PostRequest, Region};
8use query_builder::QueryBuilder;
9pub use service_info::get_service_info_json;
10pub use service_info::{Htsget, ServiceInfo, Type};
11use std::collections::HashMap;
12use std::result;
13use std::str::FromStr;
14
15pub mod error;
16pub mod http_core;
17pub mod middleware;
18pub mod post_request;
19pub mod query_builder;
20pub mod service_info;
21
22#[derive(Debug, PartialEq, Eq)]
25pub enum Endpoint {
26 Reads,
27 Variants,
28}
29
30impl FromStr for Endpoint {
31 type Err = ();
32
33 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
34 match s {
35 "reads" => Ok(Self::Reads),
36 "variants" => Ok(Self::Variants),
37 _ => Err(()),
38 }
39 }
40}
41
42pub fn match_format_from_query(
44 endpoint: &Endpoint,
45 query: &HashMap<String, String>,
46) -> Result<Format> {
47 match_format(endpoint, query.get("format"))
48}
49
50pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
52 let format = format.map(Into::into).map(|format| format.to_lowercase());
53
54 match (endpoint, format) {
55 (Endpoint::Reads, None) => Ok(Bam),
56 (Endpoint::Variants, None) => Ok(Vcf),
57 (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
58 (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
59 (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
60 (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
61 (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
62 "{format} isn't a supported format for this endpoint"
63 ))),
64 }
65}
66
67fn convert_to_query(request: Request, format: Format) -> Result<Query> {
68 let query = request.query().clone();
69
70 let builder = QueryBuilder::new(request, format)
71 .with_class(query.get("class"))?
72 .with_reference_name(query.get("referenceName"))
73 .with_range(query.get("start"), query.get("end"))?
74 .with_fields(query.get("fields"))
75 .with_tags(query.get("tags"), query.get("notags"))?;
76
77 cfg_if! {
78 if #[cfg(feature = "experimental")] {
79 Ok(builder.with_encryption_scheme(query.get("encryptionScheme"))?.build())
80 } else {
81 Ok(builder.build())
82 }
83 }
84}
85
86fn merge_responses(responses: Vec<Response>) -> Option<Response> {
87 responses.into_iter().reduce(|mut acc, mut response| {
88 acc.urls.append(&mut response.urls);
89 acc
90 })
91}
92
93#[cfg(test)]
94mod tests {
95 use std::collections::HashMap;
96 use std::path::PathBuf;
97
98 use htsget_config::storage;
99 use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
100 use htsget_search::FileStorage;
101 use htsget_search::HtsGet;
102 use htsget_search::Storage;
103 use htsget_search::from_storage::HtsGetFromStorage;
104 use http::uri::Authority;
105
106 use super::*;
107
108 #[test]
109 fn match_with_invalid_format() {
110 assert!(matches!(
111 match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
112 HtsGetError::UnsupportedFormat(_)
113 ));
114 }
115
116 #[test]
117 fn match_with_invalid_endpoint() {
118 assert!(matches!(
119 match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
120 HtsGetError::UnsupportedFormat(_)
121 ));
122 }
123
124 #[test]
125 fn match_with_valid_format() {
126 assert!(matches!(
127 match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
128 Bam,
129 ));
130 }
131
132 #[tokio::test]
133 async fn get_request() {
134 let request = HashMap::new();
135
136 let mut expected_response_headers = Headers::default();
137 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
138
139 let request = Request::new(
140 "bam/htsnexus_test_NA12878".to_string(),
141 request,
142 Default::default(),
143 );
144
145 assert_eq!(
146 get(get_searcher(), request, Endpoint::Reads).await,
147 Ok(expected_bam_json_response(expected_response_headers))
148 );
149 }
150
151 #[tokio::test]
152 async fn get_reads_request_with_variants_format() {
153 let mut request = HashMap::new();
154 request.insert("format".to_string(), "VCF".to_string());
155
156 let request = Request::new(
157 "bam/htsnexus_test_NA12878".to_string(),
158 request,
159 Default::default(),
160 );
161
162 assert!(matches!(
163 get(get_searcher(), request, Endpoint::Reads).await,
164 Err(HtsGetError::UnsupportedFormat(_))
165 ));
166 }
167
168 #[tokio::test]
169 async fn get_request_with_range() {
170 let mut request = HashMap::new();
171 request.insert("referenceName".to_string(), "chrM".to_string());
172 request.insert("start".to_string(), "149".to_string());
173 request.insert("end".to_string(), "200".to_string());
174
175 let mut expected_response_headers = Headers::default();
176 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
177
178 let request = Request::new(
179 "vcf/sample1-bcbio-cancer".to_string(),
180 request,
181 Default::default(),
182 );
183
184 assert_eq!(
185 get(get_searcher(), request, Endpoint::Variants).await,
186 Ok(expected_vcf_json_response(expected_response_headers))
187 );
188 }
189
190 #[tokio::test]
191 async fn post_request() {
192 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
193 let body = PostRequest {
194 format: None,
195 class: None,
196 fields: None,
197 tags: None,
198 notags: None,
199 regions: None,
200 };
201
202 let mut expected_response_headers = Headers::default();
203 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
204
205 assert_eq!(
206 post(get_searcher(), body, request, Endpoint::Reads).await,
207 Ok(expected_bam_json_response(expected_response_headers))
208 );
209 }
210
211 #[tokio::test]
212 async fn post_variants_request_with_reads_format() {
213 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
214 let body = PostRequest {
215 format: Some("BAM".to_string()),
216 class: None,
217 fields: None,
218 tags: None,
219 notags: None,
220 regions: None,
221 };
222
223 assert!(matches!(
224 post(get_searcher(), body, request, Endpoint::Variants).await,
225 Err(HtsGetError::UnsupportedFormat(_))
226 ));
227 }
228
229 #[tokio::test]
230 async fn post_request_with_range() {
231 let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
232 let body = PostRequest {
233 format: Some("VCF".to_string()),
234 class: None,
235 fields: None,
236 tags: None,
237 notags: None,
238 regions: Some(vec![Region {
239 reference_name: "chrM".to_string(),
240 start: Some(149),
241 end: Some(200),
242 }]),
243 };
244
245 let mut expected_response_headers = Headers::default();
246 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
247
248 assert_eq!(
249 post(get_searcher(), body, request, Endpoint::Variants).await,
250 Ok(expected_vcf_json_response(expected_response_headers))
251 );
252 }
253
254 fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
255 JsonResponse::from(Response::new(
256 Vcf,
257 vec![
258 Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
259 .with_headers(headers),
260 ],
261 ))
262 }
263
264 fn expected_bam_json_response(headers: Headers) -> JsonResponse {
265 JsonResponse::from(Response::new(
266 Bam,
267 vec![
268 Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
269 .with_headers(headers),
270 ],
271 ))
272 }
273
274 fn get_base_path() -> PathBuf {
275 std::env::current_dir()
276 .unwrap()
277 .parent()
278 .unwrap()
279 .join("data")
280 }
281
282 fn get_searcher() -> impl HtsGet + Clone {
283 HtsGetFromStorage::new(Storage::new(
284 FileStorage::new(
285 get_base_path(),
286 storage::file::File::new(
287 Scheme::Http,
288 Authority::from_static("127.0.0.1:8081"),
289 "data".to_string(),
290 ),
291 vec![],
292 )
293 .unwrap(),
294 ))
295 }
296}