Skip to main content

rns_ctl/
server.rs

1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::config::CtlConfig;
10use crate::http::{parse_request, write_response};
11use crate::state::{SharedState, WsBroadcast, WsEvent};
12use crate::ws;
13
14/// A connection stream that is either plain TCP or TLS-wrapped.
15pub(crate) enum ConnStream {
16    Plain(TcpStream),
17    #[cfg(feature = "tls")]
18    Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
19}
20
21impl ConnStream {
22    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
23        match self {
24            ConnStream::Plain(s) => s.set_read_timeout(dur),
25            #[cfg(feature = "tls")]
26            ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
27        }
28    }
29}
30
31impl Read for ConnStream {
32    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
33        match self {
34            ConnStream::Plain(s) => s.read(buf),
35            #[cfg(feature = "tls")]
36            ConnStream::Tls(s) => s.read(buf),
37        }
38    }
39}
40
41impl Write for ConnStream {
42    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
43        match self {
44            ConnStream::Plain(s) => s.write(buf),
45            #[cfg(feature = "tls")]
46            ConnStream::Tls(s) => s.write(buf),
47        }
48    }
49
50    fn flush(&mut self) -> io::Result<()> {
51        match self {
52            ConnStream::Plain(s) => s.flush(),
53            #[cfg(feature = "tls")]
54            ConnStream::Tls(s) => s.flush(),
55        }
56    }
57}
58
59/// All context needed by connection handlers.
60pub struct ServerContext {
61    pub node: NodeHandle,
62    pub state: SharedState,
63    pub ws_broadcast: WsBroadcast,
64    pub config: CtlConfig,
65    #[cfg(feature = "tls")]
66    pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
67}
68
69/// Run the HTTP/WS server. Blocks on the accept loop.
70pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
71    let listener = TcpListener::bind(addr)?;
72
73    #[cfg(feature = "tls")]
74    let scheme = if ctx.tls_config.is_some() {
75        "https"
76    } else {
77        "http"
78    };
79    #[cfg(not(feature = "tls"))]
80    let scheme = "http";
81
82    log::info!("Listening on {}://{}", scheme, addr);
83
84    for stream in listener.incoming() {
85        match stream {
86            Ok(tcp_stream) => {
87                let ctx = ctx.clone();
88                std::thread::Builder::new()
89                    .name("rns-ctl-conn".into())
90                    .spawn(move || {
91                        let conn = match wrap_stream(tcp_stream, &ctx) {
92                            Ok(c) => c,
93                            Err(e) => {
94                                log::debug!("TLS handshake error: {}", e);
95                                return;
96                            }
97                        };
98                        if let Err(e) = handle_connection(conn, &ctx) {
99                            log::debug!("Connection error: {}", e);
100                        }
101                    })
102                    .ok();
103            }
104            Err(e) => {
105                log::warn!("Accept error: {}", e);
106            }
107        }
108    }
109
110    Ok(())
111}
112
113/// Wrap a TCP stream in TLS if configured, otherwise return plain.
114fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
115    #[cfg(feature = "tls")]
116    {
117        if let Some(ref tls_config) = ctx.tls_config {
118            let server_conn = rustls::ServerConnection::new(tls_config.clone())
119                .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
120            return Ok(ConnStream::Tls(rustls::StreamOwned::new(
121                server_conn,
122                tcp_stream,
123            )));
124        }
125    }
126    let _ = ctx; // suppress unused warning when tls feature is off
127    Ok(ConnStream::Plain(tcp_stream))
128}
129
130fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
131    // Set a read timeout so we don't block forever on malformed requests
132    stream.set_read_timeout(Some(Duration::from_secs(30)))?;
133
134    let req = parse_request(&mut stream)?;
135
136    if ws::is_upgrade(&req) {
137        handle_ws_connection(stream, &req, ctx)
138    } else {
139        let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
140        write_response(&mut stream, &response)
141    }
142}
143
144fn handle_ws_connection(
145    mut stream: ConnStream,
146    req: &crate::http::HttpRequest,
147    ctx: &ServerContext,
148) -> io::Result<()> {
149    // Auth check on the upgrade request
150    if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
151        return write_response(&mut stream, &resp);
152    }
153
154    // Complete handshake
155    ws::do_handshake(&mut stream, req)?;
156
157    // Set a short read timeout for the non-blocking event loop
158    stream.set_read_timeout(Some(Duration::from_millis(50)))?;
159
160    // Create broadcast channel for this client
161    let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
162
163    // Register in broadcast list
164    {
165        let mut senders = ctx.ws_broadcast.lock().unwrap();
166        senders.push(event_tx);
167    }
168
169    // Subscribed topics for this client (no Arc/Mutex needed — single thread)
170    let mut topics = HashSet::<String>::new();
171    let mut ws_buf = ws::WsBuf::new();
172
173    loop {
174        // Try to read a frame from the client
175        match ws_buf.try_read_frame(&mut stream) {
176            Ok(Some(frame)) => match frame.opcode {
177                ws::OPCODE_TEXT => {
178                    if let Ok(text) = std::str::from_utf8(&frame.payload) {
179                        handle_ws_text(text, &mut topics, &mut stream);
180                    }
181                }
182                ws::OPCODE_PING => {
183                    let _ = ws::write_pong_frame(&mut stream, &frame.payload);
184                }
185                ws::OPCODE_CLOSE => {
186                    let _ = ws::write_close_frame(&mut stream);
187                    break;
188                }
189                _ => {}
190            },
191            Ok(None) => {
192                // No complete frame yet — fall through to drain events
193            }
194            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
195            Err(e) => {
196                log::debug!("WS read error: {}", e);
197                break;
198            }
199        }
200
201        // Drain event channel, send matching events to client
202        while let Ok(event) = event_rx.try_recv() {
203            if topics.contains(event.topic) {
204                let json = event.to_json();
205                if ws::write_text_frame(&mut stream, &json).is_err() {
206                    return Ok(());
207                }
208            }
209        }
210    }
211
212    Ok(())
213}
214
215fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
216    if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
217        match msg["type"].as_str() {
218            Some("subscribe") => {
219                if let Some(arr) = msg["topics"].as_array() {
220                    for t in arr {
221                        if let Some(s) = t.as_str() {
222                            topics.insert(s.to_string());
223                        }
224                    }
225                }
226            }
227            Some("unsubscribe") => {
228                if let Some(arr) = msg["topics"].as_array() {
229                    for t in arr {
230                        if let Some(s) = t.as_str() {
231                            topics.remove(s);
232                        }
233                    }
234                }
235            }
236            Some("ping") => {
237                let _ =
238                    ws::write_text_frame(stream, &serde_json::json!({"type": "pong"}).to_string());
239            }
240            _ => {}
241        }
242    }
243}