use std::sync::Arc;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::Channel;
use tracing::{debug, instrument};
use crate::auth::{ChannelAuthenticator, ChannelIdInterceptor, SaslStreamGuard};
use crate::client::master_inquire::{create_master_inquire_client, MasterInquireClient};
use crate::config::GooseFsConfig;
use crate::error::{Error, Result};
use crate::proto::grpc::block::{
worker_manager_master_client_service_client::WorkerManagerMasterClientServiceClient,
GetWorkerInfoListPOptions, WorkerInfo,
};
type AuthenticatedWorkerMgrClient =
WorkerManagerMasterClientServiceClient<InterceptedService<Channel, ChannelIdInterceptor>>;
#[derive(Clone)]
pub struct WorkerManagerClient {
inner: AuthenticatedWorkerMgrClient,
_sasl_guard: Arc<Option<SaslStreamGuard>>,
}
impl WorkerManagerClient {
pub async fn connect(config: &GooseFsConfig) -> Result<Self> {
let inquire_client = create_master_inquire_client(config);
Self::connect_with_inquire(config, inquire_client).await
}
pub async fn connect_with_inquire(
config: &GooseFsConfig,
inquire_client: Arc<dyn MasterInquireClient>,
) -> Result<Self> {
let primary_addr = inquire_client.get_primary_rpc_address().await?;
let endpoint_uri = format!("http://{}", primary_addr);
let endpoint = Channel::from_shared(endpoint_uri)
.map_err(|e| Error::ConfigError {
message: format!("invalid master endpoint: {}", e),
})?
.connect_timeout(config.connect_timeout)
.timeout(config.request_timeout);
let channel = endpoint.connect().await?;
let authenticator =
ChannelAuthenticator::new(config.auth_type, config.auth_username.clone(), None)
.with_auth_timeout(config.auth_timeout);
let mut auth_channel = authenticator.authenticate(channel).await?;
let sasl_guard = auth_channel.take_sasl_guard();
debug!(addr = %primary_addr, auth_type = %config.auth_type, "connected to WorkerManagerMasterClientService");
Ok(Self {
inner: WorkerManagerMasterClientServiceClient::new(auth_channel.channel),
_sasl_guard: Arc::new(sasl_guard),
})
}
pub fn from_channel(channel: Channel) -> Self {
let interceptor = ChannelIdInterceptor::new("test-no-auth".to_string());
let intercepted = InterceptedService::new(channel, interceptor);
Self {
inner: WorkerManagerMasterClientServiceClient::new(intercepted),
_sasl_guard: Arc::new(None),
}
}
#[instrument(skip(self))]
pub async fn get_worker_info_list(&self) -> Result<Vec<WorkerInfo>> {
let req = GetWorkerInfoListPOptions {};
let resp = self.inner.clone().get_worker_info_list(req).await?;
Ok(resp.into_inner().worker_infos)
}
}