homestar_runtime/network/
rpc.rs

1//! CLI-focused RPC server implementation.
2
3use crate::{
4    channel::{AsyncChannel, AsyncChannelReceiver, AsyncChannelSender},
5    runner::{self, file::ReadWorkflow, response, RpcSender},
6    settings,
7};
8use faststr::FastStr;
9use futures::{future, StreamExt};
10use std::{io, net::SocketAddr, sync::Arc, time::Duration};
11use stream_cancel::Valved;
12use tarpc::{
13    client::{self, RpcError},
14    context,
15    server::{self, incoming::Incoming, Channel},
16};
17use tokio::{runtime::Handle, select, time};
18use tokio_serde::formats::MessagePack;
19use tracing::{info, warn};
20
21mod error;
22pub use error::Error;
23
24/// Message type for messages sent back from the
25/// websocket server to the [runner] for example.
26///
27/// [runner]: crate::Runner
28#[derive(Debug)]
29pub(crate) enum ServerMessage {
30    /// Notify the [Runner] that the RPC server was given a `stop` command.
31    ///
32    /// [Runner]: crate::Runner
33    ShutdownCmd,
34    /// Message sent by the [Runner] to start a graceful shutdown.
35    ///
36    /// [Runner]: crate::Runner
37    GracefulShutdown(AsyncChannelSender<()>),
38    /// Message sent to start a [Workflow] run by reading a [Workflow] file.
39    ///
40    /// [Workflow]: homestar_workflow::Workflow
41    Run((Option<FastStr>, ReadWorkflow)),
42    /// Acknowledgement of a [Workflow] run.
43    ///
44    /// [Workflow]: homestar_workflow::Workflow
45    RunAck(Box<response::AckWorkflow>),
46    /// Error attempting to run a [Workflow].
47    ///
48    /// [Workflow]: homestar_workflow::Workflow
49    RunErr(runner::Error),
50    /// Message sent to the [Runner] to identify the node.
51    ///
52    /// [Runner]: crate::Runner
53    NodeInfo,
54    /// Acknowledgement of the node's identity/info.
55    NodeInfoAck(response::AckNodeInfo),
56    /// For skipping server messages.
57    Skip,
58}
59
60/// RPC interface definition for CLI-server interaction.
61#[tarpc::service]
62pub(crate) trait Interface {
63    /// Returns a greeting for name.
64    async fn run(
65        name: Option<FastStr>,
66        workflow_file: ReadWorkflow,
67    ) -> Result<Box<response::AckWorkflow>, Error>;
68    /// Ping the server.
69    async fn ping() -> String;
70    /// Stop the server.
71    async fn stop() -> Result<(), Error>;
72    /// Identify the node.
73    async fn node_info() -> Result<response::AckNodeInfo, Error>;
74}
75
76/// RPC server state information.
77#[derive(Debug, Clone)]
78pub(crate) struct Server {
79    /// [SocketAddr] of the RPC server.
80    pub(crate) addr: SocketAddr,
81    /// Sender for messages to be sent to the RPC server.
82    pub(crate) sender: Arc<AsyncChannelSender<ServerMessage>>,
83    /// Receiver for messages sent to the RPC server.
84    pub(crate) receiver: AsyncChannelReceiver<ServerMessage>,
85    /// Sender for messages to be sent to the [Runner].
86    ///
87    /// [Runner]: crate::Runner
88    pub(crate) runner_sender: Arc<RpcSender>,
89    /// Maximum number of connections to the RPC server.
90    pub(crate) max_connections: usize,
91    /// Timeout for the RPC server.
92    pub(crate) timeout: Duration,
93}
94
95/// RPC client wrapper.
96#[derive(Debug, Clone)]
97pub struct Client {
98    cli: InterfaceClient,
99    addr: SocketAddr,
100    ctx: context::Context,
101}
102
103/// RPC server state information.
104#[derive(Debug, Clone)]
105#[allow(dead_code)]
106struct ServerHandler {
107    addr: SocketAddr,
108    runner_sender: Arc<RpcSender>,
109    timeout: Duration,
110}
111
112impl ServerHandler {
113    fn new(addr: SocketAddr, runner_sender: Arc<RpcSender>, timeout: Duration) -> Self {
114        Self {
115            addr,
116            runner_sender,
117            timeout,
118        }
119    }
120}
121
122#[tarpc::server]
123impl Interface for ServerHandler {
124    async fn run(
125        self,
126        _: context::Context,
127        name: Option<FastStr>,
128        workflow_file: ReadWorkflow,
129    ) -> Result<Box<response::AckWorkflow>, Error> {
130        let (tx, rx) = AsyncChannel::oneshot();
131        self.runner_sender
132            .send_async((ServerMessage::Run((name, workflow_file)), Some(tx)))
133            .await
134            .map_err(|e| Error::FailureToSendOnChannel(e.to_string()))?;
135
136        let now = time::Instant::now();
137        select! {
138            Ok(msg) = rx.recv_async() => {
139                match msg {
140                    ServerMessage::RunAck(response) => {
141                        Ok(response)
142                    }
143                    ServerMessage::RunErr(err) => Err(err).map_err(|e| Error::FromRunner(e.to_string()))?,
144                    _ => Err(Error::FailureToSendOnChannel("unexpected message".into())),
145                }
146            },
147            _ = time::sleep_until(now + self.timeout) => {
148                let s = format!("server timeout of {} ms reached", self.timeout.as_millis());
149                info!(subject = "rpc.timeout",
150                      category = "rpc",
151                      "{s}");
152                Err(Error::FailureToReceiveOnChannel(s))
153            }
154
155        }
156    }
157    async fn ping(self, _: context::Context) -> String {
158        "pong".into()
159    }
160    async fn stop(self, _: context::Context) -> Result<(), Error> {
161        self.runner_sender
162            .send_async((ServerMessage::ShutdownCmd, None))
163            .await
164            .map_err(|e| Error::FailureToSendOnChannel(e.to_string()))
165    }
166    async fn node_info(self, _: context::Context) -> Result<response::AckNodeInfo, Error> {
167        let (tx, rx) = AsyncChannel::oneshot();
168        self.runner_sender
169            .send_async((ServerMessage::NodeInfo, Some(tx)))
170            .await
171            .map_err(|e| Error::FailureToSendOnChannel(e.to_string()))?;
172
173        let now = time::Instant::now();
174        select! {
175            Ok(msg) = rx.recv_async() => {
176                match msg {
177                    ServerMessage::NodeInfoAck(response) => {
178                        println!("response: {:?}", response);
179                        Ok(response)
180                    }
181                    _ => Err(Error::FailureToSendOnChannel("unexpected message".into())),
182                }
183            },
184            _ = time::sleep_until(now + self.timeout) => {
185                let s = format!("server timeout of {} ms reached", self.timeout.as_millis());
186                info!(subject = "rpc.timeout",
187                      category = "rpc",
188                      "{s}");
189                Err(Error::FailureToReceiveOnChannel(s))
190            }
191        }
192    }
193}
194
195impl Server {
196    /// Create a new instance of the RPC server.
197    pub(crate) fn new(settings: &settings::Network, runner_sender: Arc<RpcSender>) -> Self {
198        let (tx, rx) = AsyncChannel::oneshot();
199        Self {
200            addr: SocketAddr::new(settings.rpc.host, settings.rpc.port),
201            sender: tx.into(),
202            receiver: rx,
203            runner_sender,
204            max_connections: settings.rpc.max_connections,
205            timeout: settings.rpc.server_timeout,
206        }
207    }
208
209    /// Return a RPC server channel sender.
210    pub(crate) fn sender(&self) -> Arc<AsyncChannelSender<ServerMessage>> {
211        self.sender.clone()
212    }
213
214    /// Start the RPC server and connect the client.
215    pub(crate) async fn spawn(self) -> anyhow::Result<()> {
216        let mut listener =
217            tarpc::serde_transport::tcp::listen(self.addr, MessagePack::default).await?;
218        listener.config_mut().max_frame_length(usize::MAX);
219
220        info!(
221            subject = "rpc.spawn",
222            category = "rpc",
223            "RPC server listening on {}",
224            self.addr
225        );
226
227        // setup valved listener for cancellation
228        let (exit, incoming) = Valved::new(listener);
229
230        let runtime_handle = Handle::current();
231        runtime_handle.spawn(async move {
232            let fut = incoming
233                // Ignore accept errors.
234                .filter_map(|r| future::ready(r.ok()))
235                .map(server::BaseChannel::with_defaults)
236                // Limit channels to 1 per IP.
237                .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap_or(self.addr).ip())
238                .map(|channel| {
239                    let handler =
240                        ServerHandler::new(self.addr, self.runner_sender.clone(), self.timeout);
241                    channel.execute(handler.serve())
242                })
243                .buffer_unordered(self.max_connections)
244                .for_each(|_| async {});
245
246            select! {
247                Ok(ServerMessage::GracefulShutdown(tx)) = self.receiver.recv_async() => {
248                    info!(subject = "shutdown",
249                          category = "homestar.shutdown",
250                          "RPC server shutting down");
251                    drop(exit);
252                    let _ = tx.send_async(()).await;
253                }
254                _ = fut =>
255                    warn!(subject = "rpc.spawn.err",
256                          category = "rpc",
257                          "RPC server exited unexpectedly"),
258            }
259        });
260
261        Ok(())
262    }
263}
264
265impl Client {
266    /// Instantiate a new [Client] with a [tcp] connection to a running Homestar
267    /// runner/server.
268    ///
269    /// [tcp]: tarpc::serde_transport::tcp
270    pub async fn new(addr: SocketAddr, ctx: context::Context) -> Result<Self, io::Error> {
271        let transport = tarpc::serde_transport::tcp::connect(addr, MessagePack::default).await?;
272        let client = InterfaceClient::new(client::Config::default(), transport).spawn();
273        Ok(Client {
274            cli: client,
275            addr,
276            ctx,
277        })
278    }
279
280    /// Return the [SocketAddr] of the RPC server.
281    pub fn addr(&self) -> SocketAddr {
282        self.addr
283    }
284
285    /// Ping the server.
286    pub async fn ping(&self) -> Result<String, RpcError> {
287        self.cli.ping(self.ctx).await
288    }
289
290    /// Stop the server.
291    pub async fn stop(&self) -> Result<Result<(), Error>, RpcError> {
292        self.cli.stop(self.ctx).await
293    }
294
295    /// Identify the node.
296    pub async fn node_info(&self) -> Result<Result<response::AckNodeInfo, Error>, RpcError> {
297        self.cli.node_info(self.ctx).await
298    }
299
300    /// Run a [Workflow].
301    ///
302    /// [Workflow]: homestar_workflow::Workflow
303    pub async fn run(
304        &self,
305        name: Option<FastStr>,
306        workflow_file: ReadWorkflow,
307    ) -> Result<Result<Box<response::AckWorkflow>, Error>, RpcError> {
308        self.cli.run(self.ctx, name, workflow_file).await
309    }
310}