#![deny(missing_docs)]
use std::sync::Arc;
use reqwest::header::{CONTENT_TYPE, HeaderName};
use reqwest::{Client, Method, StatusCode};
use serde::de::DeserializeOwned;
use snafu::ResultExt as _;
use tokio::sync::RwLock;
use url::Url;
use crate::NifiError;
use crate::config::auth::AuthProvider;
use crate::error::{AuthSnafu, HttpSnafu};
const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
const PROXIED_ENTITIES_CHAIN: HeaderName = HeaderName::from_static("x-proxiedentitieschain");
pub struct NifiClient {
base_url: Url,
http: Client,
token: Arc<RwLock<Option<zeroize::Zeroizing<String>>>>,
auth_provider: Option<Arc<dyn AuthProvider>>,
proxied_entities_chain: Option<String>,
retry_policy: Option<crate::config::retry::RetryPolicy>,
request_id_header: Option<String>,
auth_lock: Arc<tokio::sync::Mutex<()>>,
}
impl Clone for NifiClient {
fn clone(&self) -> Self {
Self {
base_url: self.base_url.clone(),
http: self.http.clone(),
token: Arc::clone(&self.token),
auth_provider: self.auth_provider.clone(),
proxied_entities_chain: self.proxied_entities_chain.clone(),
retry_policy: self.retry_policy.clone(),
request_id_header: self.request_id_header.clone(),
auth_lock: Arc::clone(&self.auth_lock),
}
}
}
impl std::fmt::Debug for NifiClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NifiClient")
.field("base_url", &self.base_url)
.field(
"auth_provider",
&self.auth_provider.as_ref().map(|c| format!("{c:?}")),
)
.field("proxied_entities_chain", &self.proxied_entities_chain)
.field("retry_policy", &self.retry_policy)
.field("request_id_header", &self.request_id_header)
.finish_non_exhaustive()
}
}
impl NifiClient {
pub(crate) fn from_parts(
base_url: Url,
http: Client,
auth_provider: Option<Arc<dyn AuthProvider>>,
proxied_entities_chain: Option<String>,
retry_policy: Option<crate::config::retry::RetryPolicy>,
request_id_header: Option<String>,
) -> Self {
Self {
base_url,
http,
token: Arc::new(RwLock::new(None)),
auth_provider,
proxied_entities_chain,
retry_policy,
request_id_header,
auth_lock: Arc::new(tokio::sync::Mutex::new(())),
}
}
pub async fn token(&self) -> Option<String> {
self.token.read().await.as_ref().map(|t| (**t).clone())
}
pub async fn set_token(&self, token: String) {
*self.token.write().await = Some(zeroize::Zeroizing::new(token));
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub async fn logout(&self) -> Result<(), NifiError> {
let result = self.delete_inner("/access/logout").await;
*self.token.write().await = None;
if result.is_ok() {
tracing::info!("NiFi logout successful");
}
result
}
#[tracing::instrument(skip(self, username, password), fields(request_id = tracing::field::Empty))]
pub async fn login(&self, username: &str, password: &str) -> Result<(), NifiError> {
let method = Method::POST;
tracing::debug!(method = %method, path = "/access/token", "NiFi API request");
let url = self.api_url("/access/token");
let req = self.apply_request_id(self.http.post(url));
let resp = req
.form(&[("username", username), ("password", password)])
.send()
.await
.context(HttpSnafu)?;
let status = resp.status();
tracing::debug!(
method = %method,
path = "/access/token",
status = status.as_u16(),
"NiFi API response"
);
if !status.is_success() {
let body = resp.text().await.unwrap_or_else(|_| status.to_string());
tracing::debug!(
method = %method,
path = "/access/token",
status = status.as_u16(),
%body,
"NiFi API raw error body"
);
let message = extract_error_message(&body);
tracing::warn!(
method = %method,
path = "/access/token",
status = status.as_u16(),
%message,
"NiFi API error"
);
return AuthSnafu { message }.fail();
}
let token = resp.text().await.context(HttpSnafu)?;
*self.token.write().await = Some(zeroize::Zeroizing::new(token));
tracing::info!("NiFi login successful for {username}");
Ok(())
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub async fn authenticate(&self) -> Result<(), NifiError> {
let provider = self.auth_provider.as_ref().ok_or_else(|| NifiError::Auth {
message: "no auth provider configured".to_string(),
})?;
provider.authenticate(self).await
}
#[tracing::instrument(skip_all)]
async fn with_auth_retry<T, F, Fut>(&self, f: F) -> Result<T, NifiError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, NifiError>>,
{
let token_before = self.token.read().await.as_ref().map(|t| (**t).clone());
match f().await {
Err(NifiError::Unauthorized { .. }) if self.auth_provider.is_some() => {
let _guard = self.auth_lock.lock().await;
let token_now = self.token.read().await.as_ref().map(|t| (**t).clone());
if token_now == token_before {
tracing::info!("received 401, refreshing token via auth provider");
self.authenticate().await?;
} else {
tracing::debug!("token already refreshed by concurrent task, skipping re-auth");
}
drop(_guard);
f().await
}
other => other,
}
}
#[tracing::instrument(skip_all)]
async fn with_retry<T, F, Fut>(&self, f: F) -> Result<T, NifiError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, NifiError>>,
{
let Some(policy) = &self.retry_policy else {
return self.with_auth_retry(&f).await;
};
let mut last_err: Option<NifiError> = None;
for attempt in 0..=policy.max_retries {
if attempt > 0 {
let backoff = policy.backoff_for(attempt - 1);
tracing::info!(
attempt,
backoff_ms = backoff.as_millis() as u64,
"retrying after transient error"
);
tokio::time::sleep(backoff).await;
}
match self.with_auth_retry(&f).await {
Ok(v) => return Ok(v),
Err(e) if e.is_retryable() => {
tracing::warn!(attempt, error = %e, "transient error, will retry");
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
match last_err {
Some(e) => Err(e),
None => self.with_auth_retry(&f).await,
}
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::GET, path, self.http.get(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, body), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post<B, T>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
) -> Result<T, NifiError>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::deserialize(&Method::POST, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, body), fields(request_id = tracing::field::Empty))]
pub(crate) async fn put<B, T>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
) -> Result<T, NifiError>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.with_retry(|| async {
let req = self
.build_request(&Method::PUT, path, self.http.put(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::deserialize(&Method::PUT, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_no_body<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_void_no_body(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, body), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_void<B: serde::Serialize>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::check_void(&Method::POST, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn put_no_body<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::PUT, path, self.http.put(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::PUT, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, body), fields(request_id = tracing::field::Empty))]
pub(crate) async fn put_void<B: serde::Serialize>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::PUT, path, self.http.put(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::check_void(&Method::PUT, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn put_void_no_body(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::PUT, path, self.http.put(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void(&Method::PUT, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, data), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_octet_stream<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
data: bytes::Bytes,
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(data.clone());
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, data), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_multipart<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
filename: &str,
data: bytes::Bytes,
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let len = data.len() as u64;
let part = reqwest::multipart::Part::stream_with_length(data.clone(), len)
.file_name(filename.to_string());
let form = reqwest::multipart::Form::new().part("file", part);
let resp = req.multipart(form).send().await.context(HttpSnafu)?;
Self::deserialize(&Method::POST, path, resp).await
})
.await
}
#[tracing::instrument(
skip(self, text_fields, data),
fields(request_id = tracing::field::Empty)
)]
pub(crate) async fn post_multipart_with_fields<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
text_fields: &[(&str, String)],
filename: &str,
data: bytes::Bytes,
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let mut form = reqwest::multipart::Form::new();
for (name, value) in text_fields {
form = form.text((*name).to_string(), value.clone());
}
let len = data.len() as u64;
let part = reqwest::multipart::Part::stream_with_length(data.clone(), len)
.file_name(filename.to_string());
form = form.part("file", part);
let resp = req.multipart(form).send().await.context(HttpSnafu)?;
Self::deserialize(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_void(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::GET, path, self.http.get(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void_with_redirect(&Method::GET, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_void_with_query(
&self,
path: &str,
extra_headers: &[(&str, &str)],
query: &[(&str, String)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::GET,
path,
self.http.get(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void_with_redirect(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_with_query<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
query: &[(&str, String)],
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::GET,
path,
self.http.get(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_text(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<String, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::GET, path, self.http.get(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::text(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_bytes(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<Vec<u8>, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::GET, path, self.http.get(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::bytes(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_bytes_with_query(
&self,
path: &str,
extra_headers: &[(&str, &str)],
query: &[(&str, String)],
) -> Result<Vec<u8>, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::GET,
path,
self.http.get(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::bytes(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_bytes_stream(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<crate::BytesStream, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::GET, path, self.http.get(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::bytes_stream(&Method::GET, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn get_bytes_stream_with_query(
&self,
path: &str,
extra_headers: &[(&str, &str)],
query: &[(&str, String)],
) -> Result<crate::BytesStream, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::GET,
path,
self.http.get(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::bytes_stream(&Method::GET, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, data), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_void_octet_stream(
&self,
path: &str,
extra_headers: &[(&str, &str)],
data: bytes::Bytes,
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(data.clone());
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, data), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_void_multipart(
&self,
path: &str,
extra_headers: &[(&str, &str)],
filename: &str,
data: bytes::Bytes,
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let len = data.len() as u64;
let part = reqwest::multipart::Part::stream_with_length(data.clone(), len)
.file_name(filename.to_string());
let form = reqwest::multipart::Form::new().part("file", part);
let resp = req.multipart(form).send().await.context(HttpSnafu)?;
Self::check_void(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, body), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_returning_text<B: serde::Serialize>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
) -> Result<String, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::text(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, data), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_octet_stream_returning_text(
&self,
path: &str,
extra_headers: &[(&str, &str)],
data: bytes::Bytes,
) -> Result<String, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::POST, path, self.http.post(self.api_url(path)))
.await
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(data.clone());
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::text(&Method::POST, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn delete_returning_with_query<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
query: &[(&str, String)],
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::DELETE,
path,
self.http.delete(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::DELETE, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn delete_with_query(
&self,
path: &str,
extra_headers: &[(&str, &str)],
query: &[(&str, String)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::DELETE,
path,
self.http.delete(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void(&Method::DELETE, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, body, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_void_with_query<B: serde::Serialize>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
query: &[(&str, String)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(
&Method::POST,
path,
self.http.post(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::check_void(&Method::POST, path, resp).await
})
.await
}
#[tracing::instrument(skip(self, body, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn post_with_query<B, T>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
query: &[(&str, String)],
) -> Result<T, NifiError>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.with_retry(|| async {
let req = self
.build_request(
&Method::POST,
path,
self.http.post(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::deserialize(&Method::POST, path, resp).await
})
.await
}
#[allow(dead_code)]
#[tracing::instrument(skip(self, body, query), fields(request_id = tracing::field::Empty))]
pub(crate) async fn put_with_query<B, T>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
body: &B,
query: &[(&str, String)],
) -> Result<T, NifiError>
where
B: serde::Serialize,
T: DeserializeOwned,
{
self.with_retry(|| async {
let req = self
.build_request(
&Method::PUT,
path,
self.http.put(self.api_url(path)).query(query),
)
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.json(body).send().await.context(HttpSnafu)?;
Self::deserialize(&Method::PUT, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn delete_returning<T: DeserializeOwned>(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<T, NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::DELETE, path, self.http.delete(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::deserialize(&Method::DELETE, path, resp).await
})
.await
}
#[tracing::instrument(skip(self), fields(request_id = tracing::field::Empty))]
pub(crate) async fn delete(
&self,
path: &str,
extra_headers: &[(&str, &str)],
) -> Result<(), NifiError> {
self.with_retry(|| async {
let req = self
.build_request(&Method::DELETE, path, self.http.delete(self.api_url(path)))
.await;
let req = apply_extra_headers(req, extra_headers);
let resp = req.send().await.context(HttpSnafu)?;
Self::check_void(&Method::DELETE, path, resp).await
})
.await
}
async fn delete_inner(&self, path: &str) -> Result<(), NifiError> {
let resp = self
.build_request(&Method::DELETE, path, self.http.delete(self.api_url(path)))
.await
.send()
.await
.context(HttpSnafu)?;
Self::check_void(&Method::DELETE, path, resp).await
}
fn apply_request_id(&self, req: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let Some(header) = self.request_id_header.as_deref() else {
return req;
};
let id = uuid::Uuid::new_v4().to_string();
tracing::Span::current().record("request_id", id.as_str());
req.header(header, id)
}
async fn build_request(
&self,
method: &Method,
path: &str,
req: reqwest::RequestBuilder,
) -> reqwest::RequestBuilder {
let req = self.apply_request_id(req);
tracing::debug!(method = %method, path, "NiFi API request");
let guard = self.token.read().await;
let mut req = match guard.as_deref() {
Some(token) => req.bearer_auth(token.as_str()),
None => {
tracing::warn!(
"sending NiFi API request without a bearer token — call login() first"
);
req
}
};
if let Some(chain) = &self.proxied_entities_chain {
req = req.header(PROXIED_ENTITIES_CHAIN, chain);
}
req
}
async fn deserialize<T: DeserializeOwned>(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<T, NifiError> {
let resp = handle_response_status(method, path, resp).await?;
resp.json::<T>().await.context(HttpSnafu)
}
async fn check_void(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<(), NifiError> {
handle_response_status(method, path, resp).await?;
Ok(())
}
async fn text(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<String, NifiError> {
let resp = handle_response_status(method, path, resp).await?;
resp.text().await.context(HttpSnafu)
}
async fn bytes(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<Vec<u8>, NifiError> {
let resp = handle_response_status(method, path, resp).await?;
let b = resp.bytes().await.context(HttpSnafu)?;
Ok(b.to_vec())
}
async fn bytes_stream(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<crate::BytesStream, NifiError> {
use futures_util::TryStreamExt;
let resp = handle_response_status(method, path, resp).await?;
let s = resp
.bytes_stream()
.map_err(|source| NifiError::Http { source });
Ok(Box::pin(s))
}
async fn check_void_with_redirect(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<(), NifiError> {
let status = resp.status();
tracing::debug!(method = %method, path, status = status.as_u16(), "NiFi API response");
if status.is_success() || status == StatusCode::FOUND {
return Ok(());
}
let body = resp.text().await.unwrap_or_else(|_| status.to_string());
tracing::debug!(method = %method, path, status = status.as_u16(), %body, "NiFi API raw error body");
let message = extract_error_message(&body);
tracing::warn!(method = %method, path, status = status.as_u16(), %message, "NiFi API error");
Err(crate::error::api_error(status.as_u16(), message))
}
pub(crate) fn api_url(&self, path: &str) -> Url {
let mut url = self.base_url.clone();
url.set_path(&format!("/nifi-api{path}"));
url
}
}
async fn handle_response_status(
method: &Method,
path: &str,
resp: reqwest::Response,
) -> Result<reqwest::Response, NifiError> {
let status = resp.status();
tracing::debug!(method = %method, path, status = status.as_u16(), "NiFi API response");
if status.is_success() {
return Ok(resp);
}
let body = resp.text().await.unwrap_or_else(|_| status.to_string());
tracing::debug!(method = %method, path, status = status.as_u16(), %body, "NiFi API raw error body");
let message = extract_error_message(&body);
tracing::warn!(method = %method, path, status = status.as_u16(), %message, "NiFi API error");
Err(crate::error::api_error(status.as_u16(), message))
}
fn apply_extra_headers(
mut req: reqwest::RequestBuilder,
extra: &[(&str, &str)],
) -> reqwest::RequestBuilder {
for (name, value) in extra {
req = req.header(*name, *value);
}
req
}
pub fn extract_error_message(body: &str) -> String {
serde_json::from_str::<serde_json::Value>(body)
.ok()
.and_then(|v| v["message"].as_str().map(str::to_owned))
.unwrap_or_else(|| body.to_owned())
}
#[cfg(test)]
mod tests {
#[test]
fn bytes_clone_is_refcount_only() {
use bytes::Bytes;
let data = Bytes::from(vec![0u8; 1024]);
let before = data.len();
let clone1 = data.clone();
let clone2 = data.clone();
assert_eq!(clone1.len(), before);
assert_eq!(clone2.len(), before);
assert_eq!(
data.as_ptr(),
clone1.as_ptr(),
"Bytes::clone should share buffer"
);
assert_eq!(
data.as_ptr(),
clone2.as_ptr(),
"Bytes::clone should share buffer"
);
}
}