use crate::client;
use salesforce_core_pubsubapi::eventbus::v1::pub_sub_client::PubSubClient;
use tokio_stream::StreamExt;
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
pub enum Error {
#[error("Client missing")]
MissingClient,
#[error("Token response missing")]
MissingTokenResponse,
#[error("Instance URL not available: call connect() on the auth client first")]
MissingInstanceUrl,
#[error("Tenant ID not available: call connect() on the auth client first")]
MissingTenantId,
#[error("Invalid metadata value for gRPC headers: {source}")]
InvalidMetadataValue {
#[source]
source: tonic::metadata::errors::InvalidMetadataValue,
},
#[error("gRPC transport error: {0}")]
Tonic(Box<tonic::Status>),
}
struct ContextInterceptor {
auth_header: tonic::metadata::AsciiMetadataValue,
instance_url: tonic::metadata::AsciiMetadataValue,
tenant_id: tonic::metadata::AsciiMetadataValue,
}
impl tonic::service::Interceptor for ContextInterceptor {
fn call(
&mut self,
mut request: tonic::Request<()>,
) -> Result<tonic::Request<()>, tonic::Status> {
request
.metadata_mut()
.insert("accesstoken", self.auth_header.to_owned());
request
.metadata_mut()
.insert("instanceurl", self.instance_url.to_owned());
request
.metadata_mut()
.insert("tenantid", self.tenant_id.to_owned());
Ok(request)
}
}
#[derive(Debug)]
pub struct Client {
pubsub: salesforce_core_pubsubapi::eventbus::v1::pub_sub_client::PubSubClient<
tonic::service::interceptor::InterceptedService<
tonic::transport::Channel,
ContextInterceptor,
>,
>,
channel: tonic::transport::Channel,
client: client::Client,
}
impl Client {
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub fn new(channel: tonic::transport::Channel, client: client::Client) -> Result<Self, Error> {
let token = client
.current_access_token()
.map_err(|_| Error::MissingTokenResponse)?;
let auth_header: tonic::metadata::AsciiMetadataValue = token
.parse()
.map_err(|e| Error::InvalidMetadataValue { source: e })?;
let instance_url: tonic::metadata::AsciiMetadataValue = client
.instance_url
.as_ref()
.ok_or(Error::MissingInstanceUrl)?
.parse()
.map_err(|e| Error::InvalidMetadataValue { source: e })?;
let tenant_id: tonic::metadata::AsciiMetadataValue = client
.tenant_id
.as_ref()
.ok_or(Error::MissingTenantId)?
.parse()
.map_err(|e| Error::InvalidMetadataValue { source: e })?;
let interceptor = ContextInterceptor {
auth_header,
instance_url,
tenant_id,
};
let pubsub = PubSubClient::with_interceptor(channel.clone(), interceptor);
Ok(Client {
pubsub,
channel,
client,
})
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn reconnect(&mut self) -> Result<(), Error> {
self.client
.reconnect()
.await
.map_err(|e| Error::Tonic(Box::new(tonic::Status::unauthenticated(e.to_string()))))?;
let token = self
.client
.current_access_token()
.map_err(|_| Error::MissingTokenResponse)?;
let auth_header: tonic::metadata::AsciiMetadataValue = token
.parse()
.map_err(|e| Error::InvalidMetadataValue { source: e })?;
let instance_url: tonic::metadata::AsciiMetadataValue = self
.client
.instance_url
.as_ref()
.ok_or(Error::MissingInstanceUrl)?
.parse()
.map_err(|e| Error::InvalidMetadataValue { source: e })?;
let tenant_id: tonic::metadata::AsciiMetadataValue = self
.client
.tenant_id
.as_ref()
.ok_or(Error::MissingTenantId)?
.parse()
.map_err(|e| Error::InvalidMetadataValue { source: e })?;
let interceptor = ContextInterceptor {
auth_header,
instance_url,
tenant_id,
};
self.pubsub = PubSubClient::with_interceptor(self.channel.clone(), interceptor);
Ok(())
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn get_topic(
&mut self,
request: salesforce_core_pubsubapi::eventbus::v1::TopicRequest,
) -> Result<tonic::Response<salesforce_core_pubsubapi::eventbus::v1::TopicInfo>, Error> {
self.pubsub
.get_topic(tonic::Request::new(request))
.await
.map_err(|e| Error::Tonic(Box::new(e)))
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn get_schema(
&mut self,
request: salesforce_core_pubsubapi::eventbus::v1::SchemaRequest,
) -> Result<tonic::Response<salesforce_core_pubsubapi::eventbus::v1::SchemaInfo>, Error> {
self.pubsub
.get_schema(tonic::Request::new(request))
.await
.map_err(|e| Error::Tonic(Box::new(e)))
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn publish(
&mut self,
request: salesforce_core_pubsubapi::eventbus::v1::PublishRequest,
) -> Result<tonic::Response<salesforce_core_pubsubapi::eventbus::v1::PublishResponse>, Error>
{
self.pubsub
.publish(tonic::Request::new(request))
.await
.map_err(|e| Error::Tonic(Box::new(e)))
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn subscribe(
&mut self,
request: salesforce_core_pubsubapi::eventbus::v1::FetchRequest,
) -> Result<
tonic::Response<
tonic::codec::Streaming<salesforce_core_pubsubapi::eventbus::v1::FetchResponse>,
>,
Error,
> {
self.pubsub
.subscribe(
tokio_stream::iter(1..usize::MAX)
.map(move |_| request.to_owned())
.throttle(std::time::Duration::from_millis(10)),
)
.await
.map_err(|e| Error::Tonic(Box::new(e)))
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn managed_subscribe(
&mut self,
request: salesforce_core_pubsubapi::eventbus::v1::ManagedFetchRequest,
) -> Result<
tonic::Response<
tonic::codec::Streaming<salesforce_core_pubsubapi::eventbus::v1::ManagedFetchResponse>,
>,
Error,
> {
self.pubsub
.managed_subscribe(
tokio_stream::iter(1..usize::MAX)
.map(move |_| request.to_owned())
.throttle(std::time::Duration::from_millis(10)),
)
.await
.map_err(|e| Error::Tonic(Box::new(e)))
}
#[cfg_attr(feature = "trace", tracing::instrument(skip_all))]
pub async fn publish_stream(
&mut self,
request: salesforce_core_pubsubapi::eventbus::v1::PublishRequest,
) -> Result<
tonic::Response<
tonic::codec::Streaming<salesforce_core_pubsubapi::eventbus::v1::PublishResponse>,
>,
Error,
> {
self.pubsub
.publish_stream(
tokio_stream::iter(1..usize::MAX)
.map(move |_| request.to_owned())
.throttle(std::time::Duration::from_millis(10)),
)
.await
.map_err(|e| Error::Tonic(Box::new(e)))
}
}
#[cfg(test)]
mod tests {
use std::{fs, path::PathBuf};
use super::*;
use tonic::service::Interceptor;
#[tokio::test]
async fn test_new_missing_token() {
let channel = tonic::transport::Channel::from_static("https://api.pubsub.salesforce.com")
.connect()
.await
.unwrap();
let creds: &str = r#"
{
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"instance_url": "https://mydomain.salesforce.com",
"tenant_id": "some_tenant_id"
}"#;
let mut path = PathBuf::new();
path.push("credentials.json");
let _ = fs::write(path.clone(), creds);
let client = client::Builder::new()
.credentials_path(path.clone())
.build()
.unwrap();
let _ = fs::remove_file(path);
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::MissingTokenResponse)));
}
#[tokio::test]
async fn test_new_missing_instance_url() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let creds: &str = r#"
{
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"instance_url": "https://mydomain.salesforce.com",
"tenant_id": "some_tenant_id"
}"#;
let mut path = PathBuf::new();
path.push(format!("test_creds_{}.json", std::process::id()));
let _ = fs::write(path.clone(), creds);
let mut client = client::Builder::new()
.credentials_path(path.clone())
.build()
.unwrap();
let _ = fs::remove_file(path);
let token = BasicTokenResponse::new(
AccessToken::new("test_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = None; client.tenant_id = Some("test_tenant".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::MissingInstanceUrl)));
}
#[tokio::test]
async fn test_new_missing_tenant_id() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let creds: &str = r#"
{
"client_id": "some_client_id",
"client_secret": "some_client_secret",
"instance_url": "https://mydomain.salesforce.com",
"tenant_id": "some_tenant_id"
}"#;
let mut path = PathBuf::new();
path.push(format!("test_creds_tenant_{}.json", std::process::id()));
let _ = fs::write(path.clone(), creds);
let mut client = client::Builder::new()
.credentials_path(path.clone())
.build()
.unwrap();
let _ = fs::remove_file(path);
let token = BasicTokenResponse::new(
AccessToken::new("test_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://mydomain.salesforce.com".to_string());
client.tenant_id = None;
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::MissingTenantId)));
}
#[tokio::test]
async fn test_new_with_valid_client() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_new_with_invalid_token_characters() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("token\nwith\nnewlines".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::InvalidMetadataValue { .. })));
}
#[tokio::test]
async fn test_new_with_invalid_instance_url_characters() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("url\nwith\nnewlines".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::InvalidMetadataValue { .. })));
}
#[tokio::test]
async fn test_new_with_invalid_tenant_id_characters() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant\nwith\nnewlines".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::InvalidMetadataValue { .. })));
}
#[test]
fn test_interceptor_adds_headers() {
let auth_header = tonic::metadata::AsciiMetadataValue::try_from("test_token").unwrap();
let instance_url =
tonic::metadata::AsciiMetadataValue::try_from("https://test.salesforce.com").unwrap();
let tenant_id = tonic::metadata::AsciiMetadataValue::try_from("test_tenant").unwrap();
let mut interceptor = ContextInterceptor {
auth_header,
instance_url,
tenant_id,
};
let request = tonic::Request::new(());
let result = interceptor.call(request);
assert!(result.is_ok());
let request = result.unwrap();
let metadata = request.metadata();
assert_eq!(metadata.get("accesstoken").unwrap(), "test_token");
assert_eq!(
metadata.get("instanceurl").unwrap(),
"https://test.salesforce.com"
);
assert_eq!(metadata.get("tenantid").unwrap(), "test_tenant");
}
#[tokio::test]
async fn test_context_creation_success_path() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "client123".to_string(),
client_secret: Some("secret123".to_string()),
username: None,
password: None,
instance_url: "https://login.salesforce.com".to_string(),
tenant_id: "00Dxx0000001gPL".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_access_token_123".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://login.salesforce.com".to_string());
client.tenant_id = Some("00Dxx0000001gPL".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let context = Client::new(channel, client);
assert!(context.is_ok());
}
#[tokio::test]
async fn test_context_with_special_characters_in_token() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test".to_string(),
client_secret: Some("secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("abc123-xyz_789.token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_all_missing_fields() {
let client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test".to_string(),
client_secret: Some("secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "tenant".to_string(),
})
.build()
.unwrap();
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(matches!(result, Err(Error::MissingTokenResponse)));
}
#[tokio::test]
async fn test_reconnect_without_token_state() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let mut context = Client::new(channel, client).unwrap();
context.client.token_state = None;
let result = context.reconnect().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reconnect_verifies_fields_after_reconnect() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let mut context = Client::new(channel, client).unwrap();
context.client.instance_url = None;
let result = context.reconnect().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reconnect_with_invalid_metadata_after_reconnect() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let mut context = Client::new(channel, client).unwrap();
context.client.tenant_id = Some("tenant\nwith\nnewlines".to_string());
let result = context.reconnect().await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_reconnect_preserves_channel() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("initial_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let result = Client::new(channel, client);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_reconnect_without_credentials() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "test_id".to_string(),
client_secret: Some("test_secret".to_string()),
username: None,
password: None,
instance_url: "https://test.salesforce.com".to_string(),
tenant_id: "test_tenant".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_token".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://test.salesforce.com".to_string());
client.tenant_id = Some("tenant123".to_string());
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let mut context = Client::new(channel, client).unwrap();
let result = context.reconnect().await;
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, Error::Tonic(_)));
}
}
#[tokio::test]
async fn test_context_stores_client_and_channel() {
use oauth2::basic::BasicTokenResponse;
use oauth2::{AccessToken, EmptyExtraTokenFields};
let mut client = client::Builder::new()
.credentials(client::Credentials {
client_id: "client123".to_string(),
client_secret: Some("secret123".to_string()),
username: None,
password: None,
instance_url: "https://login.salesforce.com".to_string(),
tenant_id: "00Dxx0000001gPL".to_string(),
})
.build()
.unwrap();
let token = BasicTokenResponse::new(
AccessToken::new("valid_access_token_123".to_string()),
oauth2::basic::BasicTokenType::Bearer,
EmptyExtraTokenFields {},
);
let token_state = crate::client::TokenState::new(token).unwrap();
client.token_state = Some(std::sync::Arc::new(std::sync::RwLock::new(token_state)));
client.instance_url = Some("https://login.salesforce.com".to_string());
client.tenant_id = Some("00Dxx0000001gPL".to_string());
let expected_tenant = client.tenant_id.clone().unwrap();
let endpoint = tonic::transport::Endpoint::from_static("http://localhost:50051");
let channel = endpoint.connect_lazy();
let context = Client::new(channel, client).unwrap();
assert_eq!(context.client.tenant_id.as_ref().unwrap(), &expected_tenant);
assert!(context.client.instance_url.is_some());
assert!(context.client.token_state.is_some());
}
}