1use crate::{
2 config::TransportConfig, device, transport::{
3 connection::{Connection, ConnectionEvent}, quic::transport::QuicTransport, tcp::transport::TcpTransport, Transport,
4 }
5};
6use anyhow::{anyhow, Result};
7use bytes::Bytes;
8use foctet_core::{addr::node::{NodeAddr, RelayAddr}, default, id::NodeId, ip, key::Keypair, transport::{ListenerId, TransportKind}};
9use stackaddr::{segment::protocol::TransportProtocol, Identity, Protocol, StackAddr};
10use tokio_util::sync::CancellationToken;
11use std::{
12 collections::{BTreeMap, HashMap, HashSet}, net::{IpAddr, Ipv4Addr}, sync::Arc
13};
14use tokio::sync::{mpsc, Mutex};
15
16pub struct ListenerHandle {
17 conn_receiver: Arc<Mutex<mpsc::Receiver<Connection>>>,
18}
19
20impl ListenerHandle {
21 pub fn new(conn_receiver: Arc<Mutex<mpsc::Receiver<Connection>>>) -> Self {
22 Self { conn_receiver }
23 }
24
25 pub async fn accept(&self) -> Option<Connection> {
26 self.conn_receiver.lock().await.recv().await
27 }
28
29 pub async fn clone(&self) -> Self {
30 Self {
31 conn_receiver: Arc::clone(&self.conn_receiver),
32 }
33 }
34}
35
36pub struct RelayActor {
37
38}
39
40pub struct EndpointActor {
41 config: TransportConfig,
42 addrs: HashSet<StackAddr>,
43 conn_sender: mpsc::Sender<Connection>,
44 event_sender: mpsc::Sender<EndpointEvent>,
45 cmd_receiver: mpsc::Receiver<EndpointCommand>,
46 cancel: CancellationToken,
47 listen_enabled: bool,
48}
49
50impl EndpointActor {
51 pub async fn run(mut self) -> Result<()> {
52 if self.listen_enabled {
54 let mut listerner_id = ListenerId::new(1);
55 for addr in &self.addrs {
56 let config = self.config.clone();
57 let mut transport: Transport = match addr.transport() {
58 Some(transport) => match transport {
59 TransportProtocol::Quic(_) | TransportProtocol::Udp(_) => {
60 let t = QuicTransport::new(config)?;
61 Transport::Quic(t)
62 },
63 TransportProtocol::TlsOverTcp(_) | TransportProtocol::Tcp(_) => {
64 let t = TcpTransport::new(config)?;
65 Transport::Tcp(t)
66 },
67 _ => return Err(anyhow::anyhow!("Unsupported transport protocol: {:?}", transport)),
68 },
69 None => {
70 return Err(anyhow::anyhow!("Invalid transport protocol"));
71 }
72 };
73 let event_sender = self.event_sender.clone();
75 let conn_sender = self.conn_sender.clone();
76 let mut listener = transport.listen_on(listerner_id.fetch_add(1), addr.clone()).await?;
77 tokio::spawn(async move {
78 while let Some(conn_event) = listener.accept().await {
79 match conn_event {
80 ConnectionEvent::Accepted(conn) => {
81 match conn_sender.send(conn).await {
82 Ok(_) => {}
83 Err(e) => {
84 event_sender
85 .send(EndpointEvent::Error(anyhow!("Error sending connection event: {:?}", e)))
86 .await
87 .unwrap_or_else(|e| {
88 tracing::error!("Error sending connection event: {:?}", e);
89 });
90 }
91 }
92 }
93 _ => {},
94 }
95 }
96 });
97 }
98 }
99
100 loop {
102 tokio::select! {
103 _ = self.cancel.cancelled() => {
104 tracing::info!("EndpointActor loop cancelled, closing loop");
105 break;
106 }
107 Some(cmd) = self.cmd_receiver.recv() => {
108 match cmd {
109 EndpointCommand::Connect(_addr) => {
110 }
114 EndpointCommand::Listen(_addr) => {
115 }
118 EndpointCommand::Shutdown => {
119 break;
121 }
122 }
123 }
124 }
125 }
126 Ok(())
127 }
128}
129
130pub struct Endpoint {
133 config: TransportConfig,
134 addrs: HashSet<StackAddr>,
135 relay_addrs: Option<RelayAddr>,
136 priority_map: BTreeMap<u8, TransportKind>,
137 transports: HashMap<TransportKind, Transport>,
138 listener: ListenerHandle,
139 event_receiver: mpsc::Receiver<EndpointEvent>,
140 cmd_sender: mpsc::Sender<EndpointCommand>,
141 cancel: CancellationToken,
142 allow_loopback: bool,
143}
144
145impl Endpoint {
146 pub fn builder() -> EndpointBuilder {
148 EndpointBuilder::new()
149 }
150
151 pub fn default_builder() -> EndpointBuilder {
153 EndpointBuilder::default()
154 }
155
156 pub fn node_id(&self) -> NodeId {
159 self.config.keypair().public().into()
160 }
161
162 pub fn node_addr(&self) -> NodeAddr {
164 NodeAddr {
165 node_id: self.node_id(),
166 addresses: self.addrs.iter().cloned().collect(),
167 relay_addr: self.relay_addrs.clone(),
168 }
169 }
170
171 pub fn global_node_addr(&self) -> NodeAddr {
173 let global_addrs: Vec<StackAddr> = self
174 .addrs
175 .iter()
176 .cloned()
177 .filter(|addr| {
178 if let Some(ip) = addr.ip() {
179 ip::is_global_ip(&ip)
180 } else {
181 false
182 }
183 })
184 .collect();
185
186 NodeAddr {
187 node_id: self.node_id(),
188 addresses: global_addrs.into_iter().collect(),
189 relay_addr: self.relay_addrs.clone(),
190 }
191 }
192
193 pub async fn connect(&mut self, addr: StackAddr) -> Result<Connection> {
195 match addr.transport() {
196 Some(transport) => {
197 match transport {
198 TransportProtocol::Quic(_) | TransportProtocol::Udp(_) => {
199 let t = self.transports.get_mut(&TransportKind::Quic).ok_or_else(|| anyhow!("QUIC transport not found"))?;
200 t.connect(addr).await
201 },
202 TransportProtocol::TlsOverTcp(_) | TransportProtocol::Tcp(_) => {
203 let t = self.transports.get_mut(&TransportKind::TlsOverTcp).ok_or_else(|| anyhow!("TCP transport not found"))?;
204 t.connect(addr).await
205 },
206 _ => Err(anyhow!("Unsupported transport protocol: {:?}", transport)),
207 }
208 }
209 None => Err(anyhow!("Missing transport protocol in address")),
210 }
211 }
212
213 pub async fn connect_node(&mut self, addr: NodeAddr) -> Result<Connection> {
215 let iface = &netdev::get_default_interface()
216 .map_err(|e| anyhow!("Failed to get default interface: {:?}", e))?;
217 for proto in self.priority_map.values() {
218 let t = self.transports.get_mut(proto).ok_or_else(|| anyhow!("Transport not found"))?;
219 let addrs = addr.get_direct_addrs(proto, self.allow_loopback);
220 let sorted_addrs = device::sort_addrs_by_reachability(&addrs, iface);
221 for addr in sorted_addrs {
222 match t.connect(addr.clone()).await {
223 Ok(conn) => {
224 return Ok(conn);
225 }
226 Err(e) => {
227 tracing::error!("Error connecting to {}: {:?}", addr, e);
228 }
229 }
230 }
231 }
232 Err(anyhow!("No direct address found for node"))
233 }
234
235 pub async fn accept(&mut self) -> Option<Connection> {
236 self.listener.accept().await
237 }
238
239 pub async fn get_listener(&self) -> ListenerHandle {
240 self.listener.clone().await
241 }
242
243 pub async fn shutdown(&self) -> Result<()> {
244 self.cmd_sender.send(EndpointCommand::Shutdown).await?;
245 self.cancel.cancel();
246 Ok(())
247 }
248
249 pub async fn send_command(&self, cmd: EndpointCommand) -> Result<()> {
250 self.cmd_sender.send(cmd).await?;
251 Ok(())
252 }
253
254 pub async fn next_event(&mut self) -> Option<EndpointEvent> {
255 self.event_receiver.recv().await
256 }
257}
258
259#[derive(Debug)]
261pub enum EndpointEvent {
262 ConnectionEstablished {
263 node_id: NodeId,
264 addr: StackAddr,
265 },
266 ConnectionClosed {
267 node_id: NodeId,
268 },
269 NewListenAddr {
270 listener_id: ListenerId,
271 addr: StackAddr,
272 },
273 PeerDiscovered {
274 node_id: NodeId,
275 addr: StackAddr,
276 },
277 Error(anyhow::Error),
278}
279
280pub enum EndpointCommand {
281 Connect(StackAddr),
282 Listen(StackAddr),
283 Shutdown,
284}
285
286pub struct EndpointBuilder {
289 config: TransportConfig,
290 protocols: Vec<TransportKind>,
291 addrs: HashSet<StackAddr>,
292 listen_enabled: bool,
293 allow_loopback: bool,
294}
295
296impl Default for EndpointBuilder {
297 fn default() -> Self {
298 let keypair = Keypair::generate();
299 let config = TransportConfig::new(keypair.clone()).unwrap();
300
301 let mut protocols = Vec::new();
302 protocols.push(TransportKind::Quic);
303
304 let mut addrs = HashSet::new();
306 let addr = StackAddr::empty()
307 .with_protocol(Protocol::Ip4(Ipv4Addr::UNSPECIFIED))
308 .with_protocol(Protocol::Udp(default::DEFAULT_SERVER_PORT))
309 .with_protocol(Protocol::Quic)
310 .with_identity(Identity::NodeId(Bytes::copy_from_slice(&keypair.public().to_bytes())));
311
312 addrs.insert(addr);
313
314 Self {
315 config,
316 protocols,
317 addrs: addrs,
318 listen_enabled: true,
319 allow_loopback: false,
320 }
321 }
322}
323
324impl EndpointBuilder {
325 pub fn new() -> Self {
327 let keypair = Keypair::generate();
328 let config = TransportConfig::new(keypair.clone()).unwrap();
329 Self {
330 config,
331 protocols: Vec::new(),
332 addrs: HashSet::new(),
333 listen_enabled: true,
334 allow_loopback: false
335 }
336 }
337 pub fn with_keypair(mut self, keypair: Keypair) -> Self {
338 self.config.set_keypair(keypair).unwrap();
339 self
340 }
341
342 fn push_protocol(&mut self, proto: TransportKind) {
343 if !self.protocols.contains(&proto) {
344 self.protocols.push(proto);
345 }
346 }
347
348 pub fn with_quic(mut self) -> Self {
351 self.push_protocol(TransportKind::Quic);
352 self
353 }
354 pub fn with_tcp(mut self) -> Self {
357 self.push_protocol(TransportKind::TlsOverTcp);
358 self
359 }
360
361 pub fn with_addr(mut self, addr: StackAddr) -> Result<Self> {
364 let transport = addr.transport().ok_or_else(|| anyhow!("Missing transport protocol in address"))?;
365 self.push_protocol(TransportKind::from_protocol(transport)?);
366 self.addrs.insert(addr);
367 Ok(self)
368 }
369
370 pub fn without_listen(mut self) -> Self {
373 self.listen_enabled = false;
374 self
375 }
376
377 pub fn allow_loopback(mut self, allow: bool) -> Self {
379 self.allow_loopback = allow;
380 self
381 }
382
383 pub fn with_read_buffer_size(mut self, size: usize) -> Self {
385 self.config.read_buffer_size = size;
386 self
387 }
388
389 pub fn with_write_buffer_size(mut self, size: usize) -> Self {
391 self.config.write_buffer_size = size;
392 self
393 }
394
395 pub fn with_max_read_buffer_size(mut self) -> Self {
397 self.config.read_buffer_size = default::MAX_READ_BUFFER_SIZE;
398 self
399 }
400
401 pub fn with_max_write_buffer_size(mut self) -> Self {
403 self.config.write_buffer_size = default::MAX_WRITE_BUFFER_SIZE;
404 self
405 }
406
407 pub fn build(self) -> Result<Endpoint> {
410 let mut priority_map = BTreeMap::new();
411 let mut transports = HashMap::new();
412 for (i, proto) in self.protocols.iter().enumerate() {
413 let priority = (i + 1) as u8;
414 match proto {
415 TransportKind::Quic => {
416 let t = QuicTransport::new(self.config.clone())?;
417 transports.insert(TransportKind::Quic, Transport::Quic(t));
418 priority_map.insert(priority, TransportKind::Quic);
419 },
420 TransportKind::TlsOverTcp => {
421 let t = TcpTransport::new(self.config.clone())?;
422 transports.insert(TransportKind::TlsOverTcp, Transport::Tcp(t));
423 priority_map.insert(priority, TransportKind::TlsOverTcp);
424 },
425 }
426 }
427
428 let addrs = if self.addrs.is_empty() {
429 get_unspecified_stack_addrs(&self.protocols)
430 } else {
431 self.addrs.clone()
432 };
433
434 let (conn_sender, conn_receiver) = mpsc::channel(100);
436 let (event_sender, event_receiver) = mpsc::channel(100);
438 let (cmd_sender, cmd_receiver) = mpsc::channel(100);
440 let cancel = CancellationToken::new();
442 let actor = EndpointActor {
444 config: self.config.clone(),
445 addrs: addrs,
446 conn_sender,
447 event_sender,
448 cmd_receiver,
449 cancel: cancel.clone(),
450 listen_enabled: self.listen_enabled,
451 };
452 tokio::spawn(async move {
454 if let Err(e) = actor.run().await {
455 tracing::error!("Endpoint actor error: {:?}", e);
456 }
457 });
458
459 let direct_addrs = if self.addrs.is_empty() {
460 get_default_stack_addrs(&self.protocols, self.allow_loopback)
461 } else {
462 replace_with_actual_addrs(&self.addrs, &self.protocols, self.allow_loopback)
463 };
464
465 Ok(Endpoint {
466 config: self.config,
467 addrs: direct_addrs,
468 relay_addrs: None,
469 priority_map,
470 transports,
471 listener: ListenerHandle::new(Arc::new(Mutex::new(conn_receiver))),
472 event_receiver,
473 cmd_sender,
474 cancel,
475 allow_loopback: self.allow_loopback,
476 })
477 }
478}
479
480fn get_unspecified_stack_addrs(protocols: &[TransportKind]) -> HashSet<StackAddr> {
481 let unspecified_addr = device::get_unspecified_server_addr();
482 let mut addrs = HashSet::new();
483 for proto in protocols.iter() {
484 match proto {
485 TransportKind::Quic => {
486 match unspecified_addr.ip() {
487 IpAddr::V4(ipv4) => {
488 addrs.insert(StackAddr::empty()
489 .with_protocol(Protocol::Ip4(ipv4))
490 .with_protocol(Protocol::Udp(unspecified_addr.port()))
491 .with_protocol(Protocol::Quic));
492 }
493 IpAddr::V6(ipv6) => {
494 addrs.insert(StackAddr::empty()
495 .with_protocol(Protocol::Ip6(ipv6))
496 .with_protocol(Protocol::Udp(unspecified_addr.port()))
497 .with_protocol(Protocol::Quic));
498 }
499 }
500 }
501 TransportKind::TlsOverTcp => {
502 match unspecified_addr.ip() {
503 IpAddr::V4(ipv4) => {
504 addrs.insert(StackAddr::empty()
505 .with_protocol(Protocol::Ip4(ipv4))
506 .with_protocol(Protocol::Tcp(unspecified_addr.port()))
507 .with_protocol(Protocol::Tls));
508 }
509 IpAddr::V6(ipv6) => {
510 addrs.insert(StackAddr::empty()
511 .with_protocol(Protocol::Ip6(ipv6))
512 .with_protocol(Protocol::Tcp(unspecified_addr.port()))
513 .with_protocol(Protocol::Tls));
514 }
515 }
516 }
517 }
518 }
519 addrs
520}
521
522fn get_default_stack_addrs(protocols: &[TransportKind], allow_loopback: bool) -> HashSet<StackAddr> {
523 let socket_addrs = crate::device::get_default_server_addrs(default::DEFAULT_SERVER_PORT, allow_loopback);
524 let mut addrs = HashSet::new();
525 for proto in protocols.iter() {
526 for addr in socket_addrs.iter() {
527 match proto {
528 TransportKind::Quic => {
529 match addr.ip() {
530 IpAddr::V4(ipv4) => {
531 addrs.insert(StackAddr::empty()
532 .with_protocol(Protocol::Ip4(ipv4))
533 .with_protocol(Protocol::Udp(addr.port()))
534 .with_protocol(Protocol::Quic));
535 }
536 IpAddr::V6(ipv6) => {
537 addrs.insert(StackAddr::empty()
538 .with_protocol(Protocol::Ip6(ipv6))
539 .with_protocol(Protocol::Udp(addr.port()))
540 .with_protocol(Protocol::Quic));
541 }
542 }
543 }
544 TransportKind::TlsOverTcp => {
545 match addr.ip() {
546 IpAddr::V4(ipv4) => {
547 addrs.insert(StackAddr::empty()
548 .with_protocol(Protocol::Ip4(ipv4))
549 .with_protocol(Protocol::Tcp(addr.port()))
550 .with_protocol(Protocol::Tls));
551 }
552 IpAddr::V6(ipv6) => {
553 addrs.insert(StackAddr::empty()
554 .with_protocol(Protocol::Ip6(ipv6))
555 .with_protocol(Protocol::Tcp(addr.port()))
556 .with_protocol(Protocol::Tls));
557 }
558 }
559 }
560 }
561 }
562 }
563 addrs
564}
565
566fn replace_with_actual_addrs(
567 input_addrs: &HashSet<StackAddr>,
568 protocols: &[TransportKind],
569 allow_loopback: bool
570) -> HashSet<StackAddr> {
571 let mut result = HashSet::new();
572
573 let actual_addrs = crate::device::get_default_server_addrs(default::DEFAULT_SERVER_PORT, allow_loopback);
574
575 for addr in input_addrs {
576 let sock_addr = match addr.socket_addr() {
577 Some(sock_addr) => sock_addr,
578 None => {
579 tracing::error!("Invalid address: {:?}", addr);
580 continue;
581 }
582 };
583 let is_unspecified = match sock_addr.ip() {
584 IpAddr::V4(ip) => ip.is_unspecified(),
585 IpAddr::V6(ip) => ip.is_unspecified(),
586 };
587
588 if is_unspecified {
589 for actual in &actual_addrs {
590 for proto in protocols {
591 match proto {
592 TransportKind::Quic => {
593 match actual.ip() {
594 IpAddr::V4(ipv4) => {
595 if sock_addr.ip().is_ipv4() {
596 result.insert(StackAddr::empty()
597 .with_protocol(Protocol::Ip4(ipv4))
598 .with_protocol(Protocol::Udp(sock_addr.port()))
599 .with_protocol(Protocol::Quic));
600 }
601 }
602 IpAddr::V6(ipv6) => {
603 if sock_addr.ip().is_ipv6() {
604 result.insert(StackAddr::empty()
605 .with_protocol(Protocol::Ip6(ipv6))
606 .with_protocol(Protocol::Udp(sock_addr.port()))
607 .with_protocol(Protocol::Quic));
608 }
609 }
610 }
611 }
612 TransportKind::TlsOverTcp => {
613 match actual.ip() {
614 IpAddr::V4(ipv4) => {
615 if sock_addr.ip().is_ipv4() {
616 result.insert(StackAddr::empty()
617 .with_protocol(Protocol::Ip4(ipv4))
618 .with_protocol(Protocol::Tcp(sock_addr.port()))
619 .with_protocol(Protocol::Tls));
620 }
621 }
622 IpAddr::V6(ipv6) => {
623 if sock_addr.ip().is_ipv6() {
624 result.insert(StackAddr::empty()
625 .with_protocol(Protocol::Ip6(ipv6))
626 .with_protocol(Protocol::Tcp(sock_addr.port()))
627 .with_protocol(Protocol::Tls));
628 }
629 }
630 }
631 }
632 }
633 }
634 }
635 } else {
636 result.insert(addr.clone());
637 }
638 }
639 result
640}