1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::config::CtlConfig;
10use crate::http::{parse_request, write_response};
11use crate::state::{SharedState, WsBroadcast, WsEvent};
12use crate::ws;
13
14pub(crate) enum ConnStream {
16 Plain(TcpStream),
17 #[cfg(feature = "tls")]
18 Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
19}
20
21impl ConnStream {
22 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
23 match self {
24 ConnStream::Plain(s) => s.set_read_timeout(dur),
25 #[cfg(feature = "tls")]
26 ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
27 }
28 }
29}
30
31impl Read for ConnStream {
32 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
33 match self {
34 ConnStream::Plain(s) => s.read(buf),
35 #[cfg(feature = "tls")]
36 ConnStream::Tls(s) => s.read(buf),
37 }
38 }
39}
40
41impl Write for ConnStream {
42 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
43 match self {
44 ConnStream::Plain(s) => s.write(buf),
45 #[cfg(feature = "tls")]
46 ConnStream::Tls(s) => s.write(buf),
47 }
48 }
49
50 fn flush(&mut self) -> io::Result<()> {
51 match self {
52 ConnStream::Plain(s) => s.flush(),
53 #[cfg(feature = "tls")]
54 ConnStream::Tls(s) => s.flush(),
55 }
56 }
57}
58
59pub struct ServerContext {
61 pub node: NodeHandle,
62 pub state: SharedState,
63 pub ws_broadcast: WsBroadcast,
64 pub config: CtlConfig,
65 #[cfg(feature = "tls")]
66 pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
67}
68
69pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
71 let listener = TcpListener::bind(addr)?;
72
73 #[cfg(feature = "tls")]
74 let scheme = if ctx.tls_config.is_some() {
75 "https"
76 } else {
77 "http"
78 };
79 #[cfg(not(feature = "tls"))]
80 let scheme = "http";
81
82 log::info!("Listening on {}://{}", scheme, addr);
83
84 for stream in listener.incoming() {
85 match stream {
86 Ok(tcp_stream) => {
87 let ctx = ctx.clone();
88 std::thread::Builder::new()
89 .name("rns-ctl-conn".into())
90 .spawn(move || {
91 let conn = match wrap_stream(tcp_stream, &ctx) {
92 Ok(c) => c,
93 Err(e) => {
94 log::debug!("TLS handshake error: {}", e);
95 return;
96 }
97 };
98 if let Err(e) = handle_connection(conn, &ctx) {
99 log::debug!("Connection error: {}", e);
100 }
101 })
102 .ok();
103 }
104 Err(e) => {
105 log::warn!("Accept error: {}", e);
106 }
107 }
108 }
109
110 Ok(())
111}
112
113fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
115 #[cfg(feature = "tls")]
116 {
117 if let Some(ref tls_config) = ctx.tls_config {
118 let server_conn = rustls::ServerConnection::new(tls_config.clone())
119 .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
120 return Ok(ConnStream::Tls(rustls::StreamOwned::new(
121 server_conn,
122 tcp_stream,
123 )));
124 }
125 }
126 let _ = ctx; Ok(ConnStream::Plain(tcp_stream))
128}
129
130fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
131 stream.set_read_timeout(Some(Duration::from_secs(30)))?;
133
134 let req = parse_request(&mut stream)?;
135
136 if ws::is_upgrade(&req) {
137 handle_ws_connection(stream, &req, ctx)
138 } else {
139 let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
140 write_response(&mut stream, &response)
141 }
142}
143
144fn handle_ws_connection(
145 mut stream: ConnStream,
146 req: &crate::http::HttpRequest,
147 ctx: &ServerContext,
148) -> io::Result<()> {
149 if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
151 return write_response(&mut stream, &resp);
152 }
153
154 ws::do_handshake(&mut stream, req)?;
156
157 stream.set_read_timeout(Some(Duration::from_millis(50)))?;
159
160 let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
162
163 {
165 let mut senders = ctx.ws_broadcast.lock().unwrap();
166 senders.push(event_tx);
167 }
168
169 let mut topics = HashSet::<String>::new();
171 let mut ws_buf = ws::WsBuf::new();
172
173 loop {
174 match ws_buf.try_read_frame(&mut stream) {
176 Ok(Some(frame)) => match frame.opcode {
177 ws::OPCODE_TEXT => {
178 if let Ok(text) = std::str::from_utf8(&frame.payload) {
179 handle_ws_text(text, &mut topics, &mut stream);
180 }
181 }
182 ws::OPCODE_PING => {
183 let _ = ws::write_pong_frame(&mut stream, &frame.payload);
184 }
185 ws::OPCODE_CLOSE => {
186 let _ = ws::write_close_frame(&mut stream);
187 break;
188 }
189 _ => {}
190 },
191 Ok(None) => {
192 }
194 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
195 Err(e) => {
196 log::debug!("WS read error: {}", e);
197 break;
198 }
199 }
200
201 while let Ok(event) = event_rx.try_recv() {
203 if topics.contains(event.topic) {
204 let json = event.to_json();
205 if ws::write_text_frame(&mut stream, &json).is_err() {
206 return Ok(());
207 }
208 }
209 }
210 }
211
212 Ok(())
213}
214
215fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
216 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
217 match msg["type"].as_str() {
218 Some("subscribe") => {
219 if let Some(arr) = msg["topics"].as_array() {
220 for t in arr {
221 if let Some(s) = t.as_str() {
222 topics.insert(s.to_string());
223 }
224 }
225 }
226 }
227 Some("unsubscribe") => {
228 if let Some(arr) = msg["topics"].as_array() {
229 for t in arr {
230 if let Some(s) = t.as_str() {
231 topics.remove(s);
232 }
233 }
234 }
235 }
236 Some("ping") => {
237 let _ =
238 ws::write_text_frame(stream, &serde_json::json!({"type": "pong"}).to_string());
239 }
240 _ => {}
241 }
242 }
243}