1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use tokio::sync::{Mutex, Semaphore, mpsc, watch};
8
9#[cfg_attr(not(feature = "tray"), allow(dead_code))]
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub enum ServerStatus {
12 Starting,
13 Running,
14 Restarting,
15 Failed(String),
16 Stopped,
17}
18
19#[cfg_attr(not(feature = "tray"), allow(dead_code))]
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct StatusSnapshot {
22 pub service_name: String,
23 pub server_status: ServerStatus,
24 pub restarts: u64,
25 pub connected_clients: usize,
26 pub active_clients: usize,
27 pub max_active_clients: usize,
28 pub pending_requests: usize,
29 pub cached_initialize: bool,
30 pub initializing: bool,
31 pub last_reset: Option<String>,
32 pub queue_depth: usize,
33 pub child_pid: Option<u32>,
34 pub max_request_bytes: usize,
35 pub restart_backoff_ms: u64,
36 pub restart_backoff_max_ms: u64,
37 pub max_restarts: u64,
38}
39
40#[derive(Clone)]
49pub struct MuxState {
50 pub next_client_id: u64,
51 pub next_global_id: u64,
52 pub clients: HashMap<u64, mpsc::UnboundedSender<Value>>,
53 pub pending: HashMap<String, Pending>,
54 pub cached_initialize: Option<Value>,
55 pub init_waiting: Vec<(u64, Value)>,
56 pub initializing: bool,
57 pub server_status: ServerStatus,
58 pub restarts: u64,
59 pub last_reset: Option<String>,
60 pub max_active_clients: usize,
61 pub service_name: String,
62 pub max_request_bytes: usize,
63 pub request_timeout: Duration,
64 pub restart_backoff: Duration,
65 pub restart_backoff_max: Duration,
66 pub max_restarts: u64,
67 pub queue_depth: usize,
68 pub child_pid: Option<u32>,
69}
70
71#[derive(Clone, Debug)]
72pub struct Pending {
73 pub client_id: u64,
74 pub local_id: Value,
75 pub is_initialize: bool,
76 pub started_at: std::time::Instant,
77}
78
79impl MuxState {
80 #[allow(clippy::too_many_arguments)]
81 pub fn new(
82 max_active_clients: usize,
83 service_name: String,
84 max_request_bytes: usize,
85 request_timeout: Duration,
86 restart_backoff: Duration,
87 restart_backoff_max: Duration,
88 max_restarts: u64,
89 queue_depth: usize,
90 child_pid: Option<u32>,
91 ) -> Self {
92 Self {
93 next_client_id: 1,
94 next_global_id: 1,
95 clients: HashMap::new(),
96 pending: HashMap::new(),
97 cached_initialize: None,
98 init_waiting: Vec::new(),
99 initializing: false,
100 server_status: ServerStatus::Starting,
101 restarts: 0,
102 last_reset: None,
103 max_active_clients,
104 service_name,
105 max_request_bytes,
106 request_timeout,
107 restart_backoff,
108 restart_backoff_max,
109 max_restarts,
110 queue_depth,
111 child_pid,
112 }
113 }
114
115 pub fn register_client(&mut self, tx: mpsc::UnboundedSender<Value>) -> u64 {
116 let id = self.next_client_id;
117 self.next_client_id += 1;
118 self.clients.insert(id, tx);
119 id
120 }
121
122 pub fn unregister_client(&mut self, client_id: u64) {
123 self.clients.remove(&client_id);
124 self.pending.retain(|_, p| p.client_id != client_id);
125 self.init_waiting.retain(|(cid, _)| *cid != client_id);
126 }
127
128 pub fn next_request_id(&mut self) -> u64 {
129 let id = self.next_global_id;
130 self.next_global_id += 1;
131 id
132 }
133}
134
135pub fn set_id(msg: &mut Value, id: Value) {
136 if let Some(obj) = msg.as_object_mut() {
137 obj.insert("id".to_string(), id);
138 }
139}
140
141pub fn error_response(id: Value, message: &str) -> Value {
142 serde_json::json!({
143 "jsonrpc": "2.0",
144 "id": id,
145 "error": {
146 "code": -32000,
147 "message": message,
148 }
149 })
150}
151
152pub fn snapshot_for_state(st: &MuxState, active_clients: usize) -> StatusSnapshot {
153 StatusSnapshot {
154 service_name: st.service_name.clone(),
155 server_status: st.server_status.clone(),
156 restarts: st.restarts,
157 connected_clients: st.clients.len(),
158 active_clients,
159 max_active_clients: st.max_active_clients,
160 pending_requests: st.pending.len(),
161 cached_initialize: st.cached_initialize.is_some(),
162 initializing: st.initializing,
163 last_reset: st.last_reset.clone(),
164 queue_depth: st.queue_depth,
165 child_pid: st.child_pid,
166 max_request_bytes: st.max_request_bytes,
167 restart_backoff_ms: st.restart_backoff.as_millis() as u64,
168 restart_backoff_max_ms: st.restart_backoff_max.as_millis() as u64,
169 max_restarts: st.max_restarts,
170 }
171}
172
173pub async fn publish_status(
174 state: &Arc<Mutex<MuxState>>,
175 active_clients: &Arc<Semaphore>,
176 status_tx: &watch::Sender<StatusSnapshot>,
177) {
178 let st = state.lock().await;
179 let active = st
180 .max_active_clients
181 .saturating_sub(active_clients.available_permits());
182 let snapshot = snapshot_for_state(&st, active);
183 drop(st);
184 let _ = status_tx.send(snapshot);
185}
186
187pub async fn reset_state(
188 state: &Arc<Mutex<MuxState>>,
189 reason: &str,
190 active_clients: &Arc<Semaphore>,
191 status_tx: &watch::Sender<StatusSnapshot>,
192) {
193 let mut st = state.lock().await;
194 let pending = std::mem::take(&mut st.pending);
195 let waiters = std::mem::take(&mut st.init_waiting);
196 st.cached_initialize = None;
197 st.initializing = false;
198 st.last_reset = Some(reason.to_string());
199 st.queue_depth = 0;
200 st.child_pid = None;
201
202 for (_, p) in pending {
203 if let Some(tx) = st.clients.get(&p.client_id) {
204 tx.send(error_response(p.local_id, reason)).ok();
205 }
206 }
207 for (cid, lid) in waiters {
208 if let Some(tx) = st.clients.get(&cid) {
209 tx.send(error_response(lid, reason)).ok();
210 }
211 }
212 drop(st);
213 publish_status(state, active_clients, status_tx).await;
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn next_request_id_increments_sequentially() {
222 let mut state = MuxState::new(
223 5,
224 "test-service".into(),
225 1_048_576,
226 Duration::from_secs(30),
227 Duration::from_millis(1_000),
228 Duration::from_millis(30_000),
229 5,
230 0,
231 None,
232 );
233
234 let first = state.next_request_id();
235 let second = state.next_request_id();
236
237 assert_eq!(first + 1, second);
238 }
239
240 #[test]
241 fn error_response_uses_jsonrpc_2_0() {
242 let resp = error_response(Value::Number(1.into()), "oops");
243 assert_eq!(resp.get("jsonrpc"), Some(&Value::String("2.0".into())));
244 }
245}