1use serde_json::Value;
2use std::sync::Arc;
3
4use crate::runtime::errors::{RpcError, RpcErrorObject};
5use crate::runtime::events::{JsonRpcId, MsgKind};
6use crate::runtime::id::{extract_item_id, extract_thread_id, extract_turn_id};
7
8#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct ExtractedIds {
10 pub thread_id: Option<String>,
11 pub turn_id: Option<String>,
12 pub item_id: Option<String>,
13}
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct MsgMetadata {
17 pub kind: MsgKind,
18 pub response_id: Option<u64>,
19 pub rpc_id: Option<JsonRpcId>,
20 pub method: Option<Arc<str>>,
21 pub thread_id: Option<Arc<str>>,
22 pub turn_id: Option<Arc<str>>,
23 pub item_id: Option<Arc<str>>,
24}
25
26pub fn classify_message(json: &Value) -> MsgKind {
29 let has_id = json.get("id").is_some();
30 let has_method = json.get("method").is_some();
31 let has_result = json.get("result").is_some();
32 let has_error = json.get("error").is_some();
33
34 classify_jsonrpc_shape(has_id, has_method, has_result, has_error)
35}
36
37pub fn extract_ids(json: &Value) -> ExtractedIds {
40 let meta = extract_message_metadata(json);
41
42 ExtractedIds {
43 thread_id: meta.thread_id.map(|s| s.to_string()),
44 turn_id: meta.turn_id.map(|s| s.to_string()),
45 item_id: meta.item_id.map(|s| s.to_string()),
46 }
47}
48
49pub fn extract_message_metadata(json: &Value) -> MsgMetadata {
52 let obj = json.as_object();
53 let id_value = obj.and_then(|value| value.get("id"));
54 let method_value = obj.and_then(|value| value.get("method"));
55 let result_value = obj.and_then(|value| value.get("result"));
56 let error_value = obj.and_then(|value| value.get("error"));
57
58 let has_id = id_value.is_some();
59 let has_method = method_value.is_some();
60 let has_result = result_value.is_some();
61 let has_error = error_value.is_some();
62 let kind = classify_jsonrpc_shape(has_id, has_method, has_result, has_error);
63
64 let method = method_value.and_then(Value::as_str).map(Arc::from);
65 let response_id = parse_response_rpc_id_value(id_value);
66 let rpc_id = parse_jsonrpc_id_value(id_value);
67
68 let roots = [
69 obj.and_then(|value| value.get("params")),
70 result_value,
71 error_value.and_then(|value| value.get("data")),
72 Some(json),
73 ];
74
75 let mut thread_id = None;
76 let mut turn_id = None;
77 let mut item_id = None;
78 for root in roots.into_iter().flatten() {
79 if thread_id.is_none() {
80 thread_id = extract_thread_id(root).map(Arc::from);
81 }
82 if turn_id.is_none() {
83 turn_id = extract_turn_id(root).map(Arc::from);
84 }
85 if item_id.is_none() {
86 item_id = extract_item_id(root).map(Arc::from);
87 }
88 if thread_id.is_some() && turn_id.is_some() && item_id.is_some() {
89 break;
90 }
91 }
92
93 MsgMetadata {
94 kind,
95 response_id,
96 rpc_id,
97 method,
98 thread_id,
99 turn_id,
100 item_id,
101 }
102}
103
104pub fn map_rpc_error(json_error: &Value) -> RpcError {
107 let code = json_error.get("code").and_then(Value::as_i64);
108 let message = json_error
109 .get("message")
110 .and_then(Value::as_str)
111 .unwrap_or("unknown rpc error")
112 .to_owned();
113 let data = json_error.get("data").cloned();
114
115 match code {
116 Some(-32001) => RpcError::Overloaded,
117 Some(-32600) => RpcError::InvalidRequest(message),
118 Some(-32601) => RpcError::MethodNotFound(message),
119 Some(code) => RpcError::ServerError(RpcErrorObject {
120 code,
121 message,
122 data,
123 }),
124 None => RpcError::InvalidRequest("invalid rpc error payload".to_owned()),
125 }
126}
127
128fn parse_response_rpc_id_value(id_value: Option<&Value>) -> Option<u64> {
129 match id_value {
130 Some(Value::Number(number)) => number.as_u64(),
131 Some(Value::String(text)) => text.parse::<u64>().ok(),
132 _ => None,
133 }
134}
135
136fn parse_jsonrpc_id_value(id_value: Option<&Value>) -> Option<JsonRpcId> {
137 match id_value {
138 Some(Value::Number(number)) => number.as_u64().map(JsonRpcId::Number),
139 Some(Value::String(text)) => Some(JsonRpcId::Text(text.clone())),
140 _ => None,
141 }
142}
143
144fn classify_jsonrpc_shape(
145 has_id: bool,
146 has_method: bool,
147 has_result: bool,
148 has_error: bool,
149) -> MsgKind {
150 if has_id && !has_method && (has_result || has_error) {
151 return MsgKind::Response;
152 }
153 if has_id && has_method && !has_result && !has_error {
154 return MsgKind::ServerRequest;
155 }
156 if has_method && !has_id {
157 return MsgKind::Notification;
158 }
159 MsgKind::Unknown
160}
161
162#[cfg(test)]
163mod tests {
164 use serde_json::json;
165
166 use super::*;
167
168 #[test]
169 fn classify_response() {
170 let v = json!({"id":1,"result":{}});
171 assert_eq!(classify_message(&v), MsgKind::Response);
172 }
173
174 #[test]
175 fn classify_server_request() {
176 let v = json!({"id":2,"method":"item/fileChange/requestApproval","params":{}});
177 assert_eq!(classify_message(&v), MsgKind::ServerRequest);
178 }
179
180 #[test]
181 fn classify_notification() {
182 let v = json!({"method":"turn/started","params":{}});
183 assert_eq!(classify_message(&v), MsgKind::Notification);
184 }
185
186 #[test]
187 fn classify_unknown() {
188 let v = json!({"foo":"bar"});
189 assert_eq!(classify_message(&v), MsgKind::Unknown);
190 }
191
192 #[test]
193 fn extract_ids_prefers_params() {
194 let v = json!({
195 "params": {
196 "threadId": "thr_1",
197 "turnId": "turn_1",
198 "itemId": "item_1"
199 }
200 });
201 let ids = extract_ids(&v);
202 assert_eq!(ids.thread_id.as_deref(), Some("thr_1"));
203 assert_eq!(ids.turn_id.as_deref(), Some("turn_1"));
204 assert_eq!(ids.item_id.as_deref(), Some("item_1"));
205 }
206
207 #[test]
208 fn extract_ids_supports_nested_struct_ids() {
209 let v = json!({
210 "params": {
211 "thread": {"id": "thr_nested"},
212 "turn": {"id": "turn_nested"},
213 "item": {"id": "item_nested"}
214 }
215 });
216 let ids = extract_ids(&v);
217 assert_eq!(ids.thread_id.as_deref(), Some("thr_nested"));
218 assert_eq!(ids.turn_id.as_deref(), Some("turn_nested"));
219 assert_eq!(ids.item_id.as_deref(), Some("item_nested"));
220 }
221
222 #[test]
223 fn extract_ids_ignores_legacy_conversation_id() {
224 let v = json!({
225 "params": {
226 "conversationId": "thr_conv"
227 }
228 });
229 let ids = extract_ids(&v);
230 assert_eq!(ids.thread_id, None);
231 assert_eq!(ids.turn_id, None);
232 assert_eq!(ids.item_id, None);
233 }
234
235 #[test]
236 fn extract_ids_rejects_non_canonical_id_values() {
237 let v = json!({
238 "params": {
239 "threadId": " thr_space ",
240 "turn": {"id": ""},
241 "itemId": "item_ok"
242 }
243 });
244 let ids = extract_ids(&v);
245 assert_eq!(ids.thread_id, None);
246 assert_eq!(ids.turn_id, None);
247 assert_eq!(ids.item_id.as_deref(), Some("item_ok"));
248 }
249
250 #[test]
251 fn map_overloaded_error() {
252 let v = json!({"code": -32001, "message": "ingress overload"});
253 assert_eq!(map_rpc_error(&v), RpcError::Overloaded);
254 }
255
256 #[test]
257 fn extract_message_metadata_matches_legacy_helpers() {
258 let fixtures = vec![
259 json!({
260 "id": 1,
261 "result": {
262 "thread": {"id": "thr_1"},
263 "turn": {"id": "turn_1"},
264 "item": {"id": "item_1"}
265 }
266 }),
267 json!({
268 "id": "42",
269 "method": "item/fileChange/requestApproval",
270 "params": {
271 "threadId": "thr_2",
272 "turnId": "turn_2",
273 "itemId": "item_2"
274 }
275 }),
276 json!({
277 "method": "turn/started",
278 "params": {
279 "thread": {"id": "thr_3"},
280 "turn": {"id": "turn_3"}
281 }
282 }),
283 ];
284
285 for fixture in fixtures {
286 let meta = extract_message_metadata(&fixture);
287 let ids = extract_ids(&fixture);
288
289 assert_eq!(meta.kind, classify_message(&fixture));
290 assert_eq!(meta.thread_id.map(|s| s.to_string()), ids.thread_id);
291 assert_eq!(meta.turn_id.map(|s| s.to_string()), ids.turn_id);
292 assert_eq!(meta.item_id.map(|s| s.to_string()), ids.item_id);
293 }
294 }
295}