iroh_lan/
connection.rs

1use std::{
2    collections::VecDeque,
3    sync::atomic::AtomicUsize,
4    time::{Duration, SystemTime},
5};
6
7use crate::DirectMessage;
8use actor_helper::{Action, Actor, Handle, Receiver, act, act_ok};
9use anyhow::Result;
10use iroh::{
11    Endpoint,
12    endpoint::{RecvStream, SendStream},
13};
14use iroh::{
15    NodeId,
16    endpoint::{Connection, VarInt},
17};
18use tokio::io::{AsyncReadExt, AsyncWriteExt};
19use tracing::{debug, warn};
20
21const QUEUE_SIZE: usize = 1024 * 16;
22const MAX_RECONNECTS: usize = 5;
23const RECONNECT_BACKOFF_BASE: Duration = Duration::from_millis(100);
24
25#[derive(Debug, Clone)]
26pub struct Conn {
27    api: Handle<ConnActor, anyhow::Error>,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum ConnState {
32    Connecting, // ConnActor::connect() called, waiting for connection to be established (in background)
33    Idle,       // no active connection, can be connected
34    Open,       // open bi directional streams
35    Closed,     // connection closed by user or error
36    Disconnected, // connection closed by remote peer, can be recovered within 5 retries after Closed
37}
38
39#[derive(Debug)]
40struct ConnActor {
41    rx: Receiver<Action<ConnActor>>,
42    self_handle: Handle<ConnActor, anyhow::Error>,
43    state: ConnState,
44
45    // all of these need to be optionals so that we can create an empty
46    // shell of the actor and then fill in the values later so we don't wait
47    // forever in the main standalone loop for router events hanging on
48    // route_packet failed
49    conn: Option<Connection>,
50    conn_node_id: NodeId,
51    send_stream: Option<SendStream>,
52    recv_stream: Option<RecvStream>,
53    endpoint: Endpoint,
54
55    last_reconnect: tokio::time::Instant,
56    reconnect_backoff: Duration,
57    reconnect_count: AtomicUsize,
58
59    external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
60
61    receiver_queue: VecDeque<DirectMessage>,
62    receiver_notify: tokio::sync::Notify,
63
64    sender_queue: VecDeque<DirectMessage>,
65    sender_notify: tokio::sync::Notify,
66}
67
68impl Conn {
69    pub async fn new(
70        endpoint: Endpoint,
71        conn: iroh::endpoint::Connection,
72        send_stream: SendStream,
73        recv_stream: RecvStream,
74        external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
75    ) -> Result<Self> {
76        let (api, rx) = Handle::channel();
77        let mut actor = ConnActor::new(
78            rx,
79            api.clone(),
80            external_sender,
81            endpoint,
82            conn.remote_node_id()?,
83            Some(conn),
84            Some(send_stream),
85            Some(recv_stream),
86        )
87        .await;
88        tokio::spawn(async move { actor.run().await });
89        Ok(Self { api })
90    }
91
92    pub async fn connect(
93        endpoint: Endpoint,
94        node_id: NodeId,
95        external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
96    ) -> Self {
97        let (api, rx) = Handle::channel();
98        let mut actor = ConnActor::new(
99            rx,
100            api.clone(),
101            external_sender,
102            endpoint.clone(),
103            node_id,
104            None,
105            None,
106            None,
107        )
108        .await;
109
110        tokio::spawn(async move {
111            actor.set_state(ConnState::Connecting);
112            actor.run().await
113        });
114        let s = Self { api };
115
116        tokio::spawn({
117            let s = s.clone();
118            async move {
119                if let Ok(conn) = endpoint.connect(node_id, crate::Direct::ALPN).await {
120                    let _ = s.incoming_connection(conn, false).await;
121                }
122            }
123        });
124
125        s
126    }
127
128    pub async fn get_state(&self) -> ConnState {
129        if let Ok(state) = self
130            .api
131            .call(act_ok!(actor => async move {
132                actor.state
133            }))
134            .await
135        {
136            state
137        } else {
138            ConnState::Closed
139        }
140    }
141
142    pub async fn close(&self) -> Result<()> {
143        self.api.call(act_ok!(actor => actor.close())).await
144    }
145
146    pub async fn write(&self, pkg: DirectMessage) -> Result<()> {
147        self.api.call(act_ok!(actor => actor.write(pkg))).await
148    }
149
150    pub async fn incoming_connection(&self, conn: Connection, accept_not_open: bool) -> Result<()> {
151        self.api
152            .call(act!(actor => actor.incoming_connection(conn, accept_not_open)))
153            .await
154    }
155}
156
157impl Actor<anyhow::Error> for ConnActor {
158    async fn run(&mut self) -> Result<()> {
159        let mut reconnect_ticker = tokio::time::interval(Duration::from_millis(500));
160        let mut notification_ticker = tokio::time::interval(Duration::from_millis(500));
161
162        reconnect_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
163        notification_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
164
165        loop {
166            tokio::select! {
167                Ok(action) = self.rx.recv_async() => {
168                    action(self).await;
169                }
170                _ = reconnect_ticker.tick(), if self.state != ConnState::Closed => {
171
172                    let need_reconnect = self.send_stream.is_none()
173                        || self.conn.as_ref().and_then(|c| c.close_reason()).is_some();
174
175                    if need_reconnect && self.last_reconnect.elapsed() > self.reconnect_backoff {
176                        if self.reconnect_count.load(std::sync::atomic::Ordering::SeqCst) < MAX_RECONNECTS {
177                            warn!("Send stream stopped");
178                            let _ = self.try_reconnect().await;
179                        } else {
180                            warn!("Max reconnects reached, closing connection to {}", self.conn_node_id);
181                            break;
182                        }
183                    }
184                }
185                _ = notification_ticker.tick(), if self.state != ConnState::Closed
186                        && (!self.sender_queue.is_empty()
187                            || self.receiver_queue.is_empty()) => {
188
189                    if !self.sender_queue.is_empty() {
190                        self.sender_notify.notify_one();
191                    }
192                    if !self.receiver_queue.is_empty() {
193                        self.receiver_notify.notify_one();
194                    }
195                }
196                stream_recv = async {
197                    let recv = self.recv_stream.as_mut().expect("checked in if via self.recv_stream.is_some()");
198                    recv.read_u32_le().await
199                }, if self.state != ConnState::Closed && self.recv_stream.is_some() => {
200                    if let Ok(frame_size) = stream_recv {
201                        let _res = self.remote_read_next(frame_size).await;
202                    }
203                }
204                _ = self.sender_notify.notified(), if self.conn.is_some() && self.state == ConnState::Open => {
205                    while !self.sender_queue.is_empty() {
206                        if self.remote_write_next().await.is_err() {
207                            warn!("Failed to write to remote, will attempt to reconnect");
208                            self.set_state(ConnState::Disconnected);
209                            break;
210                        }
211                    }
212                }
213                _ = self.receiver_notify.notified(), if self.conn.is_some() && self.state != ConnState::Closed => {
214
215                    while let Some(msg) = self.receiver_queue.pop_back() {
216                        if self.external_sender.send(msg.clone()).is_err() {
217                            warn!("No active receivers for incoming messages");
218                            self.set_state(ConnState::Disconnected);
219                            break;
220                        }
221                    }
222                }
223                _ = tokio::signal::ctrl_c() => {
224                    break
225                }
226            }
227        }
228        self.close().await;
229        Ok(())
230    }
231}
232
233impl ConnActor {
234    #[allow(clippy::too_many_arguments)]
235    pub async fn new(
236        rx: Receiver<Action<ConnActor>>,
237        self_handle: Handle<ConnActor, anyhow::Error>,
238        external_sender: tokio::sync::broadcast::Sender<DirectMessage>,
239        endpoint: Endpoint,
240        conn_node_id: NodeId,
241        conn: Option<iroh::endpoint::Connection>,
242        send_stream: Option<SendStream>,
243        recv_stream: Option<RecvStream>,
244    ) -> Self {
245        Self {
246            rx,
247            state: if conn.is_some() && send_stream.is_some() && recv_stream.is_some() {
248                ConnState::Open
249            } else {
250                ConnState::Disconnected
251            },
252            external_sender,
253            receiver_queue: VecDeque::with_capacity(QUEUE_SIZE),
254            sender_queue: VecDeque::with_capacity(QUEUE_SIZE),
255            conn,
256            send_stream,
257            recv_stream,
258            endpoint,
259            receiver_notify: tokio::sync::Notify::new(),
260            sender_notify: tokio::sync::Notify::new(),
261            last_reconnect: tokio::time::Instant::now(),
262            reconnect_backoff: Duration::from_millis(100),
263            conn_node_id,
264            self_handle,
265            reconnect_count: AtomicUsize::new(0),
266        }
267    }
268
269    pub fn set_state(&mut self, state: ConnState) {
270        self.state = state;
271    }
272
273    pub async fn close(&mut self) {
274        self.state = ConnState::Closed;
275        if let Some(conn) = self.conn.as_mut() {
276            conn.close(VarInt::from_u32(400), b"Connection closed by user");
277        }
278        self.conn = None;
279        self.send_stream = None;
280        self.recv_stream = None;
281    }
282
283    pub async fn write(&mut self, pkg: DirectMessage) {
284        self.sender_queue.push_front(pkg);
285        self.sender_notify.notify_one();
286    }
287
288    pub async fn incoming_connection(
289        &mut self,
290        conn: Connection,
291        accept_not_open: bool,
292    ) -> Result<()> {
293        let (send_stream, recv_stream) = if accept_not_open {
294            conn.accept_bi().await?
295        } else {
296            conn.open_bi().await?
297        };
298
299        if conn.close_reason().is_some() {
300            self.state = ConnState::Closed;
301            return Err(anyhow::anyhow!("connection closed"));
302        }
303
304        self.conn = Some(conn);
305        self.send_stream = Some(send_stream);
306        self.recv_stream = Some(recv_stream);
307        self.state = ConnState::Open;
308        self.sender_notify.notify_one();
309        self.receiver_notify.notify_one();
310        self.reconnect_backoff = RECONNECT_BACKOFF_BASE;
311
312        // SHOULD NOT CHANGE but just for sanity
313        //self.conn_node_id = self.conn.clone().expect("new_conn").remote_node_id()?;
314
315        Ok(())
316    }
317
318    async fn try_reconnect(&mut self) -> Result<()> {
319        if self.state == ConnState::Closed {
320            return Err(anyhow::anyhow!("actor closed for good"));
321        }
322
323        self.state = ConnState::Connecting;
324        self.reconnect_backoff *= 3;
325        self.last_reconnect = tokio::time::Instant::now();
326
327        self.send_stream = None;
328        self.recv_stream = None;
329        self.conn = None;
330
331        tokio::spawn({
332            let api = self.self_handle.clone();
333            let endpoint = self.endpoint.clone();
334            let conn_node_id = self.conn_node_id;
335            async move {
336                if let Ok(conn) = endpoint.connect(conn_node_id, crate::Direct::ALPN).await {
337                    let _ = api
338                        .call(act!(actor => actor.incoming_connection(conn, false)))
339                        .await;
340                    let _ = api.call(act_ok!(actor => async move { actor.reconnect_count.store(0, std::sync::atomic::Ordering::SeqCst) })).await;
341                } else {
342                    let _ = api.call(act_ok!(actor => async move { actor.reconnect_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst) })).await;
343                }
344            }
345        });
346        Ok(())
347    }
348
349    async fn remote_write_next(&mut self) -> Result<()> {
350        let start = SystemTime::now();
351        let mut wrote = 0;
352        if let Some(send_stream) = &mut self.send_stream {
353            while let Some(msg) = self.sender_queue.back() {
354                let bytes = postcard::to_stdvec(msg)?;
355                send_stream.write_u32_le(bytes.len() as u32).await?;
356                send_stream.write_all(bytes.as_slice()).await?;
357                let _ = self.sender_queue.pop_back();
358                wrote += 1;
359                if wrote >= 256 {
360                    break;
361                }
362            }
363        } else {
364            return Err(anyhow::anyhow!("no send stream"));
365        }
366
367        if !self.sender_queue.is_empty() {
368            self.sender_notify.notify_one();
369        }
370
371        let end = SystemTime::now();
372        let duration = end.duration_since(start).unwrap();
373        debug!("write_remote: {wrote}; elapsed: {}", duration.as_millis());
374        Ok(())
375    }
376
377    async fn remote_read_next(&mut self, frame_len: u32) -> Result<DirectMessage> {
378        if let Some(recv_stream) = &mut self.recv_stream {
379            let mut buf = vec![0; frame_len as usize];
380
381            let start = SystemTime::now();
382            recv_stream.read_exact(&mut buf).await?;
383
384            if let Ok(pkg) = postcard::from_bytes(&buf) {
385                match pkg {
386                    DirectMessage::IpPacket(ip_pkg) => {
387                        if let Ok(ip_pkg) = ip_pkg.to_ipv4_packet() {
388                            let msg = DirectMessage::IpPacket(ip_pkg.into());
389                            self.receiver_queue.push_front(msg.clone());
390                            self.receiver_notify.notify_one();
391                            let end = SystemTime::now();
392                            let duration = end.duration_since(start).unwrap();
393                            debug!("read_remote: elapsed: {}", duration.as_millis());
394                            Ok(msg)
395                        } else {
396                            Err(anyhow::anyhow!("failed to convert to IPv4 packet"))
397                        }
398                    }
399                    #[allow(unreachable_patterns)]
400                    _ => Err(anyhow::anyhow!("unsupported message type")),
401                }
402            } else {
403                Err(anyhow::anyhow!("failed to deserialize message"))
404            }
405        } else {
406            Err(anyhow::anyhow!("no recv stream"))
407        }
408    }
409}