1use std::collections::HashSet;
2use std::io;
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::mpsc;
5use std::thread;
6
7use crate::api::{handle_request, NodeHandle};
8use crate::auth::check_ws_auth;
9use crate::config::CtlConfig;
10use crate::http::{parse_request, write_response};
11use crate::state::{SharedState, WsBroadcast, WsEvent};
12use crate::ws;
13
14pub struct ServerContext {
16 pub node: NodeHandle,
17 pub state: SharedState,
18 pub ws_broadcast: WsBroadcast,
19 pub config: CtlConfig,
20}
21
22pub fn run_server(addr: SocketAddr, ctx: std::sync::Arc<ServerContext>) -> io::Result<()> {
24 let listener = TcpListener::bind(addr)?;
25 log::info!("Listening on http://{}", addr);
26
27 for stream in listener.incoming() {
28 match stream {
29 Ok(stream) => {
30 let ctx = ctx.clone();
31 thread::Builder::new()
32 .name("rns-ctl-conn".into())
33 .spawn(move || {
34 if let Err(e) = handle_connection(stream, &ctx) {
35 log::debug!("Connection error: {}", e);
36 }
37 })
38 .ok();
39 }
40 Err(e) => {
41 log::warn!("Accept error: {}", e);
42 }
43 }
44 }
45
46 Ok(())
47}
48
49fn handle_connection(mut stream: TcpStream, ctx: &ServerContext) -> io::Result<()> {
50 stream.set_read_timeout(Some(std::time::Duration::from_secs(30)))?;
52
53 let req = parse_request(&mut stream)?;
54
55 if ws::is_upgrade(&req) {
56 handle_ws_connection(stream, &req, ctx)
57 } else {
58 let response = handle_request(&req, &ctx.node, &ctx.state, &ctx.config);
59 write_response(&mut stream, &response)
60 }
61}
62
63fn handle_ws_connection(
64 mut stream: TcpStream,
65 req: &crate::http::HttpRequest,
66 ctx: &ServerContext,
67) -> io::Result<()> {
68 if let Err(resp) = check_ws_auth(&req.query, &ctx.config) {
70 return write_response(&mut stream, &resp);
71 }
72
73 ws::do_handshake(&mut stream, req)?;
75
76 stream.set_read_timeout(None)?;
78
79 let (event_tx, event_rx) = mpsc::channel::<WsEvent>();
81
82 {
84 let mut senders = ctx.ws_broadcast.lock().unwrap();
85 senders.push(event_tx);
86 }
87
88 let topics = std::sync::Arc::new(std::sync::Mutex::new(HashSet::<String>::new()));
90 let topics_writer = topics.clone();
91
92 let mut write_stream = stream.try_clone()?;
94 let writer_handle = thread::Builder::new()
95 .name("rns-ctl-ws-writer".into())
96 .spawn(move || {
97 while let Ok(event) = event_rx.recv() {
98 let subs = topics_writer.lock().unwrap();
99 if !subs.contains(event.topic) {
100 continue;
101 }
102 drop(subs);
103 let json = event.to_json();
104 if ws::write_text_frame(&mut write_stream, &json).is_err() {
105 break;
106 }
107 }
108 })?;
109
110 let mut read_stream = stream.try_clone()?;
113 let mut ctrl_stream = stream.try_clone()?;
114 let pong_stream = std::sync::Mutex::new(stream);
115
116 ws::run_ws_loop(&mut read_stream, &mut ctrl_stream, |text| {
117 if let Ok(msg) = serde_json::from_str::<serde_json::Value>(text) {
118 match msg["type"].as_str() {
119 Some("subscribe") => {
120 if let Some(arr) = msg["topics"].as_array() {
121 let mut subs = topics.lock().unwrap();
122 for t in arr {
123 if let Some(s) = t.as_str() {
124 subs.insert(s.to_string());
125 }
126 }
127 }
128 }
129 Some("unsubscribe") => {
130 if let Some(arr) = msg["topics"].as_array() {
131 let mut subs = topics.lock().unwrap();
132 for t in arr {
133 if let Some(s) = t.as_str() {
134 subs.remove(s);
135 }
136 }
137 }
138 }
139 Some("ping") => {
140 if let Ok(mut s) = pong_stream.lock() {
141 let _ = ws::write_text_frame(
142 &mut *s,
143 &serde_json::json!({"type": "pong"}).to_string(),
144 );
145 }
146 }
147 _ => {}
148 }
149 }
150 })?;
151
152 drop(writer_handle);
154 Ok(())
155}