1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::time::Duration;
9use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13#[serde(tag = "type")]
14pub enum ParentMessage {
15 Execute {
17 code: String,
19 manifest: Option<Value>,
21 config: WorkerConfig,
23 },
24 ToolCallResult {
26 request_id: u64,
28 result: Result<Value, String>,
30 },
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35#[serde(tag = "type")]
36pub enum ChildMessage {
37 ToolCallRequest {
39 request_id: u64,
41 server: String,
43 tool: String,
45 args: Value,
47 },
48 ExecutionComplete {
50 result: Result<Value, String>,
52 },
53 Log {
55 message: String,
57 },
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct WorkerConfig {
63 pub timeout_ms: u64,
65 pub max_heap_size: usize,
67 pub max_tool_calls: usize,
69 pub max_tool_call_args_size: usize,
71 pub max_output_size: usize,
73 pub max_code_size: usize,
75 #[serde(default = "default_max_ipc_message_size")]
77 pub max_ipc_message_size: usize,
78}
79
80fn default_max_ipc_message_size() -> usize {
81 DEFAULT_MAX_IPC_MESSAGE_SIZE
82}
83
84impl From<&crate::SandboxConfig> for WorkerConfig {
85 fn from(config: &crate::SandboxConfig) -> Self {
86 Self {
87 timeout_ms: config.timeout.as_millis() as u64,
88 max_heap_size: config.max_heap_size,
89 max_tool_calls: config.max_tool_calls,
90 max_tool_call_args_size: config.max_tool_call_args_size,
91 max_output_size: config.max_output_size,
92 max_code_size: config.max_code_size,
93 max_ipc_message_size: DEFAULT_MAX_IPC_MESSAGE_SIZE,
94 }
95 }
96}
97
98impl WorkerConfig {
99 pub fn to_sandbox_config(&self) -> crate::SandboxConfig {
101 crate::SandboxConfig {
102 timeout: Duration::from_millis(self.timeout_ms),
103 max_code_size: self.max_code_size,
104 max_output_size: self.max_output_size,
105 max_heap_size: self.max_heap_size,
106 max_concurrent: 1, max_tool_calls: self.max_tool_calls,
108 max_tool_call_args_size: self.max_tool_call_args_size,
109 execution_mode: crate::executor::ExecutionMode::InProcess, }
111 }
112}
113
114pub async fn write_message<T: Serialize, W: AsyncWrite + Unpin>(
118 writer: &mut W,
119 msg: &T,
120) -> Result<(), std::io::Error> {
121 let payload = serde_json::to_vec(msg)
122 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
123 let len = u32::try_from(payload.len()).map_err(|_| {
124 std::io::Error::new(
125 std::io::ErrorKind::InvalidData,
126 format!(
127 "IPC payload too large: {} bytes (max {} bytes)",
128 payload.len(),
129 u32::MAX
130 ),
131 )
132 })?;
133 writer.write_all(&len.to_be_bytes()).await?;
134 writer.write_all(&payload).await?;
135 writer.flush().await?;
136 Ok(())
137}
138
139pub const DEFAULT_MAX_IPC_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
141
142pub async fn read_message<T: for<'de> Deserialize<'de>, R: AsyncRead + Unpin>(
147 reader: &mut R,
148) -> Result<Option<T>, std::io::Error> {
149 read_message_with_limit(reader, DEFAULT_MAX_IPC_MESSAGE_SIZE).await
150}
151
152pub async fn read_message_with_limit<T: for<'de> Deserialize<'de>, R: AsyncRead + Unpin>(
157 reader: &mut R,
158 max_size: usize,
159) -> Result<Option<T>, std::io::Error> {
160 let mut len_buf = [0u8; 4];
161 match reader.read_exact(&mut len_buf).await {
162 Ok(_) => {}
163 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
164 Err(e) => return Err(e),
165 }
166
167 let len = u32::from_be_bytes(len_buf) as usize;
168
169 if len > max_size {
171 return Err(std::io::Error::new(
172 std::io::ErrorKind::InvalidData,
173 format!(
174 "IPC message too large: {} bytes (limit: {} bytes)",
175 len, max_size
176 ),
177 ));
178 }
179
180 let mut payload = vec![0u8; len];
181 reader.read_exact(&mut payload).await?;
182
183 let msg: T = serde_json::from_slice(&payload)
184 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
185 Ok(Some(msg))
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use std::io::Cursor;
192
193 #[tokio::test]
194 async fn roundtrip_parent_execute_message() {
195 let msg = ParentMessage::Execute {
196 code: "async () => { return 42; }".into(),
197 manifest: Some(serde_json::json!({"servers": []})),
198 config: WorkerConfig {
199 timeout_ms: 5000,
200 max_heap_size: 64 * 1024 * 1024,
201 max_tool_calls: 50,
202 max_tool_call_args_size: 1024 * 1024,
203 max_output_size: 1024 * 1024,
204 max_code_size: 64 * 1024,
205 max_ipc_message_size: DEFAULT_MAX_IPC_MESSAGE_SIZE,
206 },
207 };
208
209 let mut buf = Vec::new();
210 write_message(&mut buf, &msg).await.unwrap();
211
212 let mut cursor = Cursor::new(buf);
213 let decoded: ParentMessage = read_message(&mut cursor).await.unwrap().unwrap();
214
215 match decoded {
216 ParentMessage::Execute {
217 code,
218 manifest,
219 config,
220 } => {
221 assert_eq!(code, "async () => { return 42; }");
222 assert!(manifest.is_some());
223 assert_eq!(config.timeout_ms, 5000);
224 }
225 other => panic!("expected Execute, got: {:?}", other),
226 }
227 }
228
229 #[tokio::test]
230 async fn roundtrip_parent_tool_result() {
231 let msg = ParentMessage::ToolCallResult {
232 request_id: 42,
233 result: Ok(serde_json::json!({"status": "ok"})),
234 };
235
236 let mut buf = Vec::new();
237 write_message(&mut buf, &msg).await.unwrap();
238
239 let mut cursor = Cursor::new(buf);
240 let decoded: ParentMessage = read_message(&mut cursor).await.unwrap().unwrap();
241
242 match decoded {
243 ParentMessage::ToolCallResult { request_id, result } => {
244 assert_eq!(request_id, 42);
245 assert!(result.is_ok());
246 }
247 other => panic!("expected ToolCallResult, got: {:?}", other),
248 }
249 }
250
251 #[tokio::test]
252 async fn roundtrip_parent_tool_result_error() {
253 let msg = ParentMessage::ToolCallResult {
254 request_id: 7,
255 result: Err("connection refused".into()),
256 };
257
258 let mut buf = Vec::new();
259 write_message(&mut buf, &msg).await.unwrap();
260
261 let mut cursor = Cursor::new(buf);
262 let decoded: ParentMessage = read_message(&mut cursor).await.unwrap().unwrap();
263
264 match decoded {
265 ParentMessage::ToolCallResult { request_id, result } => {
266 assert_eq!(request_id, 7);
267 assert_eq!(result.unwrap_err(), "connection refused");
268 }
269 other => panic!("expected ToolCallResult, got: {:?}", other),
270 }
271 }
272
273 #[tokio::test]
274 async fn roundtrip_child_tool_request() {
275 let msg = ChildMessage::ToolCallRequest {
276 request_id: 1,
277 server: "narsil".into(),
278 tool: "ast.parse".into(),
279 args: serde_json::json!({"file": "test.rs"}),
280 };
281
282 let mut buf = Vec::new();
283 write_message(&mut buf, &msg).await.unwrap();
284
285 let mut cursor = Cursor::new(buf);
286 let decoded: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
287
288 match decoded {
289 ChildMessage::ToolCallRequest {
290 request_id,
291 server,
292 tool,
293 args,
294 } => {
295 assert_eq!(request_id, 1);
296 assert_eq!(server, "narsil");
297 assert_eq!(tool, "ast.parse");
298 assert_eq!(args["file"], "test.rs");
299 }
300 other => panic!("expected ToolCallRequest, got: {:?}", other),
301 }
302 }
303
304 #[tokio::test]
305 async fn roundtrip_child_execution_complete() {
306 let msg = ChildMessage::ExecutionComplete {
307 result: Ok(serde_json::json!([1, 2, 3])),
308 };
309
310 let mut buf = Vec::new();
311 write_message(&mut buf, &msg).await.unwrap();
312
313 let mut cursor = Cursor::new(buf);
314 let decoded: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
315
316 match decoded {
317 ChildMessage::ExecutionComplete { result } => {
318 assert_eq!(result.unwrap(), serde_json::json!([1, 2, 3]));
319 }
320 other => panic!("expected ExecutionComplete, got: {:?}", other),
321 }
322 }
323
324 #[tokio::test]
325 async fn roundtrip_child_log() {
326 let msg = ChildMessage::Log {
327 message: "processing step 3".into(),
328 };
329
330 let mut buf = Vec::new();
331 write_message(&mut buf, &msg).await.unwrap();
332
333 let mut cursor = Cursor::new(buf);
334 let decoded: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
335
336 match decoded {
337 ChildMessage::Log { message } => {
338 assert_eq!(message, "processing step 3");
339 }
340 other => panic!("expected Log, got: {:?}", other),
341 }
342 }
343
344 #[tokio::test]
345 async fn multiple_messages_in_stream() {
346 let msg1 = ChildMessage::Log {
347 message: "first".into(),
348 };
349 let msg2 = ChildMessage::ToolCallRequest {
350 request_id: 1,
351 server: "s".into(),
352 tool: "t".into(),
353 args: serde_json::json!({}),
354 };
355 let msg3 = ChildMessage::ExecutionComplete {
356 result: Ok(serde_json::json!("done")),
357 };
358
359 let mut buf = Vec::new();
360 write_message(&mut buf, &msg1).await.unwrap();
361 write_message(&mut buf, &msg2).await.unwrap();
362 write_message(&mut buf, &msg3).await.unwrap();
363
364 let mut cursor = Cursor::new(buf);
365 let d1: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
366 let d2: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
367 let d3: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
368
369 assert!(matches!(d1, ChildMessage::Log { .. }));
370 assert!(matches!(d2, ChildMessage::ToolCallRequest { .. }));
371 assert!(matches!(d3, ChildMessage::ExecutionComplete { .. }));
372
373 let d4: Option<ChildMessage> = read_message(&mut cursor).await.unwrap();
375 assert!(d4.is_none());
376 }
377
378 #[tokio::test]
379 async fn execution_complete_error_roundtrip() {
380 let msg = ChildMessage::ExecutionComplete {
381 result: Err("failed to create tokio runtime: resource unavailable".into()),
382 };
383
384 let mut buf = Vec::new();
385 write_message(&mut buf, &msg).await.unwrap();
386
387 let mut cursor = Cursor::new(buf);
388 let decoded: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
389
390 match decoded {
391 ChildMessage::ExecutionComplete { result } => {
392 let err = result.unwrap_err();
393 assert!(
394 err.contains("tokio runtime"),
395 "expected runtime error: {err}"
396 );
397 }
398 other => panic!("expected ExecutionComplete, got: {:?}", other),
399 }
400 }
401
402 #[tokio::test]
403 async fn eof_returns_none() {
404 let mut cursor = Cursor::new(Vec::<u8>::new());
405 let result: Option<ParentMessage> = read_message(&mut cursor).await.unwrap();
406 assert!(result.is_none());
407 }
408
409 #[test]
410 fn u32_try_from_overflow() {
411 let overflow_size = u32::MAX as usize + 1;
413 assert!(u32::try_from(overflow_size).is_err());
414 }
415
416 #[tokio::test]
417 async fn write_message_normal_size_succeeds() {
418 let msg = ChildMessage::Log {
420 message: "a".repeat(1024),
421 };
422 let mut buf = Vec::new();
423 write_message(&mut buf, &msg).await.unwrap();
424 assert!(buf.len() > 1024);
425 }
426
427 #[tokio::test]
428 async fn large_message_roundtrip() {
429 let large_data = "x".repeat(1_000_000);
431 let msg = ChildMessage::ExecutionComplete {
432 result: Ok(serde_json::json!(large_data)),
433 };
434
435 let mut buf = Vec::new();
436 write_message(&mut buf, &msg).await.unwrap();
437
438 let mut cursor = Cursor::new(buf);
439 let decoded: ChildMessage = read_message(&mut cursor).await.unwrap().unwrap();
440
441 match decoded {
442 ChildMessage::ExecutionComplete { result } => {
443 assert_eq!(result.unwrap().as_str().unwrap().len(), 1_000_000);
444 }
445 other => panic!("expected ExecutionComplete, got: {:?}", other),
446 }
447 }
448
449 #[tokio::test]
450 async fn worker_config_roundtrip_from_sandbox_config() {
451 let sandbox = crate::SandboxConfig::default();
452 let worker = WorkerConfig::from(&sandbox);
453 let back = worker.to_sandbox_config();
454
455 assert_eq!(sandbox.timeout, back.timeout);
456 assert_eq!(sandbox.max_heap_size, back.max_heap_size);
457 assert_eq!(sandbox.max_tool_calls, back.max_tool_calls);
458 assert_eq!(sandbox.max_output_size, back.max_output_size);
459 assert_eq!(worker.max_ipc_message_size, DEFAULT_MAX_IPC_MESSAGE_SIZE);
460 }
461
462 #[tokio::test]
463 async fn read_message_with_limit_rejects_oversized() {
464 let msg = ChildMessage::Log {
465 message: "x".repeat(1024),
466 };
467 let mut buf = Vec::new();
468 write_message(&mut buf, &msg).await.unwrap();
469
470 let mut cursor = Cursor::new(buf);
472 let result: Result<Option<ChildMessage>, _> =
473 read_message_with_limit(&mut cursor, 64).await;
474 assert!(result.is_err());
475 let err_msg = result.unwrap_err().to_string();
476 assert!(err_msg.contains("too large"), "error: {err_msg}");
477 }
478
479 #[tokio::test]
480 async fn read_message_with_limit_accepts_within_limit() {
481 let msg = ChildMessage::Log {
482 message: "hello".into(),
483 };
484 let mut buf = Vec::new();
485 write_message(&mut buf, &msg).await.unwrap();
486
487 let mut cursor = Cursor::new(buf);
488 let result: Option<ChildMessage> =
489 read_message_with_limit(&mut cursor, 1024).await.unwrap();
490 assert!(result.is_some());
491 }
492
493 #[tokio::test]
494 async fn worker_config_ipc_limit_serde_default() {
495 let json = r#"{
497 "timeout_ms": 5000,
498 "max_heap_size": 67108864,
499 "max_tool_calls": 50,
500 "max_tool_call_args_size": 1048576,
501 "max_output_size": 1048576,
502 "max_code_size": 65536
503 }"#;
504 let config: WorkerConfig = serde_json::from_str(json).unwrap();
505 assert_eq!(config.max_ipc_message_size, DEFAULT_MAX_IPC_MESSAGE_SIZE);
506 }
507}