Skip to main content

arbor_server/
protocol.rs

1//! JSON-RPC protocol types.
2//!
3//! Implements the message format for the Arbor Protocol.
4//! Based on JSON-RPC 2.0 with some custom extensions.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9/// A JSON-RPC request.
10#[derive(Debug, Deserialize)]
11pub struct Request {
12    /// JSON-RPC version (always "2.0").
13    pub jsonrpc: String,
14
15    /// Request ID for matching responses.
16    pub id: Option<Value>,
17
18    /// Method name to invoke.
19    pub method: String,
20
21    /// Method parameters.
22    #[serde(default)]
23    pub params: Value,
24}
25
26/// A JSON-RPC response.
27#[derive(Debug, Serialize)]
28pub struct Response {
29    /// JSON-RPC version.
30    pub jsonrpc: &'static str,
31
32    /// Request ID this is responding to.
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub id: Option<Value>,
35
36    /// Result on success.
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub result: Option<Value>,
39
40    /// Error on failure.
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub error: Option<RpcError>,
43}
44
45impl Response {
46    /// Creates a success response.
47    pub fn success(id: Option<Value>, result: impl Serialize) -> Self {
48        Self {
49            jsonrpc: "2.0",
50            id,
51            result: Some(serde_json::to_value(result).unwrap_or(Value::Null)),
52            error: None,
53        }
54    }
55
56    /// Creates an error response.
57    pub fn error(id: Option<Value>, code: i32, message: impl Into<String>) -> Self {
58        Self {
59            jsonrpc: "2.0",
60            id,
61            result: None,
62            error: Some(RpcError {
63                code,
64                message: message.into(),
65                data: None,
66            }),
67        }
68    }
69
70    /// Predefined error: Parse error.
71    pub fn parse_error() -> Self {
72        Self::error(None, -32700, "Parse error")
73    }
74
75    /// Predefined error: Invalid request.
76    pub fn invalid_request(id: Option<Value>) -> Self {
77        Self::error(id, -32600, "Invalid request")
78    }
79
80    /// Predefined error: Method not found.
81    pub fn method_not_found(id: Option<Value>, method: &str) -> Self {
82        Self::error(id, -32601, format!("Method not found: {}", method))
83    }
84
85    /// Predefined error: Invalid params.
86    pub fn invalid_params(id: Option<Value>, message: impl Into<String>) -> Self {
87        Self::error(id, -32602, message)
88    }
89
90    /// Predefined error: Internal error.
91    pub fn internal_error(id: Option<Value>, message: impl Into<String>) -> Self {
92        Self::error(id, -32603, message)
93    }
94}
95
96/// A JSON-RPC error.
97#[derive(Debug, Serialize)]
98pub struct RpcError {
99    /// Error code.
100    pub code: i32,
101
102    /// Error message.
103    pub message: String,
104
105    /// Optional additional data.
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub data: Option<Value>,
108}
109
110/// Params for the discover method.
111#[derive(Debug, Deserialize)]
112pub struct DiscoverParams {
113    pub query: String,
114    #[serde(default = "default_limit")]
115    pub limit: usize,
116}
117
118/// Params for the impact method.
119#[derive(Debug, Deserialize)]
120pub struct ImpactParams {
121    pub node: String,
122    #[serde(default = "default_depth")]
123    pub depth: usize,
124}
125
126/// Params for the context method.
127#[derive(Debug, Deserialize)]
128pub struct ContextParams {
129    pub task: String,
130    #[serde(default = "default_max_tokens", rename = "maxTokens")]
131    pub max_tokens: usize,
132    #[serde(default, rename = "includeSource")]
133    pub _include_source: bool,
134}
135
136/// Params for the search method.
137#[derive(Debug, Deserialize)]
138pub struct SearchParams {
139    pub query: String,
140    pub kind: Option<String>,
141    #[serde(default = "default_limit")]
142    pub limit: usize,
143}
144
145/// Params for node.get method.
146#[derive(Debug, Deserialize)]
147pub struct NodeGetParams {
148    pub id: String,
149}
150
151fn default_limit() -> usize {
152    10
153}
154
155fn default_depth() -> usize {
156    3
157}
158
159fn default_max_tokens() -> usize {
160    8000
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_response_success() {
169        let resp = Response::success(Some(serde_json::json!(1)), serde_json::json!({"ok": true}));
170        assert!(resp.result.is_some());
171        assert!(resp.error.is_none());
172        assert_eq!(resp.jsonrpc, "2.0");
173        assert_eq!(resp.id, Some(serde_json::json!(1)));
174    }
175
176    #[test]
177    fn test_response_error() {
178        let resp = Response::error(Some(serde_json::json!(2)), -32600, "Bad request");
179        assert!(resp.result.is_none());
180        let err = resp.error.unwrap();
181        assert_eq!(err.code, -32600);
182        assert_eq!(err.message, "Bad request");
183    }
184
185    #[test]
186    fn test_response_parse_error() {
187        let resp = Response::parse_error();
188        let err = resp.error.unwrap();
189        assert_eq!(err.code, -32700);
190        assert!(resp.id.is_none());
191    }
192
193    #[test]
194    fn test_response_method_not_found() {
195        let resp = Response::method_not_found(Some(serde_json::json!(5)), "foo.bar");
196        let err = resp.error.unwrap();
197        assert_eq!(err.code, -32601);
198        assert!(err.message.contains("foo.bar"));
199    }
200
201    #[test]
202    fn test_response_invalid_params() {
203        let resp = Response::invalid_params(Some(serde_json::json!(3)), "missing field");
204        let err = resp.error.unwrap();
205        assert_eq!(err.code, -32602);
206        assert!(err.message.contains("missing field"));
207    }
208
209    #[test]
210    fn test_response_internal_error() {
211        let resp = Response::internal_error(None, "something broke");
212        let err = resp.error.unwrap();
213        assert_eq!(err.code, -32603);
214    }
215
216    #[test]
217    fn test_default_limit() {
218        assert_eq!(default_limit(), 10);
219    }
220
221    #[test]
222    fn test_default_depth() {
223        assert_eq!(default_depth(), 3);
224    }
225
226    #[test]
227    fn test_default_max_tokens() {
228        assert_eq!(default_max_tokens(), 8000);
229    }
230
231    #[test]
232    fn test_discover_params_deserialization() {
233        let json = serde_json::json!({"query": "foo"});
234        let params: DiscoverParams = serde_json::from_value(json).unwrap();
235        assert_eq!(params.query, "foo");
236        assert_eq!(params.limit, 10); // default
237    }
238
239    #[test]
240    fn test_impact_params_deserialization() {
241        let json = serde_json::json!({"node": "main", "depth": 5});
242        let params: ImpactParams = serde_json::from_value(json).unwrap();
243        assert_eq!(params.node, "main");
244        assert_eq!(params.depth, 5);
245    }
246
247    #[test]
248    fn test_search_params_with_kind_filter() {
249        let json = serde_json::json!({"query": "user", "kind": "function", "limit": 20});
250        let params: SearchParams = serde_json::from_value(json).unwrap();
251        assert_eq!(params.query, "user");
252        assert_eq!(params.kind.as_deref(), Some("function"));
253        assert_eq!(params.limit, 20);
254    }
255
256    #[test]
257    fn test_response_serialization_roundtrip() {
258        let resp = Response::success(Some(serde_json::json!(1)), "hello");
259        let json = serde_json::to_string(&resp).unwrap();
260        assert!(json.contains("\"jsonrpc\":\"2.0\""));
261        assert!(json.contains("\"hello\""));
262        // Error field should be skipped
263        assert!(!json.contains("\"error\""));
264    }
265}