oli_server/communication/
rpc.rs1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Write};
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::mpsc::{channel, Receiver, Sender};
7use std::sync::{Arc, Mutex, Once};
8
9#[derive(Debug, Deserialize)]
11struct Request {
12 #[allow(dead_code)]
14 jsonrpc: String,
15 id: Option<u64>,
16 method: String,
17 params: serde_json::Value,
18}
19
20#[derive(Debug, Serialize)]
22struct Response {
23 jsonrpc: String,
24 id: Option<u64>,
25 result: Option<serde_json::Value>,
26 error: Option<RpcError>,
27}
28
29#[derive(Debug, Serialize)]
31struct RpcError {
32 code: i32,
33 message: String,
34 data: Option<serde_json::Value>,
35}
36
37#[derive(Debug, Serialize)]
39struct Notification {
40 jsonrpc: String,
41 method: String,
42 params: serde_json::Value,
43}
44
45type MethodHandler =
47 Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, anyhow::Error> + Send + Sync>;
48
49pub struct SubscriptionManager {
51 subscribers: HashMap<String, Vec<u64>>, subscription_counter: AtomicU64,
53}
54
55impl Default for SubscriptionManager {
56 fn default() -> Self {
57 Self {
58 subscribers: HashMap::new(),
59 subscription_counter: AtomicU64::new(1),
60 }
61 }
62}
63
64impl SubscriptionManager {
65 pub fn new() -> Self {
66 Self::default()
67 }
68
69 pub fn subscribe(&mut self, event_type: &str) -> u64 {
70 let sub_id = self.subscription_counter.fetch_add(1, Ordering::SeqCst);
71 self.subscribers
72 .entry(event_type.to_string())
73 .or_default()
74 .push(sub_id);
75 sub_id
76 }
77
78 pub fn unsubscribe(&mut self, event_type: &str, sub_id: u64) -> bool {
79 if let Some(subs) = self.subscribers.get_mut(event_type) {
80 let pos = subs.iter().position(|&id| id == sub_id);
81 if let Some(idx) = pos {
82 subs.remove(idx);
83 return true;
84 }
85 }
86 false
87 }
88
89 pub fn has_subscribers(&self, event_type: &str) -> bool {
90 self.subscribers
91 .get(event_type)
92 .is_some_and(|subs| !subs.is_empty())
93 }
94
95 pub fn get_subscribers(&self, event_type: &str) -> Vec<u64> {
96 self.subscribers
97 .get(event_type)
98 .cloned()
99 .unwrap_or_default()
100 }
101}
102
103pub struct RpcServer {
105 methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
106 event_sender: Sender<(String, serde_json::Value)>,
107 event_receiver: Arc<Mutex<Receiver<(String, serde_json::Value)>>>,
109 is_running: Arc<AtomicBool>,
110 subscription_manager: Arc<Mutex<SubscriptionManager>>,
112}
113
114static mut GLOBAL_RPC_SERVER: Option<Arc<RpcServer>> = None;
116static INIT: Once = Once::new();
117
118impl Clone for RpcServer {
120 fn clone(&self) -> Self {
121 let (event_sender, event_receiver) = channel();
123
124 Self {
125 methods: self.methods.clone(),
126 event_sender,
127 event_receiver: Arc::new(Mutex::new(event_receiver)),
128 is_running: self.is_running.clone(),
129 subscription_manager: self.subscription_manager.clone(),
130 }
131 }
132}
133
134#[allow(static_mut_refs)]
136pub fn get_global_rpc_server() -> Option<Arc<RpcServer>> {
137 unsafe { GLOBAL_RPC_SERVER.clone() }
138}
139
140fn set_global_rpc_server(server: Arc<RpcServer>) {
142 INIT.call_once(|| unsafe {
143 GLOBAL_RPC_SERVER = Some(server);
144 });
145}
146
147impl RpcServer {
148 pub fn new() -> Self {
150 let (event_sender, event_receiver) = channel();
151 let server = Self {
152 methods: Arc::new(Mutex::new(HashMap::new())),
153 event_sender,
154 event_receiver: Arc::new(Mutex::new(event_receiver)),
155 is_running: Arc::new(AtomicBool::new(false)),
156 subscription_manager: Arc::new(Mutex::new(SubscriptionManager::new())),
157 };
158
159 let server_clone = server.clone();
161
162 #[allow(clippy::arc_with_non_send_sync)]
164 let server_arc = Arc::new(server_clone);
165 set_global_rpc_server(server_arc);
166
167 server
169 }
170
171 pub fn register_method<F>(&mut self, name: &str, handler: F)
173 where
174 F: Fn(serde_json::Value) -> Result<serde_json::Value, anyhow::Error>
175 + Send
176 + Sync
177 + 'static,
178 {
179 self.methods
180 .lock()
181 .unwrap()
182 .insert(name.to_string(), Box::new(handler));
183 }
184
185 pub fn event_sender(&self) -> Sender<(String, serde_json::Value)> {
187 self.event_sender.clone()
188 }
189
190 pub fn send_notification(&self, method: &str, params: serde_json::Value) -> Result<()> {
192 let has_subscribers = {
194 let manager = self.subscription_manager.lock().unwrap();
195 manager.has_subscribers(method)
197 };
198
199 self.event_sender
201 .send((method.to_string(), params.clone()))?;
202
203 let always_send = true;
206
207 if !has_subscribers && !always_send {
209 return Ok(());
210 }
211
212 let notification = Notification {
214 jsonrpc: "2.0".to_string(),
215 method: method.to_string(),
216 params,
217 };
218
219 let stdout = std::io::stdout();
221 let mut stdout = stdout.lock();
222 serde_json::to_writer(&mut stdout, ¬ification)?;
223 stdout.write_all(b"\n")?;
224 stdout.flush()?;
225
226 Ok(())
227 }
228
229 pub fn register_subscription_handlers(&mut self) {
231 let sub_manager = self.subscription_manager.clone();
233 self.register_method("subscribe", move |params| {
234 let event_type = params
235 .get("event_type")
236 .and_then(|v| v.as_str())
237 .ok_or_else(|| anyhow::anyhow!("Missing event_type parameter"))?;
238
239 let mut manager = sub_manager.lock().unwrap();
240 let sub_id = manager.subscribe(event_type);
241
242 Ok(serde_json::json!({ "subscription_id": sub_id }))
243 });
244
245 let sub_manager = self.subscription_manager.clone();
247 self.register_method("unsubscribe", move |params| {
248 let event_type = params
249 .get("event_type")
250 .and_then(|v| v.as_str())
251 .ok_or_else(|| anyhow::anyhow!("Missing event_type parameter"))?;
252
253 let sub_id = params
254 .get("subscription_id")
255 .and_then(|v| v.as_u64())
256 .ok_or_else(|| anyhow::anyhow!("Missing subscription_id parameter"))?;
257
258 let mut manager = sub_manager.lock().unwrap();
259 let success = manager.unsubscribe(event_type, sub_id);
260
261 Ok(serde_json::json!({ "success": success }))
262 });
263 }
264
265 pub fn is_running(&self) -> bool {
267 self.is_running.load(Ordering::SeqCst)
268 }
269
270 pub fn run(&self) -> Result<()> {
272 self.is_running.store(true, Ordering::SeqCst);
274
275 let stdin = std::io::stdin();
276 let stdout = std::io::stdout();
277 let mut stdout = stdout.lock();
278
279 let reader = BufReader::new(stdin.lock());
280 let methods = self.methods.clone();
281
282 for line in reader.lines() {
284 let line = line?;
285 if line.trim().is_empty() {
286 continue;
287 }
288
289 let request: Request = match serde_json::from_str(&line) {
291 Ok(request) => request,
292 Err(e) => {
293 let response = Response {
295 jsonrpc: "2.0".to_string(),
296 id: None,
297 result: None,
298 error: Some(RpcError {
299 code: -32700,
300 message: "Parse error".to_string(),
301 data: Some(serde_json::Value::String(e.to_string())),
302 }),
303 };
304 serde_json::to_writer(&mut stdout, &response)?;
305 stdout.write_all(b"\n")?;
306 stdout.flush()?;
307 continue;
308 }
309 };
310
311 let methods = methods.lock().unwrap();
313 let handler = match methods.get(&request.method) {
314 Some(handler) => handler,
315 None => {
316 let response = Response {
318 jsonrpc: "2.0".to_string(),
319 id: request.id,
320 result: None,
321 error: Some(RpcError {
322 code: -32601,
323 message: "Method not found".to_string(),
324 data: None,
325 }),
326 };
327 serde_json::to_writer(&mut stdout, &response)?;
328 stdout.write_all(b"\n")?;
329 stdout.flush()?;
330 continue;
331 }
332 };
333
334 match handler(request.params.clone()) {
336 Ok(result) => {
337 let response = Response {
339 jsonrpc: "2.0".to_string(),
340 id: request.id,
341 result: Some(result),
342 error: None,
343 };
344 serde_json::to_writer(&mut stdout, &response)?;
345 stdout.write_all(b"\n")?;
346 stdout.flush()?;
347 }
348 Err(e) => {
349 let response = Response {
351 jsonrpc: "2.0".to_string(),
352 id: request.id,
353 result: None,
354 error: Some(RpcError {
355 code: -32603,
356 message: "Internal error".to_string(),
357 data: Some(serde_json::Value::String(e.to_string())),
358 }),
359 };
360 serde_json::to_writer(&mut stdout, &response)?;
361 stdout.write_all(b"\n")?;
362 stdout.flush()?;
363 }
364 };
365
366 if let Ok(receiver) = self.event_receiver.try_lock() {
368 while let Ok((method, params)) = receiver.try_recv() {
369 let notification = Notification {
370 jsonrpc: "2.0".to_string(),
371 method,
372 params,
373 };
374 serde_json::to_writer(&mut stdout, ¬ification)?;
375 stdout.write_all(b"\n")?;
376 stdout.flush()?;
377 }
378 }
379 }
380
381 self.is_running.store(false, Ordering::SeqCst);
383
384 Ok(())
385 }
386}
387
388impl Default for RpcServer {
389 fn default() -> Self {
390 Self::new()
391 }
392}