homestar_runtime/network/
rpc.rs1use crate::{
4 channel::{AsyncChannel, AsyncChannelReceiver, AsyncChannelSender},
5 runner::{self, file::ReadWorkflow, response, RpcSender},
6 settings,
7};
8use faststr::FastStr;
9use futures::{future, StreamExt};
10use std::{io, net::SocketAddr, sync::Arc, time::Duration};
11use stream_cancel::Valved;
12use tarpc::{
13 client::{self, RpcError},
14 context,
15 server::{self, incoming::Incoming, Channel},
16};
17use tokio::{runtime::Handle, select, time};
18use tokio_serde::formats::MessagePack;
19use tracing::{info, warn};
20
21mod error;
22pub use error::Error;
23
24#[derive(Debug)]
29pub(crate) enum ServerMessage {
30 ShutdownCmd,
34 GracefulShutdown(AsyncChannelSender<()>),
38 Run((Option<FastStr>, ReadWorkflow)),
42 RunAck(Box<response::AckWorkflow>),
46 RunErr(runner::Error),
50 NodeInfo,
54 NodeInfoAck(response::AckNodeInfo),
56 Skip,
58}
59
60#[tarpc::service]
62pub(crate) trait Interface {
63 async fn run(
65 name: Option<FastStr>,
66 workflow_file: ReadWorkflow,
67 ) -> Result<Box<response::AckWorkflow>, Error>;
68 async fn ping() -> String;
70 async fn stop() -> Result<(), Error>;
72 async fn node_info() -> Result<response::AckNodeInfo, Error>;
74}
75
76#[derive(Debug, Clone)]
78pub(crate) struct Server {
79 pub(crate) addr: SocketAddr,
81 pub(crate) sender: Arc<AsyncChannelSender<ServerMessage>>,
83 pub(crate) receiver: AsyncChannelReceiver<ServerMessage>,
85 pub(crate) runner_sender: Arc<RpcSender>,
89 pub(crate) max_connections: usize,
91 pub(crate) timeout: Duration,
93}
94
95#[derive(Debug, Clone)]
97pub struct Client {
98 cli: InterfaceClient,
99 addr: SocketAddr,
100 ctx: context::Context,
101}
102
103#[derive(Debug, Clone)]
105#[allow(dead_code)]
106struct ServerHandler {
107 addr: SocketAddr,
108 runner_sender: Arc<RpcSender>,
109 timeout: Duration,
110}
111
112impl ServerHandler {
113 fn new(addr: SocketAddr, runner_sender: Arc<RpcSender>, timeout: Duration) -> Self {
114 Self {
115 addr,
116 runner_sender,
117 timeout,
118 }
119 }
120}
121
122#[tarpc::server]
123impl Interface for ServerHandler {
124 async fn run(
125 self,
126 _: context::Context,
127 name: Option<FastStr>,
128 workflow_file: ReadWorkflow,
129 ) -> Result<Box<response::AckWorkflow>, Error> {
130 let (tx, rx) = AsyncChannel::oneshot();
131 self.runner_sender
132 .send_async((ServerMessage::Run((name, workflow_file)), Some(tx)))
133 .await
134 .map_err(|e| Error::FailureToSendOnChannel(e.to_string()))?;
135
136 let now = time::Instant::now();
137 select! {
138 Ok(msg) = rx.recv_async() => {
139 match msg {
140 ServerMessage::RunAck(response) => {
141 Ok(response)
142 }
143 ServerMessage::RunErr(err) => Err(err).map_err(|e| Error::FromRunner(e.to_string()))?,
144 _ => Err(Error::FailureToSendOnChannel("unexpected message".into())),
145 }
146 },
147 _ = time::sleep_until(now + self.timeout) => {
148 let s = format!("server timeout of {} ms reached", self.timeout.as_millis());
149 info!(subject = "rpc.timeout",
150 category = "rpc",
151 "{s}");
152 Err(Error::FailureToReceiveOnChannel(s))
153 }
154
155 }
156 }
157 async fn ping(self, _: context::Context) -> String {
158 "pong".into()
159 }
160 async fn stop(self, _: context::Context) -> Result<(), Error> {
161 self.runner_sender
162 .send_async((ServerMessage::ShutdownCmd, None))
163 .await
164 .map_err(|e| Error::FailureToSendOnChannel(e.to_string()))
165 }
166 async fn node_info(self, _: context::Context) -> Result<response::AckNodeInfo, Error> {
167 let (tx, rx) = AsyncChannel::oneshot();
168 self.runner_sender
169 .send_async((ServerMessage::NodeInfo, Some(tx)))
170 .await
171 .map_err(|e| Error::FailureToSendOnChannel(e.to_string()))?;
172
173 let now = time::Instant::now();
174 select! {
175 Ok(msg) = rx.recv_async() => {
176 match msg {
177 ServerMessage::NodeInfoAck(response) => {
178 println!("response: {:?}", response);
179 Ok(response)
180 }
181 _ => Err(Error::FailureToSendOnChannel("unexpected message".into())),
182 }
183 },
184 _ = time::sleep_until(now + self.timeout) => {
185 let s = format!("server timeout of {} ms reached", self.timeout.as_millis());
186 info!(subject = "rpc.timeout",
187 category = "rpc",
188 "{s}");
189 Err(Error::FailureToReceiveOnChannel(s))
190 }
191 }
192 }
193}
194
195impl Server {
196 pub(crate) fn new(settings: &settings::Network, runner_sender: Arc<RpcSender>) -> Self {
198 let (tx, rx) = AsyncChannel::oneshot();
199 Self {
200 addr: SocketAddr::new(settings.rpc.host, settings.rpc.port),
201 sender: tx.into(),
202 receiver: rx,
203 runner_sender,
204 max_connections: settings.rpc.max_connections,
205 timeout: settings.rpc.server_timeout,
206 }
207 }
208
209 pub(crate) fn sender(&self) -> Arc<AsyncChannelSender<ServerMessage>> {
211 self.sender.clone()
212 }
213
214 pub(crate) async fn spawn(self) -> anyhow::Result<()> {
216 let mut listener =
217 tarpc::serde_transport::tcp::listen(self.addr, MessagePack::default).await?;
218 listener.config_mut().max_frame_length(usize::MAX);
219
220 info!(
221 subject = "rpc.spawn",
222 category = "rpc",
223 "RPC server listening on {}",
224 self.addr
225 );
226
227 let (exit, incoming) = Valved::new(listener);
229
230 let runtime_handle = Handle::current();
231 runtime_handle.spawn(async move {
232 let fut = incoming
233 .filter_map(|r| future::ready(r.ok()))
235 .map(server::BaseChannel::with_defaults)
236 .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap_or(self.addr).ip())
238 .map(|channel| {
239 let handler =
240 ServerHandler::new(self.addr, self.runner_sender.clone(), self.timeout);
241 channel.execute(handler.serve())
242 })
243 .buffer_unordered(self.max_connections)
244 .for_each(|_| async {});
245
246 select! {
247 Ok(ServerMessage::GracefulShutdown(tx)) = self.receiver.recv_async() => {
248 info!(subject = "shutdown",
249 category = "homestar.shutdown",
250 "RPC server shutting down");
251 drop(exit);
252 let _ = tx.send_async(()).await;
253 }
254 _ = fut =>
255 warn!(subject = "rpc.spawn.err",
256 category = "rpc",
257 "RPC server exited unexpectedly"),
258 }
259 });
260
261 Ok(())
262 }
263}
264
265impl Client {
266 pub async fn new(addr: SocketAddr, ctx: context::Context) -> Result<Self, io::Error> {
271 let transport = tarpc::serde_transport::tcp::connect(addr, MessagePack::default).await?;
272 let client = InterfaceClient::new(client::Config::default(), transport).spawn();
273 Ok(Client {
274 cli: client,
275 addr,
276 ctx,
277 })
278 }
279
280 pub fn addr(&self) -> SocketAddr {
282 self.addr
283 }
284
285 pub async fn ping(&self) -> Result<String, RpcError> {
287 self.cli.ping(self.ctx).await
288 }
289
290 pub async fn stop(&self) -> Result<Result<(), Error>, RpcError> {
292 self.cli.stop(self.ctx).await
293 }
294
295 pub async fn node_info(&self) -> Result<Result<response::AckNodeInfo, Error>, RpcError> {
297 self.cli.node_info(self.ctx).await
298 }
299
300 pub async fn run(
304 &self,
305 name: Option<FastStr>,
306 workflow_file: ReadWorkflow,
307 ) -> Result<Result<Box<response::AckWorkflow>, Error>, RpcError> {
308 self.cli.run(self.ctx, name, workflow_file).await
309 }
310}