1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::http::{parse_request, write_response};
10use crate::state::{
11 lock_ws_broadcast, ControlPlaneConfigHandle, SharedState, WsBroadcast, WsEvent,
12};
13use crate::ws;
14
15pub(crate) enum ConnStream {
17 Plain(TcpStream),
18 #[cfg(feature = "tls")]
19 Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
20}
21
22impl ConnStream {
23 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
24 match self {
25 ConnStream::Plain(s) => s.set_read_timeout(dur),
26 #[cfg(feature = "tls")]
27 ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
28 }
29 }
30}
31
32impl Read for ConnStream {
33 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
34 match self {
35 ConnStream::Plain(s) => s.read(buf),
36 #[cfg(feature = "tls")]
37 ConnStream::Tls(s) => s.read(buf),
38 }
39 }
40}
41
42impl Write for ConnStream {
43 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
44 match self {
45 ConnStream::Plain(s) => s.write(buf),
46 #[cfg(feature = "tls")]
47 ConnStream::Tls(s) => s.write(buf),
48 }
49 }
50
51 fn flush(&mut self) -> io::Result<()> {
52 match self {
53 ConnStream::Plain(s) => s.flush(),
54 #[cfg(feature = "tls")]
55 ConnStream::Tls(s) => s.flush(),
56 }
57 }
58}
59
60pub struct ServerContext {
62 pub node: NodeHandle,
63 pub state: SharedState,
64 pub ws_broadcast: WsBroadcast,
65 pub config: ControlPlaneConfigHandle,
66 #[cfg(feature = "tls")]
67 pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
68}
69
70pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
72 let listener = TcpListener::bind(addr)?;
73 run_server_with_listener(listener, ctx)
74}
75
76pub fn run_server_with_listener(
78 listener: TcpListener,
79 ctx: std::sync::Arc<ServerContext>,
80) -> io::Result<()> {
81 let addr = listener.local_addr()?;
82
83 #[cfg(feature = "tls")]
84 let scheme = if ctx.tls_config.is_some() {
85 "https"
86 } else {
87 "http"
88 };
89 #[cfg(not(feature = "tls"))]
90 let scheme = "http";
91
92 log::info!("Listening on {}://{}", scheme, addr);
93
94 for stream in listener.incoming() {
95 match stream {
96 Ok(tcp_stream) => {
97 let ctx = ctx.clone();
98 std::thread::Builder::new()
99 .name("rns-ctl-conn".into())
100 .spawn(move || {
101 let conn = match wrap_stream(tcp_stream, &ctx) {
102 Ok(c) => c,
103 Err(e) => {
104 log::debug!("TLS handshake error: {}", e);
105 return;
106 }
107 };
108 if let Err(e) = handle_connection(conn, &ctx) {
109 log::debug!("Connection error: {}", e);
110 }
111 })
112 .ok();
113 }
114 Err(e) => {
115 log::warn!("Accept error: {}", e);
116 }
117 }
118 }
119
120 Ok(())
121}
122
123fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
125 #[cfg(feature = "tls")]
126 {
127 if let Some(ref tls_config) = ctx.tls_config {
128 let server_conn = rustls::ServerConnection::new(tls_config.clone())
129 .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
130 return Ok(ConnStream::Tls(rustls::StreamOwned::new(
131 server_conn,
132 tcp_stream,
133 )));
134 }
135 }
136 let _ = ctx; Ok(ConnStream::Plain(tcp_stream))
138}
139
140fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
141 stream.set_read_timeout(Some(Duration::from_secs(30)))?;
143
144 let req = parse_request(&mut stream)?;
145
146 if ws::is_upgrade(&req) {
147 handle_ws_connection(stream, &req, ctx)
148 } else {
149 let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
150 write_response(&mut stream, &response)
151 }
152}
153
154fn handle_ws_connection(
155 mut stream: ConnStream,
156 req: &crate::http::HttpRequest,
157 ctx: &ServerContext,
158) -> io::Result<()> {
159 if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
161 return write_response(&mut stream, &resp);
162 }
163
164 ws::do_handshake(&mut stream, req)?;
166
167 stream.set_read_timeout(Some(Duration::from_millis(50)))?;
169
170 let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
172
173 {
175 let mut senders = lock_ws_broadcast(&ctx.ws_broadcast);
176 senders.push(event_tx);
177 }
178
179 let mut topics = HashSet::<String>::new();
181 let mut ws_buf = ws::WsBuf::new();
182
183 loop {
184 match ws_buf.try_read_frame(&mut stream) {
186 Ok(Some(frame)) => match frame.opcode {
187 ws::OPCODE_TEXT => {
188 if let Ok(text) = std::str::from_utf8(&frame.payload) {
189 handle_ws_text(text, &mut topics, &mut stream);
190 }
191 }
192 ws::OPCODE_PING => {
193 let _ = ws::write_pong_frame(&mut stream, &frame.payload);
194 }
195 ws::OPCODE_CLOSE => {
196 let _ = ws::write_close_frame(&mut stream);
197 break;
198 }
199 _ => {}
200 },
201 Ok(None) => {
202 }
204 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
205 Err(e) => {
206 log::debug!("WS read error: {}", e);
207 break;
208 }
209 }
210
211 while let Ok(event) = event_rx.try_recv() {
213 if topics.contains(event.topic) {
214 let json = event.to_json();
215 if ws::write_text_frame(&mut stream, &json).is_err() {
216 return Ok(());
217 }
218 }
219 }
220 }
221
222 Ok(())
223}
224
225fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
226 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
227 match msg["type"].as_str() {
228 Some("subscribe") => {
229 if let Some(arr) = msg["topics"].as_array() {
230 for t in arr {
231 if let Some(s) = t.as_str() {
232 topics.insert(s.to_string());
233 }
234 }
235 }
236 }
237 Some("unsubscribe") => {
238 if let Some(arr) = msg["topics"].as_array() {
239 for t in arr {
240 if let Some(s) = t.as_str() {
241 topics.remove(s);
242 }
243 }
244 }
245 }
246 Some("ping") => {
247 let _ =
248 ws::write_text_frame(stream, &serde_json::json!({"type": "pong"}).to_string());
249 }
250 _ => {}
251 }
252 }
253}