1use cfg_if::cfg_if;
2pub use error::{HtsGetError, Result};
3pub use htsget_config::config::Config;
4use htsget_config::types::Format::{Bam, Bcf, Cram, Vcf};
5use htsget_config::types::{Format, Query, Request, Response};
6pub use http_core::{get, post};
7pub use post_request::{PostRequest, Region};
8use query_builder::QueryBuilder;
9pub use service_info::get_service_info_json;
10pub use service_info::{Htsget, ServiceInfo, Type};
11use std::result;
12use std::str::FromStr;
13
14mod error;
15mod http_core;
16mod post_request;
17mod query_builder;
18mod service_info;
19
20#[derive(Debug, PartialEq, Eq)]
23pub enum Endpoint {
24 Reads,
25 Variants,
26}
27
28impl FromStr for Endpoint {
29 type Err = ();
30
31 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
32 match s {
33 "reads" => Ok(Self::Reads),
34 "variants" => Ok(Self::Variants),
35 _ => Err(()),
36 }
37 }
38}
39
40pub fn match_format(endpoint: &Endpoint, format: Option<impl Into<String>>) -> Result<Format> {
42 let format = format.map(Into::into).map(|format| format.to_lowercase());
43
44 match (endpoint, format) {
45 (Endpoint::Reads, None) => Ok(Bam),
46 (Endpoint::Variants, None) => Ok(Vcf),
47 (Endpoint::Reads, Some(s)) if s == "bam" => Ok(Bam),
48 (Endpoint::Reads, Some(s)) if s == "cram" => Ok(Cram),
49 (Endpoint::Variants, Some(s)) if s == "vcf" => Ok(Vcf),
50 (Endpoint::Variants, Some(s)) if s == "bcf" => Ok(Bcf),
51 (_, Some(format)) => Err(HtsGetError::UnsupportedFormat(format!(
52 "{format} isn't a supported format for this endpoint"
53 ))),
54 }
55}
56
57fn convert_to_query(request: Request, format: Format) -> Result<Query> {
58 let query = request.query().clone();
59
60 let builder = QueryBuilder::new(request, format)
61 .with_class(query.get("class"))?
62 .with_reference_name(query.get("referenceName"))
63 .with_range(query.get("start"), query.get("end"))?
64 .with_fields(query.get("fields"))
65 .with_tags(query.get("tags"), query.get("notags"))?;
66
67 cfg_if! {
68 if #[cfg(feature = "experimental")] {
69 Ok(builder.with_encryption_scheme(query.get("encryptionScheme"))?.build())
70 } else {
71 Ok(builder.build())
72 }
73 }
74}
75
76fn merge_responses(responses: Vec<Response>) -> Option<Response> {
77 responses.into_iter().reduce(|mut acc, mut response| {
78 acc.urls.append(&mut response.urls);
79 acc
80 })
81}
82
83#[cfg(test)]
84mod tests {
85 use std::collections::HashMap;
86 use std::path::PathBuf;
87
88 use htsget_config::storage;
89 use htsget_config::types::{Headers, JsonResponse, Request, Scheme, Url};
90 use htsget_search::from_storage::HtsGetFromStorage;
91 use htsget_search::FileStorage;
92 use htsget_search::HtsGet;
93 use htsget_search::Storage;
94 use http::uri::Authority;
95
96 use super::*;
97
98 #[test]
99 fn match_with_invalid_format() {
100 assert!(matches!(
101 match_format(&Endpoint::Reads, Some("Invalid".to_string())).unwrap_err(),
102 HtsGetError::UnsupportedFormat(_)
103 ));
104 }
105
106 #[test]
107 fn match_with_invalid_endpoint() {
108 assert!(matches!(
109 match_format(&Endpoint::Variants, Some("bam".to_string())).unwrap_err(),
110 HtsGetError::UnsupportedFormat(_)
111 ));
112 }
113
114 #[test]
115 fn match_with_valid_format() {
116 assert!(matches!(
117 match_format(&Endpoint::Reads, Some("bam".to_string())).unwrap(),
118 Bam,
119 ));
120 }
121
122 #[tokio::test]
123 async fn get_request() {
124 let request = HashMap::new();
125
126 let mut expected_response_headers = Headers::default();
127 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
128
129 let request = Request::new(
130 "bam/htsnexus_test_NA12878".to_string(),
131 request,
132 Default::default(),
133 );
134
135 assert_eq!(
136 get(get_searcher(), request, Endpoint::Reads).await,
137 Ok(expected_bam_json_response(expected_response_headers))
138 );
139 }
140
141 #[tokio::test]
142 async fn get_reads_request_with_variants_format() {
143 let mut request = HashMap::new();
144 request.insert("format".to_string(), "VCF".to_string());
145
146 let request = Request::new(
147 "bam/htsnexus_test_NA12878".to_string(),
148 request,
149 Default::default(),
150 );
151
152 assert!(matches!(
153 get(get_searcher(), request, Endpoint::Reads).await,
154 Err(HtsGetError::UnsupportedFormat(_))
155 ));
156 }
157
158 #[tokio::test]
159 async fn get_request_with_range() {
160 let mut request = HashMap::new();
161 request.insert("referenceName".to_string(), "chrM".to_string());
162 request.insert("start".to_string(), "149".to_string());
163 request.insert("end".to_string(), "200".to_string());
164
165 let mut expected_response_headers = Headers::default();
166 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
167
168 let request = Request::new(
169 "vcf/sample1-bcbio-cancer".to_string(),
170 request,
171 Default::default(),
172 );
173
174 assert_eq!(
175 get(get_searcher(), request, Endpoint::Variants).await,
176 Ok(expected_vcf_json_response(expected_response_headers))
177 );
178 }
179
180 #[tokio::test]
181 async fn post_request() {
182 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
183 let body = PostRequest {
184 format: None,
185 class: None,
186 fields: None,
187 tags: None,
188 notags: None,
189 regions: None,
190 };
191
192 let mut expected_response_headers = Headers::default();
193 expected_response_headers.insert("Range".to_string(), "bytes=0-2596798".to_string());
194
195 assert_eq!(
196 post(get_searcher(), body, request, Endpoint::Reads).await,
197 Ok(expected_bam_json_response(expected_response_headers))
198 );
199 }
200
201 #[tokio::test]
202 async fn post_variants_request_with_reads_format() {
203 let request = Request::new_with_id("bam/htsnexus_test_NA12878".to_string());
204 let body = PostRequest {
205 format: Some("BAM".to_string()),
206 class: None,
207 fields: None,
208 tags: None,
209 notags: None,
210 regions: None,
211 };
212
213 assert!(matches!(
214 post(get_searcher(), body, request, Endpoint::Variants).await,
215 Err(HtsGetError::UnsupportedFormat(_))
216 ));
217 }
218
219 #[tokio::test]
220 async fn post_request_with_range() {
221 let request = Request::new_with_id("vcf/sample1-bcbio-cancer".to_string());
222 let body = PostRequest {
223 format: Some("VCF".to_string()),
224 class: None,
225 fields: None,
226 tags: None,
227 notags: None,
228 regions: Some(vec![Region {
229 reference_name: "chrM".to_string(),
230 start: Some(149),
231 end: Some(200),
232 }]),
233 };
234
235 let mut expected_response_headers = Headers::default();
236 expected_response_headers.insert("Range".to_string(), "bytes=0-3493".to_string());
237
238 assert_eq!(
239 post(get_searcher(), body, request, Endpoint::Variants).await,
240 Ok(expected_vcf_json_response(expected_response_headers))
241 );
242 }
243
244 fn expected_vcf_json_response(headers: Headers) -> JsonResponse {
245 JsonResponse::from(Response::new(
246 Vcf,
247 vec![
248 Url::new("http://127.0.0.1:8081/vcf/sample1-bcbio-cancer.vcf.gz".to_string())
249 .with_headers(headers),
250 ],
251 ))
252 }
253
254 fn expected_bam_json_response(headers: Headers) -> JsonResponse {
255 JsonResponse::from(Response::new(
256 Bam,
257 vec![
258 Url::new("http://127.0.0.1:8081/bam/htsnexus_test_NA12878.bam".to_string())
259 .with_headers(headers),
260 ],
261 ))
262 }
263
264 fn get_base_path() -> PathBuf {
265 std::env::current_dir()
266 .unwrap()
267 .parent()
268 .unwrap()
269 .join("data")
270 }
271
272 fn get_searcher() -> impl HtsGet + Clone {
273 HtsGetFromStorage::new(Storage::new(
274 FileStorage::new(
275 get_base_path(),
276 storage::file::File::new(
277 Scheme::Http,
278 Authority::from_static("127.0.0.1:8081"),
279 "data".to_string(),
280 ),
281 )
282 .unwrap(),
283 ))
284 }
285}