use crate::grpc::conversions::{
datetime_to_timestamp, json_to_struct, parse_uuid, proto_to_task_state,
task_service_error_to_status, task_state_to_proto,
};
use crate::grpc::interceptors::AuthInterceptor;
use crate::grpc::state::GrpcState;
use crate::services::TaskServiceError;
use tasker_shared::models::core::task::TaskListQuery;
use tasker_shared::models::core::task_request::TaskRequest;
use tasker_shared::proto::v1::{
self as proto, task_service_server::TaskService as TaskServiceTrait,
};
use tasker_shared::types::{Permission, SecurityContext};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status};
use tracing::{debug, info, warn};
#[derive(Debug)]
pub struct TaskServiceImpl {
state: GrpcState,
auth_interceptor: AuthInterceptor,
}
impl TaskServiceImpl {
pub fn new(state: GrpcState) -> Self {
let auth_interceptor = AuthInterceptor::new(state.services.security_service.clone());
Self {
state,
auth_interceptor,
}
}
async fn authenticate_and_authorize<T>(
&self,
request: &Request<T>,
required_permission: Permission,
) -> Result<SecurityContext, Status> {
let ctx = self.auth_interceptor.authenticate(request).await?;
if !ctx.has_permission(&required_permission) {
return Err(Status::permission_denied(
"Insufficient permissions for this operation",
));
}
Ok(ctx)
}
fn check_backpressure(&self) -> Result<(), Status> {
if let Some(reason) = self.state.check_backpressure() {
let mut status = Status::unavailable(format!("Backpressure active: {reason}"));
status
.metadata_mut()
.insert("retry-after", "5".parse().unwrap());
return Err(status);
}
Ok(())
}
}
#[tonic::async_trait]
impl TaskServiceTrait for TaskServiceImpl {
async fn create_task(
&self,
request: Request<proto::CreateTaskRequest>,
) -> Result<Response<proto::CreateTaskResponse>, Status> {
let _ctx = self
.authenticate_and_authorize(&request, Permission::TasksCreate)
.await?;
self.check_backpressure()?;
let req = request.into_inner();
debug!(
name = %req.name,
namespace = %req.namespace,
version = %req.version,
"gRPC create task"
);
let context = req
.context
.map(crate::grpc::conversions::struct_to_json)
.unwrap_or_else(|| serde_json::json!({}));
let task_request = TaskRequest::builder()
.name(req.name)
.namespace(req.namespace)
.version(req.version)
.context(context)
.initiator(req.initiator.unwrap_or_else(|| "grpc".to_string()))
.reason(
req.reason
.unwrap_or_else(|| "Task created via gRPC".to_string()),
)
.tags(req.tags)
.build();
let result = self
.state
.services
.task_service
.create_task(task_request)
.await;
match result {
Ok(response) => {
info!(task_uuid = %response.task_uuid, "Task created via gRPC");
Ok(Response::new(proto::CreateTaskResponse {
task: Some(proto::Task::from(&response)),
backpressure: None,
}))
}
Err(e) => Err(task_service_error_to_status(&e)),
}
}
async fn get_task(
&self,
request: Request<proto::GetTaskRequest>,
) -> Result<Response<proto::GetTaskResponse>, Status> {
let _ctx = self
.authenticate_and_authorize(&request, Permission::TasksRead)
.await?;
let req = request.into_inner();
let task_id = parse_uuid(&req.task_uuid)?;
debug!(task_id = %task_id, "gRPC get task");
let result = self.state.services.task_service.get_task(task_id).await;
match result {
Ok(response) => {
let context = if req.include_context {
Some(proto::TaskContext {
inputs: json_to_struct(response.context.clone()),
outputs: json_to_struct(serde_json::json!({})),
merged: json_to_struct(response.context.clone()),
})
} else {
None
};
Ok(Response::new(proto::GetTaskResponse {
task: Some(proto::Task::from(&response)),
steps: vec![], context,
}))
}
Err(e) => Err(task_service_error_to_status(&e)),
}
}
async fn list_tasks(
&self,
request: Request<proto::ListTasksRequest>,
) -> Result<Response<proto::ListTasksResponse>, Status> {
let _ctx = self
.authenticate_and_authorize(&request, Permission::TasksList)
.await?;
let req = request.into_inner();
debug!("gRPC list tasks");
let page = req
.pagination
.as_ref()
.and_then(|p| p.offset)
.map(|o| (o / req.pagination.as_ref().and_then(|p| p.limit).unwrap_or(50) + 1) as u32)
.unwrap_or(1);
let per_page = req
.pagination
.as_ref()
.and_then(|p| p.limit)
.map(|l| l as u32)
.unwrap_or(50);
let status = if !req.states.is_empty() {
req.states
.first()
.and_then(|s| proto::TaskState::try_from(*s).ok().map(proto_to_task_state))
} else {
None
};
let query = TaskListQuery {
page,
per_page,
namespace: req.namespace,
status,
initiator: None,
source_system: None,
};
let result = self.state.services.task_service.list_tasks(query).await;
match result {
Ok(response) => {
let tasks: Vec<proto::Task> =
response.tasks.iter().map(proto::Task::from).collect();
let count = tasks.len() as i32;
let total = response.pagination.total_count as i64;
let offset = ((response.pagination.page - 1) * response.pagination.per_page) as i32;
Ok(Response::new(proto::ListTasksResponse {
tasks,
pagination: Some(proto::PaginationResponse {
total,
count,
offset,
has_more: response.pagination.has_next,
}),
}))
}
Err(e) => Err(task_service_error_to_status(&e)),
}
}
async fn cancel_task(
&self,
request: Request<proto::CancelTaskRequest>,
) -> Result<Response<proto::CancelTaskResponse>, Status> {
let _ctx = self
.authenticate_and_authorize(&request, Permission::TasksCancel)
.await?;
let req = request.into_inner();
let task_id = parse_uuid(&req.task_uuid)?;
info!(task_id = %task_id, reason = ?req.reason, "gRPC cancel task");
let result = self.state.services.task_service.cancel_task(task_id).await;
match result {
Ok(response) => Ok(Response::new(proto::CancelTaskResponse {
task: Some(proto::Task::from(&response)),
success: true,
message: None,
})),
Err(e) => {
if matches!(&e, TaskServiceError::CannotCancel(_)) {
Ok(Response::new(proto::CancelTaskResponse {
task: None,
success: false,
message: Some(e.to_string()),
}))
} else {
Err(task_service_error_to_status(&e))
}
}
}
}
async fn get_task_context(
&self,
request: Request<proto::GetTaskContextRequest>,
) -> Result<Response<proto::GetTaskContextResponse>, Status> {
let _ctx = self
.authenticate_and_authorize(&request, Permission::TasksContextRead)
.await?;
let req = request.into_inner();
let task_id = parse_uuid(&req.task_uuid)?;
debug!(task_id = %task_id, "gRPC get task context");
let result = self.state.services.task_service.get_task(task_id).await;
match result {
Ok(response) => Ok(Response::new(proto::GetTaskContextResponse {
context: Some(proto::TaskContext {
inputs: json_to_struct(response.context.clone()),
outputs: json_to_struct(serde_json::json!({})),
merged: json_to_struct(response.context),
}),
})),
Err(e) => Err(task_service_error_to_status(&e)),
}
}
type StreamTaskStatusStream = ReceiverStream<Result<proto::TaskStatusUpdate, Status>>;
async fn stream_task_status(
&self,
request: Request<proto::StreamTaskStatusRequest>,
) -> Result<Response<Self::StreamTaskStatusStream>, Status> {
let _ctx = self
.authenticate_and_authorize(&request, Permission::TasksRead)
.await?;
let req = request.into_inner();
let task_id = parse_uuid(&req.task_uuid)?;
info!(task_id = %task_id, "Starting task status stream");
let (tx, rx) = tokio::sync::mpsc::channel(10);
let state = self.state.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(1));
let mut last_state: Option<String> = None;
loop {
interval.tick().await;
let result = state.services.task_service.get_task(task_id).await;
match result {
Ok(response) => {
let current_state = response.status.clone();
if last_state.as_ref() != Some(¤t_state) {
let proto_state = task_state_to_proto(¤t_state);
let update = proto::TaskStatusUpdate {
update_type: proto::task_status_update::UpdateType::TaskStateChange
as i32,
timestamp: Some(datetime_to_timestamp(chrono::Utc::now())),
task_state: Some(proto_state as i32),
step_update: None,
error_message: None,
};
if tx.send(Ok(update)).await.is_err() {
break;
}
last_state = Some(current_state.clone());
if current_state == "complete"
|| current_state == "error"
|| current_state == "cancelled"
{
break;
}
}
}
Err(e) => {
warn!(error = %e, "Error getting task for status stream");
let _ = tx.send(Err(task_service_error_to_status(&e))).await;
break;
}
}
}
});
Ok(Response::new(ReceiverStream::new(rx)))
}
}