1use async_trait::async_trait;
16use bytes::Bytes;
17use std::borrow::Cow;
21use std::collections::HashMap;
22
23use tracing::{instrument, trace};
24use url::Url;
25
26use http::{
27 self, header, request::Builder, HeaderMap, HeaderValue, Method, Request, Response, Uri,
28};
29use serde::de::DeserializeOwned;
30use serde_json::json;
33
34use crate::api::{query, ApiError, BodyError, QueryParams, RestClient};
35#[cfg(feature = "async")]
36use crate::api::{AsyncClient, QueryAsync, RawQueryAsync};
37#[cfg(feature = "sync")]
38use crate::api::{Client, Query, RawQuery};
39use crate::catalog::ServiceEndpoint;
40use crate::types::ApiVersion;
41use crate::types::BoxedAsyncRead;
42use crate::types::ServiceType;
43
44pub trait RestEndpoint {
46 fn method(&self) -> Method;
48 fn endpoint(&self) -> Cow<'static, str>;
50 fn service_type(&self) -> ServiceType;
52
53 fn parameters(&self) -> QueryParams<'_> {
55 QueryParams::default()
56 }
57
58 fn body(&self) -> Result<Option<(&'static str, Vec<u8>)>, BodyError> {
62 Ok(None)
63 }
64
65 fn response_key(&self) -> Option<Cow<'static, str>> {
67 None
68 }
69
70 fn response_list_item_key(&self) -> Option<Cow<'static, str>> {
72 None
73 }
74
75 fn response_headers(&self) -> HashMap<&str, &str> {
77 HashMap::new()
78 }
79
80 fn request_headers(&self) -> Option<&HeaderMap> {
82 None
83 }
84
85 fn api_version(&self) -> Option<ApiVersion> {
92 ApiVersion::from_endpoint_url(self.endpoint())
93 }
94}
95
96pub(crate) fn set_latest_microversion<E>(
99 request: &mut Builder,
100 service_endpoint: &ServiceEndpoint,
101 endpoint: &E,
102) where
103 E: RestEndpoint,
104{
105 let mh_service_type = match endpoint.service_type() {
106 ServiceType::BlockStorage => Some("volume"),
107 ServiceType::Compute => Some("compute"),
108 ServiceType::ContainerInfrastructureManagement => Some("container-infra"),
109 ServiceType::Placement => Some("placement"),
110 _ => None,
111 };
112 if let Some(st) = mh_service_type {
113 if let Some(hdrs) = request.headers_mut() {
115 let ver = service_endpoint.version();
116 if ver.major == 0 {
117 return;
118 }
119 if let Ok(val) =
120 HeaderValue::from_str(format!("{} {}.{}", st, ver.major, ver.minor).as_str())
121 {
122 hdrs.insert("Openstack-API-Version", val);
123 }
124 }
125 }
126}
127
128pub(crate) fn prepare_request<C, E>(
129 service_endpoint: &ServiceEndpoint,
130 mut url: Url,
131 endpoint: &E,
132) -> Result<(Builder, Vec<u8>), ApiError<C::Error>>
133where
134 E: RestEndpoint,
135 C: RestClient,
136{
137 endpoint.parameters().add_to_url(&mut url);
138 let mut req = Request::builder()
139 .method(endpoint.method())
140 .uri(query::url_to_http_uri(url))
141 .header(header::ACCEPT, HeaderValue::from_static("application/json"));
142 set_latest_microversion(&mut req, service_endpoint, endpoint);
143 if let Some(request_headers) = endpoint.request_headers() {
144 let headers = req.headers_mut().unwrap();
145 for (k, v) in request_headers.iter() {
146 headers.insert(k, v.clone());
147 }
148 }
149 if let Some((mime, data)) = endpoint.body()? {
150 let req = req.header(header::CONTENT_TYPE, mime);
151 Ok((req, data))
152 } else {
153 Ok((req, Vec::new()))
154 }
155}
156
157pub(super) fn get_json<C>(
159 rsp: &Response<Bytes>,
160 uri: Option<Uri>,
161) -> Result<serde_json::Value, ApiError<C::Error>>
162where
163 C: RestClient,
164{
165 let status = rsp.status();
166 let v = if let Ok(v) = serde_json::from_slice(rsp.body()) {
167 v
168 } else {
169 return Err(ApiError::server_error(uri, rsp, rsp.body()));
170 };
171 if !status.is_success() {
172 return Err(ApiError::from_openstack(uri, rsp, v));
173 }
174 Ok(v)
175}
176
177pub fn check_response_error<C>(
179 rsp: &Response<Bytes>,
180 uri: Option<Uri>,
181) -> Result<(), ApiError<C::Error>>
182where
183 C: RestClient,
184{
185 let status = rsp.status();
186 if !status.is_success() {
187 let v = if let Ok(v) = serde_json::from_slice(rsp.body()) {
188 v
189 } else {
190 return Err(ApiError::server_error(uri, rsp, rsp.body()));
191 };
192 return Err(ApiError::from_openstack(uri, rsp, v));
193 }
194 Ok(())
195}
196
197#[cfg(feature = "sync")]
198impl<E, T, C> Query<T, C> for E
199where
200 E: RestEndpoint,
201 T: DeserializeOwned,
202 C: Client,
203{
204 #[instrument(name = "query", level = "debug", skip_all)]
205 fn query(&self, client: &C) -> Result<T, ApiError<C::Error>> {
206 let ep = client.get_service_endpoint(&self.service_type(), self.api_version().as_ref())?;
207 let url = ep.build_request_url(&self.endpoint())?;
208 let (req, data) = prepare_request::<C, E>(ep, url, self)?;
209
210 let query_uri = req.uri_ref().cloned();
211 let rsp = client.rest(req, data)?;
212 let mut v = get_json::<C>(&rsp, query_uri)?;
213 if let Some(root_key) = self.response_key() {
216 v = v[root_key.as_ref()].take();
217 }
218
219 let headers = rsp.headers();
220 for (header_key, target_val) in self.response_headers().iter() {
222 if let Some(val) = headers.get(*header_key) {
223 trace!("Registered Header {} was found", header_key);
224 v[*target_val] = json!(val.to_str().unwrap());
225 }
226 }
227 match serde_json::from_value::<T>(v) {
228 Ok(r) => Ok(r),
229 Err(e) => Err(ApiError::data_type::<T>(e)),
230 }
231 }
232}
233
234#[cfg(feature = "async")]
235#[async_trait]
236impl<E, T, C> QueryAsync<T, C> for E
237where
238 E: RestEndpoint + Sync,
239 C: AsyncClient + Sync,
240 T: DeserializeOwned + 'static,
241{
242 #[instrument(name = "query", level = "debug", skip_all)]
243 async fn query_async(&self, client: &C) -> Result<T, ApiError<C::Error>> {
244 let ep = client.get_service_endpoint(&self.service_type(), self.api_version().as_ref())?;
245 let (req, data) =
246 prepare_request::<C, E>(ep, ep.build_request_url(&self.endpoint())?, self)?;
247
248 let query_uri = req.uri_ref().cloned();
249 let rsp = client.rest_async(req, data).await?;
250 let mut v = get_json::<C>(&rsp, query_uri)?;
251
252 if let Some(root_key) = self.response_key() {
253 v = v[root_key.as_ref()].take();
254 }
255
256 let headers = rsp.headers();
257 for (header_key, target_val) in self.response_headers().iter() {
259 if let Some(val) = headers.get(*header_key) {
260 trace!("Registered Header {} was found", header_key);
261 v[*target_val] = json!(val.to_str().unwrap());
262 }
263 }
264 match serde_json::from_value::<T>(v) {
265 Ok(r) => Ok(r),
266 Err(e) => Err(ApiError::data_type::<T>(e)),
267 }
268 }
269}
270
271#[cfg(feature = "sync")]
272impl<E, C> RawQuery<C> for E
274where
275 E: RestEndpoint,
276 C: Client,
277{
278 #[instrument(name = "query", level = "debug", skip_all)]
279 fn raw_query(&self, client: &C) -> Result<Response<Bytes>, ApiError<C::Error>> {
280 let ep = client.get_service_endpoint(&self.service_type(), self.api_version().as_ref())?;
281 let (req, data) =
282 prepare_request::<C, E>(ep, ep.build_request_url(&self.endpoint())?, self)?;
283
284 let rsp = client.rest(req, data)?;
285
286 Ok(rsp)
287 }
288}
289
290#[cfg(feature = "async")]
291#[async_trait]
293impl<E, C> RawQueryAsync<C> for E
294where
295 E: RestEndpoint + Sync,
296 C: AsyncClient + Sync,
297{
298 #[instrument(name = "query", level = "debug", skip_all)]
299 async fn raw_query_async_ll(
300 &self,
301 client: &C,
302 inspect_error: Option<bool>,
303 ) -> Result<Response<Bytes>, ApiError<C::Error>> {
304 let ep = client.get_service_endpoint(&self.service_type(), self.api_version().as_ref())?;
305 let (req, data) =
306 prepare_request::<C, E>(ep, ep.build_request_url(&self.endpoint())?, self)?;
307
308 let query_uri = req.uri_ref().cloned();
309 let rsp = client.rest_async(req, data).await?;
310
311 if inspect_error.unwrap_or(true) {
312 check_response_error::<C>(&rsp, query_uri)?;
313 }
314 Ok(rsp)
315 }
316
317 async fn raw_query_async(&self, client: &C) -> Result<Response<Bytes>, ApiError<C::Error>> {
318 self.raw_query_async_ll(client, Some(true)).await
319 }
320
321 #[instrument(name = "query", level = "debug", skip_all)]
322 async fn raw_query_read_body_async(
323 &self,
324 client: &C,
325 data: BoxedAsyncRead,
326 ) -> Result<Response<Bytes>, ApiError<C::Error>> {
327 let ep = client.get_service_endpoint(&self.service_type(), self.api_version().as_ref())?;
328 let mut url = ep.build_request_url(&self.endpoint())?;
329 self.parameters().add_to_url(&mut url);
330 let mut req = Request::builder()
331 .method(self.method())
332 .uri(query::url_to_http_uri(url));
333 set_latest_microversion(&mut req, ep, self);
334 if let Some(request_headers) = self.request_headers() {
335 let headers = req.headers_mut().unwrap();
336 for (k, v) in request_headers.iter() {
337 headers.insert(k, v.clone());
338 }
339 }
340
341 let query_uri = req.uri_ref().cloned();
342 let rsp = client.rest_read_body_async(req, data).await?;
343
344 check_response_error::<C>(&rsp, query_uri)?;
345
346 Ok(rsp)
347 }
348
349 #[instrument(name = "query", level = "debug", skip_all)]
351 async fn download_async(
352 &self,
353 client: &C,
354 ) -> Result<(HeaderMap, BoxedAsyncRead), ApiError<C::Error>> {
355 let ep = client.get_service_endpoint(&self.service_type(), self.api_version().as_ref())?;
356 let (req, data) =
357 prepare_request::<C, E>(ep, ep.build_request_url(&self.endpoint())?, self)?;
358
359 let rsp = client.download_async(req, data).await?;
360
361 Ok(rsp)
362 }
363}
364
365#[cfg(feature = "sync")]
366#[cfg(test)]
367mod tests {
368 use http::StatusCode;
369 use httpmock::MockServer;
370 use serde::Deserialize;
371 use serde_json::json;
372
373 use crate::api::rest_endpoint_prelude::*;
374 use crate::api::{ApiError, Query};
375 use crate::test::client::FakeOpenStackClient;
376 use crate::types::ServiceType;
377
378 struct Dummy;
379
380 impl RestEndpoint for Dummy {
381 fn method(&self) -> http::Method {
382 http::Method::GET
383 }
384
385 fn endpoint(&self) -> Cow<'static, str> {
386 "dummy".into()
387 }
388
389 fn service_type(&self) -> ServiceType {
390 ServiceType::from("dummy")
391 }
392 }
393
394 #[derive(Debug, Deserialize)]
395 struct DummyResult {
396 value: u8,
397 }
398
399 #[test]
400 fn test_non_json_response() {
401 let server = MockServer::start();
402 let client = FakeOpenStackClient::new(server.base_url());
403 let mock = server.mock(|when, then| {
404 when.method(httpmock::Method::GET).path("/dummy");
405 then.status(200).body("not json");
406 });
407
408 let res: Result<DummyResult, _> = Dummy.query(&client);
409 let err = res.unwrap_err();
410 if let ApiError::OpenStackService { status, .. } = err {
411 assert_eq!(status, http::StatusCode::OK);
412 } else {
413 panic!("unexpected error: {err}");
414 }
415 mock.assert();
416 }
417
418 #[test]
419 fn test_empty_response() {
420 let server = MockServer::start();
421 let client = FakeOpenStackClient::new(server.base_url());
422 let mock = server.mock(|when, then| {
423 when.method(httpmock::Method::GET).path("/dummy");
424 then.status(200);
425 });
426
427 let res: Result<DummyResult, _> = Dummy.query(&client);
428 let err = res.unwrap_err();
429 if let ApiError::OpenStackService { status, .. } = err {
430 assert_eq!(status, http::StatusCode::OK);
431 } else {
432 panic!("unexpected error: {err}");
433 }
434 mock.assert();
435 }
436
437 #[test]
438 fn test_error_not_found() {
439 let server = MockServer::start();
440 let client = FakeOpenStackClient::new(server.base_url());
441 let mock = server.mock(|when, then| {
442 when.method(httpmock::Method::GET).path("/dummy");
443 then.status(404);
444 });
445 let res: Result<DummyResult, _> = Dummy.query(&client);
446 let err = res.unwrap_err();
447 if let ApiError::OpenStack { status, .. } = err {
448 assert_eq!(status, http::StatusCode::NOT_FOUND);
449 } else {
450 panic!("unexpected error: {err}");
451 }
452 mock.assert();
453 }
454
455 #[test]
456 fn test_error_bad_json() {
457 let server = MockServer::start();
458 let client = FakeOpenStackClient::new(server.base_url());
459 let mock = server.mock(|when, then| {
460 when.method(httpmock::Method::GET).path("/dummy");
461 then.status(http::StatusCode::CONFLICT);
462 });
463
464 let res: Result<DummyResult, _> = Dummy.query(&client);
465 let err = res.unwrap_err();
466 if let ApiError::OpenStackService { status, .. } = err {
467 assert_eq!(status, http::StatusCode::CONFLICT);
468 } else {
469 panic!("unexpected error: {err}");
470 }
471 mock.assert();
472 }
473
474 #[test]
475 fn test_error_detection() {
476 let server = MockServer::start();
477 let client = FakeOpenStackClient::new(server.base_url());
478 let mock = server.mock(|when, then| {
479 when.method(httpmock::Method::GET).path("/dummy");
480 then.status(http::StatusCode::CONFLICT)
481 .json_body(json!({"message": "dummy error message"}));
482 });
483
484 let res: Result<DummyResult, _> = Dummy.query(&client);
485 let err = res.unwrap_err();
486 if let ApiError::OpenStack {
487 status: _,
488 uri: _,
489 msg,
490 ..
491 } = err
492 {
493 assert_eq!(msg, "dummy error message");
494 } else {
495 panic!("unexpected error: {err}");
496 }
497 mock.assert();
498 }
499
500 #[test]
501 fn test_error_detection_unknown() {
502 let server = MockServer::start();
503 let client = FakeOpenStackClient::new(server.base_url());
504 let err_obj = json!({"bogus": "dummy error message"});
505 let mock = server.mock(|when, then| {
506 when.method(httpmock::Method::GET).path("/dummy");
507 then.status(StatusCode::CONFLICT).json_body(err_obj.clone());
508 });
509
510 let res: Result<DummyResult, _> = Dummy.query(&client);
511 let err = res.unwrap_err();
512 if let ApiError::OpenStackUnrecognized {
513 status: _,
514 uri: _,
515 obj,
516 ..
517 } = err
518 {
519 assert_eq!(obj, err_obj);
520 } else {
521 panic!("unexpected error: {err}");
522 }
523 mock.assert();
524 }
525
526 #[test]
527 fn test_bad_deserialization() {
528 let server = MockServer::start();
529 let client = FakeOpenStackClient::new(server.base_url());
530 let mock = server.mock(|when, then| {
531 when.method(httpmock::Method::GET).path("/dummy");
532 then.status(200).json_body(json!({"not_value": 0}));
533 });
534
535 let res: Result<DummyResult, _> = Dummy.query(&client);
536 let err = res.unwrap_err();
537 if let ApiError::DataType { source, typename } = err {
538 assert_eq!(source.to_string(), "missing field `value`");
539 assert_eq!(
540 typename,
541 "openstack_sdk::api::rest_endpoint::tests::DummyResult"
542 );
543 } else {
544 panic!("unexpected error: {err}");
545 }
546 mock.assert();
547 }
548
549 #[test]
550 fn test_good_deserialization() {
551 let server = MockServer::start();
552 let client = FakeOpenStackClient::new(server.base_url());
553 let mock = server.mock(|when, then| {
554 when.method(httpmock::Method::GET).path("/dummy");
555 then.status(200).json_body(json!({"value": 0}));
556 });
557
558 let res: Result<DummyResult, _> = Dummy.query(&client);
559 assert_eq!(res.unwrap().value, 0);
560 mock.assert();
561 }
562}