1use std::{
2 collections::VecDeque,
3 sync::atomic::AtomicUsize,
4 time::{Duration, SystemTime},
5};
6
7use crate::DirectMessage;
8use actor_helper::{Action, Actor, Handle, Receiver, act, act_ok};
9use anyhow::Result;
10use iroh::{
11 Endpoint,
12 endpoint::{RecvStream, SendStream},
13};
14use iroh::{
15 NodeId,
16 endpoint::{Connection, VarInt},
17};
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tracing::{debug, warn};
20
21const QUEUE_SIZE: usize = 1024 * 16;
22const MAX_RECONNECTS: usize = 5;
23const RECONNECT_BACKOFF_BASE: Duration = Duration::from_millis(100);
24
25#[derive(Debug, Clone)]
26pub struct Conn {
27 api: Handle<ConnActor, anyhow::Error>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum ConnState {
32 Connecting, Idle, Open, Closed, Disconnected, }
38
39#[derive(Debug)]
40struct ConnActor {
41 rx: Receiver<Action<ConnActor>>,
42 self_handle: Handle<ConnActor, anyhow::Error>,
43 state: ConnState,
44
45 conn: Option<Connection>,
50 conn_node_id: NodeId,
51 send_stream: Option<SendStream>,
52 recv_stream: Option<RecvStream>,
53 endpoint: Endpoint,
54
55 last_reconnect: tokio::time::Instant,
56 reconnect_backoff: Duration,
57 reconnect_count: AtomicUsize,
58
59 external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
60
61 receiver_queue: VecDeque<DirectMessage>,
62 receiver_notify: tokio::sync::Notify,
63
64 sender_queue: VecDeque<DirectMessage>,
65 sender_notify: tokio::sync::Notify,
66}
67
68impl Conn {
69 pub async fn new(
70 endpoint: Endpoint,
71 conn: iroh::endpoint::Connection,
72 send_stream: SendStream,
73 recv_stream: RecvStream,
74 external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
75 ) -> Result<Self> {
76 let (api, rx) = Handle::channel();
77 let mut actor = ConnActor::new(
78 rx,
79 api.clone(),
80 external_sender,
81 endpoint,
82 conn.remote_node_id()?,
83 Some(conn),
84 Some(send_stream),
85 Some(recv_stream),
86 )
87 .await;
88 tokio::spawn(async move { actor.run().await });
89 Ok(Self { api })
90 }
91
92 pub async fn connect(
93 endpoint: Endpoint,
94 node_id: NodeId,
95 external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
96 ) -> Self {
97 let (api, rx) = Handle::channel();
98 let mut actor = ConnActor::new(
99 rx,
100 api.clone(),
101 external_sender,
102 endpoint.clone(),
103 node_id,
104 None,
105 None,
106 None,
107 )
108 .await;
109
110 tokio::spawn(async move {
111 actor.set_state(ConnState::Connecting);
112 actor.run().await
113 });
114 let s = Self { api };
115
116 tokio::spawn({
117 let s = s.clone();
118 async move {
119 if let Ok(conn) = endpoint.connect(node_id, crate::Direct::ALPN).await {
120 let _ = s.incoming_connection(conn, false).await;
121 }
122 }
123 });
124
125 s
126 }
127
128 pub async fn get_state(&self) -> ConnState {
129 if let Ok(state) = self
130 .api
131 .call(act_ok!(actor => async move {
132 actor.state
133 }))
134 .await
135 {
136 state
137 } else {
138 ConnState::Closed
139 }
140 }
141
142 pub async fn close(&self) -> Result<()> {
143 self.api.call(act_ok!(actor => actor.close())).await
144 }
145
146 pub async fn write(&self, pkg: DirectMessage) -> Result<()> {
147 self.api.call(act_ok!(actor => actor.write(pkg))).await
148 }
149
150 pub async fn incoming_connection(&self, conn: Connection, accept_not_open: bool) -> Result<()> {
151 self.api
152 .call(act!(actor => actor.incoming_connection(conn, accept_not_open)))
153 .await
154 }
155}
156
157impl Actor<anyhow::Error> for ConnActor {
158 async fn run(&mut self) -> Result<()> {
159 let mut reconnect_ticker = tokio::time::interval(Duration::from_millis(500));
160 let mut notification_ticker = tokio::time::interval(Duration::from_millis(500));
161
162 reconnect_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
163 notification_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
164
165 loop {
166 tokio::select! {
167 Ok(action) = self.rx.recv_async() => {
168 action(self).await;
169 }
170 _ = reconnect_ticker.tick(), if self.state != ConnState::Closed => {
171
172 let need_reconnect = self.send_stream.is_none()
173 || self.conn.as_ref().and_then(|c| c.close_reason()).is_some();
174
175 if need_reconnect && self.last_reconnect.elapsed() > self.reconnect_backoff {
176 if self.reconnect_count.load(std::sync::atomic::Ordering::SeqCst) < MAX_RECONNECTS {
177 warn!("Send stream stopped");
178 let _ = self.try_reconnect().await;
179 } else {
180 warn!("Max reconnects reached, closing connection to {}", self.conn_node_id);
181 break;
182 }
183 }
184 }
185 _ = notification_ticker.tick(), if self.state != ConnState::Closed
186 && (!self.sender_queue.is_empty()
187 || self.receiver_queue.is_empty()) => {
188
189 if !self.sender_queue.is_empty() {
190 self.sender_notify.notify_one();
191 }
192 if !self.receiver_queue.is_empty() {
193 self.receiver_notify.notify_one();
194 }
195 }
196 stream_recv = async {
197 let recv = self.recv_stream.as_mut().expect("checked in if via self.recv_stream.is_some()");
198 recv.read_u32_le().await
199 }, if self.state != ConnState::Closed && self.recv_stream.is_some() => {
200 if let Ok(frame_size) = stream_recv {
201 let _res = self.remote_read_next(frame_size).await;
202 }
203 }
204 _ = self.sender_notify.notified(), if self.conn.is_some() && self.state == ConnState::Open => {
205 while !self.sender_queue.is_empty() {
206 if self.remote_write_next().await.is_err() {
207 warn!("Failed to write to remote, will attempt to reconnect");
208 self.set_state(ConnState::Disconnected);
209 break;
210 }
211 }
212 }
213 _ = self.receiver_notify.notified(), if self.conn.is_some() && self.state != ConnState::Closed => {
214
215 while let Some(msg) = self.receiver_queue.pop_back() {
216 if self.external_sender.send(msg.clone()).is_err() {
217 warn!("No active receivers for incoming messages");
218 self.set_state(ConnState::Disconnected);
219 break;
220 }
221 }
222 }
223 _ = tokio::signal::ctrl_c() => {
224 break
225 }
226 }
227 }
228 self.close().await;
229 Ok(())
230 }
231}
232
233impl ConnActor {
234 #[allow(clippy::too_many_arguments)]
235 pub async fn new(
236 rx: Receiver<Action<ConnActor>>,
237 self_handle: Handle<ConnActor, anyhow::Error>,
238 external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
239 endpoint: Endpoint,
240 conn_node_id: NodeId,
241 conn: Option<iroh::endpoint::Connection>,
242 send_stream: Option<SendStream>,
243 recv_stream: Option<RecvStream>,
244 ) -> Self {
245 Self {
246 rx,
247 state: if conn.is_some() && send_stream.is_some() && recv_stream.is_some() {
248 ConnState::Open
249 } else {
250 ConnState::Disconnected
251 },
252 external_sender,
253 receiver_queue: VecDeque::with_capacity(QUEUE_SIZE),
254 sender_queue: VecDeque::with_capacity(QUEUE_SIZE),
255 conn,
256 send_stream,
257 recv_stream,
258 endpoint,
259 receiver_notify: tokio::sync::Notify::new(),
260 sender_notify: tokio::sync::Notify::new(),
261 last_reconnect: tokio::time::Instant::now(),
262 reconnect_backoff: Duration::from_millis(100),
263 conn_node_id,
264 self_handle,
265 reconnect_count: AtomicUsize::new(0),
266 }
267 }
268
269 pub fn set_state(&mut self, state: ConnState) {
270 self.state = state;
271 }
272
273 pub async fn close(&mut self) {
274 self.state = ConnState::Closed;
275 if let Some(conn) = self.conn.as_mut() {
276 conn.close(VarInt::from_u32(400), b"Connection closed by user");
277 }
278 self.conn = None;
279 self.send_stream = None;
280 self.recv_stream = None;
281 }
282
283 pub async fn write(&mut self, pkg: DirectMessage) {
284 self.sender_queue.push_front(pkg);
285 self.sender_notify.notify_one();
286 }
287
288 pub async fn incoming_connection(
289 &mut self,
290 conn: Connection,
291 accept_not_open: bool,
292 ) -> Result<()> {
293 let (send_stream, recv_stream) = if accept_not_open {
294 conn.accept_bi().await?
295 } else {
296 conn.open_bi().await?
297 };
298
299 if conn.close_reason().is_some() {
300 self.state = ConnState::Closed;
301 return Err(anyhow::anyhow!("connection closed"));
302 }
303
304 self.conn = Some(conn);
305 self.send_stream = Some(send_stream);
306 self.recv_stream = Some(recv_stream);
307 self.state = ConnState::Open;
308 self.sender_notify.notify_one();
309 self.receiver_notify.notify_one();
310 self.reconnect_backoff = RECONNECT_BACKOFF_BASE;
311
312 Ok(())
316 }
317
318 async fn try_reconnect(&mut self) -> Result<()> {
319 if self.state == ConnState::Closed {
320 return Err(anyhow::anyhow!("actor closed for good"));
321 }
322
323 self.state = ConnState::Connecting;
324 self.reconnect_backoff *= 3;
325 self.last_reconnect = tokio::time::Instant::now();
326
327 self.send_stream = None;
328 self.recv_stream = None;
329 self.conn = None;
330
331 tokio::spawn({
332 let api = self.self_handle.clone();
333 let endpoint = self.endpoint.clone();
334 let conn_node_id = self.conn_node_id;
335 async move {
336 if let Ok(conn) = endpoint.connect(conn_node_id, crate::Direct::ALPN).await {
337 let _ = api
338 .call(act!(actor => actor.incoming_connection(conn, false)))
339 .await;
340 let _ = api.call(act_ok!(actor => async move { actor.reconnect_count.store(0, std::sync::atomic::Ordering::SeqCst) })).await;
341 } else {
342 let _ = api.call(act_ok!(actor => async move { actor.reconnect_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) })).await;
343 }
344 }
345 });
346 Ok(())
347 }
348
349 async fn remote_write_next(&mut self) -> Result<()> {
350 let start = SystemTime::now();
351 let mut wrote = 0;
352 if let Some(send_stream) = &mut self.send_stream {
353 while let Some(msg) = self.sender_queue.back() {
354 let bytes = postcard::to_stdvec(msg)?;
355 send_stream.write_u32_le(bytes.len() as u32).await?;
356 send_stream.write_all(bytes.as_slice()).await?;
357 let _ = self.sender_queue.pop_back();
358 wrote += 1;
359 if wrote >= 256 {
360 break;
361 }
362 }
363 } else {
364 return Err(anyhow::anyhow!("no send stream"));
365 }
366
367 if !self.sender_queue.is_empty() {
368 self.sender_notify.notify_one();
369 }
370
371 let end = SystemTime::now();
372 let duration = end.duration_since(start).unwrap();
373 debug!("write_remote: {wrote}; elapsed: {}", duration.as_millis());
374 Ok(())
375 }
376
377 async fn remote_read_next(&mut self, frame_len: u32) -> Result<DirectMessage> {
378 if let Some(recv_stream) = &mut self.recv_stream {
379 let mut buf = vec![0; frame_len as usize];
380
381 let start = SystemTime::now();
382 recv_stream.read_exact(&mut buf).await?;
383
384 if let Ok(pkg) = postcard::from_bytes(&buf) {
385 match pkg {
386 DirectMessage::IpPacket(ip_pkg) => {
387 if let Ok(ip_pkg) = ip_pkg.to_ipv4_packet() {
388 let msg = DirectMessage::IpPacket(ip_pkg.into());
389 self.receiver_queue.push_front(msg.clone());
390 self.receiver_notify.notify_one();
391 let end = SystemTime::now();
392 let duration = end.duration_since(start).unwrap();
393 debug!("read_remote: elapsed: {}", duration.as_millis());
394 Ok(msg)
395 } else {
396 Err(anyhow::anyhow!("failed to convert to IPv4 packet"))
397 }
398 }
399 #[allow(unreachable_patterns)]
400 _ => Err(anyhow::anyhow!("unsupported message type")),
401 }
402 } else {
403 Err(anyhow::anyhow!("failed to deserialize message"))
404 }
405 } else {
406 Err(anyhow::anyhow!("no recv stream"))
407 }
408 }
409}