Skip to main content

codex_runtime/runtime/
rpc.rs

1use serde_json::Value;
2use std::sync::Arc;
3
4use crate::runtime::errors::{RpcError, RpcErrorObject};
5use crate::runtime::events::{JsonRpcId, MsgKind};
6use crate::runtime::id::{extract_item_id, extract_thread_id, extract_turn_id};
7
8#[derive(Clone, Debug, PartialEq, Eq)]
9pub struct ExtractedIds {
10    pub thread_id: Option<String>,
11    pub turn_id: Option<String>,
12    pub item_id: Option<String>,
13}
14
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct MsgMetadata {
17    pub kind: MsgKind,
18    pub response_id: Option<u64>,
19    pub rpc_id: Option<JsonRpcId>,
20    pub method: Option<Arc<str>>,
21    pub thread_id: Option<Arc<str>>,
22    pub turn_id: Option<Arc<str>>,
23    pub item_id: Option<Arc<str>>,
24}
25
26/// Classify a raw JSON message with constant-time key presence checks.
27/// Allocation: none. Complexity: O(1).
28pub fn classify_message(json: &Value) -> MsgKind {
29    let has_id = json.get("id").is_some();
30    let has_method = json.get("method").is_some();
31    let has_result = json.get("result").is_some();
32    let has_error = json.get("error").is_some();
33
34    classify_jsonrpc_shape(has_id, has_method, has_result, has_error)
35}
36
37/// Best-effort identifier extraction from known shallow JSON-RPC slots.
38/// Allocation: up to 3 Strings (only when ids exist). Complexity: O(1).
39pub fn extract_ids(json: &Value) -> ExtractedIds {
40    let meta = extract_message_metadata(json);
41
42    ExtractedIds {
43        thread_id: meta.thread_id.map(|s| s.to_string()),
44        turn_id: meta.turn_id.map(|s| s.to_string()),
45        item_id: meta.item_id.map(|s| s.to_string()),
46    }
47}
48
49/// Extract commonly used dispatch metadata in one pass over top-level keys.
50/// Allocation: owned method/id strings only when present. Complexity: O(1).
51pub fn extract_message_metadata(json: &Value) -> MsgMetadata {
52    let obj = json.as_object();
53    let id_value = obj.and_then(|value| value.get("id"));
54    let method_value = obj.and_then(|value| value.get("method"));
55    let result_value = obj.and_then(|value| value.get("result"));
56    let error_value = obj.and_then(|value| value.get("error"));
57
58    let has_id = id_value.is_some();
59    let has_method = method_value.is_some();
60    let has_result = result_value.is_some();
61    let has_error = error_value.is_some();
62    let kind = classify_jsonrpc_shape(has_id, has_method, has_result, has_error);
63
64    let method = method_value.and_then(Value::as_str).map(Arc::from);
65    let response_id = parse_response_rpc_id_value(id_value);
66    let rpc_id = parse_jsonrpc_id_value(id_value);
67
68    let roots = [
69        obj.and_then(|value| value.get("params")),
70        result_value,
71        error_value.and_then(|value| value.get("data")),
72        Some(json),
73    ];
74
75    let mut thread_id = None;
76    let mut turn_id = None;
77    let mut item_id = None;
78    for root in roots.into_iter().flatten() {
79        if thread_id.is_none() {
80            thread_id = extract_thread_id(root).map(Arc::from);
81        }
82        if turn_id.is_none() {
83            turn_id = extract_turn_id(root).map(Arc::from);
84        }
85        if item_id.is_none() {
86            item_id = extract_item_id(root).map(Arc::from);
87        }
88        if thread_id.is_some() && turn_id.is_some() && item_id.is_some() {
89            break;
90        }
91    }
92
93    MsgMetadata {
94        kind,
95        response_id,
96        rpc_id,
97        method,
98        thread_id,
99        turn_id,
100        item_id,
101    }
102}
103
104/// Map a JSON-RPC error object into a typed error enum.
105/// Allocation: message clone + optional data clone. Complexity: O(1).
106pub fn map_rpc_error(json_error: &Value) -> RpcError {
107    let code = json_error.get("code").and_then(Value::as_i64);
108    let message = json_error
109        .get("message")
110        .and_then(Value::as_str)
111        .unwrap_or("unknown rpc error")
112        .to_owned();
113    let data = json_error.get("data").cloned();
114
115    match code {
116        Some(-32001) => RpcError::Overloaded,
117        Some(-32600) => RpcError::InvalidRequest(message),
118        Some(-32601) => RpcError::MethodNotFound(message),
119        Some(code) => RpcError::ServerError(RpcErrorObject {
120            code,
121            message,
122            data,
123        }),
124        None => RpcError::InvalidRequest("invalid rpc error payload".to_owned()),
125    }
126}
127
128fn parse_response_rpc_id_value(id_value: Option<&Value>) -> Option<u64> {
129    match id_value {
130        Some(Value::Number(number)) => number.as_u64(),
131        Some(Value::String(text)) => text.parse::<u64>().ok(),
132        _ => None,
133    }
134}
135
136fn parse_jsonrpc_id_value(id_value: Option<&Value>) -> Option<JsonRpcId> {
137    match id_value {
138        Some(Value::Number(number)) => number.as_u64().map(JsonRpcId::Number),
139        Some(Value::String(text)) => Some(JsonRpcId::Text(text.clone())),
140        _ => None,
141    }
142}
143
144fn classify_jsonrpc_shape(
145    has_id: bool,
146    has_method: bool,
147    has_result: bool,
148    has_error: bool,
149) -> MsgKind {
150    if has_id && !has_method && (has_result || has_error) {
151        return MsgKind::Response;
152    }
153    if has_id && has_method && !has_result && !has_error {
154        return MsgKind::ServerRequest;
155    }
156    if has_method && !has_id {
157        return MsgKind::Notification;
158    }
159    MsgKind::Unknown
160}
161
162#[cfg(test)]
163mod tests {
164    use serde_json::json;
165
166    use super::*;
167
168    #[test]
169    fn classify_response() {
170        let v = json!({"id":1,"result":{}});
171        assert_eq!(classify_message(&v), MsgKind::Response);
172    }
173
174    #[test]
175    fn classify_server_request() {
176        let v = json!({"id":2,"method":"item/fileChange/requestApproval","params":{}});
177        assert_eq!(classify_message(&v), MsgKind::ServerRequest);
178    }
179
180    #[test]
181    fn classify_notification() {
182        let v = json!({"method":"turn/started","params":{}});
183        assert_eq!(classify_message(&v), MsgKind::Notification);
184    }
185
186    #[test]
187    fn classify_unknown() {
188        let v = json!({"foo":"bar"});
189        assert_eq!(classify_message(&v), MsgKind::Unknown);
190    }
191
192    #[test]
193    fn extract_ids_prefers_params() {
194        let v = json!({
195            "params": {
196                "threadId": "thr_1",
197                "turnId": "turn_1",
198                "itemId": "item_1"
199            }
200        });
201        let ids = extract_ids(&v);
202        assert_eq!(ids.thread_id.as_deref(), Some("thr_1"));
203        assert_eq!(ids.turn_id.as_deref(), Some("turn_1"));
204        assert_eq!(ids.item_id.as_deref(), Some("item_1"));
205    }
206
207    #[test]
208    fn extract_ids_supports_nested_struct_ids() {
209        let v = json!({
210            "params": {
211                "thread": {"id": "thr_nested"},
212                "turn": {"id": "turn_nested"},
213                "item": {"id": "item_nested"}
214            }
215        });
216        let ids = extract_ids(&v);
217        assert_eq!(ids.thread_id.as_deref(), Some("thr_nested"));
218        assert_eq!(ids.turn_id.as_deref(), Some("turn_nested"));
219        assert_eq!(ids.item_id.as_deref(), Some("item_nested"));
220    }
221
222    #[test]
223    fn extract_ids_ignores_legacy_conversation_id() {
224        let v = json!({
225            "params": {
226                "conversationId": "thr_conv"
227            }
228        });
229        let ids = extract_ids(&v);
230        assert_eq!(ids.thread_id, None);
231        assert_eq!(ids.turn_id, None);
232        assert_eq!(ids.item_id, None);
233    }
234
235    #[test]
236    fn extract_ids_rejects_non_canonical_id_values() {
237        let v = json!({
238            "params": {
239                "threadId": " thr_space ",
240                "turn": {"id": ""},
241                "itemId": "item_ok"
242            }
243        });
244        let ids = extract_ids(&v);
245        assert_eq!(ids.thread_id, None);
246        assert_eq!(ids.turn_id, None);
247        assert_eq!(ids.item_id.as_deref(), Some("item_ok"));
248    }
249
250    #[test]
251    fn map_overloaded_error() {
252        let v = json!({"code": -32001, "message": "ingress overload"});
253        assert_eq!(map_rpc_error(&v), RpcError::Overloaded);
254    }
255
256    #[test]
257    fn extract_message_metadata_matches_legacy_helpers() {
258        let fixtures = vec![
259            json!({
260                "id": 1,
261                "result": {
262                    "thread": {"id": "thr_1"},
263                    "turn": {"id": "turn_1"},
264                    "item": {"id": "item_1"}
265                }
266            }),
267            json!({
268                "id": "42",
269                "method": "item/fileChange/requestApproval",
270                "params": {
271                    "threadId": "thr_2",
272                    "turnId": "turn_2",
273                    "itemId": "item_2"
274                }
275            }),
276            json!({
277                "method": "turn/started",
278                "params": {
279                    "thread": {"id": "thr_3"},
280                    "turn": {"id": "turn_3"}
281                }
282            }),
283        ];
284
285        for fixture in fixtures {
286            let meta = extract_message_metadata(&fixture);
287            let ids = extract_ids(&fixture);
288
289            assert_eq!(meta.kind, classify_message(&fixture));
290            assert_eq!(meta.thread_id.map(|s| s.to_string()), ids.thread_id);
291            assert_eq!(meta.turn_id.map(|s| s.to_string()), ids.turn_id);
292            assert_eq!(meta.item_id.map(|s| s.to_string()), ids.item_id);
293        }
294    }
295}