1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::config::CtlConfig;
10use crate::http::{parse_request, write_response};
11use crate::state::{SharedState, WsBroadcast, WsEvent};
12use crate::ws;
13
14pub(crate) enum ConnStream {
16 Plain(TcpStream),
17 #[cfg(feature = "tls")]
18 Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
19}
20
21impl ConnStream {
22 fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
23 match self {
24 ConnStream::Plain(s) => s.set_read_timeout(dur),
25 #[cfg(feature = "tls")]
26 ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
27 }
28 }
29}
30
31impl Read for ConnStream {
32 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
33 match self {
34 ConnStream::Plain(s) => s.read(buf),
35 #[cfg(feature = "tls")]
36 ConnStream::Tls(s) => s.read(buf),
37 }
38 }
39}
40
41impl Write for ConnStream {
42 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
43 match self {
44 ConnStream::Plain(s) => s.write(buf),
45 #[cfg(feature = "tls")]
46 ConnStream::Tls(s) => s.write(buf),
47 }
48 }
49
50 fn flush(&mut self) -> io::Result<()> {
51 match self {
52 ConnStream::Plain(s) => s.flush(),
53 #[cfg(feature = "tls")]
54 ConnStream::Tls(s) => s.flush(),
55 }
56 }
57}
58
59pub struct ServerContext {
61 pub node: NodeHandle,
62 pub state: SharedState,
63 pub ws_broadcast: WsBroadcast,
64 pub config: CtlConfig,
65 #[cfg(feature = "tls")]
66 pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
67}
68
69pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
71 let listener = TcpListener::bind(addr)?;
72
73 #[cfg(feature = "tls")]
74 let scheme = if ctx.tls_config.is_some() { "https" } else { "http" };
75 #[cfg(not(feature = "tls"))]
76 let scheme = "http";
77
78 log::info!("Listening on {}://{}", scheme, addr);
79
80 for stream in listener.incoming() {
81 match stream {
82 Ok(tcp_stream) => {
83 let ctx = ctx.clone();
84 std::thread::Builder::new()
85 .name("rns-ctl-conn".into())
86 .spawn(move || {
87 let conn = match wrap_stream(tcp_stream, &ctx) {
88 Ok(c) => c,
89 Err(e) => {
90 log::debug!("TLS handshake error: {}", e);
91 return;
92 }
93 };
94 if let Err(e) = handle_connection(conn, &ctx) {
95 log::debug!("Connection error: {}", e);
96 }
97 })
98 .ok();
99 }
100 Err(e) => {
101 log::warn!("Accept error: {}", e);
102 }
103 }
104 }
105
106 Ok(())
107}
108
109fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
111 #[cfg(feature = "tls")]
112 {
113 if let Some(ref tls_config) = ctx.tls_config {
114 let server_conn = rustls::ServerConnection::new(tls_config.clone())
115 .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
116 return Ok(ConnStream::Tls(rustls::StreamOwned::new(server_conn, tcp_stream)));
117 }
118 }
119 let _ = ctx; Ok(ConnStream::Plain(tcp_stream))
121}
122
123fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
124 stream.set_read_timeout(Some(Duration::from_secs(30)))?;
126
127 let req = parse_request(&mut stream)?;
128
129 if ws::is_upgrade(&req) {
130 handle_ws_connection(stream, &req, ctx)
131 } else {
132 let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
133 write_response(&mut stream, &response)
134 }
135}
136
137fn handle_ws_connection(
138 mut stream: ConnStream,
139 req: &crate::http::HttpRequest,
140 ctx: &ServerContext,
141) -> io::Result<()> {
142 if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
144 return write_response(&mut stream, &resp);
145 }
146
147 ws::do_handshake(&mut stream, req)?;
149
150 stream.set_read_timeout(Some(Duration::from_millis(50)))?;
152
153 let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
155
156 {
158 let mut senders = ctx.ws_broadcast.lock().unwrap();
159 senders.push(event_tx);
160 }
161
162 let mut topics = HashSet::<String>::new();
164 let mut ws_buf = ws::WsBuf::new();
165
166 loop {
167 match ws_buf.try_read_frame(&mut stream) {
169 Ok(Some(frame)) => match frame.opcode {
170 ws::OPCODE_TEXT => {
171 if let Ok(text) = std::str::from_utf8(&frame.payload) {
172 handle_ws_text(text, &mut topics, &mut stream);
173 }
174 }
175 ws::OPCODE_PING => {
176 let _ = ws::write_pong_frame(&mut stream, &frame.payload);
177 }
178 ws::OPCODE_CLOSE => {
179 let _ = ws::write_close_frame(&mut stream);
180 break;
181 }
182 _ => {}
183 },
184 Ok(None) => {
185 }
187 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
188 Err(e) => {
189 log::debug!("WS read error: {}", e);
190 break;
191 }
192 }
193
194 while let Ok(event) = event_rx.try_recv() {
196 if topics.contains(event.topic) {
197 let json = event.to_json();
198 if ws::write_text_frame(&mut stream, &json).is_err() {
199 return Ok(());
200 }
201 }
202 }
203 }
204
205 Ok(())
206}
207
208fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
209 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
210 match msg["type"].as_str() {
211 Some("subscribe") => {
212 if let Some(arr) = msg["topics"].as_array() {
213 for t in arr {
214 if let Some(s) = t.as_str() {
215 topics.insert(s.to_string());
216 }
217 }
218 }
219 }
220 Some("unsubscribe") => {
221 if let Some(arr) = msg["topics"].as_array() {
222 for t in arr {
223 if let Some(s) = t.as_str() {
224 topics.remove(s);
225 }
226 }
227 }
228 }
229 Some("ping") => {
230 let _ = ws::write_text_frame(
231 stream,
232 &serde_json::json!({"type": "pong"}).to_string(),
233 );
234 }
235 _ => {}
236 }
237 }
238}