oli_server/communication/
rpc.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::io::{BufRead, BufReader, Write};
5use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
6use std::sync::mpsc::{channel, Receiver, Sender};
7use std::sync::{Arc, Mutex, Once};
8
9/// JSON-RPC 2.0 request structure
10#[derive(Debug, Deserialize)]
11struct Request {
12    // jsonrpc field is required by the JSON-RPC 2.0 spec
13    #[allow(dead_code)]
14    jsonrpc: String,
15    id: Option<u64>,
16    method: String,
17    params: serde_json::Value,
18}
19
20/// JSON-RPC 2.0 response structure
21#[derive(Debug, Serialize)]
22struct Response {
23    jsonrpc: String,
24    id: Option<u64>,
25    result: Option<serde_json::Value>,
26    error: Option<RpcError>,
27}
28
29/// JSON-RPC 2.0 error structure
30#[derive(Debug, Serialize)]
31struct RpcError {
32    code: i32,
33    message: String,
34    data: Option<serde_json::Value>,
35}
36
37/// JSON-RPC 2.0 notification structure
38#[derive(Debug, Serialize)]
39struct Notification {
40    jsonrpc: String,
41    method: String,
42    params: serde_json::Value,
43}
44
45/// Method handler type
46type MethodHandler =
47    Box<dyn Fn(serde_json::Value) -> Result<serde_json::Value, anyhow::Error> + Send + Sync>;
48
49/// Subscription manager for event-based communication
50pub struct SubscriptionManager {
51    subscribers: HashMap<String, Vec<u64>>, // event_type -> list of subscription IDs
52    subscription_counter: AtomicU64,
53}
54
55impl Default for SubscriptionManager {
56    fn default() -> Self {
57        Self {
58            subscribers: HashMap::new(),
59            subscription_counter: AtomicU64::new(1),
60        }
61    }
62}
63
64impl SubscriptionManager {
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    pub fn subscribe(&mut self, event_type: &str) -> u64 {
70        let sub_id = self.subscription_counter.fetch_add(1, Ordering::SeqCst);
71        self.subscribers
72            .entry(event_type.to_string())
73            .or_default()
74            .push(sub_id);
75        sub_id
76    }
77
78    pub fn unsubscribe(&mut self, event_type: &str, sub_id: u64) -> bool {
79        if let Some(subs) = self.subscribers.get_mut(event_type) {
80            let pos = subs.iter().position(|&id| id == sub_id);
81            if let Some(idx) = pos {
82                subs.remove(idx);
83                return true;
84            }
85        }
86        false
87    }
88
89    pub fn has_subscribers(&self, event_type: &str) -> bool {
90        self.subscribers
91            .get(event_type)
92            .is_some_and(|subs| !subs.is_empty())
93    }
94
95    pub fn get_subscribers(&self, event_type: &str) -> Vec<u64> {
96        self.subscribers
97            .get(event_type)
98            .cloned()
99            .unwrap_or_default()
100    }
101}
102
103/// JSON-RPC server over stdio
104pub struct RpcServer {
105    methods: Arc<Mutex<HashMap<String, MethodHandler>>>,
106    event_sender: Sender<(String, serde_json::Value)>,
107    // Replace the standard mpsc::Receiver with an Arc<Mutex<>> wrapper to make it thread-safe
108    event_receiver: Arc<Mutex<Receiver<(String, serde_json::Value)>>>,
109    is_running: Arc<AtomicBool>,
110    // Add subscription manager for real-time event streaming
111    subscription_manager: Arc<Mutex<SubscriptionManager>>,
112}
113
114// Global RPC server instance
115static mut GLOBAL_RPC_SERVER: Option<Arc<RpcServer>> = None;
116static INIT: Once = Once::new();
117
118// Clone implementation for RpcServer
119impl Clone for RpcServer {
120    fn clone(&self) -> Self {
121        // Create a new channel for the cloned instance
122        let (event_sender, event_receiver) = channel();
123
124        Self {
125            methods: self.methods.clone(),
126            event_sender,
127            event_receiver: Arc::new(Mutex::new(event_receiver)),
128            is_running: self.is_running.clone(),
129            subscription_manager: self.subscription_manager.clone(),
130        }
131    }
132}
133
134/// Get global RPC server instance
135#[allow(static_mut_refs)]
136pub fn get_global_rpc_server() -> Option<Arc<RpcServer>> {
137    unsafe { GLOBAL_RPC_SERVER.clone() }
138}
139
140/// Set global RPC server instance
141fn set_global_rpc_server(server: Arc<RpcServer>) {
142    INIT.call_once(|| unsafe {
143        GLOBAL_RPC_SERVER = Some(server);
144    });
145}
146
147impl RpcServer {
148    /// Create a new RPC server
149    pub fn new() -> Self {
150        let (event_sender, event_receiver) = channel();
151        let server = Self {
152            methods: Arc::new(Mutex::new(HashMap::new())),
153            event_sender,
154            event_receiver: Arc::new(Mutex::new(event_receiver)),
155            is_running: Arc::new(AtomicBool::new(false)),
156            subscription_manager: Arc::new(Mutex::new(SubscriptionManager::new())),
157        };
158
159        // Create a clone for global registration
160        let server_clone = server.clone();
161
162        // Register as global RPC server
163        #[allow(clippy::arc_with_non_send_sync)]
164        let server_arc = Arc::new(server_clone);
165        set_global_rpc_server(server_arc);
166
167        // Return the original server
168        server
169    }
170
171    /// Register a method handler
172    pub fn register_method<F>(&mut self, name: &str, handler: F)
173    where
174        F: Fn(serde_json::Value) -> Result<serde_json::Value, anyhow::Error>
175            + Send
176            + Sync
177            + 'static,
178    {
179        self.methods
180            .lock()
181            .unwrap()
182            .insert(name.to_string(), Box::new(handler));
183    }
184
185    /// Get event sender for emitting events
186    pub fn event_sender(&self) -> Sender<(String, serde_json::Value)> {
187        self.event_sender.clone()
188    }
189
190    /// Send a notification event - will send to all subscribers of this event type
191    pub fn send_notification(&self, method: &str, params: serde_json::Value) -> Result<()> {
192        // First, check if anyone is subscribed to this event
193        let has_subscribers = {
194            let manager = self.subscription_manager.lock().unwrap();
195            // No need to log every notification
196            manager.has_subscribers(method)
197        };
198
199        // Always send through the event channel for internal event processing
200        self.event_sender
201            .send((method.to_string(), params.clone()))?;
202
203        // IMPORTANT: For now, always send notifications directly to ensure delivery
204        // This is a temporary fix to ensure notifications reach the UI
205        let always_send = true;
206
207        // If this is not a subscribed event or there are no subscribers, we're done
208        if !has_subscribers && !always_send {
209            return Ok(());
210        }
211
212        // For events with subscribers, we'll immediately send a notification through stdout
213        let notification = Notification {
214            jsonrpc: "2.0".to_string(),
215            method: method.to_string(),
216            params,
217        };
218
219        // Send directly to stdout to ensure immediate delivery
220        let stdout = std::io::stdout();
221        let mut stdout = stdout.lock();
222        serde_json::to_writer(&mut stdout, &notification)?;
223        stdout.write_all(b"\n")?;
224        stdout.flush()?;
225
226        Ok(())
227    }
228
229    /// Register subscription method handlers
230    pub fn register_subscription_handlers(&mut self) {
231        // Handle subscribe requests
232        let sub_manager = self.subscription_manager.clone();
233        self.register_method("subscribe", move |params| {
234            let event_type = params
235                .get("event_type")
236                .and_then(|v| v.as_str())
237                .ok_or_else(|| anyhow::anyhow!("Missing event_type parameter"))?;
238
239            let mut manager = sub_manager.lock().unwrap();
240            let sub_id = manager.subscribe(event_type);
241
242            Ok(serde_json::json!({ "subscription_id": sub_id }))
243        });
244
245        // Handle unsubscribe requests
246        let sub_manager = self.subscription_manager.clone();
247        self.register_method("unsubscribe", move |params| {
248            let event_type = params
249                .get("event_type")
250                .and_then(|v| v.as_str())
251                .ok_or_else(|| anyhow::anyhow!("Missing event_type parameter"))?;
252
253            let sub_id = params
254                .get("subscription_id")
255                .and_then(|v| v.as_u64())
256                .ok_or_else(|| anyhow::anyhow!("Missing subscription_id parameter"))?;
257
258            let mut manager = sub_manager.lock().unwrap();
259            let success = manager.unsubscribe(event_type, sub_id);
260
261            Ok(serde_json::json!({ "success": success }))
262        });
263    }
264
265    /// Check if the server is running
266    pub fn is_running(&self) -> bool {
267        self.is_running.load(Ordering::SeqCst)
268    }
269
270    /// Run the RPC server, processing stdin and writing to stdout
271    pub fn run(&self) -> Result<()> {
272        // Set running state
273        self.is_running.store(true, Ordering::SeqCst);
274
275        let stdin = std::io::stdin();
276        let stdout = std::io::stdout();
277        let mut stdout = stdout.lock();
278
279        let reader = BufReader::new(stdin.lock());
280        let methods = self.methods.clone();
281
282        // Process each line of input as a JSON-RPC request
283        for line in reader.lines() {
284            let line = line?;
285            if line.trim().is_empty() {
286                continue;
287            }
288
289            // Parse the request
290            let request: Request = match serde_json::from_str(&line) {
291                Ok(request) => request,
292                Err(e) => {
293                    // Send parse error
294                    let response = Response {
295                        jsonrpc: "2.0".to_string(),
296                        id: None,
297                        result: None,
298                        error: Some(RpcError {
299                            code: -32700,
300                            message: "Parse error".to_string(),
301                            data: Some(serde_json::Value::String(e.to_string())),
302                        }),
303                    };
304                    serde_json::to_writer(&mut stdout, &response)?;
305                    stdout.write_all(b"\n")?;
306                    stdout.flush()?;
307                    continue;
308                }
309            };
310
311            // Check for method
312            let methods = methods.lock().unwrap();
313            let handler = match methods.get(&request.method) {
314                Some(handler) => handler,
315                None => {
316                    // Send method not found error
317                    let response = Response {
318                        jsonrpc: "2.0".to_string(),
319                        id: request.id,
320                        result: None,
321                        error: Some(RpcError {
322                            code: -32601,
323                            message: "Method not found".to_string(),
324                            data: None,
325                        }),
326                    };
327                    serde_json::to_writer(&mut stdout, &response)?;
328                    stdout.write_all(b"\n")?;
329                    stdout.flush()?;
330                    continue;
331                }
332            };
333
334            // Execute the method
335            match handler(request.params.clone()) {
336                Ok(result) => {
337                    // Send success response
338                    let response = Response {
339                        jsonrpc: "2.0".to_string(),
340                        id: request.id,
341                        result: Some(result),
342                        error: None,
343                    };
344                    serde_json::to_writer(&mut stdout, &response)?;
345                    stdout.write_all(b"\n")?;
346                    stdout.flush()?;
347                }
348                Err(e) => {
349                    // Send error response
350                    let response = Response {
351                        jsonrpc: "2.0".to_string(),
352                        id: request.id,
353                        result: None,
354                        error: Some(RpcError {
355                            code: -32603,
356                            message: "Internal error".to_string(),
357                            data: Some(serde_json::Value::String(e.to_string())),
358                        }),
359                    };
360                    serde_json::to_writer(&mut stdout, &response)?;
361                    stdout.write_all(b"\n")?;
362                    stdout.flush()?;
363                }
364            };
365
366            // Check for any events to send
367            if let Ok(receiver) = self.event_receiver.try_lock() {
368                while let Ok((method, params)) = receiver.try_recv() {
369                    let notification = Notification {
370                        jsonrpc: "2.0".to_string(),
371                        method,
372                        params,
373                    };
374                    serde_json::to_writer(&mut stdout, &notification)?;
375                    stdout.write_all(b"\n")?;
376                    stdout.flush()?;
377                }
378            }
379        }
380
381        // Set running state to false
382        self.is_running.store(false, Ordering::SeqCst);
383
384        Ok(())
385    }
386}
387
388impl Default for RpcServer {
389    fn default() -> Self {
390        Self::new()
391    }
392}