use std::sync::Arc;
use tonic::{Request, Response, Status};
use crate::error::{DaimonError, Result};
use super::broker::{ErasedTaskBroker, TaskBroker};
use super::types::{AgentTask, TaskResult, TaskStatus};
pub mod proto {
tonic::include_proto!("daimon.distributed");
}
use proto::task_broker_service_client::TaskBrokerServiceClient;
use proto::task_broker_service_server::{TaskBrokerService, TaskBrokerServiceServer};
pub struct GrpcBrokerServer {
broker: Arc<dyn ErasedTaskBroker>,
}
impl GrpcBrokerServer {
pub fn new<B: TaskBroker + 'static>(broker: B) -> Self {
Self {
broker: Arc::new(broker),
}
}
pub fn from_erased(broker: Arc<dyn ErasedTaskBroker>) -> Self {
Self { broker }
}
pub async fn serve(self, addr: impl Into<String>) -> Result<()> {
let addr = addr
.into()
.parse()
.map_err(|e| DaimonError::Other(format!("invalid address: {e}")))?;
let svc = GrpcBrokerSvc {
broker: self.broker,
};
tonic::transport::Server::builder()
.add_service(TaskBrokerServiceServer::new(svc))
.serve(addr)
.await
.map_err(|e| DaimonError::Other(format!("grpc server: {e}")))?;
Ok(())
}
}
struct GrpcBrokerSvc {
broker: Arc<dyn ErasedTaskBroker>,
}
#[tonic::async_trait]
impl TaskBrokerService for GrpcBrokerSvc {
async fn submit(
&self,
request: Request<proto::SubmitRequest>,
) -> std::result::Result<Response<proto::SubmitResponse>, Status> {
let req = request.into_inner();
let task: AgentTask = serde_json::from_str(&req.task_json)
.map_err(|e| Status::invalid_argument(format!("invalid task json: {e}")))?;
let task_id = self
.broker
.submit_erased(task)
.await
.map_err(|e| Status::internal(e.to_string()))?;
Ok(Response::new(proto::SubmitResponse { task_id }))
}
async fn get_status(
&self,
request: Request<proto::StatusRequest>,
) -> std::result::Result<Response<proto::StatusResponse>, Status> {
let req = request.into_inner();
let status = self
.broker
.status_erased(&req.task_id)
.await
.map_err(|e| Status::internal(e.to_string()))?;
let status_json = serde_json::to_string(&status)
.map_err(|e| Status::internal(format!("serialize status: {e}")))?;
Ok(Response::new(proto::StatusResponse { status_json }))
}
async fn complete(
&self,
request: Request<proto::CompleteRequest>,
) -> std::result::Result<Response<proto::Empty>, Status> {
let req = request.into_inner();
let result: TaskResult = serde_json::from_str(&req.result_json)
.map_err(|e| Status::invalid_argument(format!("invalid result json: {e}")))?;
self.broker
.complete_erased(&req.task_id, result)
.await
.map_err(|e| Status::internal(e.to_string()))?;
Ok(Response::new(proto::Empty {}))
}
async fn fail(
&self,
request: Request<proto::FailRequest>,
) -> std::result::Result<Response<proto::Empty>, Status> {
let req = request.into_inner();
self.broker
.fail_erased(&req.task_id, req.error)
.await
.map_err(|e| Status::internal(e.to_string()))?;
Ok(Response::new(proto::Empty {}))
}
}
pub struct GrpcBrokerClient {
inner: tokio::sync::Mutex<TaskBrokerServiceClient<tonic::transport::Channel>>,
}
impl GrpcBrokerClient {
pub async fn connect(addr: impl Into<String>) -> Result<Self> {
let addr = addr.into();
let client = TaskBrokerServiceClient::connect(addr)
.await
.map_err(|e| DaimonError::Other(format!("grpc connect: {e}")))?;
Ok(Self {
inner: tokio::sync::Mutex::new(client),
})
}
}
impl TaskBroker for GrpcBrokerClient {
async fn submit(&self, task: AgentTask) -> Result<String> {
let task_json = serde_json::to_string(&task)
.map_err(|e| DaimonError::Other(format!("serialize task: {e}")))?;
let resp = self
.inner
.lock()
.await
.submit(Request::new(proto::SubmitRequest { task_json }))
.await
.map_err(|e| DaimonError::Other(format!("grpc submit: {e}")))?;
Ok(resp.into_inner().task_id)
}
async fn status(&self, task_id: &str) -> Result<TaskStatus> {
let resp = self
.inner
.lock()
.await
.get_status(Request::new(proto::StatusRequest {
task_id: task_id.to_string(),
}))
.await
.map_err(|e| DaimonError::Other(format!("grpc status: {e}")))?;
let status: TaskStatus = serde_json::from_str(&resp.into_inner().status_json)
.map_err(|e| DaimonError::Other(format!("deserialize status: {e}")))?;
Ok(status)
}
async fn receive(&self) -> Result<Option<AgentTask>> {
Err(DaimonError::Other(
"receive() is not supported over gRPC; use TaskWorker on the server side".into(),
))
}
async fn complete(&self, task_id: &str, result: TaskResult) -> Result<()> {
let result_json = serde_json::to_string(&result)
.map_err(|e| DaimonError::Other(format!("serialize result: {e}")))?;
self.inner
.lock()
.await
.complete(Request::new(proto::CompleteRequest {
task_id: task_id.to_string(),
result_json,
}))
.await
.map_err(|e| DaimonError::Other(format!("grpc complete: {e}")))?;
Ok(())
}
async fn fail(&self, task_id: &str, error: String) -> Result<()> {
self.inner
.lock()
.await
.fail(Request::new(proto::FailRequest {
task_id: task_id.to_string(),
error,
}))
.await
.map_err(|e| DaimonError::Other(format!("grpc fail: {e}")))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distributed::InProcessBroker;
#[tokio::test]
async fn test_grpc_roundtrip() {
let broker = InProcessBroker::new(32);
let broker_clone = broker.clone();
let server_handle = tokio::spawn(async move {
GrpcBrokerServer::new(broker_clone)
.serve("[::1]:0")
.await
.ok();
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
server_handle.abort();
}
#[test]
fn test_proto_types_compile() {
let _ = proto::SubmitRequest {
task_json: "{}".into(),
};
let _ = proto::SubmitResponse {
task_id: "t-1".into(),
};
let _ = proto::StatusRequest {
task_id: "t-1".into(),
};
let _ = proto::StatusResponse {
status_json: "\"Pending\"".into(),
};
let _ = proto::CompleteRequest {
task_id: "t-1".into(),
result_json: "{}".into(),
};
let _ = proto::FailRequest {
task_id: "t-1".into(),
error: "oops".into(),
};
let _ = proto::Empty {};
}
#[tokio::test]
async fn test_grpc_server_and_client() {
let broker = InProcessBroker::new(32);
let broker_for_server = broker.clone();
let listener = tokio::net::TcpListener::bind("[::1]:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let _server_handle = tokio::spawn(async move {
GrpcBrokerServer::new(broker_for_server)
.serve(addr.to_string())
.await
.ok();
});
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
let client = GrpcBrokerClient::connect(format!("http://{addr}"))
.await
.unwrap();
let task = AgentTask::new("test via grpc");
let task_id = client.submit(task).await.unwrap();
assert!(!task_id.is_empty());
let status = client.status(&task_id).await.unwrap();
assert!(
matches!(status, TaskStatus::Pending | TaskStatus::Running),
"expected pending or running, got {status:?}"
);
}
}