use crate::client;
use crate::http::HttpClientCache;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct Client {
auth_client: Arc<client::Client>,
api_version: String,
connect_timeout: std::time::Duration,
request_timeout: std::time::Duration,
http_cache: Arc<HttpClientCache>,
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("Failed to build Bulk API client")]
Build,
}
#[derive(Debug)]
pub struct ClientBuilder {
auth_client: client::Client,
api_version: Option<String>,
connect_timeout: Option<std::time::Duration>,
request_timeout: Option<std::time::Duration>,
}
impl ClientBuilder {
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn new(auth_client: client::Client) -> Self {
Self {
auth_client,
api_version: None,
connect_timeout: None,
request_timeout: None,
}
}
pub fn api_version(mut self, version: impl Into<String>) -> Self {
self.api_version = Some(version.into());
self
}
pub fn connect_timeout(mut self, timeout: std::time::Duration) -> Self {
self.connect_timeout = Some(timeout);
self
}
pub fn request_timeout(mut self, timeout: std::time::Duration) -> Self {
self.request_timeout = Some(timeout);
self
}
pub fn build(self) -> Result<Client, Error> {
Ok(Client {
auth_client: Arc::new(self.auth_client),
api_version: self
.api_version
.unwrap_or_else(|| crate::DEFAULT_API_VERSION.to_string()),
connect_timeout: self
.connect_timeout
.unwrap_or(std::time::Duration::from_secs(
crate::DEFAULT_CONNECT_TIMEOUT_SECS,
)),
request_timeout: self
.request_timeout
.unwrap_or(std::time::Duration::from_secs(
crate::DEFAULT_REQUEST_TIMEOUT_SECS,
)),
http_cache: Arc::new(HttpClientCache::new()),
})
}
}
impl Client {
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn auth_client(&self) -> &client::Client {
&self.auth_client
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn api_version(&self) -> &str {
&self.api_version
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub(crate) fn connect_timeout(&self) -> std::time::Duration {
self.connect_timeout
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub(crate) fn request_timeout(&self) -> std::time::Duration {
self.request_timeout
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn query(&self) -> super::query::QueryClient {
super::query::QueryClient::new(self.clone())
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn ingest(&self) -> super::ingest::IngestClient {
super::ingest::IngestClient::new(self.clone())
}
pub(crate) async fn get_http_client(&self) -> Result<reqwest::Client, crate::http::Error> {
self.http_cache
.get(
self.auth_client.as_ref(),
self.connect_timeout(),
self.request_timeout(),
)
.await
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub(crate) fn base_url(&self) -> Result<String, client::Error> {
let instance_url = self
.auth_client
.instance_url
.as_ref()
.ok_or(client::Error::NotConnected)?;
Ok(format!(
"{}/services/data/v{}",
instance_url, self.api_version
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
fn create_mock_auth_client() -> client::Client {
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_client_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("test_access_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = client::TokenState::new(token).unwrap();
client.token_state = Some(Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("test_tenant".to_string());
client
}
#[test]
fn test_base_url_construction() {
let auth_client = create_mock_auth_client();
let bulk_client = ClientBuilder::new(auth_client).build().unwrap();
let base_url = bulk_client.base_url().unwrap();
assert_eq!(
base_url,
format!(
"https://test.salesforce.com/services/data/v{}",
crate::DEFAULT_API_VERSION
)
);
}
#[test]
fn test_base_url_with_different_versions() {
let auth_client = create_mock_auth_client();
let bulk_client_default = ClientBuilder::new(auth_client.clone()).build().unwrap();
assert_eq!(
bulk_client_default.base_url().unwrap(),
format!(
"https://test.salesforce.com/services/data/v{}",
crate::DEFAULT_API_VERSION
)
);
let bulk_client_59 = ClientBuilder::new(auth_client)
.api_version("59.0")
.build()
.unwrap();
assert_eq!(
bulk_client_59.base_url().unwrap(),
"https://test.salesforce.com/services/data/v59.0"
);
}
#[test]
fn test_base_url_without_instance_url() {
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_client_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
client.instance_url = None;
let bulk_client = ClientBuilder::new(client)
.api_version("58.0")
.build()
.unwrap();
let result = bulk_client.base_url();
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), client::Error::NotConnected));
}
}