1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::http::{parse_request, write_response};
10use crate::state::{ControlPlaneConfigHandle, SharedState, WsBroadcast, WsEvent};
11use crate::ws;
12
13pub(crate) enum ConnStream {
15 Plain(TcpStream),
16 #[cfg(feature = "tls")]
17 Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
18}
19
20impl ConnStream {
21 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
22 match self {
23 ConnStream::Plain(s) => s.set_read_timeout(dur),
24 #[cfg(feature = "tls")]
25 ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
26 }
27 }
28}
29
30impl Read for ConnStream {
31 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
32 match self {
33 ConnStream::Plain(s) => s.read(buf),
34 #[cfg(feature = "tls")]
35 ConnStream::Tls(s) => s.read(buf),
36 }
37 }
38}
39
40impl Write for ConnStream {
41 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
42 match self {
43 ConnStream::Plain(s) => s.write(buf),
44 #[cfg(feature = "tls")]
45 ConnStream::Tls(s) => s.write(buf),
46 }
47 }
48
49 fn flush(&mut self) -> io::Result<()> {
50 match self {
51 ConnStream::Plain(s) => s.flush(),
52 #[cfg(feature = "tls")]
53 ConnStream::Tls(s) => s.flush(),
54 }
55 }
56}
57
58pub struct ServerContext {
60 pub node: NodeHandle,
61 pub state: SharedState,
62 pub ws_broadcast: WsBroadcast,
63 pub config: ControlPlaneConfigHandle,
64 #[cfg(feature = "tls")]
65 pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
66}
67
68pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
70 let listener = TcpListener::bind(addr)?;
71
72 #[cfg(feature = "tls")]
73 let scheme = if ctx.tls_config.is_some() {
74 "https"
75 } else {
76 "http"
77 };
78 #[cfg(not(feature = "tls"))]
79 let scheme = "http";
80
81 log::info!("Listening on {}://{}", scheme, addr);
82
83 for stream in listener.incoming() {
84 match stream {
85 Ok(tcp_stream) => {
86 let ctx = ctx.clone();
87 std::thread::Builder::new()
88 .name("rns-ctl-conn".into())
89 .spawn(move || {
90 let conn = match wrap_stream(tcp_stream, &ctx) {
91 Ok(c) => c,
92 Err(e) => {
93 log::debug!("TLS handshake error: {}", e);
94 return;
95 }
96 };
97 if let Err(e) = handle_connection(conn, &ctx) {
98 log::debug!("Connection error: {}", e);
99 }
100 })
101 .ok();
102 }
103 Err(e) => {
104 log::warn!("Accept error: {}", e);
105 }
106 }
107 }
108
109 Ok(())
110}
111
112fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
114 #[cfg(feature = "tls")]
115 {
116 if let Some(ref tls_config) = ctx.tls_config {
117 let server_conn = rustls::ServerConnection::new(tls_config.clone())
118 .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
119 return Ok(ConnStream::Tls(rustls::StreamOwned::new(
120 server_conn,
121 tcp_stream,
122 )));
123 }
124 }
125 let _ = ctx; Ok(ConnStream::Plain(tcp_stream))
127}
128
129fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
130 stream.set_read_timeout(Some(Duration::from_secs(30)))?;
132
133 let req = parse_request(&mut stream)?;
134
135 if ws::is_upgrade(&req) {
136 handle_ws_connection(stream, &req, ctx)
137 } else {
138 let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
139 write_response(&mut stream, &response)
140 }
141}
142
143fn handle_ws_connection(
144 mut stream: ConnStream,
145 req: &crate::http::HttpRequest,
146 ctx: &ServerContext,
147) -> io::Result<()> {
148 if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
150 return write_response(&mut stream, &resp);
151 }
152
153 ws::do_handshake(&mut stream, req)?;
155
156 stream.set_read_timeout(Some(Duration::from_millis(50)))?;
158
159 let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
161
162 {
164 let mut senders = ctx.ws_broadcast.lock().unwrap();
165 senders.push(event_tx);
166 }
167
168 let mut topics = HashSet::<String>::new();
170 let mut ws_buf = ws::WsBuf::new();
171
172 loop {
173 match ws_buf.try_read_frame(&mut stream) {
175 Ok(Some(frame)) => match frame.opcode {
176 ws::OPCODE_TEXT => {
177 if let Ok(text) = std::str::from_utf8(&frame.payload) {
178 handle_ws_text(text, &mut topics, &mut stream);
179 }
180 }
181 ws::OPCODE_PING => {
182 let _ = ws::write_pong_frame(&mut stream, &frame.payload);
183 }
184 ws::OPCODE_CLOSE => {
185 let _ = ws::write_close_frame(&mut stream);
186 break;
187 }
188 _ => {}
189 },
190 Ok(None) => {
191 }
193 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
194 Err(e) => {
195 log::debug!("WS read error: {}", e);
196 break;
197 }
198 }
199
200 while let Ok(event) = event_rx.try_recv() {
202 if topics.contains(event.topic) {
203 let json = event.to_json();
204 if ws::write_text_frame(&mut stream, &json).is_err() {
205 return Ok(());
206 }
207 }
208 }
209 }
210
211 Ok(())
212}
213
214fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
215 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
216 match msg["type"].as_str() {
217 Some("subscribe") => {
218 if let Some(arr) = msg["topics"].as_array() {
219 for t in arr {
220 if let Some(s) = t.as_str() {
221 topics.insert(s.to_string());
222 }
223 }
224 }
225 }
226 Some("unsubscribe") => {
227 if let Some(arr) = msg["topics"].as_array() {
228 for t in arr {
229 if let Some(s) = t.as_str() {
230 topics.remove(s);
231 }
232 }
233 }
234 }
235 Some("ping") => {
236 let _ =
237 ws::write_text_frame(stream, &serde_json::json!({"type": "pong"}).to_string());
238 }
239 _ => {}
240 }
241 }
242}