rmcp_mux/
state.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use tokio::sync::{Mutex, Semaphore, mpsc, watch};
8
9#[cfg_attr(not(feature = "tray"), allow(dead_code))]
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub enum ServerStatus {
12    Starting,
13    Running,
14    Restarting,
15    Failed(String),
16    Stopped,
17}
18
19#[cfg_attr(not(feature = "tray"), allow(dead_code))]
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub struct StatusSnapshot {
22    pub service_name: String,
23    pub server_status: ServerStatus,
24    pub restarts: u64,
25    pub connected_clients: usize,
26    pub active_clients: usize,
27    pub max_active_clients: usize,
28    pub pending_requests: usize,
29    pub cached_initialize: bool,
30    pub initializing: bool,
31    pub last_reset: Option<String>,
32    pub queue_depth: usize,
33    pub child_pid: Option<u32>,
34    pub max_request_bytes: usize,
35    pub restart_backoff_ms: u64,
36    pub restart_backoff_max_ms: u64,
37    pub max_restarts: u64,
38}
39
40/// Central runtime state shared between the async mux loops.
41///
42/// - `queue_depth` caps queued client messages to avoid unbounded memory growth
43///   under bursty hosts.
44/// - `max_request_bytes` and `request_timeout` are enforced per forwarded
45///   request to prevent slowloris/DoS patterns.
46/// - Restart backoff (`restart_backoff`..`restart_backoff_max`) and
47///   `max_restarts` gate child respawns so a flapping server cannot burn CPU.
48#[derive(Clone)]
49pub struct MuxState {
50    pub next_client_id: u64,
51    pub next_global_id: u64,
52    pub clients: HashMap<u64, mpsc::UnboundedSender<Value>>,
53    pub pending: HashMap<String, Pending>,
54    pub cached_initialize: Option<Value>,
55    pub init_waiting: Vec<(u64, Value)>,
56    pub initializing: bool,
57    pub server_status: ServerStatus,
58    pub restarts: u64,
59    pub last_reset: Option<String>,
60    pub max_active_clients: usize,
61    pub service_name: String,
62    pub max_request_bytes: usize,
63    pub request_timeout: Duration,
64    pub restart_backoff: Duration,
65    pub restart_backoff_max: Duration,
66    pub max_restarts: u64,
67    pub queue_depth: usize,
68    pub child_pid: Option<u32>,
69}
70
71#[derive(Clone, Debug)]
72pub struct Pending {
73    pub client_id: u64,
74    pub local_id: Value,
75    pub is_initialize: bool,
76    pub started_at: std::time::Instant,
77}
78
79impl MuxState {
80    #[allow(clippy::too_many_arguments)]
81    pub fn new(
82        max_active_clients: usize,
83        service_name: String,
84        max_request_bytes: usize,
85        request_timeout: Duration,
86        restart_backoff: Duration,
87        restart_backoff_max: Duration,
88        max_restarts: u64,
89        queue_depth: usize,
90        child_pid: Option<u32>,
91    ) -> Self {
92        Self {
93            next_client_id: 1,
94            next_global_id: 1,
95            clients: HashMap::new(),
96            pending: HashMap::new(),
97            cached_initialize: None,
98            init_waiting: Vec::new(),
99            initializing: false,
100            server_status: ServerStatus::Starting,
101            restarts: 0,
102            last_reset: None,
103            max_active_clients,
104            service_name,
105            max_request_bytes,
106            request_timeout,
107            restart_backoff,
108            restart_backoff_max,
109            max_restarts,
110            queue_depth,
111            child_pid,
112        }
113    }
114
115    pub fn register_client(&mut self, tx: mpsc::UnboundedSender<Value>) -> u64 {
116        let id = self.next_client_id;
117        self.next_client_id += 1;
118        self.clients.insert(id, tx);
119        id
120    }
121
122    pub fn unregister_client(&mut self, client_id: u64) {
123        self.clients.remove(&client_id);
124        self.pending.retain(|_, p| p.client_id != client_id);
125        self.init_waiting.retain(|(cid, _)| *cid != client_id);
126    }
127
128    pub fn next_request_id(&mut self) -> u64 {
129        let id = self.next_global_id;
130        self.next_global_id += 1;
131        id
132    }
133}
134
135pub fn set_id(msg: &mut Value, id: Value) {
136    if let Some(obj) = msg.as_object_mut() {
137        obj.insert("id".to_string(), id);
138    }
139}
140
141pub fn error_response(id: Value, message: &str) -> Value {
142    serde_json::json!({
143        "jsonrpc": "2.0",
144        "id": id,
145        "error": {
146            "code": -32000,
147            "message": message,
148        }
149    })
150}
151
152pub fn snapshot_for_state(st: &MuxState, active_clients: usize) -> StatusSnapshot {
153    StatusSnapshot {
154        service_name: st.service_name.clone(),
155        server_status: st.server_status.clone(),
156        restarts: st.restarts,
157        connected_clients: st.clients.len(),
158        active_clients,
159        max_active_clients: st.max_active_clients,
160        pending_requests: st.pending.len(),
161        cached_initialize: st.cached_initialize.is_some(),
162        initializing: st.initializing,
163        last_reset: st.last_reset.clone(),
164        queue_depth: st.queue_depth,
165        child_pid: st.child_pid,
166        max_request_bytes: st.max_request_bytes,
167        restart_backoff_ms: st.restart_backoff.as_millis() as u64,
168        restart_backoff_max_ms: st.restart_backoff_max.as_millis() as u64,
169        max_restarts: st.max_restarts,
170    }
171}
172
173pub async fn publish_status(
174    state: &Arc<Mutex<MuxState>>,
175    active_clients: &Arc<Semaphore>,
176    status_tx: &watch::Sender<StatusSnapshot>,
177) {
178    let st = state.lock().await;
179    let active = st
180        .max_active_clients
181        .saturating_sub(active_clients.available_permits());
182    let snapshot = snapshot_for_state(&st, active);
183    drop(st);
184    let _ = status_tx.send(snapshot);
185}
186
187pub async fn reset_state(
188    state: &Arc<Mutex<MuxState>>,
189    reason: &str,
190    active_clients: &Arc<Semaphore>,
191    status_tx: &watch::Sender<StatusSnapshot>,
192) {
193    let mut st = state.lock().await;
194    let pending = std::mem::take(&mut st.pending);
195    let waiters = std::mem::take(&mut st.init_waiting);
196    st.cached_initialize = None;
197    st.initializing = false;
198    st.last_reset = Some(reason.to_string());
199    st.queue_depth = 0;
200    st.child_pid = None;
201
202    for (_, p) in pending {
203        if let Some(tx) = st.clients.get(&p.client_id) {
204            tx.send(error_response(p.local_id, reason)).ok();
205        }
206    }
207    for (cid, lid) in waiters {
208        if let Some(tx) = st.clients.get(&cid) {
209            tx.send(error_response(lid, reason)).ok();
210        }
211    }
212    drop(st);
213    publish_status(state, active_clients, status_tx).await;
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn next_request_id_increments_sequentially() {
222        let mut state = MuxState::new(
223            5,
224            "test-service".into(),
225            1_048_576,
226            Duration::from_secs(30),
227            Duration::from_millis(1_000),
228            Duration::from_millis(30_000),
229            5,
230            0,
231            None,
232        );
233
234        let first = state.next_request_id();
235        let second = state.next_request_id();
236
237        assert_eq!(first + 1, second);
238    }
239
240    #[test]
241    fn error_response_uses_jsonrpc_2_0() {
242        let resp = error_response(Value::Number(1.into()), "oops");
243        assert_eq!(resp.get("jsonrpc"), Some(&Value::String("2.0".into())));
244    }
245}