1mod aliases;
4mod auth;
5mod builder;
6mod error;
7mod fhir;
8mod misc;
9mod request;
10mod search;
11
12use std::{marker::PhantomData, sync::Arc};
13
14use ::std::any::type_name;
15use misc::parse_major_fhir_version;
16use reqwest::{StatusCode, Url, header};
17
18pub use self::{
19 aliases::*, auth::LoginManager, builder::ClientBuilder, error::Error, fhir::*,
20 request::RequestSettings, search::SearchParameters,
21};
22use self::{auth::AuthCallback, misc::make_uuid_header_value};
23use crate::version::{DefaultVersion, FhirR4B, FhirR5, FhirStu3, FhirVersion};
24
25pub struct Client<Version = DefaultVersion>(Arc<ClientData>, PhantomData<Version>);
27
28struct ClientData {
30 base_url: Url,
32 client: reqwest::Client,
34 request_settings: std::sync::Mutex<RequestSettings>,
36 auth_callback: tokio::sync::Mutex<Option<AuthCallback>>,
38
39 error_on_version_mismatch: bool,
42 error_on_origin_mismatch: bool,
47}
48
49impl<V: FhirVersion> From<ClientData> for Client<V> {
50 fn from(data: ClientData) -> Self {
51 Self(Arc::new(data), PhantomData)
52 }
53}
54
55impl<V: FhirVersion> Client<V> {
56 #[must_use]
58 pub fn builder() -> ClientBuilder<V> {
59 ClientBuilder::default()
60 }
61
62 pub fn new(base_url: Url) -> Result<Self, Error> {
64 Self::builder().base_url(base_url).build()
65 }
66
67 #[must_use]
69 pub fn base_url(&self) -> &Url {
70 &self.0.base_url
71 }
72
73 fn url(&self, segments: &[&str]) -> Url {
75 let mut url = self.0.base_url.clone();
76 #[allow(clippy::expect_used, reason = "We made sure of it in the constructor")]
77 url.path_segments_mut().expect("Base URL cannot be base").pop_if_empty().extend(segments);
78 url
79 }
80
81 #[must_use]
83 pub fn request_settings(&self) -> RequestSettings {
84 #[allow(clippy::expect_used, reason = "only happens on panics, so we can panic again")]
85 self.0.request_settings.lock().expect("mutex poisened").clone()
86 }
87
88 pub fn set_request_settings(&self, settings: RequestSettings) {
93 tracing::debug!("Setting new request settings");
94 #[allow(clippy::expect_used, reason = "only happens on panics, so we can panic again")]
95 let mut request_settings = self.0.request_settings.lock().expect("mutex poisened");
96 *request_settings = settings;
97 }
98
99 pub fn patch_request_settings<F>(&self, mutator: F)
102 where
103 F: FnOnce(RequestSettings) -> RequestSettings,
104 {
105 tracing::debug!("Patching request settings");
106 #[allow(clippy::expect_used, reason = "only happens on panics, so we can panic again")]
107 let mut request_settings = self.0.request_settings.lock().expect("mutex poisened");
108 let patched = mutator(request_settings.clone());
109 *request_settings = patched;
110 }
111
112 fn convert_version<Version>(self) -> Client<Version> {
114 Client(self.0, PhantomData)
115 }
116
117 #[must_use]
119 pub fn stu3(self) -> Client<FhirStu3> {
120 self.convert_version()
121 }
122
123 #[must_use]
125 pub fn r4b(self) -> Client<FhirR4B> {
126 self.convert_version()
127 }
128
129 #[must_use]
131 pub fn r5(self) -> Client<FhirR5> {
132 self.convert_version()
133 }
134
135 #[tracing::instrument(level = "info", skip_all, fields(x_correlation_id))]
139 async fn run_request(
140 &self,
141 mut request: reqwest::RequestBuilder,
142 ) -> Result<reqwest::Response, Error> {
143 let (client, info_request_result) = request.build_split();
144 let mut info_request = info_request_result?;
145 let req_method = info_request.method().clone();
146 let req_url = info_request.url().clone();
147
148 if self.0.error_on_origin_mismatch {
150 if info_request.url().origin() != self.0.base_url.origin() {
152 return Err(Error::DifferentOrigin(info_request.url().to_string()));
153 }
154 }
155
156 let correlation_id = info_request
159 .headers_mut()
160 .entry("X-Correlation-Id")
161 .or_insert_with(make_uuid_header_value);
162 let x_correlation_id = correlation_id.to_str().ok().map(ToOwned::to_owned);
163 request = reqwest::RequestBuilder::from_parts(client, info_request);
164 tracing::Span::current().record("x_correlation_id", x_correlation_id);
165
166 let mut request_settings = self.request_settings();
168 tracing::info!("Sending {req_method} request to {req_url} (potentially with retries)");
169 let mut response = request_settings
170 .make_request(request.try_clone().ok_or(Error::RequestNotClone)?)
171 .await?;
172
173 if response.status() == StatusCode::UNAUTHORIZED {
175 if let Ok(mut auth_callback) = self.0.auth_callback.try_lock() {
176 if let Some(auth_callback) = auth_callback.as_mut() {
177 tracing::info!("Hit unauthorized response, calling auth_callback");
178 let auth_value = auth_callback
179 .authenticate(self.0.client.clone())
180 .await
181 .map_err(|err| Error::AuthCallback(format!("{err:#}")))?;
182 self.patch_request_settings(move |settings| {
183 settings.header(header::AUTHORIZATION, auth_value)
184 });
185 } else {
186 return Ok(response);
188 }
189 } else {
190 _ = self.0.auth_callback.lock().await;
193 }
194 request_settings = self.request_settings();
196 tracing::info!("Retrying request after authorization refresh");
197 response = request_settings.make_request(request).await?;
198 }
199
200 tracing::info!("Got response: {}", response.status());
201
202 if self.0.error_on_version_mismatch {
204 if let Some(version) = parse_major_fhir_version(response.headers())? {
205 let expected = V::VERSION.split_once('.').map_or(V::VERSION, |(major, _)| major);
206 if version != expected {
207 return Err(Error::DifferentFhirVersion(version.to_owned()));
208 }
209 }
210 }
211
212 Ok(response)
213 }
214
215 pub async fn send_custom_request<F>(&self, make_request: F) -> Result<reqwest::Response, Error>
226 where
227 F: FnOnce(&reqwest::Client) -> reqwest::RequestBuilder + Send,
228 {
229 let request = (make_request)(&self.0.client);
230 self.run_request(request).await
231 }
232}
233
234impl<V> Clone for Client<V> {
235 fn clone(&self) -> Self {
236 Self(self.0.clone(), self.1)
237 }
238}
239
240impl std::fmt::Debug for ClientData {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 let auth_callback = match self.auth_callback.try_lock() {
243 Ok(inside) => {
244 if inside.is_some() {
245 "Some(<login_manager>)"
246 } else {
247 "None"
248 }
249 }
250 Err(_) => "<blocked>",
251 };
252
253 f.debug_struct("ClientData")
254 .field("base_url", &self.base_url)
255 .field("client", &self.client)
256 .field("request_settings", &self.request_settings)
257 .field("auth_callback", &auth_callback)
258 .field("error_on_version_mismatch", &self.error_on_version_mismatch)
259 .field("error_on_origin_mismatch", &self.error_on_origin_mismatch)
260 .finish()
261 }
262}
263
264impl<V> std::fmt::Debug for Client<V> {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 f.debug_struct("Client").field("data", &self.0).field("version", &type_name::<V>()).finish()
267 }
268}
269
270#[cfg(test)]
271mod tests;