Skip to main content

rns_ctl/
server.rs

1use std::collections::HashSet;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::time::Duration;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::http::{parse_request, write_response};
10use crate::state::{ControlPlaneConfigHandle, SharedState, WsBroadcast, WsEvent};
11use crate::ws;
12
13/// A connection stream that is either plain TCP or TLS-wrapped.
14pub(crate) enum ConnStream {
15    Plain(TcpStream),
16    #[cfg(feature = "tls")]
17    Tls(rustls::StreamOwned<rustls::ServerConnection, TcpStream>),
18}
19
20impl ConnStream {
21    fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
22        match self {
23            ConnStream::Plain(s) => s.set_read_timeout(dur),
24            #[cfg(feature = "tls")]
25            ConnStream::Tls(s) => s.sock.set_read_timeout(dur),
26        }
27    }
28}
29
30impl Read for ConnStream {
31    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
32        match self {
33            ConnStream::Plain(s) => s.read(buf),
34            #[cfg(feature = "tls")]
35            ConnStream::Tls(s) => s.read(buf),
36        }
37    }
38}
39
40impl Write for ConnStream {
41    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
42        match self {
43            ConnStream::Plain(s) => s.write(buf),
44            #[cfg(feature = "tls")]
45            ConnStream::Tls(s) => s.write(buf),
46        }
47    }
48
49    fn flush(&mut self) -> io::Result<()> {
50        match self {
51            ConnStream::Plain(s) => s.flush(),
52            #[cfg(feature = "tls")]
53            ConnStream::Tls(s) => s.flush(),
54        }
55    }
56}
57
58/// All context needed by connection handlers.
59pub struct ServerContext {
60    pub node: NodeHandle,
61    pub state: SharedState,
62    pub ws_broadcast: WsBroadcast,
63    pub config: ControlPlaneConfigHandle,
64    #[cfg(feature = "tls")]
65    pub tls_config: Option<std::sync::Arc<rustls::ServerConfig>>,
66}
67
68/// Run the HTTP/WS server. Blocks on the accept loop.
69pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
70    let listener = TcpListener::bind(addr)?;
71
72    #[cfg(feature = "tls")]
73    let scheme = if ctx.tls_config.is_some() {
74        "https"
75    } else {
76        "http"
77    };
78    #[cfg(not(feature = "tls"))]
79    let scheme = "http";
80
81    log::info!("Listening on {}://{}", scheme, addr);
82
83    for stream in listener.incoming() {
84        match stream {
85            Ok(tcp_stream) => {
86                let ctx = ctx.clone();
87                std::thread::Builder::new()
88                    .name("rns-ctl-conn".into())
89                    .spawn(move || {
90                        let conn = match wrap_stream(tcp_stream, &ctx) {
91                            Ok(c) => c,
92                            Err(e) => {
93                                log::debug!("TLS handshake error: {}", e);
94                                return;
95                            }
96                        };
97                        if let Err(e) = handle_connection(conn, &ctx) {
98                            log::debug!("Connection error: {}", e);
99                        }
100                    })
101                    .ok();
102            }
103            Err(e) => {
104                log::warn!("Accept error: {}", e);
105            }
106        }
107    }
108
109    Ok(())
110}
111
112/// Wrap a TCP stream in TLS if configured, otherwise return plain.
113fn wrap_stream(tcp_stream: TcpStream, ctx: &ServerContext) -> io::Result<ConnStream> {
114    #[cfg(feature = "tls")]
115    {
116        if let Some(ref tls_config) = ctx.tls_config {
117            let server_conn = rustls::ServerConnection::new(tls_config.clone())
118                .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("TLS error: {}", e)))?;
119            return Ok(ConnStream::Tls(rustls::StreamOwned::new(
120                server_conn,
121                tcp_stream,
122            )));
123        }
124    }
125    let _ = ctx; // suppress unused warning when tls feature is off
126    Ok(ConnStream::Plain(tcp_stream))
127}
128
129fn handle_connection(mut stream: ConnStream, ctx: &ServerContext) -> io::Result<()> {
130    // Set a read timeout so we don't block forever on malformed requests
131    stream.set_read_timeout(Some(Duration::from_secs(30)))?;
132
133    let req = parse_request(&mut stream)?;
134
135    if ws::is_upgrade(&req) {
136        handle_ws_connection(stream, &req, ctx)
137    } else {
138        let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
139        write_response(&mut stream, &response)
140    }
141}
142
143fn handle_ws_connection(
144    mut stream: ConnStream,
145    req: &crate::http::HttpRequest,
146    ctx: &ServerContext,
147) -> io::Result<()> {
148    // Auth check on the upgrade request
149    if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
150        return write_response(&mut stream, &resp);
151    }
152
153    // Complete handshake
154    ws::do_handshake(&mut stream, req)?;
155
156    // Set a short read timeout for the non-blocking event loop
157    stream.set_read_timeout(Some(Duration::from_millis(50)))?;
158
159    // Create broadcast channel for this client
160    let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
161
162    // Register in broadcast list
163    {
164        let mut senders = ctx.ws_broadcast.lock().unwrap();
165        senders.push(event_tx);
166    }
167
168    // Subscribed topics for this client (no Arc/Mutex needed — single thread)
169    let mut topics = HashSet::<String>::new();
170    let mut ws_buf = ws::WsBuf::new();
171
172    loop {
173        // Try to read a frame from the client
174        match ws_buf.try_read_frame(&mut stream) {
175            Ok(Some(frame)) => match frame.opcode {
176                ws::OPCODE_TEXT => {
177                    if let Ok(text) = std::str::from_utf8(&frame.payload) {
178                        handle_ws_text(text, &mut topics, &mut stream);
179                    }
180                }
181                ws::OPCODE_PING => {
182                    let _ = ws::write_pong_frame(&mut stream, &frame.payload);
183                }
184                ws::OPCODE_CLOSE => {
185                    let _ = ws::write_close_frame(&mut stream);
186                    break;
187                }
188                _ => {}
189            },
190            Ok(None) => {
191                // No complete frame yet — fall through to drain events
192            }
193            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
194            Err(e) => {
195                log::debug!("WS read error: {}", e);
196                break;
197            }
198        }
199
200        // Drain event channel, send matching events to client
201        while let Ok(event) = event_rx.try_recv() {
202            if topics.contains(event.topic) {
203                let json = event.to_json();
204                if ws::write_text_frame(&mut stream, &json).is_err() {
205                    return Ok(());
206                }
207            }
208        }
209    }
210
211    Ok(())
212}
213
214fn handle_ws_text(text: &str, topics: &mut HashSet<String>, stream: &mut ConnStream) {
215    if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
216        match msg["type"].as_str() {
217            Some("subscribe") => {
218                if let Some(arr) = msg["topics"].as_array() {
219                    for t in arr {
220                        if let Some(s) = t.as_str() {
221                            topics.insert(s.to_string());
222                        }
223                    }
224                }
225            }
226            Some("unsubscribe") => {
227                if let Some(arr) = msg["topics"].as_array() {
228                    for t in arr {
229                        if let Some(s) = t.as_str() {
230                            topics.remove(s);
231                        }
232                    }
233                }
234            }
235            Some("ping") => {
236                let _ =
237                    ws::write_text_frame(stream, &serde_json::json!({"type": "pong"}).to_string());
238            }
239            _ => {}
240        }
241    }
242}