1use crate::HtsGetError::{InternalError, InvalidInput};
2use crate::middleware::auth::Auth;
3use crate::{
4 Endpoint, HtsGetError, PostRequest, Result, convert_to_query, match_format_from_query,
5 merge_responses,
6};
7use cfg_if::cfg_if;
8use futures::StreamExt;
9use futures::stream::FuturesOrdered;
10use htsget_config::config::advanced::auth::AuthorizationRestrictions;
11use htsget_config::config::service_info::PackageInfo;
12use htsget_config::types::{JsonResponse, Query, Request, Response};
13use htsget_search::HtsGet;
14use http::HeaderMap;
15use serde_json::Value;
16use tokio::select;
17use tracing::debug;
18use tracing::instrument;
19
20async fn authenticate(headers: &HeaderMap, auth: Option<Auth>) -> Result<Option<Auth>> {
21 if let Some(mut auth) = auth {
22 if auth.config().auth_mode().is_some() {
23 auth.validate_jwt(headers).await?;
24 Ok(Some(auth))
25 } else {
26 Ok(Some(auth))
27 }
28 } else {
29 Ok(auth)
30 }
31}
32
33async fn authorize(
34 headers: &HeaderMap,
35 path: &str,
36 queries: &mut [Query],
37 auth: Option<Auth>,
38 extensions: Option<Value>,
39 endpoint: &Endpoint,
40) -> Result<Option<(AuthorizationRestrictions, bool)>> {
41 if let Some(mut auth) = auth {
42 let _rules = auth
43 .validate_authorization(headers, path, queries, extensions, endpoint)
44 .await?;
45 cfg_if! {
46 if #[cfg(feature = "experimental")] {
47 if auth.config().add_hint() {
48 Ok(_rules.map(|rules| (rules, true)))
49 } else {
50 Ok(_rules.map(|rules| (rules, false)))
51 }
52 } else {
53 Ok(_rules.map(|rules| (rules, false)))
54 }
55 }
56 } else {
57 Ok(None)
58 }
59}
60
61#[instrument(level = "debug", skip_all, ret)]
65pub async fn get(
66 searcher: impl HtsGet + Send + Sync + 'static,
67 request: Request,
68 endpoint: Endpoint,
69 auth: Option<Auth>,
70 package_info: Option<&PackageInfo>,
71 extensions: Option<Value>,
72) -> Result<JsonResponse> {
73 let path = request.path().to_string();
74 let headers = request.headers().clone();
75
76 let auth = authenticate(&headers, auth).await?;
77 debug!(auth = ?auth, "auth");
78
79 let format = match_format_from_query(&endpoint, request.query())?;
80 let mut query = vec![convert_to_query(request, format)?];
81 let rules = authorize(
82 &headers,
83 &path,
84 query.as_mut_slice(),
85 auth,
86 extensions,
87 &endpoint,
88 )
89 .await?;
90
91 debug!(endpoint = ?endpoint, query = ?query, "getting GET response");
92
93 let query = query.into_iter().next().expect("single element vector");
94
95 debug!(rules = ?rules, "rules");
96 let response = if let Some((ref rules, _)) = rules {
97 let mut remote_locations = rules.clone().into_remote_locations();
98 if let Some(package_info) = package_info {
99 remote_locations
100 .set_from_package_info(package_info)
101 .map_err(|_| InternalError("invalid remote locations".to_string()))?;
102 }
103 debug!(remote_locations = ?remote_locations, "remote locations");
104
105 match remote_locations
107 .search(query.clone())
108 .await
109 .map(JsonResponse::from)
110 {
111 Ok(response) => response,
112 Err(_) => searcher.search(query).await.map(JsonResponse::from)?,
113 }
114 } else {
115 searcher.search(query).await.map(JsonResponse::from)?
116 };
117
118 cfg_if! {
119 if #[cfg(feature = "experimental")] {
120 let allowed = match rules {
121 Some((rules, add_hint)) if add_hint => Some(rules.into_rules()),
122 _ => None
123 };
124 Ok(response.with_allowed(allowed))
125 } else {
126 Ok(response)
127 }
128 }
129}
130
131#[instrument(level = "debug", skip_all, ret)]
134pub async fn post(
135 searcher: impl HtsGet + Clone + Send + Sync + 'static,
136 body: PostRequest,
137 request: Request,
138 endpoint: Endpoint,
139 auth: Option<Auth>,
140 package_info: Option<&PackageInfo>,
141 extensions: Option<Value>,
142) -> Result<JsonResponse> {
143 let path = request.path().to_string();
144 let headers = request.headers().clone();
145
146 let auth = authenticate(&headers, auth).await?;
147 debug!(auth = ?auth, "auth");
148
149 if !request.query().is_empty() {
150 return Err(InvalidInput(
151 "query parameters should be empty for a POST request".to_string(),
152 ));
153 }
154
155 let mut queries = body.get_queries(request, &endpoint)?;
156 let rules = authorize(
157 &headers,
158 &path,
159 queries.as_mut_slice(),
160 auth,
161 extensions,
162 &endpoint,
163 )
164 .await?;
165
166 debug!(endpoint = ?endpoint, queries = ?queries, "getting POST response");
167
168 let mut futures = FuturesOrdered::new();
169 debug!(rules = ?rules, "rules");
170
171 if let Some((ref rules, _)) = rules {
172 for query in queries {
173 let mut remote_locations = rules.clone().into_remote_locations();
174 if let Some(package_info) = package_info {
175 remote_locations
176 .set_from_package_info(package_info)
177 .map_err(|_| InternalError("invalid remote locations".to_string()))?;
178 }
179 let owned_searcher = searcher.clone();
180 debug!(remote_locations = ?remote_locations, "remote locations");
181
182 futures.push_back(tokio::spawn(async move {
184 match remote_locations.search(query.clone()).await {
185 Ok(response) => Ok(response),
186 Err(_) => owned_searcher.search(query).await,
187 }
188 }));
189 }
190 } else {
191 for query in queries {
192 let owned_searcher = searcher.clone();
193 futures.push_back(tokio::spawn(
194 async move { owned_searcher.search(query).await },
195 ));
196 }
197 };
198
199 let mut responses: Vec<Response> = Vec::new();
200 loop {
201 select! {
202 Some(next) = futures.next() => responses.push(next.map_err(|err| HtsGetError::InternalError(err.to_string()))?.map_err(HtsGetError::from)?),
203 else => break
204 }
205 }
206
207 let response =
208 JsonResponse::from(merge_responses(responses).expect("expected at least one response"));
209 cfg_if! {
210 if #[cfg(feature = "experimental")] {
211 let allowed = match rules {
212 Some((rules, add_hint)) if add_hint => Some(rules.into_rules()),
213 _ => None
214 };
215 Ok(response.with_allowed(allowed))
216 } else {
217 Ok(response)
218 }
219 }
220}