1use std::collections::HashSet;
2
3#[cfg(feature = "experimental")]
4use htsget_config::encryption_scheme::EncryptionScheme;
5use htsget_config::types::{Class, Fields, Format, Interval, Query, Request, Tags};
6use tracing::instrument;
7
8use crate::error::{HtsGetError, Result};
9
10#[derive(Debug)]
12pub struct QueryBuilder {
13 query: Query,
14}
15
16impl QueryBuilder {
17 pub fn new(request: Request, format: Format) -> Self {
18 let id = request.path().to_string();
19
20 Self {
21 query: Query::new(id, format, request),
22 }
23 }
24
25 pub fn build(self) -> Query {
26 self.query
27 }
28
29 #[instrument(level = "trace", skip_all, ret)]
30 pub fn with_class(mut self, class: Option<impl Into<String>>) -> Result<Self> {
31 let class = class.map(Into::into);
32
33 self.query = self.query.with_class(match class {
34 None => Class::Body,
35 Some(class) if class == "header" => Class::Header,
36 Some(class) => {
37 return Err(HtsGetError::InvalidInput(format!(
38 "invalid class `{class}`"
39 )));
40 }
41 });
42
43 Ok(self)
44 }
45
46 #[instrument(level = "trace", skip_all, ret)]
47 pub fn with_reference_name(mut self, reference_name: Option<impl Into<String>>) -> Self {
48 if let Some(reference_name) = reference_name {
49 self.query = self.query.with_reference_name(reference_name);
50 }
51 self
52 }
53
54 pub fn with_interval(mut self, interval: Interval) -> Self {
56 self.query = self.query.with_interval(interval);
57 self
58 }
59
60 #[instrument(level = "trace", skip_all, ret)]
61 pub fn with_range(
62 self,
63 start: Option<impl Into<String>>,
64 end: Option<impl Into<String>>,
65 ) -> Result<Self> {
66 let start = start
67 .map(Into::into)
68 .map(|start| {
69 start
70 .parse::<u32>()
71 .map_err(|err| HtsGetError::InvalidInput(format!("`{start}` isn't a valid start: {err}")))
72 })
73 .transpose()?;
74 let end = end
75 .map(Into::into)
76 .map(|end| {
77 end
78 .parse::<u32>()
79 .map_err(|err| HtsGetError::InvalidInput(format!("`{end}` isn't a valid end: {err}")))
80 })
81 .transpose()?;
82
83 self.with_range_from_u32(start, end)
84 }
85
86 pub fn with_range_from_u32(mut self, start: Option<u32>, end: Option<u32>) -> Result<Self> {
87 if let Some(start) = start {
88 self.query = self.query.with_start(start);
89 }
90 if let Some(end) = end {
91 self.query = self.query.with_end(end);
92 }
93
94 if (self.query.interval().start().is_some() || self.query.interval().end().is_some())
95 && self
96 .query
97 .reference_name()
98 .filter(|name| *name != "*")
99 .is_none()
100 {
101 return Err(HtsGetError::InvalidInput(
102 "reference name must be specified with start or end range".to_string(),
103 ));
104 }
105
106 if let (Some(start), Some(end)) = &(self.query.interval().start(), self.query.interval().end())
107 && start > end
108 {
109 return Err(HtsGetError::InvalidRange(format!(
110 "end is greater than start (`{start}` > `{end}`)"
111 )));
112 }
113
114 Ok(self)
115 }
116
117 #[instrument(level = "trace", skip_all, ret)]
118 pub fn with_fields(self, fields: Option<impl Into<String>>) -> Self {
119 self.with_fields_from_vec(
120 fields.map(|fields| fields.into().split(',').map(|s| s.to_string()).collect()),
121 )
122 }
123
124 pub fn with_fields_from_vec(mut self, fields: Option<Vec<impl Into<String>>>) -> Self {
125 if let Some(fields) = fields {
126 self.query = self
127 .query
128 .with_fields(Fields::List(fields.into_iter().map(Into::into).collect()));
129 }
130
131 self
132 }
133
134 #[instrument(level = "trace", skip_all, ret)]
135 pub fn with_tags(
136 self,
137 tags: Option<impl Into<String>>,
138 notags: Option<impl Into<String>>,
139 ) -> Result<Self> {
140 self.with_tags_from_vec(
141 tags.map(|tags| tags.into().split(',').map(|s| s.to_string()).collect()),
142 notags.map(|notags| notags.into().split(',').map(|s| s.to_string()).collect()),
143 )
144 }
145
146 pub fn with_tags_from_vec(
147 mut self,
148 tags: Option<Vec<impl Into<String>>>,
149 notags: Option<Vec<impl Into<String>>>,
150 ) -> Result<Self> {
151 let notags = match notags {
152 Some(notags) => notags.into_iter().map(Into::into).collect(),
153 None => vec![],
154 };
155
156 if let Some(tags) = tags {
157 let tags: HashSet<String> = tags.into_iter().map(Into::into).collect();
158 if tags.iter().any(|tag| notags.contains(tag)) {
159 return Err(HtsGetError::InvalidInput(
160 "tags and notags can't intersect".to_string(),
161 ));
162 }
163 self.query = self.query.with_tags(Tags::List(tags));
164 };
165
166 if !notags.is_empty() {
167 self.query = self.query.with_no_tags(notags);
168 }
169
170 Ok(self)
171 }
172
173 #[cfg(feature = "experimental")]
175 pub fn with_encryption_scheme(
176 mut self,
177 encryption_scheme: Option<impl Into<String>>,
178 ) -> Result<Self> {
179 if let Some(scheme) = encryption_scheme {
180 let scheme = match scheme.into().to_lowercase().as_str() {
181 "c4gh" => Ok(EncryptionScheme::C4GH),
182 scheme => Err(HtsGetError::UnsupportedFormat(format!(
183 "invalid encryption scheme `{scheme}`"
184 ))),
185 }?;
186
187 self.query = self.query.with_encryption_scheme(scheme);
188 }
189
190 Ok(self)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use htsget_config::types::Format::{Bam, Vcf};
197 use htsget_config::types::NoTags;
198
199 use super::*;
200
201 #[test]
202 fn query_with_id() {
203 let request = Request::new_with_id("ValidId".to_string());
204 assert_eq!(
205 QueryBuilder::new(request, Bam).build().id(),
206 "ValidId".to_string()
207 );
208 }
209
210 #[test]
211 fn query_with_format() {
212 let request = Request::new_with_id("ValidId".to_string());
213 assert_eq!(QueryBuilder::new(request, Vcf).build().format(), Vcf);
214 }
215
216 #[test]
217 fn query_with_class() {
218 let request = Request::new_with_id("ValidId".to_string());
219
220 assert_eq!(
221 QueryBuilder::new(request, Bam)
222 .with_class(Some("header"))
223 .unwrap()
224 .build()
225 .class(),
226 Class::Header
227 );
228 }
229
230 #[test]
231 fn query_with_reference_name() {
232 let request = Request::new_with_id("ValidId".to_string());
233
234 assert_eq!(
235 QueryBuilder::new(request, Bam)
236 .with_reference_name(Some("ValidName"))
237 .build()
238 .reference_name(),
239 Some("ValidName")
240 );
241 }
242
243 #[test]
244 fn query_with_range() {
245 let request = Request::new_with_id("ValidId".to_string());
246
247 let query = QueryBuilder::new(request, Bam)
248 .with_reference_name(Some("ValidName"))
249 .with_range(Some("3"), Some("5"))
250 .unwrap()
251 .build();
252 assert_eq!(
253 (query.interval().start(), query.interval().end()),
254 (Some(3), Some(5))
255 );
256 }
257
258 #[test]
259 fn query_with_range_but_without_reference_name() {
260 let request = Request::new_with_id("ValidId".to_string());
261
262 assert!(matches!(
263 QueryBuilder::new(request, Bam)
264 .with_range(Some("3"), Some("5"))
265 .unwrap_err(),
266 HtsGetError::InvalidInput(_)
267 ));
268 }
269
270 #[test]
271 fn query_with_invalid_start() {
272 let request = Request::new_with_id("ValidId".to_string());
273
274 assert!(matches!(
275 QueryBuilder::new(request, Bam)
276 .with_reference_name(Some("ValidName"))
277 .with_range(Some("a"), Some("5"))
278 .unwrap_err(),
279 HtsGetError::InvalidInput(_)
280 ));
281 }
282
283 #[test]
284 fn query_with_invalid_end() {
285 let request = Request::new_with_id("ValidId".to_string());
286
287 assert!(matches!(
288 QueryBuilder::new(request, Bam)
289 .with_reference_name(Some("ValidName"))
290 .with_range(Some("5"), Some("a"))
291 .unwrap_err(),
292 HtsGetError::InvalidInput(_)
293 ));
294 }
295
296 #[test]
297 fn query_with_invalid_range() {
298 let request = Request::new_with_id("ValidId".to_string());
299
300 assert!(matches!(
301 QueryBuilder::new(request, Bam)
302 .with_reference_name(Some("ValidName"))
303 .with_range(Some("5"), Some("3"))
304 .unwrap_err(),
305 HtsGetError::InvalidRange(_)
306 ));
307 }
308
309 #[test]
310 fn query_with_fields() {
311 let request = Request::new_with_id("ValidId".to_string());
312
313 assert_eq!(
314 QueryBuilder::new(request, Bam)
315 .with_fields(Some("header,part1,part2"))
316 .build()
317 .fields(),
318 &Fields::List(HashSet::from_iter(vec![
319 "header".to_string(),
320 "part1".to_string(),
321 "part2".to_string()
322 ]))
323 );
324 }
325
326 #[test]
327 fn query_with_tags() {
328 let request = Request::new_with_id("ValidId".to_string());
329
330 let query = QueryBuilder::new(request, Bam)
331 .with_tags(Some("header,part1,part2"), Some("part3"))
332 .unwrap()
333 .build();
334 assert_eq!(
335 query.tags(),
336 &Tags::List(HashSet::from_iter(vec![
337 "header".to_string(),
338 "part1".to_string(),
339 "part2".to_string()
340 ]))
341 );
342 assert_eq!(
343 query.no_tags(),
344 &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
345 );
346 }
347
348 #[test]
349 fn query_with_invalid_tags() {
350 let request = Request::new_with_id("ValidId".to_string());
351
352 let query = QueryBuilder::new(request, Bam)
353 .with_tags(Some("header,part1,part2"), Some("part3"))
354 .unwrap()
355 .build();
356 assert_eq!(
357 query.tags(),
358 &Tags::List(HashSet::from_iter(vec![
359 "header".to_string(),
360 "part1".to_string(),
361 "part2".to_string()
362 ]))
363 );
364 assert_eq!(
365 query.no_tags(),
366 &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
367 );
368 }
369}