foctet_net/
endpoint.rs

1use crate::{
2    config::TransportConfig, device, transport::{
3        connection::{Connection, ConnectionEvent}, quic::transport::QuicTransport, tcp::transport::TcpTransport, Transport,
4    }
5};
6use anyhow::{anyhow, Result};
7use bytes::Bytes;
8use foctet_core::{addr::node::{NodeAddr, RelayAddr}, default, id::NodeId, ip, key::Keypair, transport::{ListenerId, TransportKind}};
9use stackaddr::{segment::protocol::TransportProtocol, Identity, Protocol, StackAddr};
10use tokio_util::sync::CancellationToken;
11use std::{
12    collections::{BTreeMap, HashMap, HashSet}, net::{IpAddr, Ipv4Addr}, sync::Arc
13};
14use tokio::sync::{mpsc, Mutex};
15
16pub struct ListenerHandle {
17    conn_receiver: Arc<Mutex<mpsc::Receiver<Connection>>>,
18}
19
20impl ListenerHandle {
21    pub fn new(conn_receiver: Arc<Mutex<mpsc::Receiver<Connection>>>) -> Self {
22        Self { conn_receiver }
23    }
24
25    pub async fn accept(&self) -> Option<Connection> {
26        self.conn_receiver.lock().await.recv().await
27    }
28
29    pub async fn clone(&self) -> Self {
30        Self {
31            conn_receiver: Arc::clone(&self.conn_receiver),
32        }
33    }
34}
35
36pub struct RelayActor {
37
38}
39
40pub struct EndpointActor {
41    config: TransportConfig,
42    addrs: HashSet<StackAddr>,
43    conn_sender: mpsc::Sender<Connection>,
44    event_sender: mpsc::Sender<EndpointEvent>,
45    cmd_receiver: mpsc::Receiver<EndpointCommand>,
46    cancel: CancellationToken,
47    listen_enabled: bool,
48}
49
50impl EndpointActor {
51    pub async fn run(mut self) -> Result<()> {
52        // Create a listener for each address
53        if self.listen_enabled {
54            let mut listerner_id = ListenerId::new(1);
55            for addr in &self.addrs {
56                let config = self.config.clone();
57                let mut transport: Transport = match addr.transport() {
58                    Some(transport) => match transport {
59                        TransportProtocol::Quic(_) | TransportProtocol::Udp(_) => {
60                            let t = QuicTransport::new(config)?;
61                            Transport::Quic(t)
62                        },
63                        TransportProtocol::TlsOverTcp(_) | TransportProtocol::Tcp(_) => {
64                            let t = TcpTransport::new(config)?;
65                            Transport::Tcp(t)
66                        },
67                        _ => return Err(anyhow::anyhow!("Unsupported transport protocol: {:?}", transport)),
68                    },
69                    None => {
70                        return Err(anyhow::anyhow!("Invalid transport protocol"));
71                    }
72                };
73                // Listen for incoming connections
74                let event_sender = self.event_sender.clone();
75                let conn_sender = self.conn_sender.clone();
76                let mut listener = transport.listen_on(listerner_id.fetch_add(1), addr.clone()).await?;
77                tokio::spawn(async move {
78                    while let Some(conn_event) = listener.accept().await {
79                        match conn_event {
80                            ConnectionEvent::Accepted(conn) => {
81                                match conn_sender.send(conn).await {
82                                    Ok(_) => {}
83                                    Err(e) => {
84                                        event_sender
85                                            .send(EndpointEvent::Error(anyhow!("Error sending connection event: {:?}", e)))
86                                            .await
87                                            .unwrap_or_else(|e| {
88                                                tracing::error!("Error sending connection event: {:?}", e);
89                                            });
90                                    }
91                                }
92                            }
93                            _ => {},
94                        }
95                    }
96                });
97            }
98        }
99        
100        // Handle commands
101        loop {
102            tokio::select! {
103                _ = self.cancel.cancelled() => {
104                    tracing::info!("EndpointActor loop cancelled, closing loop");
105                    break;
106                }
107                Some(cmd) = self.cmd_receiver.recv() => {
108                    match cmd {
109                        EndpointCommand::Connect(_addr) => {
110                            // Handle connect command
111                            // TODO!: Additional logic to handle connection
112                            // For now, connect via Endpoint::connect
113                        }
114                        EndpointCommand::Listen(_addr) => {
115                            // Handle listen command
116                            // TODO!: Additional logic to handle listening on a new address
117                        }
118                        EndpointCommand::Shutdown => {
119                            // Handle shutdown command
120                            break;
121                        }
122                    }
123                }
124            }
125        }
126        Ok(())
127    }
128}
129
130/// The endpoint for network communication.
131/// This is the main entry point for establishing connections and listening for incoming connections.
132pub struct Endpoint {
133    config: TransportConfig,
134    addrs: HashSet<StackAddr>,
135    relay_addrs: Option<RelayAddr>,
136    priority_map: BTreeMap<u8, TransportKind>,
137    transports: HashMap<TransportKind, Transport>,
138    listener: ListenerHandle,
139    event_receiver: mpsc::Receiver<EndpointEvent>,
140    cmd_sender: mpsc::Sender<EndpointCommand>,
141    cancel: CancellationToken,
142    allow_loopback: bool,
143}
144
145impl Endpoint {
146    /// Create a new endpoint builder for building an endpoint.
147    pub fn builder() -> EndpointBuilder {
148        EndpointBuilder::new()
149    }
150
151    /// Create a new endpoint with the default configuration.
152    pub fn default_builder() -> EndpointBuilder {
153        EndpointBuilder::default()
154    }
155
156    /// Get the node ID of the endpoint.
157    /// This is the public key of the endpoint's keypair.
158    pub fn node_id(&self) -> NodeId {
159        self.config.keypair().public().into()
160    }
161
162    /// Return current node address for the endpoint.
163    pub fn node_addr(&self) -> NodeAddr {
164        NodeAddr {
165            node_id: self.node_id(),
166            addresses: self.addrs.iter().cloned().collect(),
167            relay_addr: self.relay_addrs.clone(),
168        }
169    }
170
171    /// Return global-only node address for the endpoint.
172    pub fn global_node_addr(&self) -> NodeAddr {
173        let global_addrs: Vec<StackAddr> = self
174            .addrs
175            .iter()
176            .cloned()
177            .filter(|addr| {
178                if let Some(ip) = addr.ip() {
179                    ip::is_global_ip(&ip)
180                } else {
181                    false
182                }
183            })
184            .collect();
185
186        NodeAddr {
187            node_id: self.node_id(),
188            addresses: global_addrs.into_iter().collect(),
189            relay_addr: self.relay_addrs.clone(),
190        }
191    }
192
193    /// Connect to a remote node using the given StackAddr.
194    pub async fn connect(&mut self, addr: StackAddr) -> Result<Connection> {
195        match addr.transport() {
196            Some(transport) => {
197                match transport {
198                    TransportProtocol::Quic(_) | TransportProtocol::Udp(_) => {
199                        let t = self.transports.get_mut(&TransportKind::Quic).ok_or_else(|| anyhow!("QUIC transport not found"))?;
200                        t.connect(addr).await
201                    },
202                    TransportProtocol::TlsOverTcp(_) | TransportProtocol::Tcp(_) => {
203                        let t = self.transports.get_mut(&TransportKind::TlsOverTcp).ok_or_else(|| anyhow!("TCP transport not found"))?;
204                        t.connect(addr).await
205                    },
206                    _ => Err(anyhow!("Unsupported transport protocol: {:?}", transport)),
207                }
208            }
209            None => Err(anyhow!("Missing transport protocol in address")),
210        }
211    }
212
213    /// Connect to a remote node using the given NodeAddr.
214    pub async fn connect_node(&mut self, addr: NodeAddr) -> Result<Connection> {
215        let iface = &netdev::get_default_interface()
216            .map_err(|e| anyhow!("Failed to get default interface: {:?}", e))?;
217        for proto in self.priority_map.values() {
218            let t = self.transports.get_mut(proto).ok_or_else(|| anyhow!("Transport not found"))?;
219            let addrs = addr.get_direct_addrs(proto, self.allow_loopback);
220            let sorted_addrs = device::sort_addrs_by_reachability(&addrs, iface);
221            for addr in sorted_addrs {
222                match t.connect(addr.clone()).await {
223                    Ok(conn) => {
224                        return Ok(conn);
225                    }
226                    Err(e) => {
227                        tracing::error!("Error connecting to {}: {:?}", addr, e);
228                    }
229                }
230            }
231        }
232        Err(anyhow!("No direct address found for node"))
233    }
234
235    pub async fn accept(&mut self) -> Option<Connection> {
236        self.listener.accept().await
237    }
238
239    pub async fn get_listener(&self) -> ListenerHandle {
240        self.listener.clone().await
241    }
242
243    pub async fn shutdown(&self) -> Result<()> {
244        self.cmd_sender.send(EndpointCommand::Shutdown).await?;
245        self.cancel.cancel();
246        Ok(())
247    }
248
249    pub async fn send_command(&self, cmd: EndpointCommand) -> Result<()> {
250        self.cmd_sender.send(cmd).await?;
251        Ok(())
252    }
253
254    pub async fn next_event(&mut self) -> Option<EndpointEvent> {
255        self.event_receiver.recv().await
256    }
257}
258
259/// Represents an event that occurs in the endpoint.
260#[derive(Debug)]
261pub enum EndpointEvent {
262    ConnectionEstablished {
263        node_id: NodeId,
264        addr: StackAddr,
265    },
266    ConnectionClosed {
267        node_id: NodeId,
268    },
269    NewListenAddr {
270        listener_id: ListenerId,
271        addr: StackAddr,
272    },
273    PeerDiscovered {
274        node_id: NodeId,
275        addr: StackAddr,
276    },
277    Error(anyhow::Error),
278}
279
280pub enum EndpointCommand {
281    Connect(StackAddr),
282    Listen(StackAddr),
283    Shutdown,
284}
285
286/// The builder for building an endpoint.
287/// Provides methods for configuring the endpoint with builder pattern.
288pub struct EndpointBuilder {
289    config: TransportConfig,
290    protocols: Vec<TransportKind>,
291    addrs: HashSet<StackAddr>,
292    listen_enabled: bool,
293    allow_loopback: bool,
294}
295
296impl Default for EndpointBuilder {
297    fn default() -> Self {
298        let keypair = Keypair::generate();
299        let config = TransportConfig::new(keypair.clone()).unwrap();
300
301        let mut protocols = Vec::new();
302        protocols.push(TransportKind::Quic);
303
304        // Default stack address
305        let mut addrs = HashSet::new();
306        let addr = StackAddr::empty()
307        .with_protocol(Protocol::Ip4(Ipv4Addr::UNSPECIFIED))
308        .with_protocol(Protocol::Udp(default::DEFAULT_SERVER_PORT))
309        .with_protocol(Protocol::Quic)
310        .with_identity(Identity::NodeId(Bytes::copy_from_slice(&keypair.public().to_bytes())));
311
312        addrs.insert(addr);
313
314        Self { 
315            config, 
316            protocols, 
317            addrs: addrs,
318            listen_enabled: true,
319            allow_loopback: false,
320        }
321    }
322}
323
324impl EndpointBuilder {
325    /// Create a new endpoint builder with the given keypair.
326    pub fn new() -> Self {
327        let keypair = Keypair::generate();
328        let config = TransportConfig::new(keypair.clone()).unwrap();
329        Self { 
330            config, 
331            protocols: Vec::new(), 
332            addrs: HashSet::new(), 
333            listen_enabled: true, 
334            allow_loopback: false 
335        }
336    }
337    pub fn with_keypair(mut self, keypair: Keypair) -> Self {
338        self.config.set_keypair(keypair).unwrap();
339        self
340    }
341
342    fn push_protocol(&mut self, proto: TransportKind) {
343        if !self.protocols.contains(&proto) {
344            self.protocols.push(proto);
345        }
346    }
347
348    /// Add QUIC support to the endpoint.
349    /// If `addr` is not set, the default address will be used for listening.
350    pub fn with_quic(mut self) -> Self {
351        self.push_protocol(TransportKind::Quic);
352        self
353    }
354    /// Add TCP support to the endpoint.
355    /// If `addr` is not set, the default address will be used for listening.
356    pub fn with_tcp(mut self) -> Self {
357        self.push_protocol(TransportKind::TlsOverTcp);
358        self
359    }
360
361    /// Add listen address to the endpoint.
362    /// This address will be used for listening for incoming connections.
363    pub fn with_addr(mut self, addr: StackAddr) -> Result<Self> {
364        let transport = addr.transport().ok_or_else(|| anyhow!("Missing transport protocol in address"))?;
365        self.push_protocol(TransportKind::from_protocol(transport)?);
366        self.addrs.insert(addr);
367        Ok(self)
368    }
369
370    /// Set listener disabled.
371    /// This will disable the listener for incoming connections.
372    pub fn without_listen(mut self) -> Self {
373        self.listen_enabled = false;
374        self
375    }
376
377    /// Set allow loopback.
378    pub fn allow_loopback(mut self, allow: bool) -> Self {
379        self.allow_loopback = allow;
380        self
381    }
382
383    /// Set the read buffer size for the endpoint.
384    pub fn with_read_buffer_size(mut self, size: usize) -> Self {
385        self.config.read_buffer_size = size;
386        self
387    }
388
389    /// Set the write buffer size for the endpoint.
390    pub fn with_write_buffer_size(mut self, size: usize) -> Self {
391        self.config.write_buffer_size = size;
392        self
393    }
394
395    /// Set the maximum read buffer size for the endpoint.
396    pub fn with_max_read_buffer_size(mut self) -> Self {
397        self.config.read_buffer_size = default::MAX_READ_BUFFER_SIZE;
398        self
399    }
400
401    /// Set the maximum write buffer size for the endpoint.
402    pub fn with_max_write_buffer_size(mut self) -> Self {
403        self.config.write_buffer_size = default::MAX_WRITE_BUFFER_SIZE;
404        self
405    }
406
407    /// Build and spawn the endpoint.
408    /// This will create the endpoint and start listening for incoming connections.
409    pub fn build(self) -> Result<Endpoint> {
410        let mut priority_map = BTreeMap::new();
411        let mut transports = HashMap::new();
412        for (i, proto) in self.protocols.iter().enumerate() {
413            let priority = (i + 1) as u8;
414            match proto {
415                TransportKind::Quic => {
416                    let t = QuicTransport::new(self.config.clone())?;
417                    transports.insert(TransportKind::Quic, Transport::Quic(t));
418                    priority_map.insert(priority, TransportKind::Quic);
419                },
420                TransportKind::TlsOverTcp => {
421                    let t = TcpTransport::new(self.config.clone())?;
422                    transports.insert(TransportKind::TlsOverTcp, Transport::Tcp(t));
423                    priority_map.insert(priority, TransportKind::TlsOverTcp);
424                },
425            }
426        }
427
428        let addrs = if self.addrs.is_empty() {
429            get_unspecified_stack_addrs(&self.protocols)
430        } else {
431            self.addrs.clone()
432        };
433
434        // Create a channel for connection events
435        let (conn_sender, conn_receiver) = mpsc::channel(100);
436        // Create a channel for endpoint events
437        let (event_sender, event_receiver) = mpsc::channel(100);
438        // Create a channel for endpoint commands
439        let (cmd_sender, cmd_receiver) = mpsc::channel(100);
440        // Create a cancellation token for the endpoint
441        let cancel = CancellationToken::new();
442        // Create the endpoint actor
443        let actor = EndpointActor {
444            config: self.config.clone(),
445            addrs: addrs,
446            conn_sender,
447            event_sender,
448            cmd_receiver,
449            cancel: cancel.clone(),
450            listen_enabled: self.listen_enabled,
451        };
452        // Spawn the endpoint actor
453        tokio::spawn(async move {
454            if let Err(e) = actor.run().await {
455                tracing::error!("Endpoint actor error: {:?}", e);
456            }
457        });
458
459        let direct_addrs = if self.addrs.is_empty() {
460            get_default_stack_addrs(&self.protocols, self.allow_loopback)
461        } else {
462            replace_with_actual_addrs(&self.addrs, &self.protocols, self.allow_loopback)
463        };
464
465        Ok(Endpoint {
466            config: self.config,
467            addrs: direct_addrs,
468            relay_addrs: None,
469            priority_map,
470            transports,
471            listener: ListenerHandle::new(Arc::new(Mutex::new(conn_receiver))),
472            event_receiver,
473            cmd_sender,
474            cancel,
475            allow_loopback: self.allow_loopback,
476        })
477    }
478}
479
480fn get_unspecified_stack_addrs(protocols: &[TransportKind]) -> HashSet<StackAddr> {
481    let unspecified_addr = device::get_unspecified_server_addr();
482    let mut addrs = HashSet::new();
483    for proto in protocols.iter() {
484        match proto {
485            TransportKind::Quic => {
486                match unspecified_addr.ip() {
487                    IpAddr::V4(ipv4) => {
488                        addrs.insert(StackAddr::empty()
489                            .with_protocol(Protocol::Ip4(ipv4))
490                            .with_protocol(Protocol::Udp(unspecified_addr.port()))
491                            .with_protocol(Protocol::Quic));
492                    }
493                    IpAddr::V6(ipv6) => {
494                        addrs.insert(StackAddr::empty()
495                            .with_protocol(Protocol::Ip6(ipv6))
496                            .with_protocol(Protocol::Udp(unspecified_addr.port()))
497                            .with_protocol(Protocol::Quic));
498                    }
499                }
500            }
501            TransportKind::TlsOverTcp => {
502                match unspecified_addr.ip() {
503                    IpAddr::V4(ipv4) => {
504                        addrs.insert(StackAddr::empty()
505                            .with_protocol(Protocol::Ip4(ipv4))
506                            .with_protocol(Protocol::Tcp(unspecified_addr.port()))
507                            .with_protocol(Protocol::Tls));
508                    }
509                    IpAddr::V6(ipv6) => {
510                        addrs.insert(StackAddr::empty()
511                            .with_protocol(Protocol::Ip6(ipv6))
512                            .with_protocol(Protocol::Tcp(unspecified_addr.port()))
513                            .with_protocol(Protocol::Tls));
514                    }
515                }
516            }
517        }
518    }
519    addrs
520}
521
522fn get_default_stack_addrs(protocols: &[TransportKind], allow_loopback: bool) -> HashSet<StackAddr> {
523    let socket_addrs = crate::device::get_default_server_addrs(default::DEFAULT_SERVER_PORT, allow_loopback);
524    let mut addrs = HashSet::new();
525    for proto in protocols.iter() {
526        for addr in socket_addrs.iter() {
527            match proto {
528                TransportKind::Quic => {
529                    match addr.ip() {
530                        IpAddr::V4(ipv4) => {
531                            addrs.insert(StackAddr::empty()
532                                .with_protocol(Protocol::Ip4(ipv4))
533                                .with_protocol(Protocol::Udp(addr.port()))
534                                .with_protocol(Protocol::Quic));
535                        }
536                        IpAddr::V6(ipv6) => {
537                            addrs.insert(StackAddr::empty()
538                                .with_protocol(Protocol::Ip6(ipv6))
539                                .with_protocol(Protocol::Udp(addr.port()))
540                                .with_protocol(Protocol::Quic));
541                        }
542                    }
543                }
544                TransportKind::TlsOverTcp => {
545                    match addr.ip() {
546                        IpAddr::V4(ipv4) => {
547                            addrs.insert(StackAddr::empty()
548                                .with_protocol(Protocol::Ip4(ipv4))
549                                .with_protocol(Protocol::Tcp(addr.port()))
550                                .with_protocol(Protocol::Tls));
551                        }
552                        IpAddr::V6(ipv6) => {
553                            addrs.insert(StackAddr::empty()
554                                .with_protocol(Protocol::Ip6(ipv6))
555                                .with_protocol(Protocol::Tcp(addr.port()))
556                                .with_protocol(Protocol::Tls));
557                        }
558                    }
559                }
560            }
561        }
562    }
563    addrs
564}
565
566fn replace_with_actual_addrs(
567    input_addrs: &HashSet<StackAddr>,
568    protocols: &[TransportKind],
569    allow_loopback: bool
570) -> HashSet<StackAddr> {
571    let mut result = HashSet::new();
572
573    let actual_addrs = crate::device::get_default_server_addrs(default::DEFAULT_SERVER_PORT, allow_loopback);
574
575    for addr in input_addrs {
576        let sock_addr = match addr.socket_addr() {
577            Some(sock_addr) => sock_addr,
578            None => {
579                tracing::error!("Invalid address: {:?}", addr);
580                continue;
581            }
582        };
583        let is_unspecified = match sock_addr.ip() {
584            IpAddr::V4(ip) => ip.is_unspecified(),
585            IpAddr::V6(ip) => ip.is_unspecified(),
586        };
587
588        if is_unspecified {
589            for actual in &actual_addrs {
590                for proto in protocols {
591                    match proto {
592                        TransportKind::Quic => {
593                            match actual.ip() {
594                                IpAddr::V4(ipv4) => {
595                                    if sock_addr.ip().is_ipv4() {
596                                        result.insert(StackAddr::empty()
597                                            .with_protocol(Protocol::Ip4(ipv4))
598                                            .with_protocol(Protocol::Udp(sock_addr.port()))
599                                            .with_protocol(Protocol::Quic));
600                                    }
601                                }
602                                IpAddr::V6(ipv6) => {
603                                    if sock_addr.ip().is_ipv6() {
604                                        result.insert(StackAddr::empty()
605                                            .with_protocol(Protocol::Ip6(ipv6))
606                                            .with_protocol(Protocol::Udp(sock_addr.port()))
607                                            .with_protocol(Protocol::Quic));
608                                    }
609                                }
610                            }
611                        }
612                        TransportKind::TlsOverTcp => {
613                            match actual.ip() {
614                                IpAddr::V4(ipv4) => {
615                                    if sock_addr.ip().is_ipv4() {
616                                        result.insert(StackAddr::empty()
617                                            .with_protocol(Protocol::Ip4(ipv4))
618                                            .with_protocol(Protocol::Tcp(sock_addr.port()))
619                                            .with_protocol(Protocol::Tls));
620                                    }
621                                }
622                                IpAddr::V6(ipv6) => {
623                                    if sock_addr.ip().is_ipv6() {
624                                        result.insert(StackAddr::empty()
625                                            .with_protocol(Protocol::Ip6(ipv6))
626                                            .with_protocol(Protocol::Tcp(sock_addr.port()))
627                                            .with_protocol(Protocol::Tls));
628                                    }
629                                }
630                            }
631                        }
632                    }
633                }
634            }
635        } else {
636            result.insert(addr.clone());
637        }
638    }
639    result
640}