1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9 Arc,
10 atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16
17pub const ENV_RUNNER_TCP_SOCKET: &str = "RRQ_RUNNER_TCP_SOCKET";
18const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
19const RESPONSE_CHANNEL_CAPACITY: usize = 64;
20const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
21
22fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
23 Box::new(std::io::Error::new(
24 std::io::ErrorKind::InvalidInput,
25 message.into(),
26 ))
27}
28
29pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
30 let raw = raw.trim();
31 if raw.is_empty() {
32 return Err(invalid_input("runner tcp_socket cannot be empty"));
33 }
34
35 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
36 let (host, port_str) = rest
37 .split_once("]:")
38 .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
39 (host, port_str)
40 } else {
41 let (host, port_str) = raw
42 .rsplit_once(':')
43 .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
44 if host.is_empty() {
45 return Err(invalid_input("runner tcp_socket host cannot be empty"));
46 }
47 (host, port_str)
48 };
49
50 let port: u16 = port_str
51 .parse()
52 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
53 if port == 0 {
54 return Err(invalid_input("runner tcp_socket port must be > 0"));
55 }
56
57 let ip = if host == "localhost" {
58 IpAddr::V4(Ipv4Addr::LOCALHOST)
59 } else {
60 let parsed: IpAddr = host
61 .parse()
62 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
63 if !parsed.is_loopback() {
64 return Err(invalid_input("runner tcp_socket host must be localhost"));
65 }
66 parsed
67 };
68
69 Ok(SocketAddr::new(ip, port))
70}
71
72pub struct RunnerRuntime {
73 runtime: tokio::runtime::Runtime,
74}
75
76impl RunnerRuntime {
77 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
78 Ok(Self {
79 runtime: tokio::runtime::Runtime::new()?,
80 })
81 }
82
83 pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
84 self.runtime.enter()
85 }
86
87 pub fn run_tcp(
88 &self,
89 registry: &Registry,
90 addr: SocketAddr,
91 ) -> Result<(), Box<dyn std::error::Error>> {
92 let telemetry = NoopTelemetry;
93 self.run_tcp_with(registry, addr, &telemetry)
94 }
95
96 pub fn run_tcp_with<T: Telemetry + ?Sized>(
97 &self,
98 registry: &Registry,
99 addr: SocketAddr,
100 telemetry: &T,
101 ) -> Result<(), Box<dyn std::error::Error>> {
102 run_tcp_loop(&self.runtime, registry, addr, telemetry)
103 }
104}
105
106pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
107 RunnerRuntime::new()?.run_tcp(registry, addr)
108}
109
110pub fn run_tcp_with<T: Telemetry + ?Sized>(
111 registry: &Registry,
112 addr: SocketAddr,
113 telemetry: &T,
114) -> Result<(), Box<dyn std::error::Error>> {
115 RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
116}
117
118fn run_tcp_loop<T: Telemetry + ?Sized>(
119 runtime: &tokio::runtime::Runtime,
120 registry: &Registry,
121 addr: SocketAddr,
122 telemetry: &T,
123) -> Result<(), Box<dyn std::error::Error>> {
124 let registry = registry.clone();
125 let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
126 let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
127 Arc::new(Mutex::new(HashMap::new()));
128 let telemetry = telemetry.clone_box();
129 runtime.block_on(async move {
130 if !addr.ip().is_loopback() {
131 return Err(invalid_input(format!(
132 "runner tcp_socket must be loopback-only (got {addr})"
133 )));
134 }
135 let listener = TcpListener::bind(addr).await?;
136 loop {
137 let (stream, _) = listener.accept().await?;
138 let registry = registry.clone();
139 let telemetry = telemetry.clone();
140 let in_flight = in_flight.clone();
141 let job_index = job_index.clone();
142 tokio::spawn(async move {
143 if let Err(err) =
144 handle_connection(stream, ®istry, telemetry.as_ref(), in_flight, job_index)
145 .await
146 {
147 tracing::error!("runner connection error: {err}");
148 }
149 });
150 }
151 })
152}
153
154async fn handle_connection<S, T>(
155 stream: S,
156 registry: &Registry,
157 telemetry: &T,
158 in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
159 job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
160) -> Result<(), Box<dyn std::error::Error>>
161where
162 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
163 T: Telemetry + ?Sized,
164{
165 let (mut reader, mut writer) = tokio::io::split(stream);
166 let (response_tx, mut response_rx) =
167 mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
168 let writer_task = tokio::spawn(async move {
169 while let Some(outcome) = response_rx.recv().await {
170 let response = RunnerMessage::Response { payload: outcome };
171 if write_message(&mut writer, &response).await.is_err() {
172 break;
173 }
174 }
175 });
176 let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
177 Arc::new(Mutex::new(std::collections::HashSet::new()));
178
179 loop {
180 let message = match read_message(&mut reader).await? {
181 Some(message) => message,
182 None => break,
183 };
184 match message {
185 RunnerMessage::Request { payload } => {
186 if payload.protocol_version != PROTOCOL_VERSION {
187 let outcome = ExecutionOutcome::error(
188 payload.job_id.clone(),
189 payload.request_id.clone(),
190 "Unsupported protocol version",
191 );
192 let _ = response_tx.send(outcome).await;
193 continue;
194 }
195
196 let request_id = payload.request_id.clone();
197 let job_id = payload.job_id.clone();
198 {
199 let mut active = connection_requests.lock().await;
200 if active.len() >= RESPONSE_CHANNEL_CAPACITY {
201 let outcome = ExecutionOutcome::error(
202 payload.job_id.clone(),
203 payload.request_id.clone(),
204 "Runner busy: too many in-flight requests",
205 );
206 drop(active);
207 let send_result =
208 timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
209 match send_result {
210 Ok(Ok(())) => {}
211 Ok(Err(_)) => {
212 return Err("runner response channel closed".into());
213 }
214 Err(_) => {
215 return Err("runner response channel stalled".into());
216 }
217 }
218 continue;
219 }
220 active.insert(request_id.clone());
221 }
222 let response_tx = response_tx.clone();
223 let registry = registry.clone();
224 let telemetry = telemetry.clone_box();
225 let in_flight_for_task = in_flight.clone();
226 let job_index_for_task = job_index.clone();
227 let active_for_task = connection_requests.clone();
228 let request_id_for_task = request_id.clone();
229 let job_id_for_task = job_id.clone();
230 let response_tx_for_task = response_tx.clone();
231 let completed = Arc::new(AtomicBool::new(false));
232 let completed_for_task = completed.clone();
233
234 let handle = tokio::spawn(async move {
235 let outcome =
236 execute_with_deadline(payload, registry, telemetry.as_ref()).await;
237 completed_for_task.store(true, Ordering::SeqCst);
238 let send_result =
239 timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
240 match send_result {
241 Ok(Ok(())) => {}
242 Ok(Err(_)) => {
243 tracing::warn!("runner response channel closed; dropping outcome");
244 }
245 Err(_) => {
246 tracing::warn!("runner response channel stalled; dropping outcome");
247 }
248 }
249 {
250 let mut in_flight = in_flight_for_task.lock().await;
251 in_flight.remove(&request_id_for_task);
252 }
253 {
254 let mut job_index = job_index_for_task.lock().await;
255 if let Some(entries) = job_index.get_mut(&job_id_for_task) {
256 entries.remove(&request_id_for_task);
257 if entries.is_empty() {
258 job_index.remove(&job_id_for_task);
259 }
260 }
261 }
262 {
263 let mut active = active_for_task.lock().await;
264 active.remove(&request_id_for_task);
265 }
266 });
267
268 {
269 let mut in_flight = in_flight.lock().await;
270 in_flight.insert(
271 request_id.clone(),
272 InFlightTask {
273 job_id: job_id.clone(),
274 handle,
275 response_tx: response_tx.clone(),
276 connection_requests: connection_requests.clone(),
277 completed,
278 },
279 );
280 }
281 {
282 let mut job_index = job_index.lock().await;
283 job_index
284 .entry(job_id)
285 .or_insert_with(HashSet::new)
286 .insert(request_id);
287 }
288 }
289 RunnerMessage::Cancel { payload } => {
290 handle_cancel(payload, &in_flight, &job_index).await;
291 }
292 RunnerMessage::Response { .. } => {
293 let outcome = ExecutionOutcome {
294 job_id: Some("unknown".to_string()),
295 request_id: None,
296 status: rrq_protocol::OutcomeStatus::Error,
297 result: None,
298 error: Some(ExecutionError {
299 message: "unexpected response message".to_string(),
300 error_type: None,
301 code: None,
302 details: None,
303 }),
304 retry_after_seconds: None,
305 };
306 let _ = response_tx.send(outcome).await;
307 }
308 }
309 }
310
311 let request_ids = {
312 let mut active = connection_requests.lock().await;
313 active.drain().collect::<Vec<_>>()
314 };
315 for request_id in request_ids {
316 let task = {
317 let mut in_flight = in_flight.lock().await;
318 in_flight.remove(&request_id)
319 };
320 if let Some(task) = task {
321 task.handle.abort();
322 let mut job_index = job_index.lock().await;
323 if let Some(entries) = job_index.get_mut(&task.job_id) {
324 entries.remove(&request_id);
325 if entries.is_empty() {
326 job_index.remove(&task.job_id);
327 }
328 }
329 }
330 }
331 writer_task.abort();
332
333 Ok(())
334}
335
336struct InFlightTask {
337 job_id: String,
338 handle: tokio::task::JoinHandle<()>,
339 response_tx: mpsc::Sender<ExecutionOutcome>,
340 connection_requests: Arc<Mutex<HashSet<String>>>,
341 completed: Arc<AtomicBool>,
342}
343
344async fn handle_cancel(
345 payload: CancelRequest,
346 in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
347 job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
348) {
349 if payload.protocol_version != PROTOCOL_VERSION {
350 return;
351 }
352 let request_ids = if let Some(request_id) = payload.request_id.clone() {
353 vec![request_id]
354 } else {
355 let job_index = job_index.lock().await;
356 job_index
357 .get(&payload.job_id)
358 .map(|ids| ids.iter().cloned().collect())
359 .unwrap_or_else(Vec::new)
360 };
361 if request_ids.is_empty() {
362 return;
363 }
364
365 for request_id in request_ids {
366 let task = {
367 let mut in_flight = in_flight.lock().await;
368 if let Some(task) = in_flight.get(&request_id)
369 && task.completed.load(Ordering::SeqCst)
370 {
371 None
372 } else {
373 in_flight.remove(&request_id)
374 }
375 };
376 if let Some(task) = task {
377 task.handle.abort();
378 {
379 let mut active = task.connection_requests.lock().await;
380 active.remove(&request_id);
381 }
382 let outcome = ExecutionOutcome {
383 job_id: Some(payload.job_id.clone()),
384 request_id: Some(request_id.clone()),
385 status: OutcomeStatus::Error,
386 result: None,
387 error: Some(ExecutionError {
388 message: "Job cancelled".to_string(),
389 error_type: Some("cancelled".to_string()),
390 code: None,
391 details: None,
392 }),
393 retry_after_seconds: None,
394 };
395 let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
396 match send_result {
397 Ok(Ok(())) => {}
398 Ok(Err(_)) => {
399 tracing::warn!("runner response channel closed; dropping cancel outcome");
400 }
401 Err(_) => {
402 tracing::warn!("runner response channel stalled; dropping cancel outcome");
403 }
404 }
405 let mut job_index = job_index.lock().await;
406 if let Some(entries) = job_index.get_mut(&task.job_id) {
407 entries.remove(&request_id);
408 if entries.is_empty() {
409 job_index.remove(&task.job_id);
410 }
411 }
412 }
413 }
414}
415
416async fn execute_with_deadline<T: Telemetry + ?Sized>(
417 request: rrq_protocol::ExecutionRequest,
418 registry: Registry,
419 telemetry: &T,
420) -> ExecutionOutcome {
421 let job_id = request.job_id.clone();
422 let request_id = request.request_id.clone();
423 let deadline = request.context.deadline;
424 if let Some(deadline) = deadline {
425 let now: DateTime<Utc> = Utc::now();
426 if deadline <= now {
427 return ExecutionOutcome::timeout(
428 job_id.clone(),
429 request_id.clone(),
430 "Job deadline exceeded",
431 );
432 }
433 if let Ok(remaining) = (deadline - now).to_std() {
434 match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
435 Ok(outcome) => return outcome,
436 Err(_) => {
437 return ExecutionOutcome::timeout(
438 job_id.clone(),
439 request_id.clone(),
440 "Job execution timed out",
441 );
442 }
443 }
444 }
445 return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
446 }
447 registry.execute_with(request, telemetry).await
448}
449
450async fn read_message<R: AsyncRead + Unpin>(
451 stream: &mut R,
452) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
453 let mut header = [0u8; 4];
454 match stream.read_exact(&mut header).await {
455 Ok(_) => {}
456 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
457 Err(err) => return Err(err.into()),
458 }
459 let length = u32::from_be_bytes(header) as usize;
460 if length == 0 {
461 return Err("runner message payload cannot be empty".into());
462 }
463 if length > MAX_FRAME_LEN {
464 return Err("runner message payload too large".into());
465 }
466 let mut payload = vec![0u8; length];
467 stream.read_exact(&mut payload).await?;
468 let message = serde_json::from_slice(&payload)?;
469 Ok(Some(message))
470}
471
472async fn write_message<W: AsyncWrite + Unpin>(
473 stream: &mut W,
474 message: &RunnerMessage,
475) -> Result<(), Box<dyn std::error::Error>> {
476 let framed = encode_frame(message)?;
477 stream.write_all(&framed).await?;
478 stream.flush().await?;
479 Ok(())
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::registry::Registry;
486 use crate::telemetry::NoopTelemetry;
487 use chrono::Utc;
488 use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
489 use serde_json::json;
490 use tokio::net::{TcpListener, TcpStream};
491 use tokio::time::{Duration, timeout};
492
493 fn build_request(function_name: &str) -> ExecutionRequest {
494 ExecutionRequest {
495 protocol_version: PROTOCOL_VERSION.to_string(),
496 request_id: "req-1".to_string(),
497 job_id: "job-1".to_string(),
498 function_name: function_name.to_string(),
499 params: std::collections::HashMap::new(),
500 context: ExecutionContext {
501 job_id: "job-1".to_string(),
502 attempt: 1,
503 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
504 queue_name: "default".to_string(),
505 deadline: None,
506 trace_context: None,
507 worker_id: None,
508 },
509 }
510 }
511
512 async fn tcp_pair() -> (TcpStream, TcpStream) {
513 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
514 let addr = listener.local_addr().unwrap();
515 let client = TcpStream::connect(addr).await.unwrap();
516 let (server, _) = listener.accept().await.unwrap();
517 (client, server)
518 }
519
520 #[tokio::test]
521 async fn handle_connection_executes_request() {
522 let mut registry = Registry::new();
523 registry.register("echo", |request| async move {
524 ExecutionOutcome::success(
525 request.job_id.clone(),
526 request.request_id.clone(),
527 json!({"ok": true}),
528 )
529 });
530 let (client, server) = tcp_pair().await;
531 let in_flight = Arc::new(Mutex::new(HashMap::new()));
532 let job_index = Arc::new(Mutex::new(HashMap::new()));
533 let server_task = tokio::spawn(async move {
534 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
535 .await
536 .unwrap();
537 });
538 let mut client = client;
539 let request = build_request("echo");
540 let message = RunnerMessage::Request { payload: request };
541 write_message(&mut client, &message).await.unwrap();
542 let response = read_message(&mut client).await.unwrap().unwrap();
543 match response {
544 RunnerMessage::Response { payload } => {
545 assert_eq!(payload.status, OutcomeStatus::Success);
546 assert_eq!(payload.result, Some(json!({"ok": true})));
547 }
548 _ => panic!("expected response"),
549 }
550 drop(client);
551 let _ = server_task.await;
552 }
553
554 #[tokio::test]
555 async fn handle_connection_rejects_bad_protocol() {
556 let registry = Registry::new();
557 let (client, server) = tcp_pair().await;
558 let in_flight = Arc::new(Mutex::new(HashMap::new()));
559 let job_index = Arc::new(Mutex::new(HashMap::new()));
560 let server_task = tokio::spawn(async move {
561 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
562 .await
563 .unwrap();
564 });
565 let mut client = client;
566 let mut request = build_request("echo");
567 request.protocol_version = "0".to_string();
568 let message = RunnerMessage::Request { payload: request };
569 write_message(&mut client, &message).await.unwrap();
570 let response = read_message(&mut client).await.unwrap().unwrap();
571 match response {
572 RunnerMessage::Response { payload } => {
573 assert_eq!(payload.status, OutcomeStatus::Error);
574 }
575 _ => panic!("expected response"),
576 }
577 drop(client);
578 let _ = server_task.await;
579 }
580
581 #[tokio::test]
582 async fn handle_connection_cancels_inflight() {
583 let mut registry = Registry::new();
584 registry.register("sleep", |request| async move {
585 tokio::time::sleep(Duration::from_millis(200)).await;
586 ExecutionOutcome::success(
587 request.job_id.clone(),
588 request.request_id.clone(),
589 json!({"ok": true}),
590 )
591 });
592 let (client, server) = tcp_pair().await;
593 let in_flight = Arc::new(Mutex::new(HashMap::new()));
594 let job_index = Arc::new(Mutex::new(HashMap::new()));
595 let server_task = tokio::spawn(async move {
596 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
597 .await
598 .unwrap();
599 });
600 let mut client = client;
601 let request = ExecutionRequest {
602 protocol_version: PROTOCOL_VERSION.to_string(),
603 request_id: "req-cancel".to_string(),
604 job_id: "job-cancel".to_string(),
605 function_name: "sleep".to_string(),
606 params: std::collections::HashMap::new(),
607 context: ExecutionContext {
608 job_id: "job-cancel".to_string(),
609 attempt: 1,
610 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
611 queue_name: "default".to_string(),
612 deadline: None,
613 trace_context: None,
614 worker_id: None,
615 },
616 };
617 let message = RunnerMessage::Request {
618 payload: request.clone(),
619 };
620 write_message(&mut client, &message).await.unwrap();
621 let cancel = RunnerMessage::Cancel {
622 payload: CancelRequest {
623 protocol_version: PROTOCOL_VERSION.to_string(),
624 job_id: request.job_id.clone(),
625 request_id: Some(request.request_id.clone()),
626 hard_kill: false,
627 },
628 };
629 write_message(&mut client, &cancel).await.unwrap();
630 let response = read_message(&mut client).await.unwrap().unwrap();
631 match response {
632 RunnerMessage::Response { payload } => {
633 assert_eq!(payload.status, OutcomeStatus::Error);
634 let error_type = payload
635 .error
636 .as_ref()
637 .and_then(|error| error.error_type.as_deref());
638 assert_eq!(error_type, Some("cancelled"));
639 }
640 _ => panic!("expected response"),
641 }
642 drop(client);
643 let _ = server_task.await;
644 }
645
646 #[tokio::test]
647 async fn cancel_frees_connection_capacity() {
648 let mut registry = Registry::new();
649 let gate = Arc::new(tokio::sync::Semaphore::new(0));
650 let gate_for_handler = gate.clone();
651 registry.register("block", move |request| {
652 let gate = gate_for_handler.clone();
653 async move {
654 let _permit = gate.acquire().await.expect("semaphore closed");
655 ExecutionOutcome::success(
656 request.job_id.clone(),
657 request.request_id.clone(),
658 json!({"ok": true}),
659 )
660 }
661 });
662 let (client, server) = tcp_pair().await;
663 let in_flight = Arc::new(Mutex::new(HashMap::new()));
664 let job_index = Arc::new(Mutex::new(HashMap::new()));
665 let server_task = tokio::spawn(async move {
666 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
667 .await
668 .unwrap();
669 });
670 let mut client = client;
671 let job_id = "job-capacity".to_string();
672 for i in 0..RESPONSE_CHANNEL_CAPACITY {
673 let mut request = build_request("block");
674 request.request_id = format!("req-{i}");
675 request.job_id = job_id.clone();
676 request.context.job_id = job_id.clone();
677 write_message(&mut client, &RunnerMessage::Request { payload: request })
678 .await
679 .unwrap();
680 }
681
682 let cancel = RunnerMessage::Cancel {
683 payload: CancelRequest {
684 protocol_version: PROTOCOL_VERSION.to_string(),
685 job_id: job_id.clone(),
686 request_id: Some("req-0".to_string()),
687 hard_kill: false,
688 },
689 };
690 write_message(&mut client, &cancel).await.unwrap();
691 let response = timeout(Duration::from_secs(1), read_message(&mut client))
692 .await
693 .unwrap()
694 .unwrap()
695 .unwrap();
696 match response {
697 RunnerMessage::Response { payload } => {
698 assert_eq!(payload.status, OutcomeStatus::Error);
699 let error_type = payload
700 .error
701 .as_ref()
702 .and_then(|error| error.error_type.as_deref());
703 assert_eq!(error_type, Some("cancelled"));
704 }
705 _ => panic!("expected response"),
706 }
707
708 let mut extra_request = build_request("block");
709 extra_request.request_id = "req-extra".to_string();
710 extra_request.job_id = job_id.clone();
711 extra_request.context.job_id = job_id.clone();
712 write_message(
713 &mut client,
714 &RunnerMessage::Request {
715 payload: extra_request,
716 },
717 )
718 .await
719 .unwrap();
720
721 gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
722
723 let mut saw_extra = false;
724 for _ in 0..RESPONSE_CHANNEL_CAPACITY {
725 let response = timeout(Duration::from_secs(1), read_message(&mut client))
726 .await
727 .unwrap()
728 .unwrap()
729 .unwrap();
730 if let RunnerMessage::Response { payload } = response
731 && payload.request_id.as_deref() == Some("req-extra")
732 {
733 assert_eq!(payload.status, OutcomeStatus::Success);
734 saw_extra = true;
735 }
736 }
737 assert!(saw_extra, "extra request never completed");
738
739 drop(client);
740 let _ = server_task.await;
741 }
742
743 #[tokio::test]
744 async fn execute_with_deadline_times_out() {
745 let mut registry = Registry::new();
746 registry.register("echo", |request| async move {
747 ExecutionOutcome::success(
748 request.job_id.clone(),
749 request.request_id.clone(),
750 json!({"ok": true}),
751 )
752 });
753 let mut request = build_request("echo");
754 request.context.deadline = Some(
755 "2020-01-01T00:00:00Z"
756 .parse::<chrono::DateTime<Utc>>()
757 .unwrap(),
758 );
759 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
760 assert_eq!(outcome.status, OutcomeStatus::Timeout);
761 }
762
763 #[tokio::test]
764 async fn execute_with_deadline_succeeds_before_deadline() {
765 let mut registry = Registry::new();
766 registry.register("echo", |request| async move {
767 ExecutionOutcome::success(
768 request.job_id.clone(),
769 request.request_id.clone(),
770 json!({"ok": true}),
771 )
772 });
773 let mut request = build_request("echo");
774 request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
775 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
776 assert_eq!(outcome.status, OutcomeStatus::Success);
777 }
778
779 #[tokio::test]
780 async fn handle_connection_handles_unexpected_response_message() {
781 let registry = Registry::new();
782 let (client, server) = tcp_pair().await;
783 let in_flight = Arc::new(Mutex::new(HashMap::new()));
784 let job_index = Arc::new(Mutex::new(HashMap::new()));
785 let server_task = tokio::spawn(async move {
786 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
787 .await
788 .unwrap();
789 });
790 let mut client = client;
791 let response = RunnerMessage::Response {
792 payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
793 };
794 write_message(&mut client, &response).await.unwrap();
795 let reply = read_message(&mut client).await.unwrap().unwrap();
796 match reply {
797 RunnerMessage::Response { payload } => {
798 assert_eq!(payload.status, OutcomeStatus::Error);
799 assert!(
800 payload
801 .error
802 .as_ref()
803 .unwrap()
804 .message
805 .contains("unexpected response")
806 );
807 }
808 _ => panic!("expected response"),
809 }
810 drop(client);
811 let _ = server_task.await;
812 }
813
814 #[tokio::test]
815 async fn handle_connection_cancels_by_job_id() {
816 let mut registry = Registry::new();
817 registry.register("sleep", |request| async move {
818 tokio::time::sleep(Duration::from_millis(200)).await;
819 ExecutionOutcome::success(
820 request.job_id.clone(),
821 request.request_id.clone(),
822 json!({"ok": true}),
823 )
824 });
825 let (client, server) = tcp_pair().await;
826 let in_flight = Arc::new(Mutex::new(HashMap::new()));
827 let job_index = Arc::new(Mutex::new(HashMap::new()));
828 let server_task = tokio::spawn(async move {
829 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
830 .await
831 .unwrap();
832 });
833 let mut client = client;
834 let request = build_request("sleep");
835 let message = RunnerMessage::Request {
836 payload: request.clone(),
837 };
838 write_message(&mut client, &message).await.unwrap();
839 let cancel = RunnerMessage::Cancel {
840 payload: CancelRequest {
841 protocol_version: PROTOCOL_VERSION.to_string(),
842 job_id: request.job_id.clone(),
843 request_id: None,
844 hard_kill: false,
845 },
846 };
847 write_message(&mut client, &cancel).await.unwrap();
848 let response = read_message(&mut client).await.unwrap().unwrap();
849 match response {
850 RunnerMessage::Response { payload } => {
851 assert_eq!(payload.status, OutcomeStatus::Error);
852 let error_type = payload
853 .error
854 .as_ref()
855 .and_then(|error| error.error_type.as_deref());
856 assert_eq!(error_type, Some("cancelled"));
857 }
858 _ => panic!("expected response"),
859 }
860 drop(client);
861 let _ = server_task.await;
862 }
863
864 #[tokio::test]
865 async fn handle_cancel_by_job_id_cancels_all_requests() {
866 let mut registry = Registry::new();
867 registry.register("sleep", |request| async move {
868 tokio::time::sleep(Duration::from_millis(200)).await;
869 ExecutionOutcome::success(
870 request.job_id.clone(),
871 request.request_id.clone(),
872 json!({"ok": true}),
873 )
874 });
875 let (client, server) = tcp_pair().await;
876 let in_flight = Arc::new(Mutex::new(HashMap::new()));
877 let job_index = Arc::new(Mutex::new(HashMap::new()));
878 let server_task = tokio::spawn(async move {
879 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
880 .await
881 .unwrap();
882 });
883 let mut client = client;
884 let mut request1 = build_request("sleep");
885 request1.request_id = "req-1".to_string();
886 request1.job_id = "job-shared".to_string();
887 let mut request2 = build_request("sleep");
888 request2.request_id = "req-2".to_string();
889 request2.job_id = "job-shared".to_string();
890 write_message(&mut client, &RunnerMessage::Request { payload: request1 })
891 .await
892 .unwrap();
893 write_message(&mut client, &RunnerMessage::Request { payload: request2 })
894 .await
895 .unwrap();
896 let cancel = RunnerMessage::Cancel {
897 payload: CancelRequest {
898 protocol_version: PROTOCOL_VERSION.to_string(),
899 job_id: "job-shared".to_string(),
900 request_id: None,
901 hard_kill: false,
902 },
903 };
904 write_message(&mut client, &cancel).await.unwrap();
905
906 let mut cancelled = 0;
907 for _ in 0..2 {
908 let response = timeout(Duration::from_millis(200), read_message(&mut client))
909 .await
910 .unwrap()
911 .unwrap()
912 .unwrap();
913 match response {
914 RunnerMessage::Response { payload } => {
915 assert_eq!(payload.status, OutcomeStatus::Error);
916 let error_type = payload
917 .error
918 .as_ref()
919 .and_then(|error| error.error_type.as_deref());
920 assert_eq!(error_type, Some("cancelled"));
921 cancelled += 1;
922 }
923 _ => panic!("expected response"),
924 }
925 }
926 assert_eq!(cancelled, 2);
927 drop(client);
928 let _ = server_task.await;
929 }
930
931 #[tokio::test]
932 async fn connection_teardown_clears_tracking_maps() {
933 let mut registry = Registry::new();
934 registry.register("sleep", |request| async move {
935 tokio::time::sleep(Duration::from_millis(200)).await;
936 ExecutionOutcome::success(
937 request.job_id.clone(),
938 request.request_id.clone(),
939 json!({"ok": true}),
940 )
941 });
942 let (client, server) = tcp_pair().await;
943 let in_flight = Arc::new(Mutex::new(HashMap::new()));
944 let job_index = Arc::new(Mutex::new(HashMap::new()));
945 let in_flight_for_server = in_flight.clone();
946 let job_index_for_server = job_index.clone();
947 let server_task = tokio::spawn(async move {
948 handle_connection(
949 server,
950 ®istry,
951 &NoopTelemetry,
952 in_flight_for_server,
953 job_index_for_server,
954 )
955 .await
956 .unwrap();
957 });
958 let mut client = client;
959 let request = build_request("sleep");
960 let message = RunnerMessage::Request {
961 payload: request.clone(),
962 };
963 write_message(&mut client, &message).await.unwrap();
964
965 let mut inserted = false;
966 for _ in 0..20 {
967 let has_in_flight = {
968 let guard = in_flight.lock().await;
969 guard.contains_key(&request.request_id)
970 };
971 let has_job_index = {
972 let guard = job_index.lock().await;
973 guard.contains_key(&request.job_id)
974 };
975 if has_in_flight && has_job_index {
976 inserted = true;
977 break;
978 }
979 tokio::time::sleep(Duration::from_millis(10)).await;
980 }
981 assert!(inserted, "request never entered tracking maps");
982
983 drop(client);
984 let _ = server_task.await;
985
986 let in_flight = in_flight.lock().await;
987 let job_index = job_index.lock().await;
988 assert!(in_flight.is_empty());
989 assert!(job_index.is_empty());
990 }
991
992 #[tokio::test]
993 async fn handle_cancel_ignores_invalid_protocol() {
994 let in_flight = Arc::new(Mutex::new(HashMap::new()));
995 let job_index = Arc::new(Mutex::new(HashMap::new()));
996 let (tx, _rx) = mpsc::channel(1);
997 let handle = tokio::spawn(async {});
998 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
999 {
1000 let mut guard = in_flight.lock().await;
1001 guard.insert(
1002 "req-1".to_string(),
1003 InFlightTask {
1004 job_id: "job-1".to_string(),
1005 handle,
1006 response_tx: tx,
1007 connection_requests,
1008 completed: Arc::new(AtomicBool::new(false)),
1009 },
1010 );
1011 }
1012 let payload = CancelRequest {
1013 protocol_version: "0".to_string(),
1014 job_id: "job-1".to_string(),
1015 request_id: None,
1016 hard_kill: false,
1017 };
1018 handle_cancel(payload, &in_flight, &job_index).await;
1019 let guard = in_flight.lock().await;
1020 assert!(guard.contains_key("req-1"));
1021 guard.get("req-1").unwrap().handle.abort();
1022 }
1023
1024 #[tokio::test]
1025 async fn handle_cancel_skips_completed_requests() {
1026 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1027 let job_index = Arc::new(Mutex::new(HashMap::new()));
1028 let (tx, mut rx) = mpsc::channel(1);
1029 let handle = tokio::spawn(async {
1030 tokio::time::sleep(Duration::from_millis(50)).await;
1031 });
1032 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1033 {
1034 let mut guard = in_flight.lock().await;
1035 guard.insert(
1036 "req-1".to_string(),
1037 InFlightTask {
1038 job_id: "job-1".to_string(),
1039 handle,
1040 response_tx: tx,
1041 connection_requests,
1042 completed: Arc::new(AtomicBool::new(true)),
1043 },
1044 );
1045 }
1046 {
1047 let mut guard = job_index.lock().await;
1048 guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1049 }
1050 let payload = CancelRequest {
1051 protocol_version: PROTOCOL_VERSION.to_string(),
1052 job_id: "job-1".to_string(),
1053 request_id: Some("req-1".to_string()),
1054 hard_kill: false,
1055 };
1056 handle_cancel(payload, &in_flight, &job_index).await;
1057 assert!(in_flight.lock().await.contains_key("req-1"));
1058 assert!(job_index.lock().await.contains_key("job-1"));
1059 assert!(rx.try_recv().is_err());
1060 let task = in_flight.lock().await.remove("req-1").unwrap();
1061 task.handle.abort();
1062 }
1063
1064 #[tokio::test]
1065 async fn read_message_handles_empty_and_invalid_payloads() {
1066 let (mut client, mut server) = tokio::io::duplex(64);
1067 client.write_all(&0u32.to_be_bytes()).await.unwrap();
1069 let err = read_message(&mut server).await.unwrap_err();
1070 assert!(err.to_string().contains("payload cannot be empty"));
1071
1072 let (mut client, mut server) = tokio::io::duplex(64);
1074 let payload = b"not-json";
1075 let len = (payload.len() as u32).to_be_bytes();
1076 client.write_all(&len).await.unwrap();
1077 client.write_all(payload).await.unwrap();
1078 let err = read_message(&mut server).await.unwrap_err();
1079 assert!(err.to_string().contains("expected"));
1080
1081 let (mut client, mut server) = tokio::io::duplex(64);
1083 let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1084 client.write_all(&len).await.unwrap();
1085 let err = read_message(&mut server).await.unwrap_err();
1086 assert!(err.to_string().contains("payload too large"));
1087 }
1088
1089 #[tokio::test]
1090 async fn read_message_returns_none_on_eof() {
1091 let (client, mut server) = tokio::io::duplex(8);
1092 drop(client);
1093 let message = read_message(&mut server).await.unwrap();
1094 assert!(message.is_none());
1095 }
1096}