1use serde_json::Value;
2
3use crate::runtime::api::summarize_sandbox_policy_wire_value;
4use crate::runtime::errors::RpcError;
5use crate::runtime::turn_output::{parse_thread_id, parse_turn_id};
6
7pub mod methods {
9 pub const THREAD_START: &str = "thread/start";
10 pub const THREAD_RESUME: &str = "thread/resume";
11 pub const THREAD_FORK: &str = "thread/fork";
12 pub const THREAD_ARCHIVE: &str = "thread/archive";
13 pub const THREAD_READ: &str = "thread/read";
14 pub const THREAD_LIST: &str = "thread/list";
15 pub const THREAD_LOADED_LIST: &str = "thread/loaded/list";
16 pub const THREAD_ROLLBACK: &str = "thread/rollback";
17 pub const SKILLS_LIST: &str = "skills/list";
18 pub const COMMAND_EXEC: &str = "command/exec";
19 pub const COMMAND_EXEC_WRITE: &str = "command/exec/write";
20 pub const COMMAND_EXEC_TERMINATE: &str = "command/exec/terminate";
21 pub const COMMAND_EXEC_RESIZE: &str = "command/exec/resize";
22 pub const TURN_START: &str = "turn/start";
23 pub const TURN_INTERRUPT: &str = "turn/interrupt";
24
25 pub const ITEM_COMMAND_EXECUTION_REQUEST_APPROVAL: &str =
27 "item/commandExecution/requestApproval";
28 pub const ITEM_FILE_CHANGE_REQUEST_APPROVAL: &str = "item/fileChange/requestApproval";
29 pub const ITEM_TOOL_REQUEST_USER_INPUT: &str = "item/tool/requestUserInput";
30 pub const ITEM_TOOL_CALL: &str = "item/tool/call";
31 pub const ACCOUNT_CHATGPT_AUTH_TOKENS_REFRESH: &str = "account/chatgptAuthTokens/refresh";
32
33 pub const THREAD_STARTED: &str = "thread/started";
35 pub const TURN_STARTED: &str = "turn/started";
36 pub const TURN_COMPLETED: &str = "turn/completed";
37 pub const TURN_FAILED: &str = "turn/failed";
38 pub const TURN_CANCELLED: &str = "turn/cancelled";
39 pub const TURN_INTERRUPTED: &str = "turn/interrupted";
40 pub const TURN_DIFF_UPDATED: &str = "turn/diff/updated";
41 pub const TURN_PLAN_UPDATED: &str = "turn/plan/updated";
42 pub const ITEM_STARTED: &str = "item/started";
43 pub const ITEM_AGENT_MESSAGE_DELTA: &str = "item/agentMessage/delta";
44 pub const ITEM_COMMAND_EXECUTION_OUTPUT_DELTA: &str = "item/commandExecution/outputDelta";
45 pub const COMMAND_EXEC_OUTPUT_DELTA: &str = "command/exec/outputDelta";
46 pub const ITEM_COMPLETED: &str = "item/completed";
47 pub const APPROVAL_ACK: &str = "approval/ack";
48 pub const SKILLS_CHANGED: &str = "skills/changed";
49
50 pub const KNOWN: [&str; 15] = [
51 THREAD_START,
52 THREAD_RESUME,
53 THREAD_FORK,
54 THREAD_ARCHIVE,
55 THREAD_READ,
56 THREAD_LIST,
57 THREAD_LOADED_LIST,
58 THREAD_ROLLBACK,
59 SKILLS_LIST,
60 COMMAND_EXEC,
61 COMMAND_EXEC_WRITE,
62 COMMAND_EXEC_TERMINATE,
63 COMMAND_EXEC_RESIZE,
64 TURN_START,
65 TURN_INTERRUPT,
66 ];
67}
68
69#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
71pub enum RpcValidationMode {
72 None,
74 #[default]
76 KnownMethods,
77}
78
79#[derive(Clone, Copy, Debug, PartialEq, Eq)]
81pub enum RpcRequestContract {
82 Object,
83 ThreadStart,
84 ThreadId,
85 ThreadIdAndTurnId,
86 ProcessId,
87 CommandExec,
88 CommandExecWrite,
89 CommandExecResize,
90}
91
92#[derive(Clone, Copy, Debug, PartialEq, Eq)]
94pub enum RpcResponseContract {
95 Object,
96 ThreadId,
97 TurnId,
98 DataArray,
99 CommandExec,
100}
101
102#[derive(Clone, Copy, Debug, PartialEq, Eq)]
104pub struct RpcContractDescriptor {
105 pub method: &'static str,
106 pub request: RpcRequestContract,
107 pub response: RpcResponseContract,
108}
109
110const FIELD_PARAMS: &str = "params";
111const FIELD_RESULT: &str = "result";
112const FIELD_PARAMS_SANDBOX_POLICY: &str = "params.sandboxPolicy";
113const KEY_DATA: &str = "data";
114const KEY_PROCESS_ID: &str = "processId";
115const KEY_SIZE: &str = "size";
116
117const RPC_CONTRACT_DESCRIPTORS: [RpcContractDescriptor; 15] = [
118 RpcContractDescriptor {
119 method: methods::THREAD_START,
120 request: RpcRequestContract::ThreadStart,
121 response: RpcResponseContract::ThreadId,
122 },
123 RpcContractDescriptor {
124 method: methods::THREAD_RESUME,
125 request: RpcRequestContract::ThreadId,
126 response: RpcResponseContract::ThreadId,
127 },
128 RpcContractDescriptor {
129 method: methods::THREAD_FORK,
130 request: RpcRequestContract::ThreadId,
131 response: RpcResponseContract::ThreadId,
132 },
133 RpcContractDescriptor {
134 method: methods::THREAD_ARCHIVE,
135 request: RpcRequestContract::ThreadId,
136 response: RpcResponseContract::Object,
137 },
138 RpcContractDescriptor {
139 method: methods::THREAD_READ,
140 request: RpcRequestContract::ThreadId,
141 response: RpcResponseContract::ThreadId,
142 },
143 RpcContractDescriptor {
144 method: methods::THREAD_LIST,
145 request: RpcRequestContract::Object,
146 response: RpcResponseContract::DataArray,
147 },
148 RpcContractDescriptor {
149 method: methods::THREAD_LOADED_LIST,
150 request: RpcRequestContract::Object,
151 response: RpcResponseContract::DataArray,
152 },
153 RpcContractDescriptor {
154 method: methods::THREAD_ROLLBACK,
155 request: RpcRequestContract::ThreadId,
156 response: RpcResponseContract::ThreadId,
157 },
158 RpcContractDescriptor {
159 method: methods::SKILLS_LIST,
160 request: RpcRequestContract::Object,
161 response: RpcResponseContract::DataArray,
162 },
163 RpcContractDescriptor {
164 method: methods::COMMAND_EXEC,
165 request: RpcRequestContract::CommandExec,
166 response: RpcResponseContract::CommandExec,
167 },
168 RpcContractDescriptor {
169 method: methods::COMMAND_EXEC_WRITE,
170 request: RpcRequestContract::CommandExecWrite,
171 response: RpcResponseContract::Object,
172 },
173 RpcContractDescriptor {
174 method: methods::COMMAND_EXEC_TERMINATE,
175 request: RpcRequestContract::ProcessId,
176 response: RpcResponseContract::Object,
177 },
178 RpcContractDescriptor {
179 method: methods::COMMAND_EXEC_RESIZE,
180 request: RpcRequestContract::CommandExecResize,
181 response: RpcResponseContract::Object,
182 },
183 RpcContractDescriptor {
184 method: methods::TURN_START,
185 request: RpcRequestContract::ThreadId,
186 response: RpcResponseContract::TurnId,
187 },
188 RpcContractDescriptor {
189 method: methods::TURN_INTERRUPT,
190 request: RpcRequestContract::ThreadIdAndTurnId,
191 response: RpcResponseContract::Object,
192 },
193];
194
195pub fn rpc_contract_descriptors() -> &'static [RpcContractDescriptor] {
197 &RPC_CONTRACT_DESCRIPTORS
198}
199
200pub fn rpc_contract_descriptor(method: &str) -> Option<&'static RpcContractDescriptor> {
202 RPC_CONTRACT_DESCRIPTORS
203 .iter()
204 .find(|descriptor| descriptor.method == method)
205}
206
207pub fn validate_rpc_request(
212 method: &str,
213 params: &Value,
214 mode: RpcValidationMode,
215) -> Result<(), RpcError> {
216 validate_method_name(method)?;
217
218 if mode == RpcValidationMode::None {
219 return Ok(());
220 }
221
222 match rpc_contract_descriptor(method) {
223 Some(descriptor) => validate_request_by_descriptor(method, params, *descriptor),
224 None => Ok(()),
225 }
226}
227
228pub fn validate_rpc_response(
232 method: &str,
233 result: &Value,
234 mode: RpcValidationMode,
235) -> Result<(), RpcError> {
236 validate_method_name(method)?;
237
238 if mode == RpcValidationMode::None {
239 return Ok(());
240 }
241
242 match rpc_contract_descriptor(method) {
243 Some(descriptor) => validate_response_by_descriptor(method, result, *descriptor),
244 None => Ok(()),
245 }
246}
247
248#[derive(Clone, Copy, Debug, PartialEq, Eq)]
249enum RpcContractSurface {
250 Request,
251 Response,
252}
253
254#[derive(Clone, Debug, PartialEq, Eq)]
255enum RpcContractViolation {
256 EmptyMethod,
257 FieldMustBeObject { field_name: String },
258 FieldMustBeNonEmptyString { field_name: String, key: String },
259 MissingThreadId,
260 MissingTurnId,
261 ResultDataMustBeArray,
262 CommandMustBeArray,
263 CommandMustNotBeEmpty,
264 CommandItemsMustBeStrings,
265 ProcessIdRequiredForStreaming,
266 DisableOutputCapConflictsWithOutputBytesCap,
267 DisableTimeoutConflictsWithTimeoutMs,
268 TimeoutMsMustBeNonNegative,
269 OutputBytesCapMustBePositive,
270 SizeRequiresTty,
271 SizeMustBeObject,
272 SizeRowsMustBePositive,
273 SizeColsMustBePositive,
274 WriteRequestMustIncludeDeltaOrCloseStdin,
275 ExitCodeMustBeI32CompatibleInteger,
276 StdoutMustBeString,
277 StderrMustBeString,
278 ParamsFieldMustBeString { key: String },
279 Custom(String),
280}
281
282impl RpcContractViolation {
283 fn reason(&self) -> String {
284 match self {
285 Self::EmptyMethod => "json-rpc method must not be empty".to_owned(),
286 Self::FieldMustBeObject { field_name } => format!("{field_name} must be an object"),
287 Self::FieldMustBeNonEmptyString { field_name, key } => {
288 format!("{field_name}.{key} must be a non-empty string")
289 }
290 Self::MissingThreadId => "result is missing thread id".to_owned(),
291 Self::MissingTurnId => "result is missing turn id".to_owned(),
292 Self::ResultDataMustBeArray => "result.data must be an array".to_owned(),
293 Self::CommandMustBeArray => "params.command must be an array".to_owned(),
294 Self::CommandMustNotBeEmpty => "params.command must not be empty".to_owned(),
295 Self::CommandItemsMustBeStrings => "params.command items must be strings".to_owned(),
296 Self::ProcessIdRequiredForStreaming => {
297 "params.processId is required when tty or streaming is enabled".to_owned()
298 }
299 Self::DisableOutputCapConflictsWithOutputBytesCap => {
300 "params.disableOutputCap cannot be combined with params.outputBytesCap".to_owned()
301 }
302 Self::DisableTimeoutConflictsWithTimeoutMs => {
303 "params.disableTimeout cannot be combined with params.timeoutMs".to_owned()
304 }
305 Self::TimeoutMsMustBeNonNegative => "params.timeoutMs must be >= 0".to_owned(),
306 Self::OutputBytesCapMustBePositive => "params.outputBytesCap must be > 0".to_owned(),
307 Self::SizeRequiresTty => "params.size is only valid when params.tty is true".to_owned(),
308 Self::SizeMustBeObject => "params.size must be an object".to_owned(),
309 Self::SizeRowsMustBePositive => "params.size.rows must be > 0".to_owned(),
310 Self::SizeColsMustBePositive => "params.size.cols must be > 0".to_owned(),
311 Self::WriteRequestMustIncludeDeltaOrCloseStdin => {
312 "params must include deltaBase64, closeStdin, or both".to_owned()
313 }
314 Self::ExitCodeMustBeI32CompatibleInteger => {
315 "result.exitCode must be an i32-compatible integer".to_owned()
316 }
317 Self::StdoutMustBeString => "result.stdout must be a string".to_owned(),
318 Self::StderrMustBeString => "result.stderr must be a string".to_owned(),
319 Self::ParamsFieldMustBeString { key } => format!("params.{key} must be a string"),
320 Self::Custom(reason) => reason.clone(),
321 }
322 }
323}
324
325fn validate_request_by_descriptor(
326 method: &str,
327 params: &Value,
328 descriptor: RpcContractDescriptor,
329) -> Result<(), RpcError> {
330 match descriptor.request {
331 RpcRequestContract::Object => {
332 require_object(params, method, FIELD_PARAMS)?;
333 Ok(())
334 }
335 RpcRequestContract::ThreadStart => validate_thread_start_request(params, method),
336 RpcRequestContract::ThreadId => require_string(params, method, "threadId", FIELD_PARAMS),
337 RpcRequestContract::ThreadIdAndTurnId => {
338 require_string(params, method, "threadId", FIELD_PARAMS)?;
339 require_string(params, method, "turnId", FIELD_PARAMS)
340 }
341 RpcRequestContract::ProcessId => {
342 require_string(params, method, KEY_PROCESS_ID, FIELD_PARAMS)
343 }
344 RpcRequestContract::CommandExec => validate_command_exec_request(params, method),
345 RpcRequestContract::CommandExecWrite => validate_command_exec_write_request(params, method),
346 RpcRequestContract::CommandExecResize => {
347 validate_command_exec_resize_request(params, method)
348 }
349 }
350}
351
352fn validate_response_by_descriptor(
353 method: &str,
354 result: &Value,
355 descriptor: RpcContractDescriptor,
356) -> Result<(), RpcError> {
357 match descriptor.response {
358 RpcResponseContract::Object => {
359 require_response_object(result, method, FIELD_RESULT)?;
360 Ok(())
361 }
362 RpcResponseContract::ThreadId => {
363 if parse_thread_id(result).is_none() {
364 Err(project_contract_violation(
365 method,
366 RpcContractSurface::Response,
367 &RpcContractViolation::MissingThreadId,
368 result,
369 ))
370 } else {
371 Ok(())
372 }
373 }
374 RpcResponseContract::TurnId => {
375 if parse_turn_id(result).is_none() {
376 Err(project_contract_violation(
377 method,
378 RpcContractSurface::Response,
379 &RpcContractViolation::MissingTurnId,
380 result,
381 ))
382 } else {
383 Ok(())
384 }
385 }
386 RpcResponseContract::DataArray => {
387 let obj = require_response_object(result, method, FIELD_RESULT)?;
388 match obj.get(KEY_DATA) {
389 Some(Value::Array(_)) => Ok(()),
390 _ => Err(project_contract_violation(
391 method,
392 RpcContractSurface::Response,
393 &RpcContractViolation::ResultDataMustBeArray,
394 result,
395 )),
396 }
397 }
398 RpcResponseContract::CommandExec => validate_command_exec_response(result, method),
399 }
400}
401
402fn validate_method_name(method: &str) -> Result<(), RpcError> {
403 if method.trim().is_empty() {
404 return Err(project_contract_violation(
405 method,
406 RpcContractSurface::Request,
407 &RpcContractViolation::EmptyMethod,
408 &Value::Null,
409 ));
410 }
411 Ok(())
412}
413
414fn require_object<'a>(
415 value: &'a Value,
416 method: &str,
417 field_name: &str,
418) -> Result<&'a serde_json::Map<String, Value>, RpcError> {
419 require_object_on(RpcContractSurface::Request, value, method, field_name)
420}
421
422fn require_response_object<'a>(
423 value: &'a Value,
424 method: &str,
425 field_name: &str,
426) -> Result<&'a serde_json::Map<String, Value>, RpcError> {
427 require_object_on(RpcContractSurface::Response, value, method, field_name)
428}
429
430fn require_object_on<'a>(
431 surface: RpcContractSurface,
432 value: &'a Value,
433 method: &str,
434 field_name: &str,
435) -> Result<&'a serde_json::Map<String, Value>, RpcError> {
436 value.as_object().ok_or_else(|| {
437 project_contract_violation(
438 method,
439 surface,
440 &RpcContractViolation::FieldMustBeObject {
441 field_name: field_name.to_owned(),
442 },
443 value,
444 )
445 })
446}
447
448fn require_string(
449 value: &Value,
450 method: &str,
451 key: &str,
452 field_name: &str,
453) -> Result<(), RpcError> {
454 let obj = require_object(value, method, field_name)?;
455 match obj.get(key).and_then(Value::as_str) {
456 Some(v) if !v.trim().is_empty() => Ok(()),
457 _ => Err(project_contract_violation(
458 method,
459 RpcContractSurface::Request,
460 &RpcContractViolation::FieldMustBeNonEmptyString {
461 field_name: field_name.to_owned(),
462 key: key.to_owned(),
463 },
464 value,
465 )),
466 }
467}
468
469fn validate_thread_start_request(params: &Value, method: &str) -> Result<(), RpcError> {
470 require_object(params, method, FIELD_PARAMS)?;
471 Ok(())
472}
473
474fn validate_command_exec_request(params: &Value, method: &str) -> Result<(), RpcError> {
475 let obj = require_object(params, method, FIELD_PARAMS)?;
476 let command = obj
477 .get("command")
478 .and_then(Value::as_array)
479 .ok_or_else(|| {
480 project_contract_violation(
481 method,
482 RpcContractSurface::Request,
483 &RpcContractViolation::CommandMustBeArray,
484 params,
485 )
486 })?;
487 if command.is_empty() {
488 return Err(project_contract_violation(
489 method,
490 RpcContractSurface::Request,
491 &RpcContractViolation::CommandMustNotBeEmpty,
492 params,
493 ));
494 }
495 if command.iter().any(|value| value.as_str().is_none()) {
496 return Err(project_contract_violation(
497 method,
498 RpcContractSurface::Request,
499 &RpcContractViolation::CommandItemsMustBeStrings,
500 params,
501 ));
502 }
503
504 let process_id = get_optional_non_empty_string(obj, KEY_PROCESS_ID).map_err(|violation| {
505 project_contract_violation(method, RpcContractSurface::Request, &violation, params)
506 })?;
507 let tty = get_bool(obj, "tty");
508 let stream_stdin = get_bool(obj, "streamStdin");
509 let stream_stdout_stderr = get_bool(obj, "streamStdoutStderr");
510 let effective_stream_stdin = tty || stream_stdin;
511 let effective_stream_stdout_stderr = tty || stream_stdout_stderr;
512
513 if (tty || effective_stream_stdin || effective_stream_stdout_stderr) && process_id.is_none() {
514 return Err(project_contract_violation(
515 method,
516 RpcContractSurface::Request,
517 &RpcContractViolation::ProcessIdRequiredForStreaming,
518 params,
519 ));
520 }
521 if get_bool(obj, "disableOutputCap") && obj.get("outputBytesCap").is_some() {
522 return Err(project_contract_violation(
523 method,
524 RpcContractSurface::Request,
525 &RpcContractViolation::DisableOutputCapConflictsWithOutputBytesCap,
526 params,
527 ));
528 }
529 if get_bool(obj, "disableTimeout") && obj.get("timeoutMs").is_some() {
530 return Err(project_contract_violation(
531 method,
532 RpcContractSurface::Request,
533 &RpcContractViolation::DisableTimeoutConflictsWithTimeoutMs,
534 params,
535 ));
536 }
537 if let Some(timeout_ms) = obj.get("timeoutMs").and_then(Value::as_i64) {
538 if timeout_ms < 0 {
539 return Err(project_contract_violation(
540 method,
541 RpcContractSurface::Request,
542 &RpcContractViolation::TimeoutMsMustBeNonNegative,
543 params,
544 ));
545 }
546 }
547 if let Some(output_bytes_cap) = obj.get("outputBytesCap").and_then(Value::as_u64) {
548 if output_bytes_cap == 0 {
549 return Err(project_contract_violation(
550 method,
551 RpcContractSurface::Request,
552 &RpcContractViolation::OutputBytesCapMustBePositive,
553 params,
554 ));
555 }
556 }
557 if let Some(size) = obj.get(KEY_SIZE) {
558 if !tty {
559 return Err(project_contract_violation(
560 method,
561 RpcContractSurface::Request,
562 &RpcContractViolation::SizeRequiresTty,
563 params,
564 ));
565 }
566 validate_command_exec_size(size, method, params)?;
567 }
568 if let Some(sandbox_policy) = obj.get("sandboxPolicy") {
569 summarize_sandbox_policy_wire_value(sandbox_policy, FIELD_PARAMS_SANDBOX_POLICY)
570 .map_err(|reason| invalid_request(method, &reason, params))?;
571 }
572
573 Ok(())
574}
575
576fn validate_command_exec_write_request(params: &Value, method: &str) -> Result<(), RpcError> {
577 require_string(params, method, KEY_PROCESS_ID, FIELD_PARAMS)?;
578 let obj = require_object(params, method, FIELD_PARAMS)?;
579 let has_delta = obj.get("deltaBase64").and_then(Value::as_str).is_some();
580 let close_stdin = get_bool(obj, "closeStdin");
581 if !has_delta && !close_stdin {
582 return Err(project_contract_violation(
583 method,
584 RpcContractSurface::Request,
585 &RpcContractViolation::WriteRequestMustIncludeDeltaOrCloseStdin,
586 params,
587 ));
588 }
589 Ok(())
590}
591
592fn validate_command_exec_resize_request(params: &Value, method: &str) -> Result<(), RpcError> {
593 require_string(params, method, KEY_PROCESS_ID, FIELD_PARAMS)?;
594 let obj = require_object(params, method, FIELD_PARAMS)?;
595 let size = obj.get(KEY_SIZE).ok_or_else(|| {
596 project_contract_violation(
597 method,
598 RpcContractSurface::Request,
599 &RpcContractViolation::SizeMustBeObject,
600 params,
601 )
602 })?;
603 validate_command_exec_size(size, method, params)
604}
605
606fn validate_command_exec_response(result: &Value, method: &str) -> Result<(), RpcError> {
607 let obj = require_response_object(result, method, FIELD_RESULT)?;
608 match obj.get("exitCode").and_then(Value::as_i64) {
609 Some(code) if i32::try_from(code).is_ok() => {}
610 _ => {
611 return Err(project_contract_violation(
612 method,
613 RpcContractSurface::Response,
614 &RpcContractViolation::ExitCodeMustBeI32CompatibleInteger,
615 result,
616 ));
617 }
618 }
619 if obj.get("stdout").and_then(Value::as_str).is_none() {
620 return Err(project_contract_violation(
621 method,
622 RpcContractSurface::Response,
623 &RpcContractViolation::StdoutMustBeString,
624 result,
625 ));
626 }
627 if obj.get("stderr").and_then(Value::as_str).is_none() {
628 return Err(project_contract_violation(
629 method,
630 RpcContractSurface::Response,
631 &RpcContractViolation::StderrMustBeString,
632 result,
633 ));
634 }
635 Ok(())
636}
637
638fn validate_command_exec_size(size: &Value, method: &str, payload: &Value) -> Result<(), RpcError> {
639 let size_obj = size.as_object().ok_or_else(|| {
640 project_contract_violation(
641 method,
642 RpcContractSurface::Request,
643 &RpcContractViolation::SizeMustBeObject,
644 payload,
645 )
646 })?;
647 let rows = size_obj.get("rows").and_then(Value::as_u64).unwrap_or(0);
648 let cols = size_obj.get("cols").and_then(Value::as_u64).unwrap_or(0);
649 if rows == 0 {
650 return Err(project_contract_violation(
651 method,
652 RpcContractSurface::Request,
653 &RpcContractViolation::SizeRowsMustBePositive,
654 payload,
655 ));
656 }
657 if cols == 0 {
658 return Err(project_contract_violation(
659 method,
660 RpcContractSurface::Request,
661 &RpcContractViolation::SizeColsMustBePositive,
662 payload,
663 ));
664 }
665 Ok(())
666}
667
668fn get_optional_non_empty_string<'a>(
669 obj: &'a serde_json::Map<String, Value>,
670 key: &str,
671) -> Result<Option<&'a str>, RpcContractViolation> {
672 match obj.get(key) {
673 Some(Value::String(text)) if !text.trim().is_empty() => Ok(Some(text)),
674 Some(Value::String(_)) => Err(RpcContractViolation::FieldMustBeNonEmptyString {
675 field_name: FIELD_PARAMS.to_owned(),
676 key: key.to_owned(),
677 }),
678 Some(_) => Err(RpcContractViolation::ParamsFieldMustBeString {
679 key: key.to_owned(),
680 }),
681 None => Ok(None),
682 }
683}
684
685fn get_bool(obj: &serde_json::Map<String, Value>, key: &str) -> bool {
686 obj.get(key).and_then(Value::as_bool).unwrap_or(false)
687}
688
689fn invalid_request(method: &str, reason: &str, payload: &Value) -> RpcError {
690 project_contract_violation(
691 method,
692 RpcContractSurface::Request,
693 &RpcContractViolation::Custom(reason.to_owned()),
694 payload,
695 )
696}
697
698fn project_contract_violation(
699 method: &str,
700 surface: RpcContractSurface,
701 violation: &RpcContractViolation,
702 payload: &Value,
703) -> RpcError {
704 let side = match surface {
705 RpcContractSurface::Request => "request",
706 RpcContractSurface::Response => "response",
707 };
708 RpcError::InvalidRequest(format!(
709 "invalid json-rpc {side} for {method}: {}; payload={}",
710 violation.reason(),
711 payload_summary(payload),
712 ))
713}
714
715pub(crate) fn payload_summary(payload: &Value) -> String {
716 const MAX_KEYS: usize = 6;
717 match payload {
718 Value::Object(map) => {
719 let mut keys: Vec<&str> = map.keys().map(|key| key.as_str()).collect();
720 keys.sort_unstable();
721 let preview: Vec<&str> = keys.into_iter().take(MAX_KEYS).collect();
722 let more = if map.len() > MAX_KEYS { ",..." } else { "" };
723 format!("object(keys=[{}{}])", preview.join(","), more)
724 }
725 Value::Array(items) => format!("array(len={})", items.len()),
726 Value::String(text) => format!("string(len={})", text.len()),
727 Value::Number(_) => "number".to_owned(),
728 Value::Bool(_) => "bool".to_owned(),
729 Value::Null => "null".to_owned(),
730 }
731}
732
733#[cfg(test)]
734mod tests {
735 use super::*;
736 use serde_json::json;
737
738 #[test]
739 fn rejects_empty_method() {
740 let err = validate_rpc_request("", &json!({}), RpcValidationMode::KnownMethods)
741 .expect_err("empty method must fail");
742 assert!(matches!(err, RpcError::InvalidRequest(_)));
743 }
744
745 #[test]
746 fn validates_turn_interrupt_params_shape() {
747 let err = validate_rpc_request(
748 "turn/interrupt",
749 &json!({"threadId":"thr"}),
750 RpcValidationMode::KnownMethods,
751 )
752 .expect_err("missing turnId must fail");
753 assert!(matches!(err, RpcError::InvalidRequest(_)));
754
755 validate_rpc_request(
756 "turn/interrupt",
757 &json!({"threadId":"thr", "turnId":"turn"}),
758 RpcValidationMode::KnownMethods,
759 )
760 .expect("valid params");
761 }
762
763 #[test]
764 fn validates_thread_start_response_thread_id() {
765 let err = validate_rpc_response(
766 "thread/start",
767 &json!({"thread": {}}),
768 RpcValidationMode::KnownMethods,
769 )
770 .expect_err("missing thread id must fail");
771 assert!(matches!(err, RpcError::InvalidRequest(_)));
772
773 validate_rpc_response(
774 "thread/start",
775 &json!({"thread": {"id":"thr_1"}}),
776 RpcValidationMode::KnownMethods,
777 )
778 .expect("valid response");
779 }
780
781 #[test]
782 fn validates_turn_start_response_turn_id() {
783 let err = validate_rpc_response(
784 "turn/start",
785 &json!({"turn": {}}),
786 RpcValidationMode::KnownMethods,
787 )
788 .expect_err("missing turn id must fail");
789 assert!(matches!(err, RpcError::InvalidRequest(_)));
790
791 validate_rpc_response(
792 "turn/start",
793 &json!({"turn": {"id":"turn_1"}}),
794 RpcValidationMode::KnownMethods,
795 )
796 .expect("valid response");
797 }
798
799 #[test]
800 fn validates_skills_list_response_shape() {
801 let err = validate_rpc_response(
802 "skills/list",
803 &json!({"skills":[]}),
804 RpcValidationMode::KnownMethods,
805 )
806 .expect_err("missing result.data must fail");
807 assert!(matches!(err, RpcError::InvalidRequest(_)));
808
809 validate_rpc_response(
810 "skills/list",
811 &json!({"data":[]}),
812 RpcValidationMode::KnownMethods,
813 )
814 .expect("valid response");
815 }
816
817 #[test]
818 fn validates_command_exec_request_constraints() {
819 let err = validate_rpc_request(
820 "command/exec",
821 &json!({"command":["bash"],"tty":true}),
822 RpcValidationMode::KnownMethods,
823 )
824 .expect_err("tty without processId must fail");
825 assert!(matches!(err, RpcError::InvalidRequest(_)));
826
827 let err = validate_rpc_request(
828 "command/exec",
829 &json!({"command":["bash"],"disableTimeout":true,"timeoutMs":1}),
830 RpcValidationMode::KnownMethods,
831 )
832 .expect_err("disableTimeout + timeoutMs must fail");
833 assert!(matches!(err, RpcError::InvalidRequest(_)));
834
835 validate_rpc_request(
836 "command/exec",
837 &json!({"command":["bash"],"processId":"proc-1","tty":true}),
838 RpcValidationMode::KnownMethods,
839 )
840 .expect("tty with processId should pass");
841 }
842
843 #[test]
844 fn validates_command_exec_request_rejects_non_string_process_id() {
845 let err = validate_rpc_request(
846 "command/exec",
847 &json!({"command":["bash"],"processId":123}),
848 RpcValidationMode::KnownMethods,
849 )
850 .expect_err("non-string processId must fail");
851
852 let RpcError::InvalidRequest(message) = err else {
853 panic!("expected invalid request");
854 };
855 assert!(message.contains("params.processId must be a string"));
856 }
857
858 #[test]
859 fn validates_command_exec_response_shape() {
860 let err = validate_rpc_response(
861 "command/exec",
862 &json!({"exitCode":0,"stdout":"ok"}),
863 RpcValidationMode::KnownMethods,
864 )
865 .expect_err("stderr missing must fail");
866 assert!(matches!(err, RpcError::InvalidRequest(_)));
867
868 validate_rpc_response(
869 "command/exec",
870 &json!({"exitCode":0,"stdout":"ok","stderr":""}),
871 RpcValidationMode::KnownMethods,
872 )
873 .expect("valid command exec response");
874 }
875
876 #[test]
877 fn passes_unknown_method_in_known_mode() {
878 validate_rpc_request(
879 "echo/custom",
880 &json!({"k":"v"}),
881 RpcValidationMode::KnownMethods,
882 )
883 .expect("unknown method request should pass");
884 validate_rpc_response(
885 "echo/custom",
886 &json!({"ok":true}),
887 RpcValidationMode::KnownMethods,
888 )
889 .expect("unknown method response should pass");
890 }
891
892 #[test]
893 fn known_method_catalog_is_stable() {
894 assert_eq!(
895 methods::KNOWN,
896 [
897 methods::THREAD_START,
898 methods::THREAD_RESUME,
899 methods::THREAD_FORK,
900 methods::THREAD_ARCHIVE,
901 methods::THREAD_READ,
902 methods::THREAD_LIST,
903 methods::THREAD_LOADED_LIST,
904 methods::THREAD_ROLLBACK,
905 methods::SKILLS_LIST,
906 methods::COMMAND_EXEC,
907 methods::COMMAND_EXEC_WRITE,
908 methods::COMMAND_EXEC_TERMINATE,
909 methods::COMMAND_EXEC_RESIZE,
910 methods::TURN_START,
911 methods::TURN_INTERRUPT,
912 ]
913 );
914 }
915
916 #[test]
917 fn descriptor_catalog_matches_known_method_catalog() {
918 let descriptor_methods: Vec<&'static str> = rpc_contract_descriptors()
919 .iter()
920 .map(|descriptor| descriptor.method)
921 .collect();
922 assert_eq!(descriptor_methods, methods::KNOWN);
923 }
924
925 #[test]
926 fn default_validation_mode_is_known_methods() {
927 assert_eq!(
928 RpcValidationMode::default(),
929 RpcValidationMode::KnownMethods
930 );
931 }
932
933 #[test]
934 fn skips_validation_in_none_mode() {
935 validate_rpc_request("", &json!(null), RpcValidationMode::None)
936 .expect_err("empty method must still fail");
937
938 validate_rpc_request("turn/start", &json!(null), RpcValidationMode::None)
939 .expect("none mode skips params shape");
940 validate_rpc_response("turn/start", &json!(null), RpcValidationMode::None)
941 .expect("none mode skips result shape");
942 }
943
944 #[test]
945 fn invalid_request_error_redacts_payload_values() {
946 let err = validate_rpc_request(
947 "turn/interrupt",
948 &json!({"threadId":"thr_sensitive","secret":"token-123"}),
949 RpcValidationMode::KnownMethods,
950 )
951 .expect_err("missing turnId must fail");
952
953 let RpcError::InvalidRequest(message) = err else {
954 panic!("expected invalid request");
955 };
956 assert!(message.contains("invalid json-rpc request for turn/interrupt"));
957 assert!(message.contains("params.turnId must be a non-empty string"));
958 assert!(message.contains("payload=object(keys=[secret,threadId])"));
959 assert!(!message.contains("token-123"));
960 assert!(!message.contains("thr_sensitive"));
961 }
962
963 #[test]
964 fn invalid_response_error_redacts_payload_values() {
965 let err = validate_rpc_response(
966 "thread/start",
967 &json!({"thread": {}, "secret": {"token":"abc"}}),
968 RpcValidationMode::KnownMethods,
969 )
970 .expect_err("missing thread id must fail");
971
972 let RpcError::InvalidRequest(message) = err else {
973 panic!("expected invalid request");
974 };
975 assert!(message.contains("invalid json-rpc response for thread/start"));
976 assert!(message.contains("result is missing thread id"));
977 assert!(message.contains("payload=object(keys=[secret,thread])"));
978 assert!(!message.contains("abc"));
979 }
980
981 #[test]
982 fn rejects_response_scalar_id_fallback() {
983 let err = validate_rpc_response(
984 "thread/start",
985 &json!("thr_scalar"),
986 RpcValidationMode::KnownMethods,
987 )
988 .expect_err("scalar id fallback must not be accepted");
989 assert!(matches!(err, RpcError::InvalidRequest(_)));
990 }
991}