use crate::{
channel::{AsyncChannel, AsyncChannelReceiver, AsyncChannelSender},
runner::{self, file::ReadWorkflow, response, RpcSender},
settings,
};
use faststr::FastStr;
use futures::{future, StreamExt};
use std::{io, net::SocketAddr, sync::Arc, time::Duration};
use stream_cancel::Valved;
use tarpc::{
client::{self, RpcError},
context,
server::{self, incoming::Incoming, Channel},
};
use tokio::{runtime::Handle, select, time};
use tokio_serde::formats::MessagePack;
use tracing::{info, warn};
mod error;
pub use error::Error;
#[derive(Debug)]
pub(crate) enum ServerMessage {
ShutdownCmd,
GracefulShutdown(AsyncChannelSender<()>),
Run((Option<FastStr>, ReadWorkflow)),
RunAck(Box<response::AckWorkflow>),
RunErr(runner::Error),
NodeInfo,
NodeInfoAck(response::AckNodeInfo),
Skip,
}
#[tarpc::service]
pub(crate) trait Interface {
async fn run(
name: Option<FastStr>,
workflow_file: ReadWorkflow,
) -> Result<Box<response::AckWorkflow>, Error>;
async fn ping() -> String;
async fn stop() -> Result<(), Error>;
async fn node_info() -> Result<response::AckNodeInfo, Error>;
}
#[derive(Debug, Clone)]
pub(crate) struct Server {
pub(crate) addr: SocketAddr,
pub(crate) sender: Arc<AsyncChannelSender<ServerMessage>>,
pub(crate) receiver: AsyncChannelReceiver<ServerMessage>,
pub(crate) runner_sender: Arc<RpcSender>,
pub(crate) max_connections: usize,
pub(crate) timeout: Duration,
}
#[derive(Debug, Clone)]
pub struct Client {
cli: InterfaceClient,
addr: SocketAddr,
ctx: context::Context,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct ServerHandler {
addr: SocketAddr,
runner_sender: Arc<RpcSender>,
timeout: Duration,
}
impl ServerHandler {
fn new(addr: SocketAddr, runner_sender: Arc<RpcSender>, timeout: Duration) -> Self {
Self {
addr,
runner_sender,
timeout,
}
}
}
#[tarpc::server]
impl Interface for ServerHandler {
async fn run(
self,
_: context::Context,
name: Option<FastStr>,
workflow_file: ReadWorkflow,
) -> Result<Box<response::AckWorkflow>, Error> {
let (tx, rx) = AsyncChannel::oneshot();
self.runner_sender
.send_async((ServerMessage::Run((name, workflow_file)), Some(tx)))
.await
.map_err(|e| Error::FailureToSendOnChannel(e.to_string()))?;
let now = time::Instant::now();
select! {
Ok(msg) = rx.recv_async() => {
match msg {
ServerMessage::RunAck(response) => {
Ok(response)
}
ServerMessage::RunErr(err) => Err(err).map_err(|e| Error::FromRunner(e.to_string()))?,
_ => Err(Error::FailureToSendOnChannel("unexpected message".into())),
}
},
_ = time::sleep_until(now + self.timeout) => {
let s = format!("server timeout of {} ms reached", self.timeout.as_millis());
info!(subject = "rpc.timeout",
category = "rpc",
"{s}");
Err(Error::FailureToReceiveOnChannel(s))
}
}
}
async fn ping(self, _: context::Context) -> String {
"pong".into()
}
async fn stop(self, _: context::Context) -> Result<(), Error> {
self.runner_sender
.send_async((ServerMessage::ShutdownCmd, None))
.await
.map_err(|e| Error::FailureToSendOnChannel(e.to_string()))
}
async fn node_info(self, _: context::Context) -> Result<response::AckNodeInfo, Error> {
let (tx, rx) = AsyncChannel::oneshot();
self.runner_sender
.send_async((ServerMessage::NodeInfo, Some(tx)))
.await
.map_err(|e| Error::FailureToSendOnChannel(e.to_string()))?;
let now = time::Instant::now();
select! {
Ok(msg) = rx.recv_async() => {
match msg {
ServerMessage::NodeInfoAck(response) => {
println!("response: {:?}", response);
Ok(response)
}
_ => Err(Error::FailureToSendOnChannel("unexpected message".into())),
}
},
_ = time::sleep_until(now + self.timeout) => {
let s = format!("server timeout of {} ms reached", self.timeout.as_millis());
info!(subject = "rpc.timeout",
category = "rpc",
"{s}");
Err(Error::FailureToReceiveOnChannel(s))
}
}
}
}
impl Server {
pub(crate) fn new(settings: &settings::Network, runner_sender: Arc<RpcSender>) -> Self {
let (tx, rx) = AsyncChannel::oneshot();
Self {
addr: SocketAddr::new(settings.rpc.host, settings.rpc.port),
sender: tx.into(),
receiver: rx,
runner_sender,
max_connections: settings.rpc.max_connections,
timeout: settings.rpc.server_timeout,
}
}
pub(crate) fn sender(&self) -> Arc<AsyncChannelSender<ServerMessage>> {
self.sender.clone()
}
pub(crate) async fn spawn(self) -> anyhow::Result<()> {
let mut listener =
tarpc::serde_transport::tcp::listen(self.addr, MessagePack::default).await?;
listener.config_mut().max_frame_length(usize::MAX);
info!(
subject = "rpc.spawn",
category = "rpc",
"RPC server listening on {}",
self.addr
);
let (exit, incoming) = Valved::new(listener);
let runtime_handle = Handle::current();
runtime_handle.spawn(async move {
let fut = incoming
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap_or(self.addr).ip())
.map(|channel| {
let handler =
ServerHandler::new(self.addr, self.runner_sender.clone(), self.timeout);
channel.execute(handler.serve())
})
.buffer_unordered(self.max_connections)
.for_each(|_| async {});
select! {
Ok(ServerMessage::GracefulShutdown(tx)) = self.receiver.recv_async() => {
info!(subject = "shutdown",
category = "homestar.shutdown",
"RPC server shutting down");
drop(exit);
let _ = tx.send_async(()).await;
}
_ = fut =>
warn!(subject = "rpc.spawn.err",
category = "rpc",
"RPC server exited unexpectedly"),
}
});
Ok(())
}
}
impl Client {
pub async fn new(addr: SocketAddr, ctx: context::Context) -> Result<Self, io::Error> {
let transport = tarpc::serde_transport::tcp::connect(addr, MessagePack::default).await?;
let client = InterfaceClient::new(client::Config::default(), transport).spawn();
Ok(Client {
cli: client,
addr,
ctx,
})
}
pub fn addr(&self) -> SocketAddr {
self.addr
}
pub async fn ping(&self) -> Result<String, RpcError> {
self.cli.ping(self.ctx).await
}
pub async fn stop(&self) -> Result<Result<(), Error>, RpcError> {
self.cli.stop(self.ctx).await
}
pub async fn node_info(&self) -> Result<Result<response::AckNodeInfo, Error>, RpcError> {
self.cli.node_info(self.ctx).await
}
pub async fn run(
&self,
name: Option<FastStr>,
workflow_file: ReadWorkflow,
) -> Result<Result<Box<response::AckWorkflow>, Error>, RpcError> {
self.cli.run(self.ctx, name, workflow_file).await
}
}