1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9 Arc,
10 atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16
17pub const ENV_RUNNER_TCP_SOCKET: &str = "RRQ_RUNNER_TCP_SOCKET";
18const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
19const RESPONSE_CHANNEL_CAPACITY: usize = 64;
20const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
21
22fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
23 Box::new(std::io::Error::new(
24 std::io::ErrorKind::InvalidInput,
25 message.into(),
26 ))
27}
28
29pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
30 let raw = raw.trim();
31 if raw.is_empty() {
32 return Err(invalid_input("runner tcp_socket cannot be empty"));
33 }
34
35 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
36 let (host, port_str) = rest
37 .split_once("]:")
38 .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
39 (host, port_str)
40 } else {
41 let (host, port_str) = raw
42 .rsplit_once(':')
43 .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
44 if host.is_empty() {
45 return Err(invalid_input("runner tcp_socket host cannot be empty"));
46 }
47 (host, port_str)
48 };
49
50 let port: u16 = port_str
51 .parse()
52 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
53 if port == 0 {
54 return Err(invalid_input("runner tcp_socket port must be > 0"));
55 }
56
57 let ip = if host == "localhost" {
58 IpAddr::V4(Ipv4Addr::LOCALHOST)
59 } else {
60 let parsed: IpAddr = host
61 .parse()
62 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
63 if !parsed.is_loopback() {
64 return Err(invalid_input("runner tcp_socket host must be localhost"));
65 }
66 parsed
67 };
68
69 Ok(SocketAddr::new(ip, port))
70}
71
72pub struct RunnerRuntime {
73 runtime: tokio::runtime::Runtime,
74}
75
76impl RunnerRuntime {
77 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
78 Ok(Self {
79 runtime: tokio::runtime::Runtime::new()?,
80 })
81 }
82
83 pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
84 self.runtime.enter()
85 }
86
87 pub fn run_tcp(
88 &self,
89 registry: &Registry,
90 addr: SocketAddr,
91 ) -> Result<(), Box<dyn std::error::Error>> {
92 let telemetry = NoopTelemetry;
93 self.run_tcp_with(registry, addr, &telemetry)
94 }
95
96 pub fn run_tcp_with<T: Telemetry + ?Sized>(
97 &self,
98 registry: &Registry,
99 addr: SocketAddr,
100 telemetry: &T,
101 ) -> Result<(), Box<dyn std::error::Error>> {
102 run_tcp_loop(&self.runtime, registry, addr, telemetry)
103 }
104}
105
106pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
107 RunnerRuntime::new()?.run_tcp(registry, addr)
108}
109
110pub fn run_tcp_with<T: Telemetry + ?Sized>(
111 registry: &Registry,
112 addr: SocketAddr,
113 telemetry: &T,
114) -> Result<(), Box<dyn std::error::Error>> {
115 RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
116}
117
118fn run_tcp_loop<T: Telemetry + ?Sized>(
119 runtime: &tokio::runtime::Runtime,
120 registry: &Registry,
121 addr: SocketAddr,
122 telemetry: &T,
123) -> Result<(), Box<dyn std::error::Error>> {
124 let registry = registry.clone();
125 let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
126 let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
127 Arc::new(Mutex::new(HashMap::new()));
128 let telemetry = telemetry.clone_box();
129 runtime.block_on(async move {
130 if !addr.ip().is_loopback() {
131 return Err(invalid_input(format!(
132 "runner tcp_socket must be loopback-only (got {addr})"
133 )));
134 }
135 let listener = TcpListener::bind(addr).await?;
136 loop {
137 let (stream, _) = listener.accept().await?;
138 let registry = registry.clone();
139 let telemetry = telemetry.clone();
140 let in_flight = in_flight.clone();
141 let job_index = job_index.clone();
142 tokio::spawn(async move {
143 if let Err(err) =
144 handle_connection(stream, ®istry, telemetry.as_ref(), in_flight, job_index)
145 .await
146 {
147 tracing::error!("runner connection error: {err}");
148 }
149 });
150 }
151 })
152}
153
154async fn handle_connection<S, T>(
155 stream: S,
156 registry: &Registry,
157 telemetry: &T,
158 in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
159 job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
160) -> Result<(), Box<dyn std::error::Error>>
161where
162 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
163 T: Telemetry + ?Sized,
164{
165 let (mut reader, mut writer) = tokio::io::split(stream);
166 let (response_tx, mut response_rx) =
167 mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
168 let writer_task = tokio::spawn(async move {
169 while let Some(outcome) = response_rx.recv().await {
170 let response = RunnerMessage::Response { payload: outcome };
171 if write_message(&mut writer, &response).await.is_err() {
172 break;
173 }
174 }
175 });
176 let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
177 Arc::new(Mutex::new(std::collections::HashSet::new()));
178
179 loop {
180 let message = match read_message(&mut reader).await? {
181 Some(message) => message,
182 None => break,
183 };
184 match message {
185 RunnerMessage::Request { payload } => {
186 if payload.protocol_version != PROTOCOL_VERSION {
187 let outcome = ExecutionOutcome::error(
188 payload.job_id.clone(),
189 payload.request_id.clone(),
190 "Unsupported protocol version",
191 );
192 let _ = response_tx.send(outcome).await;
193 continue;
194 }
195
196 let request_id = payload.request_id.clone();
197 let job_id = payload.job_id.clone();
198 {
199 let mut active = connection_requests.lock().await;
200 if active.len() >= RESPONSE_CHANNEL_CAPACITY {
201 let outcome = ExecutionOutcome::error(
202 payload.job_id.clone(),
203 payload.request_id.clone(),
204 "Runner busy: too many in-flight requests",
205 );
206 drop(active);
207 let _ = response_tx.try_send(outcome);
208 continue;
209 }
210 active.insert(request_id.clone());
211 }
212 let response_tx = response_tx.clone();
213 let registry = registry.clone();
214 let telemetry = telemetry.clone_box();
215 let in_flight_for_task = in_flight.clone();
216 let job_index_for_task = job_index.clone();
217 let active_for_task = connection_requests.clone();
218 let request_id_for_task = request_id.clone();
219 let job_id_for_task = job_id.clone();
220 let response_tx_for_task = response_tx.clone();
221 let completed = Arc::new(AtomicBool::new(false));
222 let completed_for_task = completed.clone();
223
224 let handle = tokio::spawn(async move {
225 let outcome =
226 execute_with_deadline(payload, registry, telemetry.as_ref()).await;
227 completed_for_task.store(true, Ordering::SeqCst);
228 let send_result =
229 timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
230 match send_result {
231 Ok(Ok(())) => {}
232 Ok(Err(_)) => {
233 tracing::warn!("runner response channel closed; dropping outcome");
234 }
235 Err(_) => {
236 tracing::warn!("runner response channel stalled; dropping outcome");
237 }
238 }
239 {
240 let mut in_flight = in_flight_for_task.lock().await;
241 in_flight.remove(&request_id_for_task);
242 }
243 {
244 let mut job_index = job_index_for_task.lock().await;
245 if let Some(entries) = job_index.get_mut(&job_id_for_task) {
246 entries.remove(&request_id_for_task);
247 if entries.is_empty() {
248 job_index.remove(&job_id_for_task);
249 }
250 }
251 }
252 {
253 let mut active = active_for_task.lock().await;
254 active.remove(&request_id_for_task);
255 }
256 });
257
258 {
259 let mut in_flight = in_flight.lock().await;
260 in_flight.insert(
261 request_id.clone(),
262 InFlightTask {
263 job_id: job_id.clone(),
264 handle,
265 response_tx: response_tx.clone(),
266 connection_requests: connection_requests.clone(),
267 completed,
268 },
269 );
270 }
271 {
272 let mut job_index = job_index.lock().await;
273 job_index
274 .entry(job_id)
275 .or_insert_with(HashSet::new)
276 .insert(request_id);
277 }
278 }
279 RunnerMessage::Cancel { payload } => {
280 handle_cancel(payload, &in_flight, &job_index).await;
281 }
282 RunnerMessage::Response { .. } => {
283 let outcome = ExecutionOutcome {
284 job_id: Some("unknown".to_string()),
285 request_id: None,
286 status: rrq_protocol::OutcomeStatus::Error,
287 result: None,
288 error: Some(ExecutionError {
289 message: "unexpected response message".to_string(),
290 error_type: None,
291 code: None,
292 details: None,
293 }),
294 retry_after_seconds: None,
295 };
296 let _ = response_tx.send(outcome).await;
297 }
298 }
299 }
300
301 let request_ids = {
302 let mut active = connection_requests.lock().await;
303 active.drain().collect::<Vec<_>>()
304 };
305 for request_id in request_ids {
306 let task = {
307 let mut in_flight = in_flight.lock().await;
308 in_flight.remove(&request_id)
309 };
310 if let Some(task) = task {
311 task.handle.abort();
312 let mut job_index = job_index.lock().await;
313 if let Some(entries) = job_index.get_mut(&task.job_id) {
314 entries.remove(&request_id);
315 if entries.is_empty() {
316 job_index.remove(&task.job_id);
317 }
318 }
319 }
320 }
321 writer_task.abort();
322
323 Ok(())
324}
325
326struct InFlightTask {
327 job_id: String,
328 handle: tokio::task::JoinHandle<()>,
329 response_tx: mpsc::Sender<ExecutionOutcome>,
330 connection_requests: Arc<Mutex<HashSet<String>>>,
331 completed: Arc<AtomicBool>,
332}
333
334async fn handle_cancel(
335 payload: CancelRequest,
336 in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
337 job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
338) {
339 if payload.protocol_version != PROTOCOL_VERSION {
340 return;
341 }
342 let request_ids = if let Some(request_id) = payload.request_id.clone() {
343 vec![request_id]
344 } else {
345 let job_index = job_index.lock().await;
346 job_index
347 .get(&payload.job_id)
348 .map(|ids| ids.iter().cloned().collect())
349 .unwrap_or_else(Vec::new)
350 };
351 if request_ids.is_empty() {
352 return;
353 }
354
355 for request_id in request_ids {
356 let task = {
357 let mut in_flight = in_flight.lock().await;
358 if let Some(task) = in_flight.get(&request_id)
359 && task.completed.load(Ordering::SeqCst)
360 {
361 None
362 } else {
363 in_flight.remove(&request_id)
364 }
365 };
366 if let Some(task) = task {
367 task.handle.abort();
368 {
369 let mut active = task.connection_requests.lock().await;
370 active.remove(&request_id);
371 }
372 let outcome = ExecutionOutcome {
373 job_id: Some(payload.job_id.clone()),
374 request_id: Some(request_id.clone()),
375 status: OutcomeStatus::Error,
376 result: None,
377 error: Some(ExecutionError {
378 message: "Job cancelled".to_string(),
379 error_type: Some("cancelled".to_string()),
380 code: None,
381 details: None,
382 }),
383 retry_after_seconds: None,
384 };
385 let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
386 match send_result {
387 Ok(Ok(())) => {}
388 Ok(Err(_)) => {
389 tracing::warn!("runner response channel closed; dropping cancel outcome");
390 }
391 Err(_) => {
392 tracing::warn!("runner response channel stalled; dropping cancel outcome");
393 }
394 }
395 let mut job_index = job_index.lock().await;
396 if let Some(entries) = job_index.get_mut(&task.job_id) {
397 entries.remove(&request_id);
398 if entries.is_empty() {
399 job_index.remove(&task.job_id);
400 }
401 }
402 }
403 }
404}
405
406async fn execute_with_deadline<T: Telemetry + ?Sized>(
407 request: rrq_protocol::ExecutionRequest,
408 registry: Registry,
409 telemetry: &T,
410) -> ExecutionOutcome {
411 let job_id = request.job_id.clone();
412 let request_id = request.request_id.clone();
413 let deadline = request.context.deadline;
414 if let Some(deadline) = deadline {
415 let now: DateTime<Utc> = Utc::now();
416 if deadline <= now {
417 return ExecutionOutcome::timeout(
418 job_id.clone(),
419 request_id.clone(),
420 "Job deadline exceeded",
421 );
422 }
423 if let Ok(remaining) = (deadline - now).to_std() {
424 match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
425 Ok(outcome) => return outcome,
426 Err(_) => {
427 return ExecutionOutcome::timeout(
428 job_id.clone(),
429 request_id.clone(),
430 "Job execution timed out",
431 );
432 }
433 }
434 }
435 return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
436 }
437 registry.execute_with(request, telemetry).await
438}
439
440async fn read_message<R: AsyncRead + Unpin>(
441 stream: &mut R,
442) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
443 let mut header = [0u8; 4];
444 match stream.read_exact(&mut header).await {
445 Ok(_) => {}
446 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
447 Err(err) => return Err(err.into()),
448 }
449 let length = u32::from_be_bytes(header) as usize;
450 if length == 0 {
451 return Err("runner message payload cannot be empty".into());
452 }
453 if length > MAX_FRAME_LEN {
454 return Err("runner message payload too large".into());
455 }
456 let mut payload = vec![0u8; length];
457 stream.read_exact(&mut payload).await?;
458 let message = serde_json::from_slice(&payload)?;
459 Ok(Some(message))
460}
461
462async fn write_message<W: AsyncWrite + Unpin>(
463 stream: &mut W,
464 message: &RunnerMessage,
465) -> Result<(), Box<dyn std::error::Error>> {
466 let framed = encode_frame(message)?;
467 stream.write_all(&framed).await?;
468 stream.flush().await?;
469 Ok(())
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475 use crate::registry::Registry;
476 use crate::telemetry::NoopTelemetry;
477 use chrono::Utc;
478 use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
479 use serde_json::json;
480 use tokio::net::{TcpListener, TcpStream};
481 use tokio::time::{Duration, timeout};
482
483 fn build_request(function_name: &str) -> ExecutionRequest {
484 ExecutionRequest {
485 protocol_version: PROTOCOL_VERSION.to_string(),
486 request_id: "req-1".to_string(),
487 job_id: "job-1".to_string(),
488 function_name: function_name.to_string(),
489 args: vec![],
490 kwargs: std::collections::HashMap::new(),
491 context: ExecutionContext {
492 job_id: "job-1".to_string(),
493 attempt: 1,
494 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
495 queue_name: "default".to_string(),
496 deadline: None,
497 trace_context: None,
498 worker_id: None,
499 },
500 }
501 }
502
503 async fn tcp_pair() -> (TcpStream, TcpStream) {
504 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
505 let addr = listener.local_addr().unwrap();
506 let client = TcpStream::connect(addr).await.unwrap();
507 let (server, _) = listener.accept().await.unwrap();
508 (client, server)
509 }
510
511 #[tokio::test]
512 async fn handle_connection_executes_request() {
513 let mut registry = Registry::new();
514 registry.register("echo", |request| async move {
515 ExecutionOutcome::success(
516 request.job_id.clone(),
517 request.request_id.clone(),
518 json!({"ok": true}),
519 )
520 });
521 let (client, server) = tcp_pair().await;
522 let in_flight = Arc::new(Mutex::new(HashMap::new()));
523 let job_index = Arc::new(Mutex::new(HashMap::new()));
524 let server_task = tokio::spawn(async move {
525 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
526 .await
527 .unwrap();
528 });
529 let mut client = client;
530 let request = build_request("echo");
531 let message = RunnerMessage::Request { payload: request };
532 write_message(&mut client, &message).await.unwrap();
533 let response = read_message(&mut client).await.unwrap().unwrap();
534 match response {
535 RunnerMessage::Response { payload } => {
536 assert_eq!(payload.status, OutcomeStatus::Success);
537 assert_eq!(payload.result, Some(json!({"ok": true})));
538 }
539 _ => panic!("expected response"),
540 }
541 drop(client);
542 let _ = server_task.await;
543 }
544
545 #[tokio::test]
546 async fn handle_connection_rejects_bad_protocol() {
547 let registry = Registry::new();
548 let (client, server) = tcp_pair().await;
549 let in_flight = Arc::new(Mutex::new(HashMap::new()));
550 let job_index = Arc::new(Mutex::new(HashMap::new()));
551 let server_task = tokio::spawn(async move {
552 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
553 .await
554 .unwrap();
555 });
556 let mut client = client;
557 let mut request = build_request("echo");
558 request.protocol_version = "0".to_string();
559 let message = RunnerMessage::Request { payload: request };
560 write_message(&mut client, &message).await.unwrap();
561 let response = read_message(&mut client).await.unwrap().unwrap();
562 match response {
563 RunnerMessage::Response { payload } => {
564 assert_eq!(payload.status, OutcomeStatus::Error);
565 }
566 _ => panic!("expected response"),
567 }
568 drop(client);
569 let _ = server_task.await;
570 }
571
572 #[tokio::test]
573 async fn handle_connection_cancels_inflight() {
574 let mut registry = Registry::new();
575 registry.register("sleep", |request| async move {
576 tokio::time::sleep(Duration::from_millis(200)).await;
577 ExecutionOutcome::success(
578 request.job_id.clone(),
579 request.request_id.clone(),
580 json!({"ok": true}),
581 )
582 });
583 let (client, server) = tcp_pair().await;
584 let in_flight = Arc::new(Mutex::new(HashMap::new()));
585 let job_index = Arc::new(Mutex::new(HashMap::new()));
586 let server_task = tokio::spawn(async move {
587 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
588 .await
589 .unwrap();
590 });
591 let mut client = client;
592 let request = ExecutionRequest {
593 protocol_version: PROTOCOL_VERSION.to_string(),
594 request_id: "req-cancel".to_string(),
595 job_id: "job-cancel".to_string(),
596 function_name: "sleep".to_string(),
597 args: vec![],
598 kwargs: std::collections::HashMap::new(),
599 context: ExecutionContext {
600 job_id: "job-cancel".to_string(),
601 attempt: 1,
602 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
603 queue_name: "default".to_string(),
604 deadline: None,
605 trace_context: None,
606 worker_id: None,
607 },
608 };
609 let message = RunnerMessage::Request {
610 payload: request.clone(),
611 };
612 write_message(&mut client, &message).await.unwrap();
613 let cancel = RunnerMessage::Cancel {
614 payload: CancelRequest {
615 protocol_version: PROTOCOL_VERSION.to_string(),
616 job_id: request.job_id.clone(),
617 request_id: Some(request.request_id.clone()),
618 hard_kill: false,
619 },
620 };
621 write_message(&mut client, &cancel).await.unwrap();
622 let response = read_message(&mut client).await.unwrap().unwrap();
623 match response {
624 RunnerMessage::Response { payload } => {
625 assert_eq!(payload.status, OutcomeStatus::Error);
626 let error_type = payload
627 .error
628 .as_ref()
629 .and_then(|error| error.error_type.as_deref());
630 assert_eq!(error_type, Some("cancelled"));
631 }
632 _ => panic!("expected response"),
633 }
634 drop(client);
635 let _ = server_task.await;
636 }
637
638 #[tokio::test]
639 async fn cancel_frees_connection_capacity() {
640 let mut registry = Registry::new();
641 let gate = Arc::new(tokio::sync::Semaphore::new(0));
642 let gate_for_handler = gate.clone();
643 registry.register("block", move |request| {
644 let gate = gate_for_handler.clone();
645 async move {
646 let _permit = gate.acquire().await.expect("semaphore closed");
647 ExecutionOutcome::success(
648 request.job_id.clone(),
649 request.request_id.clone(),
650 json!({"ok": true}),
651 )
652 }
653 });
654 let (client, server) = tcp_pair().await;
655 let in_flight = Arc::new(Mutex::new(HashMap::new()));
656 let job_index = Arc::new(Mutex::new(HashMap::new()));
657 let server_task = tokio::spawn(async move {
658 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
659 .await
660 .unwrap();
661 });
662 let mut client = client;
663 let job_id = "job-capacity".to_string();
664 for i in 0..RESPONSE_CHANNEL_CAPACITY {
665 let mut request = build_request("block");
666 request.request_id = format!("req-{i}");
667 request.job_id = job_id.clone();
668 request.context.job_id = job_id.clone();
669 write_message(&mut client, &RunnerMessage::Request { payload: request })
670 .await
671 .unwrap();
672 }
673
674 let cancel = RunnerMessage::Cancel {
675 payload: CancelRequest {
676 protocol_version: PROTOCOL_VERSION.to_string(),
677 job_id: job_id.clone(),
678 request_id: Some("req-0".to_string()),
679 hard_kill: false,
680 },
681 };
682 write_message(&mut client, &cancel).await.unwrap();
683 let response = timeout(Duration::from_secs(1), read_message(&mut client))
684 .await
685 .unwrap()
686 .unwrap()
687 .unwrap();
688 match response {
689 RunnerMessage::Response { payload } => {
690 assert_eq!(payload.status, OutcomeStatus::Error);
691 let error_type = payload
692 .error
693 .as_ref()
694 .and_then(|error| error.error_type.as_deref());
695 assert_eq!(error_type, Some("cancelled"));
696 }
697 _ => panic!("expected response"),
698 }
699
700 let mut extra_request = build_request("block");
701 extra_request.request_id = "req-extra".to_string();
702 extra_request.job_id = job_id.clone();
703 extra_request.context.job_id = job_id.clone();
704 write_message(
705 &mut client,
706 &RunnerMessage::Request {
707 payload: extra_request,
708 },
709 )
710 .await
711 .unwrap();
712
713 gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
714
715 let mut saw_extra = false;
716 for _ in 0..RESPONSE_CHANNEL_CAPACITY {
717 let response = timeout(Duration::from_secs(1), read_message(&mut client))
718 .await
719 .unwrap()
720 .unwrap()
721 .unwrap();
722 if let RunnerMessage::Response { payload } = response
723 && payload.request_id.as_deref() == Some("req-extra")
724 {
725 assert_eq!(payload.status, OutcomeStatus::Success);
726 saw_extra = true;
727 }
728 }
729 assert!(saw_extra, "extra request never completed");
730
731 drop(client);
732 let _ = server_task.await;
733 }
734
735 #[tokio::test]
736 async fn execute_with_deadline_times_out() {
737 let mut registry = Registry::new();
738 registry.register("echo", |request| async move {
739 ExecutionOutcome::success(
740 request.job_id.clone(),
741 request.request_id.clone(),
742 json!({"ok": true}),
743 )
744 });
745 let mut request = build_request("echo");
746 request.context.deadline = Some(
747 "2020-01-01T00:00:00Z"
748 .parse::<chrono::DateTime<Utc>>()
749 .unwrap(),
750 );
751 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
752 assert_eq!(outcome.status, OutcomeStatus::Timeout);
753 }
754
755 #[tokio::test]
756 async fn execute_with_deadline_succeeds_before_deadline() {
757 let mut registry = Registry::new();
758 registry.register("echo", |request| async move {
759 ExecutionOutcome::success(
760 request.job_id.clone(),
761 request.request_id.clone(),
762 json!({"ok": true}),
763 )
764 });
765 let mut request = build_request("echo");
766 request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
767 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
768 assert_eq!(outcome.status, OutcomeStatus::Success);
769 }
770
771 #[tokio::test]
772 async fn handle_connection_handles_unexpected_response_message() {
773 let registry = Registry::new();
774 let (client, server) = tcp_pair().await;
775 let in_flight = Arc::new(Mutex::new(HashMap::new()));
776 let job_index = Arc::new(Mutex::new(HashMap::new()));
777 let server_task = tokio::spawn(async move {
778 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
779 .await
780 .unwrap();
781 });
782 let mut client = client;
783 let response = RunnerMessage::Response {
784 payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
785 };
786 write_message(&mut client, &response).await.unwrap();
787 let reply = read_message(&mut client).await.unwrap().unwrap();
788 match reply {
789 RunnerMessage::Response { payload } => {
790 assert_eq!(payload.status, OutcomeStatus::Error);
791 assert!(
792 payload
793 .error
794 .as_ref()
795 .unwrap()
796 .message
797 .contains("unexpected response")
798 );
799 }
800 _ => panic!("expected response"),
801 }
802 drop(client);
803 let _ = server_task.await;
804 }
805
806 #[tokio::test]
807 async fn handle_connection_cancels_by_job_id() {
808 let mut registry = Registry::new();
809 registry.register("sleep", |request| async move {
810 tokio::time::sleep(Duration::from_millis(200)).await;
811 ExecutionOutcome::success(
812 request.job_id.clone(),
813 request.request_id.clone(),
814 json!({"ok": true}),
815 )
816 });
817 let (client, server) = tcp_pair().await;
818 let in_flight = Arc::new(Mutex::new(HashMap::new()));
819 let job_index = Arc::new(Mutex::new(HashMap::new()));
820 let server_task = tokio::spawn(async move {
821 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
822 .await
823 .unwrap();
824 });
825 let mut client = client;
826 let request = build_request("sleep");
827 let message = RunnerMessage::Request {
828 payload: request.clone(),
829 };
830 write_message(&mut client, &message).await.unwrap();
831 let cancel = RunnerMessage::Cancel {
832 payload: CancelRequest {
833 protocol_version: PROTOCOL_VERSION.to_string(),
834 job_id: request.job_id.clone(),
835 request_id: None,
836 hard_kill: false,
837 },
838 };
839 write_message(&mut client, &cancel).await.unwrap();
840 let response = read_message(&mut client).await.unwrap().unwrap();
841 match response {
842 RunnerMessage::Response { payload } => {
843 assert_eq!(payload.status, OutcomeStatus::Error);
844 let error_type = payload
845 .error
846 .as_ref()
847 .and_then(|error| error.error_type.as_deref());
848 assert_eq!(error_type, Some("cancelled"));
849 }
850 _ => panic!("expected response"),
851 }
852 drop(client);
853 let _ = server_task.await;
854 }
855
856 #[tokio::test]
857 async fn handle_cancel_by_job_id_cancels_all_requests() {
858 let mut registry = Registry::new();
859 registry.register("sleep", |request| async move {
860 tokio::time::sleep(Duration::from_millis(200)).await;
861 ExecutionOutcome::success(
862 request.job_id.clone(),
863 request.request_id.clone(),
864 json!({"ok": true}),
865 )
866 });
867 let (client, server) = tcp_pair().await;
868 let in_flight = Arc::new(Mutex::new(HashMap::new()));
869 let job_index = Arc::new(Mutex::new(HashMap::new()));
870 let server_task = tokio::spawn(async move {
871 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
872 .await
873 .unwrap();
874 });
875 let mut client = client;
876 let mut request1 = build_request("sleep");
877 request1.request_id = "req-1".to_string();
878 request1.job_id = "job-shared".to_string();
879 let mut request2 = build_request("sleep");
880 request2.request_id = "req-2".to_string();
881 request2.job_id = "job-shared".to_string();
882 write_message(&mut client, &RunnerMessage::Request { payload: request1 })
883 .await
884 .unwrap();
885 write_message(&mut client, &RunnerMessage::Request { payload: request2 })
886 .await
887 .unwrap();
888 let cancel = RunnerMessage::Cancel {
889 payload: CancelRequest {
890 protocol_version: PROTOCOL_VERSION.to_string(),
891 job_id: "job-shared".to_string(),
892 request_id: None,
893 hard_kill: false,
894 },
895 };
896 write_message(&mut client, &cancel).await.unwrap();
897
898 let mut cancelled = 0;
899 for _ in 0..2 {
900 let response = timeout(Duration::from_millis(200), read_message(&mut client))
901 .await
902 .unwrap()
903 .unwrap()
904 .unwrap();
905 match response {
906 RunnerMessage::Response { payload } => {
907 assert_eq!(payload.status, OutcomeStatus::Error);
908 let error_type = payload
909 .error
910 .as_ref()
911 .and_then(|error| error.error_type.as_deref());
912 assert_eq!(error_type, Some("cancelled"));
913 cancelled += 1;
914 }
915 _ => panic!("expected response"),
916 }
917 }
918 assert_eq!(cancelled, 2);
919 drop(client);
920 let _ = server_task.await;
921 }
922
923 #[tokio::test]
924 async fn connection_teardown_clears_tracking_maps() {
925 let mut registry = Registry::new();
926 registry.register("sleep", |request| async move {
927 tokio::time::sleep(Duration::from_millis(200)).await;
928 ExecutionOutcome::success(
929 request.job_id.clone(),
930 request.request_id.clone(),
931 json!({"ok": true}),
932 )
933 });
934 let (client, server) = tcp_pair().await;
935 let in_flight = Arc::new(Mutex::new(HashMap::new()));
936 let job_index = Arc::new(Mutex::new(HashMap::new()));
937 let in_flight_for_server = in_flight.clone();
938 let job_index_for_server = job_index.clone();
939 let server_task = tokio::spawn(async move {
940 handle_connection(
941 server,
942 ®istry,
943 &NoopTelemetry,
944 in_flight_for_server,
945 job_index_for_server,
946 )
947 .await
948 .unwrap();
949 });
950 let mut client = client;
951 let request = build_request("sleep");
952 let message = RunnerMessage::Request {
953 payload: request.clone(),
954 };
955 write_message(&mut client, &message).await.unwrap();
956
957 let mut inserted = false;
958 for _ in 0..20 {
959 let has_in_flight = {
960 let guard = in_flight.lock().await;
961 guard.contains_key(&request.request_id)
962 };
963 let has_job_index = {
964 let guard = job_index.lock().await;
965 guard.contains_key(&request.job_id)
966 };
967 if has_in_flight && has_job_index {
968 inserted = true;
969 break;
970 }
971 tokio::time::sleep(Duration::from_millis(10)).await;
972 }
973 assert!(inserted, "request never entered tracking maps");
974
975 drop(client);
976 let _ = server_task.await;
977
978 let in_flight = in_flight.lock().await;
979 let job_index = job_index.lock().await;
980 assert!(in_flight.is_empty());
981 assert!(job_index.is_empty());
982 }
983
984 #[tokio::test]
985 async fn handle_cancel_ignores_invalid_protocol() {
986 let in_flight = Arc::new(Mutex::new(HashMap::new()));
987 let job_index = Arc::new(Mutex::new(HashMap::new()));
988 let (tx, _rx) = mpsc::channel(1);
989 let handle = tokio::spawn(async {});
990 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
991 {
992 let mut guard = in_flight.lock().await;
993 guard.insert(
994 "req-1".to_string(),
995 InFlightTask {
996 job_id: "job-1".to_string(),
997 handle,
998 response_tx: tx,
999 connection_requests,
1000 completed: Arc::new(AtomicBool::new(false)),
1001 },
1002 );
1003 }
1004 let payload = CancelRequest {
1005 protocol_version: "0".to_string(),
1006 job_id: "job-1".to_string(),
1007 request_id: None,
1008 hard_kill: false,
1009 };
1010 handle_cancel(payload, &in_flight, &job_index).await;
1011 let guard = in_flight.lock().await;
1012 assert!(guard.contains_key("req-1"));
1013 guard.get("req-1").unwrap().handle.abort();
1014 }
1015
1016 #[tokio::test]
1017 async fn handle_cancel_skips_completed_requests() {
1018 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1019 let job_index = Arc::new(Mutex::new(HashMap::new()));
1020 let (tx, mut rx) = mpsc::channel(1);
1021 let handle = tokio::spawn(async {
1022 tokio::time::sleep(Duration::from_millis(50)).await;
1023 });
1024 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1025 {
1026 let mut guard = in_flight.lock().await;
1027 guard.insert(
1028 "req-1".to_string(),
1029 InFlightTask {
1030 job_id: "job-1".to_string(),
1031 handle,
1032 response_tx: tx,
1033 connection_requests,
1034 completed: Arc::new(AtomicBool::new(true)),
1035 },
1036 );
1037 }
1038 {
1039 let mut guard = job_index.lock().await;
1040 guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1041 }
1042 let payload = CancelRequest {
1043 protocol_version: PROTOCOL_VERSION.to_string(),
1044 job_id: "job-1".to_string(),
1045 request_id: Some("req-1".to_string()),
1046 hard_kill: false,
1047 };
1048 handle_cancel(payload, &in_flight, &job_index).await;
1049 assert!(in_flight.lock().await.contains_key("req-1"));
1050 assert!(job_index.lock().await.contains_key("job-1"));
1051 assert!(rx.try_recv().is_err());
1052 let task = in_flight.lock().await.remove("req-1").unwrap();
1053 task.handle.abort();
1054 }
1055
1056 #[tokio::test]
1057 async fn read_message_handles_empty_and_invalid_payloads() {
1058 let (mut client, mut server) = tokio::io::duplex(64);
1059 client.write_all(&0u32.to_be_bytes()).await.unwrap();
1061 let err = read_message(&mut server).await.unwrap_err();
1062 assert!(err.to_string().contains("payload cannot be empty"));
1063
1064 let (mut client, mut server) = tokio::io::duplex(64);
1066 let payload = b"not-json";
1067 let len = (payload.len() as u32).to_be_bytes();
1068 client.write_all(&len).await.unwrap();
1069 client.write_all(payload).await.unwrap();
1070 let err = read_message(&mut server).await.unwrap_err();
1071 assert!(err.to_string().contains("expected"));
1072
1073 let (mut client, mut server) = tokio::io::duplex(64);
1075 let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1076 client.write_all(&len).await.unwrap();
1077 let err = read_message(&mut server).await.unwrap_err();
1078 assert!(err.to_string().contains("payload too large"));
1079 }
1080
1081 #[tokio::test]
1082 async fn read_message_returns_none_on_eof() {
1083 let (client, mut server) = tokio::io::duplex(8);
1084 drop(client);
1085 let message = read_message(&mut server).await.unwrap();
1086 assert!(message.is_none());
1087 }
1088}