1use std::net::SocketAddr;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::sync::mpsc;
4use std::sync::Arc;
5use std::thread::JoinHandle;
6
7use optic_core::{NetworkConfig, NetworkEvents, OpticError, OpticErrorKind, OpticResult, PeerId, NetworkMode};
8use tokio::runtime;
9use tokio::sync::mpsc as tokio_mpsc;
10
11use crate::channels::{inbound_data_channel, lifecycle_channel, outbound_channel, LifecycleEvent, TransportCommand};
12use crate::transport::run_transport;
13
14pub struct NetworkHandle {
20 thread: Option<JoinHandle<()>>,
21 inbound_data_rx: tokio_mpsc::UnboundedReceiver<(PeerId, Vec<u8>)>,
22 lifecycle_rx: tokio_mpsc::UnboundedReceiver<LifecycleEvent>,
23 outbound_tx: tokio_mpsc::UnboundedSender<TransportCommand>,
24 local_addr: Option<SocketAddr>,
25 shutdown_flag: Arc<AtomicBool>,
26}
27
28impl NetworkHandle {
29 pub fn new(config: NetworkConfig) -> OpticResult<Self> {
35 let (inbound_data_tx, inbound_data_rx) = inbound_data_channel();
36 let (lifecycle_tx, lifecycle_rx) = lifecycle_channel();
37 let (outbound_tx, outbound_rx) = outbound_channel();
38
39 let rt = runtime::Builder::new_current_thread()
40 .enable_io()
41 .enable_time()
42 .build()
43 .map_err(|e| OpticError::new(OpticErrorKind::Custom, &format!("failed to build tokio runtime: {e}")))?;
44
45 let (bound_addr_tx, bound_addr_rx) = mpsc::channel();
47
48 let shutdown_flag = Arc::new(AtomicBool::new(false));
49 let shutdown_flag_clone = shutdown_flag.clone();
50
51 let is_host = matches!(&config.mode, NetworkMode::Host { .. });
53 let config_port = match &config.mode {
54 NetworkMode::Host { port } => Some(*port),
55 NetworkMode::Client { .. } => None,
56 };
57
58 let thread = std::thread::Builder::new()
59 .name("optic-network".into())
60 .spawn(move || {
61 rt.block_on(async move {
62 run_transport(config, inbound_data_tx, lifecycle_tx, outbound_rx, bound_addr_tx).await;
63 shutdown_flag_clone.store(true, Ordering::SeqCst);
64 });
65 })
66 .map_err(|e| OpticError::new(OpticErrorKind::Custom, &format!("failed to spawn network thread: {e}")))?;
67
68 let local_addr = if is_host {
70 let addr = bound_addr_rx
71 .recv()
72 .map_err(|_| OpticError::new(OpticErrorKind::Custom, "network thread exited before binding"))?
73 .unwrap_or_else(|| {
74 let port = config_port.unwrap_or(0);
75 ([0, 0, 0, 0], port).into()
76 });
77 Some(addr)
78 } else {
79 None
80 };
81
82 Ok(Self {
83 thread: Some(thread),
84 inbound_data_rx,
85 lifecycle_rx,
86 outbound_tx,
87 local_addr,
88 shutdown_flag,
89 })
90 }
91
92 pub fn poll(&mut self, out: &mut NetworkEvents) {
95 while let Ok(event) = self.lifecycle_rx.try_recv() {
97 match event {
98 LifecycleEvent::Connected(pid) => out.peers_connected.push(pid),
99 LifecycleEvent::Disconnected(pid) => out.peers_disconnected.push(pid),
100 }
101 }
102 while let Ok((pid, data)) = self.inbound_data_rx.try_recv() {
104 out.packets.push((pid, data));
105 }
106 }
107
108 pub fn send(&self, peer: PeerId, bytes: &[u8]) -> OpticResult<()> {
113 self.outbound_tx
114 .send(TransportCommand::SendTo(peer, bytes.to_vec()))
115 .map_err(|_| OpticError::new(OpticErrorKind::Custom, "outbound channel closed"))?;
116 Ok(())
117 }
118
119 pub fn send_all(&self, bytes: &[u8]) -> OpticResult<()> {
121 self.outbound_tx
122 .send(TransportCommand::SendAll(bytes.to_vec()))
123 .map_err(|_| OpticError::new(OpticErrorKind::Custom, "outbound channel closed"))?;
124 Ok(())
125 }
126
127 pub fn send_all_except(&self, exclude: PeerId, bytes: &[u8]) -> OpticResult<()> {
129 self.outbound_tx
130 .send(TransportCommand::SendAllExcept(exclude, bytes.to_vec()))
131 .map_err(|_| OpticError::new(OpticErrorKind::Custom, "outbound channel closed"))?;
132 Ok(())
133 }
134
135 pub fn disconnect(&self, peer: PeerId) {
139 let _ = self.outbound_tx.send(TransportCommand::DisconnectPeer(peer));
140 }
141
142 pub fn peers(&self) -> Vec<PeerId> {
148 Vec::new()
149 }
150
151 pub fn local_addr(&self) -> Option<SocketAddr> {
153 self.local_addr
154 }
155
156 pub fn shutdown(&mut self) {
158 let _ = self.outbound_tx.send(TransportCommand::Shutdown);
159 if let Some(thread) = self.thread.take() {
160 let _ = thread.join();
161 }
162 }
163
164 pub fn is_shutdown(&self) -> bool {
166 self.shutdown_flag.load(Ordering::SeqCst)
167 }
168}
169
170impl Drop for NetworkHandle {
171 fn drop(&mut self) {
172 self.shutdown();
173 }
174}