Skip to main content

rns_ctl/
server.rs

1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::http::{parse_request, write_response};
10use crate::state::{
11    lock_ws_broadcast, ControlPlaneConfigHandle, SharedState, WsBroadcast, WsEvent,
12};
13use crate::ws;
14
15/// A connection stream that is either plain TCP or TLS-wrapped.
16pub(crate) enum ConnStream {
17    Plain(TcpStream),
18    #[cfg(feature = "tls")]
19    Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
20}
21
22impl ConnStream {
23    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
24        match self {
25            ConnStream::Plain(s) => s.set_read_timeout(dur),
26            #[cfg(feature = "tls")]
27            ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
28        }
29    }
30}
31
32impl Read for ConnStream {
33    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
34        match self {
35            ConnStream::Plain(s) => s.read(buf),
36            #[cfg(feature = "tls")]
37            ConnStream::Tls(s) => s.read(buf),
38        }
39    }
40}
41
42impl Write for ConnStream {
43    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
44        match self {
45            ConnStream::Plain(s) => s.write(buf),
46            #[cfg(feature = "tls")]
47            ConnStream::Tls(s) => s.write(buf),
48        }
49    }
50
51    fn flush(&mut self) -> io::Result<()> {
52        match self {
53            ConnStream::Plain(s) => s.flush(),
54            #[cfg(feature = "tls")]
55            ConnStream::Tls(s) => s.flush(),
56        }
57    }
58}
59
60/// All context needed by connection handlers.
61pub struct ServerContext {
62    pub node: NodeHandle,
63    pub state: SharedState,
64    pub ws_broadcast: WsBroadcast,
65    pub config: ControlPlaneConfigHandle,
66    #[cfg(feature = "tls")]
67    pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
68}
69
70/// Run the HTTP/WS server. Blocks on the accept loop.
71pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
72    let listener = TcpListener::bind(addr)?;
73    run_server_with_listener(listener, ctx)
74}
75
76/// Run the HTTP/WS server using a pre-bound listener. Blocks on the accept loop.
77pub fn run_server_with_listener(
78    listener: TcpListener,
79    ctx: std::sync::Arc<ServerContext>,
80) -> io::Result<()> {
81    let addr = listener.local_addr()?;
82
83    #[cfg(feature = "tls")]
84    let scheme = if ctx.tls_config.is_some() {
85        "https"
86    } else {
87        "http"
88    };
89    #[cfg(not(feature = "tls"))]
90    let scheme = "http";
91
92    log::info!("Listening on {}://{}", scheme, addr);
93
94    for stream in listener.incoming() {
95        match stream {
96            Ok(tcp_stream) => {
97                let ctx = ctx.clone();
98                std::thread::Builder::new()
99                    .name("rns-ctl-conn".into())
100                    .spawn(move || {
101                        let conn = match wrap_stream(tcp_stream, &ctx) {
102                            Ok(c) => c,
103                            Err(e) => {
104                                log::debug!("TLS handshake error: {}", e);
105                                return;
106                            }
107                        };
108                        if let Err(e) = handle_connection(conn, &ctx) {
109                            log::debug!("Connection error: {}", e);
110                        }
111                    })
112                    .ok();
113            }
114            Err(e) => {
115                log::warn!("Accept error: {}", e);
116            }
117        }
118    }
119
120    Ok(())
121}
122
123/// Wrap a TCP stream in TLS if configured, otherwise return plain.
124fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
125    #[cfg(feature = "tls")]
126    {
127        if let Some(ref tls_config) = ctx.tls_config {
128            let server_conn = rustls::ServerConnection::new(tls_config.clone())
129                .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
130            return Ok(ConnStream::Tls(rustls::StreamOwned::new(
131                server_conn,
132                tcp_stream,
133            )));
134        }
135    }
136    let _ = ctx; // suppress unused warning when tls feature is off
137    Ok(ConnStream::Plain(tcp_stream))
138}
139
140fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
141    // Set a read timeout so we don't block forever on malformed requests
142    stream.set_read_timeout(Some(Duration::from_secs(30)))?;
143
144    let req = parse_request(&mut stream)?;
145
146    if ws::is_upgrade(&req) {
147        handle_ws_connection(stream, &req, ctx)
148    } else {
149        let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
150        write_response(&mut stream, &response)
151    }
152}
153
154fn handle_ws_connection(
155    mut stream: ConnStream,
156    req: &crate::http::HttpRequest,
157    ctx: &ServerContext,
158) -> io::Result<()> {
159    // Auth check on the upgrade request
160    if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
161        return write_response(&mut stream, &resp);
162    }
163
164    // Complete handshake
165    ws::do_handshake(&mut stream, req)?;
166
167    // Set a short read timeout for the non-blocking event loop
168    stream.set_read_timeout(Some(Duration::from_millis(50)))?;
169
170    // Create broadcast channel for this client
171    let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
172
173    // Register in broadcast list
174    {
175        let mut senders = lock_ws_broadcast(&ctx.ws_broadcast);
176        senders.push(event_tx);
177    }
178
179    // Subscribed topics for this client (no Arc/Mutex needed — single thread)
180    let mut topics = HashSet::<String>::new();
181    let mut ws_buf = ws::WsBuf::new();
182
183    loop {
184        // Try to read a frame from the client
185        match ws_buf.try_read_frame(&mut stream) {
186            Ok(Some(frame)) => match frame.opcode {
187                ws::OPCODE_TEXT => {
188                    if let Ok(text) = std::str::from_utf8(&frame.payload) {
189                        handle_ws_text(text, &mut topics, &mut stream);
190                    }
191                }
192                ws::OPCODE_PING => {
193                    let _ = ws::write_pong_frame(&mut stream, &frame.payload);
194                }
195                ws::OPCODE_CLOSE => {
196                    let _ = ws::write_close_frame(&mut stream);
197                    break;
198                }
199                _ => {}
200            },
201            Ok(None) => {
202                // No complete frame yet — fall through to drain events
203            }
204            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
205            Err(e) => {
206                log::debug!("WS read error: {}", e);
207                break;
208            }
209        }
210
211        // Drain event channel, send matching events to client
212        while let Ok(event) = event_rx.try_recv() {
213            if topics.contains(event.topic) {
214                let json = event.to_json();
215                if ws::write_text_frame(&mut stream, &json).is_err() {
216                    return Ok(());
217                }
218            }
219        }
220    }
221
222    Ok(())
223}
224
225fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
226    if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
227        match msg["type"].as_str() {
228            Some("subscribe") => {
229                if let Some(arr) = msg["topics"].as_array() {
230                    for t in arr {
231                        if let Some(s) = t.as_str() {
232                            topics.insert(s.to_string());
233                        }
234                    }
235                }
236            }
237            Some("unsubscribe") => {
238                if let Some(arr) = msg["topics"].as_array() {
239                    for t in arr {
240                        if let Some(s) = t.as_str() {
241                            topics.remove(s);
242                        }
243                    }
244                }
245            }
246            Some("ping") => {
247                let _ =
248                    ws::write_text_frame(stream, &serde_json::json!({"type": "pong"}).to_string());
249            }
250            _ => {}
251        }
252    }
253}