1use crate::error::{Error, Result};
2use jsonrpsee::types::error::ErrorObject as JsonRpcError;
3use serde::{Deserialize, Serialize};
4use std::os::unix::io::OwnedFd;
5
6pub const FDS_KEY: &str = "fds";
8pub const JSONRPC_VERSION: &str = "2.0";
10
11pub fn get_fd_count(value: &serde_json::Value) -> usize {
14 value
15 .get(FDS_KEY)
16 .and_then(|v| v.as_u64())
17 .map(|n| n as usize)
18 .unwrap_or(0)
19}
20
21fn skip_if_zero_or_none(fds: &Option<usize>) -> bool {
23 fds.is_none_or(|n| n == 0)
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct JsonRpcRequest {
29 pub jsonrpc: String,
31 pub method: String,
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub params: Option<serde_json::Value>,
36 pub id: serde_json::Value,
38 #[serde(skip_serializing_if = "skip_if_zero_or_none")]
40 pub fds: Option<usize>,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct JsonRpcResponse {
46 pub jsonrpc: String,
48 #[serde(skip_serializing_if = "Option::is_none")]
50 pub result: Option<serde_json::Value>,
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub error: Option<JsonRpcError<'static>>,
54 pub id: serde_json::Value,
56 #[serde(skip_serializing_if = "skip_if_zero_or_none")]
58 pub fds: Option<usize>,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct JsonRpcNotification {
64 pub jsonrpc: String,
66 pub method: String,
68 #[serde(skip_serializing_if = "Option::is_none")]
70 pub params: Option<serde_json::Value>,
71 #[serde(skip_serializing_if = "skip_if_zero_or_none")]
73 pub fds: Option<usize>,
74}
75
76#[derive(Debug, Clone)]
78pub enum JsonRpcMessage {
79 Request(JsonRpcRequest),
81 Response(JsonRpcResponse),
83 Notification(JsonRpcNotification),
85}
86
87impl JsonRpcRequest {
88 pub fn new(method: String, params: Option<serde_json::Value>, id: serde_json::Value) -> Self {
90 Self {
91 jsonrpc: JSONRPC_VERSION.to_string(),
92 method,
93 params,
94 id,
95 fds: None,
96 }
97 }
98}
99
100impl JsonRpcResponse {
101 pub fn success(result: serde_json::Value, id: serde_json::Value) -> Self {
103 Self {
104 jsonrpc: JSONRPC_VERSION.to_string(),
105 result: Some(result),
106 error: None,
107 id,
108 fds: None,
109 }
110 }
111
112 pub fn error(error: JsonRpcError<'static>, id: serde_json::Value) -> Self {
114 Self {
115 jsonrpc: JSONRPC_VERSION.to_string(),
116 result: None,
117 error: Some(error),
118 id,
119 fds: None,
120 }
121 }
122}
123
124impl JsonRpcNotification {
125 pub fn new(method: String, params: Option<serde_json::Value>) -> Self {
127 Self {
128 jsonrpc: JSONRPC_VERSION.to_string(),
129 method,
130 params,
131 fds: None,
132 }
133 }
134}
135
136impl JsonRpcMessage {
137 pub fn to_json_value(&self) -> Result<serde_json::Value> {
139 match self {
140 JsonRpcMessage::Request(req) => Ok(serde_json::to_value(req)?),
141 JsonRpcMessage::Response(res) => Ok(serde_json::to_value(res)?),
142 JsonRpcMessage::Notification(notif) => Ok(serde_json::to_value(notif)?),
143 }
144 }
145
146 pub fn from_json_value(value: serde_json::Value) -> Result<Self> {
148 if let serde_json::Value::Object(obj) = &value {
149 if obj.contains_key("method") && obj.contains_key("id") {
150 let request: JsonRpcRequest = serde_json::from_value(value)?;
151 Ok(JsonRpcMessage::Request(request))
152 } else if obj.contains_key("result") || obj.contains_key("error") {
153 let response: JsonRpcResponse = serde_json::from_value(value)?;
154 Ok(JsonRpcMessage::Response(response))
155 } else if obj.contains_key("method") {
156 let notification: JsonRpcNotification = serde_json::from_value(value)?;
157 Ok(JsonRpcMessage::Notification(notification))
158 } else {
159 Err(Error::InvalidMessage("Invalid JSON-RPC message".into()))
160 }
161 } else {
162 Err(Error::InvalidMessage("Expected JSON object".into()))
163 }
164 }
165}
166
167#[derive(Debug)]
169pub struct MessageWithFds {
170 pub message: JsonRpcMessage,
172 pub file_descriptors: Vec<OwnedFd>,
174}
175
176impl JsonRpcMessage {
177 pub fn set_fds(&mut self, count: usize) {
179 let fds = if count > 0 { Some(count) } else { None };
180 match self {
181 JsonRpcMessage::Request(req) => req.fds = fds,
182 JsonRpcMessage::Response(res) => res.fds = fds,
183 JsonRpcMessage::Notification(notif) => notif.fds = fds,
184 }
185 }
186
187 pub fn get_fds(&self) -> usize {
189 match self {
190 JsonRpcMessage::Request(req) => req.fds.unwrap_or(0),
191 JsonRpcMessage::Response(res) => res.fds.unwrap_or(0),
192 JsonRpcMessage::Notification(notif) => notif.fds.unwrap_or(0),
193 }
194 }
195}
196
197impl MessageWithFds {
198 pub fn new(message: JsonRpcMessage, file_descriptors: Vec<OwnedFd>) -> Self {
200 Self {
201 message,
202 file_descriptors,
203 }
204 }
205
206 pub fn serialize(&self) -> Result<String> {
208 self.serialize_impl(false)
209 }
210
211 pub fn serialize_pretty(&self) -> Result<String> {
213 self.serialize_impl(true)
214 }
215
216 fn serialize_impl(&self, pretty: bool) -> Result<String> {
217 let mut message = self.message.clone();
219 message.set_fds(self.file_descriptors.len());
220
221 let message_json = message.to_json_value()?;
222 let json_str = if pretty {
223 serde_json::to_string_pretty(&message_json)?
224 } else {
225 serde_json::to_string(&message_json)?
226 };
227 Ok(json_str)
228 }
229
230 pub fn from_json_with_fds(json_str: &str, fds: Vec<OwnedFd>) -> Result<Self> {
233 let message_json: serde_json::Value = serde_json::from_str(json_str)?;
234 let expected_count = get_fd_count(&message_json);
235
236 if expected_count != fds.len() {
237 return Err(Error::MismatchedCount {
238 expected: expected_count,
239 found: fds.len(),
240 });
241 }
242
243 let message = JsonRpcMessage::from_json_value(message_json)?;
244 Ok(Self::new(message, fds))
245 }
246}
247
248pub const FILE_DESCRIPTOR_ERROR_CODE: i32 = -32050;
250
251pub fn file_descriptor_error() -> JsonRpcError<'static> {
253 JsonRpcError::owned(
254 FILE_DESCRIPTOR_ERROR_CODE,
255 "File Descriptor Error",
256 None::<serde_json::Value>,
257 )
258}
259
260#[cfg(kani)]
261mod verification {
262 use super::*;
263
264 #[kani::proof]
270 fn check_skip_none() {
271 let result = skip_if_zero_or_none(&None);
272 kani::assert(result, "None should be skipped");
273 }
274
275 #[kani::proof]
277 fn check_skip_zero() {
278 let result = skip_if_zero_or_none(&Some(0));
279 kani::assert(result, "Some(0) should be skipped");
280 }
281
282 #[kani::proof]
284 fn check_skip_nonzero() {
285 let n: usize = kani::any();
286 kani::assume(n > 0);
287 let result = skip_if_zero_or_none(&Some(n));
288 kani::assert(!result, "Some(n > 0) should not be skipped");
289 }
290
291 #[kani::proof]
297 fn check_get_fds_none() {
298 let msg = JsonRpcMessage::Notification(JsonRpcNotification {
299 jsonrpc: String::new(),
300 method: String::new(),
301 params: None,
302 fds: None,
303 });
304 let result = msg.get_fds();
305 kani::assert(result == 0, "None fds should return 0");
306 }
307
308 #[kani::proof]
310 fn check_get_fds_some() {
311 let n: usize = kani::any();
312 let msg = JsonRpcMessage::Notification(JsonRpcNotification {
313 jsonrpc: String::new(),
314 method: String::new(),
315 params: None,
316 fds: Some(n),
317 });
318 let result = msg.get_fds();
319 kani::assert(result == n, "get_fds should return the fds value");
320 }
321}