1use std::{collections::BTreeMap, net::SocketAddr, sync::Arc, time::Duration};
3
4use anyhow::Result;
5use futures_lite::Stream;
6use iroh::{Endpoint, NodeAddr, NodeId, RelayUrl};
7use quic_rpc::server::{ChannelTypes, RpcChannel, RpcServerError};
8use tracing::{debug, info};
9
10use super::proto::{net, node::CounterStats, Request};
11use crate::rpc::{
12 client::net::NodeStatus,
13 proto::{
14 node::{self, ShutdownRequest, StatsRequest, StatsResponse, StatusRequest},
15 RpcError, RpcResult, RpcService,
16 },
17};
18
19pub trait AbstractNode: Send + Sync + 'static {
21 fn endpoint(&self) -> &Endpoint;
23
24 fn shutdown(&self);
26
27 fn rpc_addr(&self) -> Option<SocketAddr> {
29 None
30 }
31
32 fn stats(&self) -> anyhow::Result<BTreeMap<String, CounterStats>> {
34 anyhow::bail!("metrics are disabled");
35 }
36}
37
38struct Handler(Arc<dyn AbstractNode>);
39
40pub async fn handle_rpc_request<C: ChannelTypes<RpcService>>(
42 node: Arc<dyn AbstractNode>,
43 msg: Request,
44 chan: RpcChannel<RpcService, C>,
45) -> Result<(), RpcServerError<C>> {
46 use Request::*;
47 match msg {
48 Node(msg) => Handler(node).handle_node_request(msg, chan).await,
49 Net(msg) => Handler(node).handle_net_request(msg, chan).await,
50 }
51}
52
53impl Handler {
54 fn endpoint(&self) -> &Endpoint {
55 self.0.endpoint()
56 }
57
58 async fn handle_node_request<C: ChannelTypes<RpcService>>(
59 self,
60 msg: node::Request,
61 chan: RpcChannel<RpcService, C>,
62 ) -> Result<(), RpcServerError<C>> {
63 use node::Request::*;
64 debug!("handling node request: {msg}");
65 match msg {
66 Status(msg) => chan.rpc(msg, self, Self::node_status).await,
67 Shutdown(msg) => chan.rpc(msg, self, Self::node_shutdown).await,
68 Stats(msg) => chan.rpc(msg, self, Self::node_stats).await,
69 }
70 }
71
72 async fn handle_net_request<C: ChannelTypes<RpcService>>(
73 self,
74 msg: net::Request,
75 chan: RpcChannel<RpcService, C>,
76 ) -> Result<(), RpcServerError<C>> {
77 use net::Request::*;
78 debug!("handling net request: {msg}");
79 match msg {
80 Watch(msg) => chan.server_streaming(msg, self, Self::node_watch).await,
81 Id(msg) => chan.rpc(msg, self, Self::node_id).await,
82 Addr(msg) => chan.rpc(msg, self, Self::node_addr).await,
83 Relay(msg) => chan.rpc(msg, self, Self::node_relay).await,
84 RemoteInfosIter(msg) => {
85 chan.server_streaming(msg, self, Self::remote_infos_iter)
86 .await
87 }
88 RemoteInfo(msg) => chan.rpc(msg, self, Self::remote_info).await,
89 AddAddr(msg) => chan.rpc(msg, self, Self::node_add_addr).await,
90 }
91 }
92
93 #[allow(clippy::unused_async)]
94 async fn node_shutdown(self, request: ShutdownRequest) {
95 if request.force {
96 info!("hard shutdown requested");
97 std::process::exit(0);
98 } else {
99 info!("graceful shutdown requested");
101 self.0.shutdown();
102 }
103 }
104
105 #[allow(clippy::unused_async)]
106 async fn node_stats(self, _req: StatsRequest) -> RpcResult<StatsResponse> {
107 Ok(StatsResponse {
108 stats: self.0.stats().map_err(|e| RpcError::new(&*e))?,
109 })
110 }
111
112 async fn node_status(self, _: StatusRequest) -> RpcResult<NodeStatus> {
113 Ok(NodeStatus {
114 addr: self
115 .endpoint()
116 .node_addr()
117 .await
118 .map_err(|e| RpcError::new(&*e))?,
119 listen_addrs: self.local_endpoint_addresses().await.unwrap_or_default(),
120 version: env!("CARGO_PKG_VERSION").to_string(),
121 rpc_addr: self.0.rpc_addr(),
122 })
123 }
124
125 async fn local_endpoint_addresses(&self) -> Result<Vec<SocketAddr>> {
126 let endpoints = self.endpoint().direct_addresses().initialized().await?;
127
128 Ok(endpoints.into_iter().map(|x| x.addr).collect())
129 }
130
131 async fn node_addr(self, _: net::AddrRequest) -> RpcResult<NodeAddr> {
132 let addr = self
133 .endpoint()
134 .node_addr()
135 .await
136 .map_err(|e| RpcError::new(&*e))?;
137 Ok(addr)
138 }
139
140 fn remote_infos_iter(
141 self,
142 _: net::RemoteInfosIterRequest,
143 ) -> impl Stream<Item = RpcResult<net::RemoteInfosIterResponse>> + Send + 'static {
144 let mut infos: Vec<_> = self.endpoint().remote_info_iter().collect();
145 infos.sort_by_key(|n| n.node_id.to_string());
146 futures_lite::stream::iter(
147 infos
148 .into_iter()
149 .map(|info| Ok(net::RemoteInfosIterResponse { info })),
150 )
151 }
152
153 #[allow(clippy::unused_async)]
154 async fn node_id(self, _: net::IdRequest) -> RpcResult<NodeId> {
155 Ok(self.endpoint().secret_key().public())
156 }
157
158 #[allow(clippy::unused_async)]
160 async fn remote_info(self, req: net::RemoteInfoRequest) -> RpcResult<net::RemoteInfoResponse> {
161 let net::RemoteInfoRequest { node_id } = req;
162 let info = self.endpoint().remote_info(node_id);
163 Ok(net::RemoteInfoResponse { info })
164 }
165
166 #[allow(clippy::unused_async)]
168 async fn node_add_addr(self, req: net::AddAddrRequest) -> RpcResult<()> {
169 let net::AddAddrRequest { addr } = req;
170 self.endpoint()
171 .add_node_addr(addr)
172 .map_err(|e| RpcError::new(&*e))?;
173 Ok(())
174 }
175
176 #[allow(clippy::unused_async)]
177 async fn node_relay(self, _: net::RelayRequest) -> RpcResult<Option<RelayUrl>> {
178 let res = self
179 .endpoint()
180 .home_relay()
181 .get()
182 .map_err(|e| RpcError::new(&e))?;
183 Ok(res)
184 }
185
186 fn node_watch(self, _: net::NodeWatchRequest) -> impl Stream<Item = net::WatchResponse> + Send {
187 futures_lite::stream::unfold((), |()| async move {
188 tokio::time::sleep(HEALTH_POLL_WAIT).await;
189 Some((
190 net::WatchResponse {
191 version: env!("CARGO_PKG_VERSION").to_string(),
192 },
193 (),
194 ))
195 })
196 }
197}
198
199const HEALTH_POLL_WAIT: Duration = Duration::from_secs(1);