Skip to main content

codex_runtime/runtime/
rpc_contract.rs

1use serde_json::Value;
2
3use crate::runtime::api::summarize_sandbox_policy_wire_value;
4use crate::runtime::errors::RpcError;
5use crate::runtime::turn_output::{parse_thread_id, parse_turn_id};
6
7/// Canonical method catalog shared by facade constants and known-method validation.
8pub mod methods {
9    pub const THREAD_START: &str = "thread/start";
10    pub const THREAD_RESUME: &str = "thread/resume";
11    pub const THREAD_FORK: &str = "thread/fork";
12    pub const THREAD_ARCHIVE: &str = "thread/archive";
13    pub const THREAD_READ: &str = "thread/read";
14    pub const THREAD_LIST: &str = "thread/list";
15    pub const THREAD_LOADED_LIST: &str = "thread/loaded/list";
16    pub const THREAD_ROLLBACK: &str = "thread/rollback";
17    pub const SKILLS_LIST: &str = "skills/list";
18    pub const COMMAND_EXEC: &str = "command/exec";
19    pub const COMMAND_EXEC_WRITE: &str = "command/exec/write";
20    pub const COMMAND_EXEC_TERMINATE: &str = "command/exec/terminate";
21    pub const COMMAND_EXEC_RESIZE: &str = "command/exec/resize";
22    pub const TURN_START: &str = "turn/start";
23    pub const TURN_INTERRUPT: &str = "turn/interrupt";
24
25    // Server-request methods (runtime inbound requests requiring a client response)
26    pub const ITEM_COMMAND_EXECUTION_REQUEST_APPROVAL: &str =
27        "item/commandExecution/requestApproval";
28    pub const ITEM_FILE_CHANGE_REQUEST_APPROVAL: &str = "item/fileChange/requestApproval";
29    pub const ITEM_TOOL_REQUEST_USER_INPUT: &str = "item/tool/requestUserInput";
30    pub const ITEM_TOOL_CALL: &str = "item/tool/call";
31    pub const ACCOUNT_CHATGPT_AUTH_TOKENS_REFRESH: &str = "account/chatgptAuthTokens/refresh";
32
33    // Server-pushed notification events (not client requests)
34    pub const THREAD_STARTED: &str = "thread/started";
35    pub const TURN_STARTED: &str = "turn/started";
36    pub const TURN_COMPLETED: &str = "turn/completed";
37    pub const TURN_FAILED: &str = "turn/failed";
38    pub const TURN_CANCELLED: &str = "turn/cancelled";
39    pub const TURN_INTERRUPTED: &str = "turn/interrupted";
40    pub const TURN_DIFF_UPDATED: &str = "turn/diff/updated";
41    pub const TURN_PLAN_UPDATED: &str = "turn/plan/updated";
42    pub const ITEM_STARTED: &str = "item/started";
43    pub const ITEM_AGENT_MESSAGE_DELTA: &str = "item/agentMessage/delta";
44    pub const ITEM_COMMAND_EXECUTION_OUTPUT_DELTA: &str = "item/commandExecution/outputDelta";
45    pub const COMMAND_EXEC_OUTPUT_DELTA: &str = "command/exec/outputDelta";
46    pub const ITEM_COMPLETED: &str = "item/completed";
47    pub const APPROVAL_ACK: &str = "approval/ack";
48    pub const SKILLS_CHANGED: &str = "skills/changed";
49
50    pub const KNOWN: [&str; 15] = [
51        THREAD_START,
52        THREAD_RESUME,
53        THREAD_FORK,
54        THREAD_ARCHIVE,
55        THREAD_READ,
56        THREAD_LIST,
57        THREAD_LOADED_LIST,
58        THREAD_ROLLBACK,
59        SKILLS_LIST,
60        COMMAND_EXEC,
61        COMMAND_EXEC_WRITE,
62        COMMAND_EXEC_TERMINATE,
63        COMMAND_EXEC_RESIZE,
64        TURN_START,
65        TURN_INTERRUPT,
66    ];
67}
68
69/// Validation mode for JSON-RPC data integrity checks.
70#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
71pub enum RpcValidationMode {
72    /// Skip all contract checks.
73    None,
74    /// Validate only methods known to the current app-server contract.
75    #[default]
76    KnownMethods,
77}
78
79/// Request-shape rule for one RPC method contract descriptor.
80#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81pub enum RpcRequestContract {
82    Object,
83    ThreadStart,
84    ThreadId,
85    ThreadIdAndTurnId,
86    ProcessId,
87    CommandExec,
88    CommandExecWrite,
89    CommandExecResize,
90}
91
92/// Response-shape rule for one RPC method contract descriptor.
93#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94pub enum RpcResponseContract {
95    Object,
96    ThreadId,
97    TurnId,
98    DataArray,
99    CommandExec,
100}
101
102/// Single-source descriptor for one app-server RPC contract method.
103#[derive(Clone, Copy, Debug, PartialEq, Eq)]
104pub struct RpcContractDescriptor {
105    pub method: &'static str,
106    pub request: RpcRequestContract,
107    pub response: RpcResponseContract,
108}
109
110const FIELD_PARAMS: &str = "params";
111const FIELD_RESULT: &str = "result";
112const FIELD_PARAMS_SANDBOX_POLICY: &str = "params.sandboxPolicy";
113const KEY_DATA: &str = "data";
114const KEY_PROCESS_ID: &str = "processId";
115const KEY_SIZE: &str = "size";
116
117const RPC_CONTRACT_DESCRIPTORS: [RpcContractDescriptor; 15] = [
118    RpcContractDescriptor {
119        method: methods::THREAD_START,
120        request: RpcRequestContract::ThreadStart,
121        response: RpcResponseContract::ThreadId,
122    },
123    RpcContractDescriptor {
124        method: methods::THREAD_RESUME,
125        request: RpcRequestContract::ThreadId,
126        response: RpcResponseContract::ThreadId,
127    },
128    RpcContractDescriptor {
129        method: methods::THREAD_FORK,
130        request: RpcRequestContract::ThreadId,
131        response: RpcResponseContract::ThreadId,
132    },
133    RpcContractDescriptor {
134        method: methods::THREAD_ARCHIVE,
135        request: RpcRequestContract::ThreadId,
136        response: RpcResponseContract::Object,
137    },
138    RpcContractDescriptor {
139        method: methods::THREAD_READ,
140        request: RpcRequestContract::ThreadId,
141        response: RpcResponseContract::ThreadId,
142    },
143    RpcContractDescriptor {
144        method: methods::THREAD_LIST,
145        request: RpcRequestContract::Object,
146        response: RpcResponseContract::DataArray,
147    },
148    RpcContractDescriptor {
149        method: methods::THREAD_LOADED_LIST,
150        request: RpcRequestContract::Object,
151        response: RpcResponseContract::DataArray,
152    },
153    RpcContractDescriptor {
154        method: methods::THREAD_ROLLBACK,
155        request: RpcRequestContract::ThreadId,
156        response: RpcResponseContract::ThreadId,
157    },
158    RpcContractDescriptor {
159        method: methods::SKILLS_LIST,
160        request: RpcRequestContract::Object,
161        response: RpcResponseContract::DataArray,
162    },
163    RpcContractDescriptor {
164        method: methods::COMMAND_EXEC,
165        request: RpcRequestContract::CommandExec,
166        response: RpcResponseContract::CommandExec,
167    },
168    RpcContractDescriptor {
169        method: methods::COMMAND_EXEC_WRITE,
170        request: RpcRequestContract::CommandExecWrite,
171        response: RpcResponseContract::Object,
172    },
173    RpcContractDescriptor {
174        method: methods::COMMAND_EXEC_TERMINATE,
175        request: RpcRequestContract::ProcessId,
176        response: RpcResponseContract::Object,
177    },
178    RpcContractDescriptor {
179        method: methods::COMMAND_EXEC_RESIZE,
180        request: RpcRequestContract::CommandExecResize,
181        response: RpcResponseContract::Object,
182    },
183    RpcContractDescriptor {
184        method: methods::TURN_START,
185        request: RpcRequestContract::ThreadId,
186        response: RpcResponseContract::TurnId,
187    },
188    RpcContractDescriptor {
189        method: methods::TURN_INTERRUPT,
190        request: RpcRequestContract::ThreadIdAndTurnId,
191        response: RpcResponseContract::Object,
192    },
193];
194
195/// Canonical RPC contract descriptor list (single source of truth).
196pub fn rpc_contract_descriptors() -> &'static [RpcContractDescriptor] {
197    &RPC_CONTRACT_DESCRIPTORS
198}
199
200/// Contract descriptor for one method, when the method is known.
201pub fn rpc_contract_descriptor(method: &str) -> Option<&'static RpcContractDescriptor> {
202    RPC_CONTRACT_DESCRIPTORS
203        .iter()
204        .find(|descriptor| descriptor.method == method)
205}
206
207/// Validate outgoing JSON-RPC request payload for one method.
208///
209/// - Always validates that method name is non-empty.
210/// - In `KnownMethods` mode, validates request shape for known methods.
211pub fn validate_rpc_request(
212    method: &str,
213    params: &Value,
214    mode: RpcValidationMode,
215) -> Result<(), RpcError> {
216    validate_method_name(method)?;
217
218    if mode == RpcValidationMode::None {
219        return Ok(());
220    }
221
222    match rpc_contract_descriptor(method) {
223        Some(descriptor) => validate_request_by_descriptor(method, params, *descriptor),
224        None => Ok(()),
225    }
226}
227
228/// Validate incoming JSON-RPC result payload for one method.
229///
230/// In `KnownMethods` mode this enforces minimum shape invariants for known methods.
231pub fn validate_rpc_response(
232    method: &str,
233    result: &Value,
234    mode: RpcValidationMode,
235) -> Result<(), RpcError> {
236    validate_method_name(method)?;
237
238    if mode == RpcValidationMode::None {
239        return Ok(());
240    }
241
242    match rpc_contract_descriptor(method) {
243        Some(descriptor) => validate_response_by_descriptor(method, result, *descriptor),
244        None => Ok(()),
245    }
246}
247
248#[derive(Clone, Copy, Debug, PartialEq, Eq)]
249enum RpcContractSurface {
250    Request,
251    Response,
252}
253
254#[derive(Clone, Debug, PartialEq, Eq)]
255enum RpcContractViolation {
256    EmptyMethod,
257    FieldMustBeObject { field_name: String },
258    FieldMustBeNonEmptyString { field_name: String, key: String },
259    MissingThreadId,
260    MissingTurnId,
261    ResultDataMustBeArray,
262    CommandMustBeArray,
263    CommandMustNotBeEmpty,
264    CommandItemsMustBeStrings,
265    ProcessIdRequiredForStreaming,
266    DisableOutputCapConflictsWithOutputBytesCap,
267    DisableTimeoutConflictsWithTimeoutMs,
268    TimeoutMsMustBeNonNegative,
269    OutputBytesCapMustBePositive,
270    SizeRequiresTty,
271    SizeMustBeObject,
272    SizeRowsMustBePositive,
273    SizeColsMustBePositive,
274    WriteRequestMustIncludeDeltaOrCloseStdin,
275    ExitCodeMustBeI32CompatibleInteger,
276    StdoutMustBeString,
277    StderrMustBeString,
278    ParamsFieldMustBeString { key: String },
279    Custom(String),
280}
281
282impl RpcContractViolation {
283    fn reason(&self) -> String {
284        match self {
285            Self::EmptyMethod => "json-rpc method must not be empty".to_owned(),
286            Self::FieldMustBeObject { field_name } => format!("{field_name} must be an object"),
287            Self::FieldMustBeNonEmptyString { field_name, key } => {
288                format!("{field_name}.{key} must be a non-empty string")
289            }
290            Self::MissingThreadId => "result is missing thread id".to_owned(),
291            Self::MissingTurnId => "result is missing turn id".to_owned(),
292            Self::ResultDataMustBeArray => "result.data must be an array".to_owned(),
293            Self::CommandMustBeArray => "params.command must be an array".to_owned(),
294            Self::CommandMustNotBeEmpty => "params.command must not be empty".to_owned(),
295            Self::CommandItemsMustBeStrings => "params.command items must be strings".to_owned(),
296            Self::ProcessIdRequiredForStreaming => {
297                "params.processId is required when tty or streaming is enabled".to_owned()
298            }
299            Self::DisableOutputCapConflictsWithOutputBytesCap => {
300                "params.disableOutputCap cannot be combined with params.outputBytesCap".to_owned()
301            }
302            Self::DisableTimeoutConflictsWithTimeoutMs => {
303                "params.disableTimeout cannot be combined with params.timeoutMs".to_owned()
304            }
305            Self::TimeoutMsMustBeNonNegative => "params.timeoutMs must be >= 0".to_owned(),
306            Self::OutputBytesCapMustBePositive => "params.outputBytesCap must be > 0".to_owned(),
307            Self::SizeRequiresTty => "params.size is only valid when params.tty is true".to_owned(),
308            Self::SizeMustBeObject => "params.size must be an object".to_owned(),
309            Self::SizeRowsMustBePositive => "params.size.rows must be > 0".to_owned(),
310            Self::SizeColsMustBePositive => "params.size.cols must be > 0".to_owned(),
311            Self::WriteRequestMustIncludeDeltaOrCloseStdin => {
312                "params must include deltaBase64, closeStdin, or both".to_owned()
313            }
314            Self::ExitCodeMustBeI32CompatibleInteger => {
315                "result.exitCode must be an i32-compatible integer".to_owned()
316            }
317            Self::StdoutMustBeString => "result.stdout must be a string".to_owned(),
318            Self::StderrMustBeString => "result.stderr must be a string".to_owned(),
319            Self::ParamsFieldMustBeString { key } => format!("params.{key} must be a string"),
320            Self::Custom(reason) => reason.clone(),
321        }
322    }
323}
324
325fn validate_request_by_descriptor(
326    method: &str,
327    params: &Value,
328    descriptor: RpcContractDescriptor,
329) -> Result<(), RpcError> {
330    match descriptor.request {
331        RpcRequestContract::Object => {
332            require_object(params, method, FIELD_PARAMS)?;
333            Ok(())
334        }
335        RpcRequestContract::ThreadStart => validate_thread_start_request(params, method),
336        RpcRequestContract::ThreadId => require_string(params, method, "threadId", FIELD_PARAMS),
337        RpcRequestContract::ThreadIdAndTurnId => {
338            require_string(params, method, "threadId", FIELD_PARAMS)?;
339            require_string(params, method, "turnId", FIELD_PARAMS)
340        }
341        RpcRequestContract::ProcessId => {
342            require_string(params, method, KEY_PROCESS_ID, FIELD_PARAMS)
343        }
344        RpcRequestContract::CommandExec => validate_command_exec_request(params, method),
345        RpcRequestContract::CommandExecWrite => validate_command_exec_write_request(params, method),
346        RpcRequestContract::CommandExecResize => {
347            validate_command_exec_resize_request(params, method)
348        }
349    }
350}
351
352fn validate_response_by_descriptor(
353    method: &str,
354    result: &Value,
355    descriptor: RpcContractDescriptor,
356) -> Result<(), RpcError> {
357    match descriptor.response {
358        RpcResponseContract::Object => {
359            require_response_object(result, method, FIELD_RESULT)?;
360            Ok(())
361        }
362        RpcResponseContract::ThreadId => {
363            if parse_thread_id(result).is_none() {
364                Err(project_contract_violation(
365                    method,
366                    RpcContractSurface::Response,
367                    &RpcContractViolation::MissingThreadId,
368                    result,
369                ))
370            } else {
371                Ok(())
372            }
373        }
374        RpcResponseContract::TurnId => {
375            if parse_turn_id(result).is_none() {
376                Err(project_contract_violation(
377                    method,
378                    RpcContractSurface::Response,
379                    &RpcContractViolation::MissingTurnId,
380                    result,
381                ))
382            } else {
383                Ok(())
384            }
385        }
386        RpcResponseContract::DataArray => {
387            let obj = require_response_object(result, method, FIELD_RESULT)?;
388            match obj.get(KEY_DATA) {
389                Some(Value::Array(_)) => Ok(()),
390                _ => Err(project_contract_violation(
391                    method,
392                    RpcContractSurface::Response,
393                    &RpcContractViolation::ResultDataMustBeArray,
394                    result,
395                )),
396            }
397        }
398        RpcResponseContract::CommandExec => validate_command_exec_response(result, method),
399    }
400}
401
402fn validate_method_name(method: &str) -> Result<(), RpcError> {
403    if method.trim().is_empty() {
404        return Err(project_contract_violation(
405            method,
406            RpcContractSurface::Request,
407            &RpcContractViolation::EmptyMethod,
408            &Value::Null,
409        ));
410    }
411    Ok(())
412}
413
414fn require_object<'a>(
415    value: &'a Value,
416    method: &str,
417    field_name: &str,
418) -> Result<&'a serde_json::Map<String, Value>, RpcError> {
419    require_object_on(RpcContractSurface::Request, value, method, field_name)
420}
421
422fn require_response_object<'a>(
423    value: &'a Value,
424    method: &str,
425    field_name: &str,
426) -> Result<&'a serde_json::Map<String, Value>, RpcError> {
427    require_object_on(RpcContractSurface::Response, value, method, field_name)
428}
429
430fn require_object_on<'a>(
431    surface: RpcContractSurface,
432    value: &'a Value,
433    method: &str,
434    field_name: &str,
435) -> Result<&'a serde_json::Map<String, Value>, RpcError> {
436    value.as_object().ok_or_else(|| {
437        project_contract_violation(
438            method,
439            surface,
440            &RpcContractViolation::FieldMustBeObject {
441                field_name: field_name.to_owned(),
442            },
443            value,
444        )
445    })
446}
447
448fn require_string(
449    value: &Value,
450    method: &str,
451    key: &str,
452    field_name: &str,
453) -> Result<(), RpcError> {
454    let obj = require_object(value, method, field_name)?;
455    match obj.get(key).and_then(Value::as_str) {
456        Some(v) if !v.trim().is_empty() => Ok(()),
457        _ => Err(project_contract_violation(
458            method,
459            RpcContractSurface::Request,
460            &RpcContractViolation::FieldMustBeNonEmptyString {
461                field_name: field_name.to_owned(),
462                key: key.to_owned(),
463            },
464            value,
465        )),
466    }
467}
468
469fn validate_thread_start_request(params: &Value, method: &str) -> Result<(), RpcError> {
470    require_object(params, method, FIELD_PARAMS)?;
471    Ok(())
472}
473
474fn validate_command_exec_request(params: &Value, method: &str) -> Result<(), RpcError> {
475    let obj = require_object(params, method, FIELD_PARAMS)?;
476    let command = obj
477        .get("command")
478        .and_then(Value::as_array)
479        .ok_or_else(|| {
480            project_contract_violation(
481                method,
482                RpcContractSurface::Request,
483                &RpcContractViolation::CommandMustBeArray,
484                params,
485            )
486        })?;
487    if command.is_empty() {
488        return Err(project_contract_violation(
489            method,
490            RpcContractSurface::Request,
491            &RpcContractViolation::CommandMustNotBeEmpty,
492            params,
493        ));
494    }
495    if command.iter().any(|value| value.as_str().is_none()) {
496        return Err(project_contract_violation(
497            method,
498            RpcContractSurface::Request,
499            &RpcContractViolation::CommandItemsMustBeStrings,
500            params,
501        ));
502    }
503
504    let process_id = get_optional_non_empty_string(obj, KEY_PROCESS_ID).map_err(|violation| {
505        project_contract_violation(method, RpcContractSurface::Request, &violation, params)
506    })?;
507    let tty = get_bool(obj, "tty");
508    let stream_stdin = get_bool(obj, "streamStdin");
509    let stream_stdout_stderr = get_bool(obj, "streamStdoutStderr");
510    let effective_stream_stdin = tty || stream_stdin;
511    let effective_stream_stdout_stderr = tty || stream_stdout_stderr;
512
513    if (tty || effective_stream_stdin || effective_stream_stdout_stderr) && process_id.is_none() {
514        return Err(project_contract_violation(
515            method,
516            RpcContractSurface::Request,
517            &RpcContractViolation::ProcessIdRequiredForStreaming,
518            params,
519        ));
520    }
521    if get_bool(obj, "disableOutputCap") && obj.get("outputBytesCap").is_some() {
522        return Err(project_contract_violation(
523            method,
524            RpcContractSurface::Request,
525            &RpcContractViolation::DisableOutputCapConflictsWithOutputBytesCap,
526            params,
527        ));
528    }
529    if get_bool(obj, "disableTimeout") && obj.get("timeoutMs").is_some() {
530        return Err(project_contract_violation(
531            method,
532            RpcContractSurface::Request,
533            &RpcContractViolation::DisableTimeoutConflictsWithTimeoutMs,
534            params,
535        ));
536    }
537    if let Some(timeout_ms) = obj.get("timeoutMs").and_then(Value::as_i64) {
538        if timeout_ms < 0 {
539            return Err(project_contract_violation(
540                method,
541                RpcContractSurface::Request,
542                &RpcContractViolation::TimeoutMsMustBeNonNegative,
543                params,
544            ));
545        }
546    }
547    if let Some(output_bytes_cap) = obj.get("outputBytesCap").and_then(Value::as_u64) {
548        if output_bytes_cap == 0 {
549            return Err(project_contract_violation(
550                method,
551                RpcContractSurface::Request,
552                &RpcContractViolation::OutputBytesCapMustBePositive,
553                params,
554            ));
555        }
556    }
557    if let Some(size) = obj.get(KEY_SIZE) {
558        if !tty {
559            return Err(project_contract_violation(
560                method,
561                RpcContractSurface::Request,
562                &RpcContractViolation::SizeRequiresTty,
563                params,
564            ));
565        }
566        validate_command_exec_size(size, method, params)?;
567    }
568    if let Some(sandbox_policy) = obj.get("sandboxPolicy") {
569        summarize_sandbox_policy_wire_value(sandbox_policy, FIELD_PARAMS_SANDBOX_POLICY)
570            .map_err(|reason| invalid_request(method, &reason, params))?;
571    }
572
573    Ok(())
574}
575
576fn validate_command_exec_write_request(params: &Value, method: &str) -> Result<(), RpcError> {
577    require_string(params, method, KEY_PROCESS_ID, FIELD_PARAMS)?;
578    let obj = require_object(params, method, FIELD_PARAMS)?;
579    let has_delta = obj.get("deltaBase64").and_then(Value::as_str).is_some();
580    let close_stdin = get_bool(obj, "closeStdin");
581    if !has_delta && !close_stdin {
582        return Err(project_contract_violation(
583            method,
584            RpcContractSurface::Request,
585            &RpcContractViolation::WriteRequestMustIncludeDeltaOrCloseStdin,
586            params,
587        ));
588    }
589    Ok(())
590}
591
592fn validate_command_exec_resize_request(params: &Value, method: &str) -> Result<(), RpcError> {
593    require_string(params, method, KEY_PROCESS_ID, FIELD_PARAMS)?;
594    let obj = require_object(params, method, FIELD_PARAMS)?;
595    let size = obj.get(KEY_SIZE).ok_or_else(|| {
596        project_contract_violation(
597            method,
598            RpcContractSurface::Request,
599            &RpcContractViolation::SizeMustBeObject,
600            params,
601        )
602    })?;
603    validate_command_exec_size(size, method, params)
604}
605
606fn validate_command_exec_response(result: &Value, method: &str) -> Result<(), RpcError> {
607    let obj = require_response_object(result, method, FIELD_RESULT)?;
608    match obj.get("exitCode").and_then(Value::as_i64) {
609        Some(code) if i32::try_from(code).is_ok() => {}
610        _ => {
611            return Err(project_contract_violation(
612                method,
613                RpcContractSurface::Response,
614                &RpcContractViolation::ExitCodeMustBeI32CompatibleInteger,
615                result,
616            ));
617        }
618    }
619    if obj.get("stdout").and_then(Value::as_str).is_none() {
620        return Err(project_contract_violation(
621            method,
622            RpcContractSurface::Response,
623            &RpcContractViolation::StdoutMustBeString,
624            result,
625        ));
626    }
627    if obj.get("stderr").and_then(Value::as_str).is_none() {
628        return Err(project_contract_violation(
629            method,
630            RpcContractSurface::Response,
631            &RpcContractViolation::StderrMustBeString,
632            result,
633        ));
634    }
635    Ok(())
636}
637
638fn validate_command_exec_size(size: &Value, method: &str, payload: &Value) -> Result<(), RpcError> {
639    let size_obj = size.as_object().ok_or_else(|| {
640        project_contract_violation(
641            method,
642            RpcContractSurface::Request,
643            &RpcContractViolation::SizeMustBeObject,
644            payload,
645        )
646    })?;
647    let rows = size_obj.get("rows").and_then(Value::as_u64).unwrap_or(0);
648    let cols = size_obj.get("cols").and_then(Value::as_u64).unwrap_or(0);
649    if rows == 0 {
650        return Err(project_contract_violation(
651            method,
652            RpcContractSurface::Request,
653            &RpcContractViolation::SizeRowsMustBePositive,
654            payload,
655        ));
656    }
657    if cols == 0 {
658        return Err(project_contract_violation(
659            method,
660            RpcContractSurface::Request,
661            &RpcContractViolation::SizeColsMustBePositive,
662            payload,
663        ));
664    }
665    Ok(())
666}
667
668fn get_optional_non_empty_string<'a>(
669    obj: &'a serde_json::Map<String, Value>,
670    key: &str,
671) -> Result<Option<&'a str>, RpcContractViolation> {
672    match obj.get(key) {
673        Some(Value::String(text)) if !text.trim().is_empty() => Ok(Some(text)),
674        Some(Value::String(_)) => Err(RpcContractViolation::FieldMustBeNonEmptyString {
675            field_name: FIELD_PARAMS.to_owned(),
676            key: key.to_owned(),
677        }),
678        Some(_) => Err(RpcContractViolation::ParamsFieldMustBeString {
679            key: key.to_owned(),
680        }),
681        None => Ok(None),
682    }
683}
684
685fn get_bool(obj: &serde_json::Map<String, Value>, key: &str) -> bool {
686    obj.get(key).and_then(Value::as_bool).unwrap_or(false)
687}
688
689fn invalid_request(method: &str, reason: &str, payload: &Value) -> RpcError {
690    project_contract_violation(
691        method,
692        RpcContractSurface::Request,
693        &RpcContractViolation::Custom(reason.to_owned()),
694        payload,
695    )
696}
697
698fn project_contract_violation(
699    method: &str,
700    surface: RpcContractSurface,
701    violation: &RpcContractViolation,
702    payload: &Value,
703) -> RpcError {
704    let side = match surface {
705        RpcContractSurface::Request => "request",
706        RpcContractSurface::Response => "response",
707    };
708    RpcError::InvalidRequest(format!(
709        "invalid json-rpc {side} for {method}: {}; payload={}",
710        violation.reason(),
711        payload_summary(payload),
712    ))
713}
714
715pub(crate) fn payload_summary(payload: &Value) -> String {
716    const MAX_KEYS: usize = 6;
717    match payload {
718        Value::Object(map) => {
719            let mut keys: Vec<&str> = map.keys().map(|key| key.as_str()).collect();
720            keys.sort_unstable();
721            let preview: Vec<&str> = keys.into_iter().take(MAX_KEYS).collect();
722            let more = if map.len() > MAX_KEYS { ",..." } else { "" };
723            format!("object(keys=[{}{}])", preview.join(","), more)
724        }
725        Value::Array(items) => format!("array(len={})", items.len()),
726        Value::String(text) => format!("string(len={})", text.len()),
727        Value::Number(_) => "number".to_owned(),
728        Value::Bool(_) => "bool".to_owned(),
729        Value::Null => "null".to_owned(),
730    }
731}
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736    use serde_json::json;
737
738    #[test]
739    fn rejects_empty_method() {
740        let err = validate_rpc_request("", &json!({}), RpcValidationMode::KnownMethods)
741            .expect_err("empty method must fail");
742        assert!(matches!(err, RpcError::InvalidRequest(_)));
743    }
744
745    #[test]
746    fn validates_turn_interrupt_params_shape() {
747        let err = validate_rpc_request(
748            "turn/interrupt",
749            &json!({"threadId":"thr"}),
750            RpcValidationMode::KnownMethods,
751        )
752        .expect_err("missing turnId must fail");
753        assert!(matches!(err, RpcError::InvalidRequest(_)));
754
755        validate_rpc_request(
756            "turn/interrupt",
757            &json!({"threadId":"thr", "turnId":"turn"}),
758            RpcValidationMode::KnownMethods,
759        )
760        .expect("valid params");
761    }
762
763    #[test]
764    fn validates_thread_start_response_thread_id() {
765        let err = validate_rpc_response(
766            "thread/start",
767            &json!({"thread": {}}),
768            RpcValidationMode::KnownMethods,
769        )
770        .expect_err("missing thread id must fail");
771        assert!(matches!(err, RpcError::InvalidRequest(_)));
772
773        validate_rpc_response(
774            "thread/start",
775            &json!({"thread": {"id":"thr_1"}}),
776            RpcValidationMode::KnownMethods,
777        )
778        .expect("valid response");
779    }
780
781    #[test]
782    fn validates_turn_start_response_turn_id() {
783        let err = validate_rpc_response(
784            "turn/start",
785            &json!({"turn": {}}),
786            RpcValidationMode::KnownMethods,
787        )
788        .expect_err("missing turn id must fail");
789        assert!(matches!(err, RpcError::InvalidRequest(_)));
790
791        validate_rpc_response(
792            "turn/start",
793            &json!({"turn": {"id":"turn_1"}}),
794            RpcValidationMode::KnownMethods,
795        )
796        .expect("valid response");
797    }
798
799    #[test]
800    fn validates_skills_list_response_shape() {
801        let err = validate_rpc_response(
802            "skills/list",
803            &json!({"skills":[]}),
804            RpcValidationMode::KnownMethods,
805        )
806        .expect_err("missing result.data must fail");
807        assert!(matches!(err, RpcError::InvalidRequest(_)));
808
809        validate_rpc_response(
810            "skills/list",
811            &json!({"data":[]}),
812            RpcValidationMode::KnownMethods,
813        )
814        .expect("valid response");
815    }
816
817    #[test]
818    fn validates_command_exec_request_constraints() {
819        let err = validate_rpc_request(
820            "command/exec",
821            &json!({"command":["bash"],"tty":true}),
822            RpcValidationMode::KnownMethods,
823        )
824        .expect_err("tty without processId must fail");
825        assert!(matches!(err, RpcError::InvalidRequest(_)));
826
827        let err = validate_rpc_request(
828            "command/exec",
829            &json!({"command":["bash"],"disableTimeout":true,"timeoutMs":1}),
830            RpcValidationMode::KnownMethods,
831        )
832        .expect_err("disableTimeout + timeoutMs must fail");
833        assert!(matches!(err, RpcError::InvalidRequest(_)));
834
835        validate_rpc_request(
836            "command/exec",
837            &json!({"command":["bash"],"processId":"proc-1","tty":true}),
838            RpcValidationMode::KnownMethods,
839        )
840        .expect("tty with processId should pass");
841    }
842
843    #[test]
844    fn validates_command_exec_request_rejects_non_string_process_id() {
845        let err = validate_rpc_request(
846            "command/exec",
847            &json!({"command":["bash"],"processId":123}),
848            RpcValidationMode::KnownMethods,
849        )
850        .expect_err("non-string processId must fail");
851
852        let RpcError::InvalidRequest(message) = err else {
853            panic!("expected invalid request");
854        };
855        assert!(message.contains("params.processId must be a string"));
856    }
857
858    #[test]
859    fn validates_command_exec_response_shape() {
860        let err = validate_rpc_response(
861            "command/exec",
862            &json!({"exitCode":0,"stdout":"ok"}),
863            RpcValidationMode::KnownMethods,
864        )
865        .expect_err("stderr missing must fail");
866        assert!(matches!(err, RpcError::InvalidRequest(_)));
867
868        validate_rpc_response(
869            "command/exec",
870            &json!({"exitCode":0,"stdout":"ok","stderr":""}),
871            RpcValidationMode::KnownMethods,
872        )
873        .expect("valid command exec response");
874    }
875
876    #[test]
877    fn passes_unknown_method_in_known_mode() {
878        validate_rpc_request(
879            "echo/custom",
880            &json!({"k":"v"}),
881            RpcValidationMode::KnownMethods,
882        )
883        .expect("unknown method request should pass");
884        validate_rpc_response(
885            "echo/custom",
886            &json!({"ok":true}),
887            RpcValidationMode::KnownMethods,
888        )
889        .expect("unknown method response should pass");
890    }
891
892    #[test]
893    fn known_method_catalog_is_stable() {
894        assert_eq!(
895            methods::KNOWN,
896            [
897                methods::THREAD_START,
898                methods::THREAD_RESUME,
899                methods::THREAD_FORK,
900                methods::THREAD_ARCHIVE,
901                methods::THREAD_READ,
902                methods::THREAD_LIST,
903                methods::THREAD_LOADED_LIST,
904                methods::THREAD_ROLLBACK,
905                methods::SKILLS_LIST,
906                methods::COMMAND_EXEC,
907                methods::COMMAND_EXEC_WRITE,
908                methods::COMMAND_EXEC_TERMINATE,
909                methods::COMMAND_EXEC_RESIZE,
910                methods::TURN_START,
911                methods::TURN_INTERRUPT,
912            ]
913        );
914    }
915
916    #[test]
917    fn descriptor_catalog_matches_known_method_catalog() {
918        let descriptor_methods: Vec<&'static str> = rpc_contract_descriptors()
919            .iter()
920            .map(|descriptor| descriptor.method)
921            .collect();
922        assert_eq!(descriptor_methods, methods::KNOWN);
923    }
924
925    #[test]
926    fn default_validation_mode_is_known_methods() {
927        assert_eq!(
928            RpcValidationMode::default(),
929            RpcValidationMode::KnownMethods
930        );
931    }
932
933    #[test]
934    fn skips_validation_in_none_mode() {
935        validate_rpc_request("", &json!(null), RpcValidationMode::None)
936            .expect_err("empty method must still fail");
937
938        validate_rpc_request("turn/start", &json!(null), RpcValidationMode::None)
939            .expect("none mode skips params shape");
940        validate_rpc_response("turn/start", &json!(null), RpcValidationMode::None)
941            .expect("none mode skips result shape");
942    }
943
944    #[test]
945    fn invalid_request_error_redacts_payload_values() {
946        let err = validate_rpc_request(
947            "turn/interrupt",
948            &json!({"threadId":"thr_sensitive","secret":"token-123"}),
949            RpcValidationMode::KnownMethods,
950        )
951        .expect_err("missing turnId must fail");
952
953        let RpcError::InvalidRequest(message) = err else {
954            panic!("expected invalid request");
955        };
956        assert!(message.contains("invalid json-rpc request for turn/interrupt"));
957        assert!(message.contains("params.turnId must be a non-empty string"));
958        assert!(message.contains("payload=object(keys=[secret,threadId])"));
959        assert!(!message.contains("token-123"));
960        assert!(!message.contains("thr_sensitive"));
961    }
962
963    #[test]
964    fn invalid_response_error_redacts_payload_values() {
965        let err = validate_rpc_response(
966            "thread/start",
967            &json!({"thread": {}, "secret": {"token":"abc"}}),
968            RpcValidationMode::KnownMethods,
969        )
970        .expect_err("missing thread id must fail");
971
972        let RpcError::InvalidRequest(message) = err else {
973            panic!("expected invalid request");
974        };
975        assert!(message.contains("invalid json-rpc response for thread/start"));
976        assert!(message.contains("result is missing thread id"));
977        assert!(message.contains("payload=object(keys=[secret,thread])"));
978        assert!(!message.contains("abc"));
979    }
980
981    #[test]
982    fn rejects_response_scalar_id_fallback() {
983        let err = validate_rpc_response(
984            "thread/start",
985            &json!("thr_scalar"),
986            RpcValidationMode::KnownMethods,
987        )
988        .expect_err("scalar id fallback must not be accepted");
989        assert!(matches!(err, RpcError::InvalidRequest(_)));
990    }
991}