1use std::collections::HashSet;
2
3#[cfg(feature = "experimental")]
4use htsget_config::encryption_scheme::EncryptionScheme;
5use htsget_config::types::{Class, Fields, Format, Query, Request, Tags};
6use tracing::instrument;
7
8use crate::error::{HtsGetError, Result};
9
10#[derive(Debug)]
12pub struct QueryBuilder {
13 query: Query,
14}
15
16impl QueryBuilder {
17 pub fn new(request: Request, format: Format) -> Self {
18 let id = request.path().to_string();
19
20 Self {
21 query: Query::new(id, format, request),
22 }
23 }
24
25 pub fn build(self) -> Query {
26 self.query
27 }
28
29 #[instrument(level = "trace", skip_all, ret)]
30 pub fn with_class(mut self, class: Option<impl Into<String>>) -> Result<Self> {
31 let class = class.map(Into::into);
32
33 self.query = self.query.with_class(match class {
34 None => Class::Body,
35 Some(class) if class == "header" => Class::Header,
36 Some(class) => {
37 return Err(HtsGetError::InvalidInput(format!(
38 "invalid class `{class}`"
39 )));
40 }
41 });
42
43 Ok(self)
44 }
45
46 #[instrument(level = "trace", skip_all, ret)]
47 pub fn with_reference_name(mut self, reference_name: Option<impl Into<String>>) -> Self {
48 if let Some(reference_name) = reference_name {
49 self.query = self.query.with_reference_name(reference_name);
50 }
51 self
52 }
53
54 #[instrument(level = "trace", skip_all, ret)]
55 pub fn with_range(
56 self,
57 start: Option<impl Into<String>>,
58 end: Option<impl Into<String>>,
59 ) -> Result<Self> {
60 let start = start
61 .map(Into::into)
62 .map(|start| {
63 start
64 .parse::<u32>()
65 .map_err(|err| HtsGetError::InvalidInput(format!("`{start}` isn't a valid start: {err}")))
66 })
67 .transpose()?;
68 let end = end
69 .map(Into::into)
70 .map(|end| {
71 end
72 .parse::<u32>()
73 .map_err(|err| HtsGetError::InvalidInput(format!("`{end}` isn't a valid end: {err}")))
74 })
75 .transpose()?;
76
77 self.with_range_from_u32(start, end)
78 }
79
80 pub fn with_range_from_u32(mut self, start: Option<u32>, end: Option<u32>) -> Result<Self> {
81 if let Some(start) = start {
82 self.query = self.query.with_start(start);
83 }
84 if let Some(end) = end {
85 self.query = self.query.with_end(end);
86 }
87
88 if (self.query.interval().start().is_some() || self.query.interval().end().is_some())
89 && self
90 .query
91 .reference_name()
92 .filter(|name| *name != "*")
93 .is_none()
94 {
95 return Err(HtsGetError::InvalidInput(
96 "reference name must be specified with start or end range".to_string(),
97 ));
98 }
99
100 if let (Some(start), Some(end)) = &(self.query.interval().start(), self.query.interval().end())
101 {
102 if start > end {
103 return Err(HtsGetError::InvalidRange(format!(
104 "end is greater than start (`{start}` > `{end}`)"
105 )));
106 }
107 }
108
109 Ok(self)
110 }
111
112 #[instrument(level = "trace", skip_all, ret)]
113 pub fn with_fields(self, fields: Option<impl Into<String>>) -> Self {
114 self.with_fields_from_vec(
115 fields.map(|fields| fields.into().split(',').map(|s| s.to_string()).collect()),
116 )
117 }
118
119 pub fn with_fields_from_vec(mut self, fields: Option<Vec<impl Into<String>>>) -> Self {
120 if let Some(fields) = fields {
121 self.query = self
122 .query
123 .with_fields(Fields::List(fields.into_iter().map(Into::into).collect()));
124 }
125
126 self
127 }
128
129 #[instrument(level = "trace", skip_all, ret)]
130 pub fn with_tags(
131 self,
132 tags: Option<impl Into<String>>,
133 notags: Option<impl Into<String>>,
134 ) -> Result<Self> {
135 self.with_tags_from_vec(
136 tags.map(|tags| tags.into().split(',').map(|s| s.to_string()).collect()),
137 notags.map(|notags| notags.into().split(',').map(|s| s.to_string()).collect()),
138 )
139 }
140
141 pub fn with_tags_from_vec(
142 mut self,
143 tags: Option<Vec<impl Into<String>>>,
144 notags: Option<Vec<impl Into<String>>>,
145 ) -> Result<Self> {
146 let notags = match notags {
147 Some(notags) => notags.into_iter().map(Into::into).collect(),
148 None => vec![],
149 };
150
151 if let Some(tags) = tags {
152 let tags: HashSet<String> = tags.into_iter().map(Into::into).collect();
153 if tags.iter().any(|tag| notags.contains(tag)) {
154 return Err(HtsGetError::InvalidInput(
155 "tags and notags can't intersect".to_string(),
156 ));
157 }
158 self.query = self.query.with_tags(Tags::List(tags));
159 };
160
161 if !notags.is_empty() {
162 self.query = self.query.with_no_tags(notags);
163 }
164
165 Ok(self)
166 }
167
168 #[cfg(feature = "experimental")]
170 pub fn with_encryption_scheme(
171 mut self,
172 encryption_scheme: Option<impl Into<String>>,
173 ) -> Result<Self> {
174 if let Some(scheme) = encryption_scheme {
175 let scheme = match scheme.into().to_lowercase().as_str() {
176 "c4gh" => Ok(EncryptionScheme::C4GH),
177 scheme => Err(HtsGetError::UnsupportedFormat(format!(
178 "invalid encryption scheme `{scheme}`"
179 ))),
180 }?;
181
182 self.query = self.query.with_encryption_scheme(scheme);
183 }
184
185 Ok(self)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use htsget_config::types::Format::{Bam, Vcf};
192 use htsget_config::types::NoTags;
193
194 use super::*;
195
196 #[test]
197 fn query_with_id() {
198 let request = Request::new_with_id("ValidId".to_string());
199 assert_eq!(
200 QueryBuilder::new(request, Bam).build().id(),
201 "ValidId".to_string()
202 );
203 }
204
205 #[test]
206 fn query_with_format() {
207 let request = Request::new_with_id("ValidId".to_string());
208 assert_eq!(QueryBuilder::new(request, Vcf).build().format(), Vcf);
209 }
210
211 #[test]
212 fn query_with_class() {
213 let request = Request::new_with_id("ValidId".to_string());
214
215 assert_eq!(
216 QueryBuilder::new(request, Bam)
217 .with_class(Some("header"))
218 .unwrap()
219 .build()
220 .class(),
221 Class::Header
222 );
223 }
224
225 #[test]
226 fn query_with_reference_name() {
227 let request = Request::new_with_id("ValidId".to_string());
228
229 assert_eq!(
230 QueryBuilder::new(request, Bam)
231 .with_reference_name(Some("ValidName"))
232 .build()
233 .reference_name(),
234 Some("ValidName")
235 );
236 }
237
238 #[test]
239 fn query_with_range() {
240 let request = Request::new_with_id("ValidId".to_string());
241
242 let query = QueryBuilder::new(request, Bam)
243 .with_reference_name(Some("ValidName"))
244 .with_range(Some("3"), Some("5"))
245 .unwrap()
246 .build();
247 assert_eq!(
248 (query.interval().start(), query.interval().end()),
249 (Some(3), Some(5))
250 );
251 }
252
253 #[test]
254 fn query_with_range_but_without_reference_name() {
255 let request = Request::new_with_id("ValidId".to_string());
256
257 assert!(matches!(
258 QueryBuilder::new(request, Bam)
259 .with_range(Some("3"), Some("5"))
260 .unwrap_err(),
261 HtsGetError::InvalidInput(_)
262 ));
263 }
264
265 #[test]
266 fn query_with_invalid_start() {
267 let request = Request::new_with_id("ValidId".to_string());
268
269 assert!(matches!(
270 QueryBuilder::new(request, Bam)
271 .with_reference_name(Some("ValidName"))
272 .with_range(Some("a"), Some("5"))
273 .unwrap_err(),
274 HtsGetError::InvalidInput(_)
275 ));
276 }
277
278 #[test]
279 fn query_with_invalid_end() {
280 let request = Request::new_with_id("ValidId".to_string());
281
282 assert!(matches!(
283 QueryBuilder::new(request, Bam)
284 .with_reference_name(Some("ValidName"))
285 .with_range(Some("5"), Some("a"))
286 .unwrap_err(),
287 HtsGetError::InvalidInput(_)
288 ));
289 }
290
291 #[test]
292 fn query_with_invalid_range() {
293 let request = Request::new_with_id("ValidId".to_string());
294
295 assert!(matches!(
296 QueryBuilder::new(request, Bam)
297 .with_reference_name(Some("ValidName"))
298 .with_range(Some("5"), Some("3"))
299 .unwrap_err(),
300 HtsGetError::InvalidRange(_)
301 ));
302 }
303
304 #[test]
305 fn query_with_fields() {
306 let request = Request::new_with_id("ValidId".to_string());
307
308 assert_eq!(
309 QueryBuilder::new(request, Bam)
310 .with_fields(Some("header,part1,part2"))
311 .build()
312 .fields(),
313 &Fields::List(HashSet::from_iter(vec![
314 "header".to_string(),
315 "part1".to_string(),
316 "part2".to_string()
317 ]))
318 );
319 }
320
321 #[test]
322 fn query_with_tags() {
323 let request = Request::new_with_id("ValidId".to_string());
324
325 let query = QueryBuilder::new(request, Bam)
326 .with_tags(Some("header,part1,part2"), Some("part3"))
327 .unwrap()
328 .build();
329 assert_eq!(
330 query.tags(),
331 &Tags::List(HashSet::from_iter(vec![
332 "header".to_string(),
333 "part1".to_string(),
334 "part2".to_string()
335 ]))
336 );
337 assert_eq!(
338 query.no_tags(),
339 &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
340 );
341 }
342
343 #[test]
344 fn query_with_invalid_tags() {
345 let request = Request::new_with_id("ValidId".to_string());
346
347 let query = QueryBuilder::new(request, Bam)
348 .with_tags(Some("header,part1,part2"), Some("part3"))
349 .unwrap()
350 .build();
351 assert_eq!(
352 query.tags(),
353 &Tags::List(HashSet::from_iter(vec![
354 "header".to_string(),
355 "part1".to_string(),
356 "part2".to_string()
357 ]))
358 );
359 assert_eq!(
360 query.no_tags(),
361 &NoTags(Some(HashSet::from_iter(vec!["part3".to_string()])))
362 );
363 }
364}