use std::sync::Arc;
use std::time::Duration;
use futures::StreamExt;
use futures::stream::BoxStream;
use tonic::transport::{Channel, ClientTlsConfig};
use tracing::instrument;
use crate::api::{
CancellationResult, EventStreamMessage, JobCancelRequest, JobSetCancelRequest, JobSetRequest,
JobSubmitRequest, JobSubmitResponse, event_client::EventClient, submit_client::SubmitClient,
};
use crate::auth::TokenProvider;
use crate::error::Error;
#[derive(Clone)]
pub struct ArmadaClient {
submit_client: SubmitClient<Channel>,
event_client: EventClient<Channel>,
token_provider: Arc<dyn TokenProvider + Send + Sync>,
timeout: Option<Duration>,
}
impl ArmadaClient {
pub async fn connect(
endpoint: impl Into<String>,
token_provider: impl TokenProvider + 'static,
) -> Result<Self, Error> {
let channel = Channel::from_shared(endpoint.into())
.map_err(|e| Error::InvalidUri(e.to_string()))?
.connect()
.await?;
Ok(Self::from_parts(channel, token_provider))
}
pub async fn connect_tls(
endpoint: impl Into<String>,
token_provider: impl TokenProvider + 'static,
) -> Result<Self, Error> {
let channel = Channel::from_shared(endpoint.into())
.map_err(|e| Error::InvalidUri(e.to_string()))?
.tls_config(ClientTlsConfig::new())?
.connect()
.await?;
Ok(Self::from_parts(channel, token_provider))
}
fn from_parts(channel: Channel, token_provider: impl TokenProvider + 'static) -> Self {
Self {
submit_client: SubmitClient::new(channel.clone()),
event_client: EventClient::new(channel),
token_provider: Arc::new(token_provider),
timeout: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
fn apply_timeout<T>(&self, req: &mut tonic::Request<T>) {
if let Some(t) = self.timeout {
req.set_timeout(t);
}
}
#[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
pub async fn submit(&self, request: JobSubmitRequest) -> Result<JobSubmitResponse, Error> {
let token = self.token_provider.token().await?;
let mut req = tonic::Request::new(request);
if !token.is_empty() {
req.metadata_mut().insert("authorization", token.parse()?);
}
self.apply_timeout(&mut req);
let resp = self.submit_client.clone().submit_jobs(req).await?;
Ok(resp.into_inner())
}
#[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
pub async fn cancel_jobs(
&self,
request: JobCancelRequest,
) -> Result<CancellationResult, Error> {
let token = self.token_provider.token().await?;
let mut req = tonic::Request::new(request);
if !token.is_empty() {
req.metadata_mut().insert("authorization", token.parse()?);
}
self.apply_timeout(&mut req);
let resp = self.submit_client.clone().cancel_jobs(req).await?;
Ok(resp.into_inner())
}
#[instrument(skip(self, request), fields(queue = %request.queue, job_set_id = %request.job_set_id))]
pub async fn cancel_job_set(&self, request: JobSetCancelRequest) -> Result<(), Error> {
let token = self.token_provider.token().await?;
let mut req = tonic::Request::new(request);
if !token.is_empty() {
req.metadata_mut().insert("authorization", token.parse()?);
}
self.apply_timeout(&mut req);
self.submit_client.clone().cancel_job_set(req).await?;
Ok(())
}
#[instrument(skip_all, fields(queue, job_set_id))]
pub async fn watch(
&self,
queue: impl Into<String>,
job_set_id: impl Into<String>,
from_message_id: Option<String>,
) -> Result<BoxStream<'static, Result<EventStreamMessage, Error>>, Error> {
let queue: String = queue.into();
let job_set_id: String = job_set_id.into();
tracing::Span::current()
.record("queue", queue.as_str())
.record("job_set_id", job_set_id.as_str());
let token = self.token_provider.token().await?;
let job_set_request = JobSetRequest {
id: job_set_id,
queue,
from_message_id: from_message_id.unwrap_or_default(),
watch: true,
error_if_missing: false,
};
let mut req = tonic::Request::new(job_set_request);
if !token.is_empty() {
req.metadata_mut().insert("authorization", token.parse()?);
}
self.apply_timeout(&mut req);
let stream = self
.event_client
.clone()
.get_job_set_events(req)
.await?
.into_inner();
Ok(Box::pin(stream.map(|r| r.map_err(Error::from))))
}
}