auto_lsp_server/
request_registry.rs1use super::{main_loop::Task, Session};
20use lsp_server::{Message, Request, RequestId, Response};
21use serde::{de::DeserializeOwned, Serialize};
22use std::{collections::HashMap, panic::RefUnwindSafe, sync::Arc};
23
24type Callback<Db> = Arc<
26 dyn Fn(&Db, serde_json::Value) -> anyhow::Result<serde_json::Value>
27 + Send
28 + Sync
29 + RefUnwindSafe
30 + 'static,
31>;
32
33type SyncMutCallback<Db> =
35 Box<dyn Fn(&mut Session<Db>, serde_json::Value) -> anyhow::Result<serde_json::Value>>;
36
37#[derive(Default)]
45pub struct RequestRegistry<Db: salsa::Database> {
46 handlers: HashMap<String, Callback<Db>>,
47 sync_mut_handlers: HashMap<String, SyncMutCallback<Db>>,
48}
49
50impl<Db: salsa::Database + Clone + Send + RefUnwindSafe> RequestRegistry<Db> {
51 pub fn on<R, F>(&mut self, handler: F) -> &mut Self
52 where
53 R: lsp_types::request::Request,
54 R::Params: DeserializeOwned,
55 R::Result: Serialize,
56 F: Fn(&Db, R::Params) -> anyhow::Result<R::Result> + Send + Sync + RefUnwindSafe + 'static,
57 {
58 let method = R::METHOD.to_string();
59 let callback: Callback<Db> = Arc::new(move |session, params| {
60 let parsed_params: R::Params = serde_json::from_value(params)?;
61 let result = handler(session, parsed_params)?;
62 Ok(serde_json::to_value(result)?)
63 });
64
65 self.handlers.insert(method, callback);
66 self
67 }
68
69 pub fn on_mut<R, F>(&mut self, handler: F) -> &mut Self
75 where
76 R: lsp_types::request::Request,
77 R::Params: DeserializeOwned,
78 R::Result: Serialize,
79 F: Fn(&mut Session<Db>, R::Params) -> anyhow::Result<R::Result> + Send + Sync + 'static,
80 {
81 let method = R::METHOD.to_string();
82 let callback: SyncMutCallback<Db> = Box::new(move |session, params| {
83 let parsed_params: R::Params = serde_json::from_value(params)?;
84 let result = handler(session, parsed_params)?;
85 Ok(serde_json::to_value(result)?)
86 });
87
88 self.sync_mut_handlers.insert(method, callback);
89 self
90 }
91
92 pub(crate) fn get(&self, req: &Request) -> Option<&Callback<Db>> {
93 self.handlers.get(&req.method)
94 }
95
96 pub(crate) fn get_sync_mut(&self, req: &Request) -> Option<&SyncMutCallback<Db>> {
97 self.sync_mut_handlers.get(&req.method)
98 }
99
100 pub(crate) fn exec(session: &Session<Db>, callback: &Callback<Db>, req: Request) {
102 let params = req.params;
103 let id = req.id.clone();
104
105 let snapshot = session.snapshot();
106 let cb = Arc::clone(callback);
107 session.task_pool.spawn(move |sender| {
108 let cb = cb.clone();
109 match snapshot.with_db(|db| cb(db, params)) {
110 Err(e) => {
111 log::warn!("Cancelled request: {e}");
112 }
113 Ok(result) => match result {
114 Ok(result) => sender
115 .send(Task::Response(Response {
116 id,
117 result: Some(result),
118 error: None,
119 }))
120 .unwrap(),
121 Err(e) => {
122 sender
123 .send(Task::Response(Self::response_error(id, e)))
124 .unwrap();
125 }
126 },
127 }
128 });
129 }
130
131 pub(crate) fn exec_sync_mut(
135 session: &mut Session<Db>,
136 callback: &SyncMutCallback<Db>,
137 req: Request,
138 ) -> anyhow::Result<()> {
139 if let Err(e) = callback(session, req.params.clone()) {
140 Self::complete(session, Self::response_error(req.id, e))
141 } else {
142 Ok(())
143 }
144 }
145
146 pub(crate) fn complete(
147 session: &mut Session<Db>,
148 response: lsp_server::Response,
149 ) -> anyhow::Result<()> {
150 let id = response.id.clone();
151 if !session.req_queue.incoming.is_completed(&id) {
152 session.req_queue.incoming.complete(&id);
153 }
154 Ok(session
155 .connection
156 .sender
157 .send(Message::Response(response))?)
158 }
159
160 pub(crate) fn response_error(id: RequestId, error: anyhow::Error) -> lsp_server::Response {
161 Response {
162 id,
163 result: None,
164 error: Some(lsp_server::ResponseError {
165 code: -32803, message: error.to_string(),
167 data: None,
168 }),
169 }
170 }
171
172 pub(crate) fn request_mismatch(id: RequestId, error: anyhow::Error) -> lsp_server::Response {
173 Response {
174 id,
175 result: None,
176 error: Some(lsp_server::ResponseError {
177 code: -32601, message: format!("Method mismatch for request '{error}'"),
179 data: None,
180 }),
181 }
182 }
183}