Skip to main content

rns_ctl/
server.rs

1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::config::CtlConfig;
10use crate::http::{parse_request, write_response};
11use crate::state::{SharedState, WsBroadcast, WsEvent};
12use crate::ws;
13
14/// A connection stream that is either plain TCP or TLS-wrapped.
15pub(crate) enum ConnStream {
16    Plain(TcpStream),
17    #[cfg(feature = "tls")]
18    Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
19}
20
21impl ConnStream {
22    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
23        match self {
24            ConnStream::Plain(s) => s.set_read_timeout(dur),
25            #[cfg(feature = "tls")]
26            ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
27        }
28    }
29}
30
31impl Read for ConnStream {
32    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
33        match self {
34            ConnStream::Plain(s) => s.read(buf),
35            #[cfg(feature = "tls")]
36            ConnStream::Tls(s) => s.read(buf),
37        }
38    }
39}
40
41impl Write for ConnStream {
42    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
43        match self {
44            ConnStream::Plain(s) => s.write(buf),
45            #[cfg(feature = "tls")]
46            ConnStream::Tls(s) => s.write(buf),
47        }
48    }
49
50    fn flush(&mut self) -> io::Result<()> {
51        match self {
52            ConnStream::Plain(s) => s.flush(),
53            #[cfg(feature = "tls")]
54            ConnStream::Tls(s) => s.flush(),
55        }
56    }
57}
58
59/// All context needed by connection handlers.
60pub struct ServerContext {
61    pub node: NodeHandle,
62    pub state: SharedState,
63    pub ws_broadcast: WsBroadcast,
64    pub config: CtlConfig,
65    #[cfg(feature = "tls")]
66    pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
67}
68
69/// Run the HTTP/WS server. Blocks on the accept loop.
70pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
71    let listener = TcpListener::bind(addr)?;
72
73    #[cfg(feature = "tls")]
74    let scheme = if ctx.tls_config.is_some() { "https" } else { "http" };
75    #[cfg(not(feature = "tls"))]
76    let scheme = "http";
77
78    log::info!("Listening on {}://{}", scheme, addr);
79
80    for stream in listener.incoming() {
81        match stream {
82            Ok(tcp_stream) => {
83                let ctx = ctx.clone();
84                std::thread::Builder::new()
85                    .name("rns-ctl-conn".into())
86                    .spawn(move || {
87                        let conn = match wrap_stream(tcp_stream, &ctx) {
88                            Ok(c) => c,
89                            Err(e) => {
90                                log::debug!("TLS handshake error: {}", e);
91                                return;
92                            }
93                        };
94                        if let Err(e) = handle_connection(conn, &ctx) {
95                            log::debug!("Connection error: {}", e);
96                        }
97                    })
98                    .ok();
99            }
100            Err(e) => {
101                log::warn!("Accept error: {}", e);
102            }
103        }
104    }
105
106    Ok(())
107}
108
109/// Wrap a TCP stream in TLS if configured, otherwise return plain.
110fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
111    #[cfg(feature = "tls")]
112    {
113        if let Some(ref tls_config) = ctx.tls_config {
114            let server_conn = rustls::ServerConnection::new(tls_config.clone())
115                .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
116            return Ok(ConnStream::Tls(rustls::StreamOwned::new(server_conn, tcp_stream)));
117        }
118    }
119    let _ = ctx; // suppress unused warning when tls feature is off
120    Ok(ConnStream::Plain(tcp_stream))
121}
122
123fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
124    // Set a read timeout so we don't block forever on malformed requests
125    stream.set_read_timeout(Some(Duration::from_secs(30)))?;
126
127    let req = parse_request(&mut stream)?;
128
129    if ws::is_upgrade(&req) {
130        handle_ws_connection(stream, &req, ctx)
131    } else {
132        let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
133        write_response(&mut stream, &response)
134    }
135}
136
137fn handle_ws_connection(
138    mut stream: ConnStream,
139    req: &crate::http::HttpRequest,
140    ctx: &ServerContext,
141) -> io::Result<()> {
142    // Auth check on the upgrade request
143    if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
144        return write_response(&mut stream, &resp);
145    }
146
147    // Complete handshake
148    ws::do_handshake(&mut stream, req)?;
149
150    // Set a short read timeout for the non-blocking event loop
151    stream.set_read_timeout(Some(Duration::from_millis(50)))?;
152
153    // Create broadcast channel for this client
154    let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
155
156    // Register in broadcast list
157    {
158        let mut senders = ctx.ws_broadcast.lock().unwrap();
159        senders.push(event_tx);
160    }
161
162    // Subscribed topics for this client (no Arc/Mutex needed — single thread)
163    let mut topics = HashSet::<String>::new();
164    let mut ws_buf = ws::WsBuf::new();
165
166    loop {
167        // Try to read a frame from the client
168        match ws_buf.try_read_frame(&mut stream) {
169            Ok(Some(frame)) => match frame.opcode {
170                ws::OPCODE_TEXT => {
171                    if let Ok(text) = std::str::from_utf8(&frame.payload) {
172                        handle_ws_text(text, &mut topics, &mut stream);
173                    }
174                }
175                ws::OPCODE_PING => {
176                    let _ = ws::write_pong_frame(&mut stream, &frame.payload);
177                }
178                ws::OPCODE_CLOSE => {
179                    let _ = ws::write_close_frame(&mut stream);
180                    break;
181                }
182                _ => {}
183            },
184            Ok(None) => {
185                // No complete frame yet — fall through to drain events
186            }
187            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
188            Err(e) => {
189                log::debug!("WS read error: {}", e);
190                break;
191            }
192        }
193
194        // Drain event channel, send matching events to client
195        while let Ok(event) = event_rx.try_recv() {
196            if topics.contains(event.topic) {
197                let json = event.to_json();
198                if ws::write_text_frame(&mut stream, &json).is_err() {
199                    return Ok(());
200                }
201            }
202        }
203    }
204
205    Ok(())
206}
207
208fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
209    if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
210        match msg["type"].as_str() {
211            Some("subscribe") => {
212                if let Some(arr) = msg["topics"].as_array() {
213                    for t in arr {
214                        if let Some(s) = t.as_str() {
215                            topics.insert(s.to_string());
216                        }
217                    }
218                }
219            }
220            Some("unsubscribe") => {
221                if let Some(arr) = msg["topics"].as_array() {
222                    for t in arr {
223                        if let Some(s) = t.as_str() {
224                            topics.remove(s);
225                        }
226                    }
227                }
228            }
229            Some("ping") => {
230                let _ = ws::write_text_frame(
231                    stream,
232                    &serde_json::json!({"type": "pong"}).to_string(),
233                );
234            }
235            _ => {}
236        }
237    }
238}