1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9 Arc,
10 atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
17const RESPONSE_CHANNEL_CAPACITY: usize = 64;
18const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
19
20fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
21 Box::new(std::io::Error::new(
22 std::io::ErrorKind::InvalidInput,
23 message.into(),
24 ))
25}
26
27pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
28 let raw = raw.trim();
29 if raw.is_empty() {
30 return Err(invalid_input("runner tcp_socket cannot be empty"));
31 }
32
33 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
34 let (host, port_str) = rest
35 .split_once("]:")
36 .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
37 (host, port_str)
38 } else {
39 let (host, port_str) = raw
40 .rsplit_once(':')
41 .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
42 if host.is_empty() {
43 return Err(invalid_input("runner tcp_socket host cannot be empty"));
44 }
45 (host, port_str)
46 };
47
48 let port: u16 = port_str
49 .parse()
50 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
51 if port == 0 {
52 return Err(invalid_input("runner tcp_socket port must be > 0"));
53 }
54
55 let ip = if host == "localhost" {
56 IpAddr::V4(Ipv4Addr::LOCALHOST)
57 } else {
58 let parsed: IpAddr = host
59 .parse()
60 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
61 if !parsed.is_loopback() {
62 return Err(invalid_input("runner tcp_socket host must be localhost"));
63 }
64 parsed
65 };
66
67 Ok(SocketAddr::new(ip, port))
68}
69
70pub struct RunnerRuntime {
71 runtime: tokio::runtime::Runtime,
72}
73
74impl RunnerRuntime {
75 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
76 Ok(Self {
77 runtime: tokio::runtime::Runtime::new()?,
78 })
79 }
80
81 pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
82 self.runtime.enter()
83 }
84
85 pub fn run_tcp(
86 &self,
87 registry: &Registry,
88 addr: SocketAddr,
89 ) -> Result<(), Box<dyn std::error::Error>> {
90 let telemetry = NoopTelemetry;
91 self.run_tcp_with(registry, addr, &telemetry)
92 }
93
94 pub fn run_tcp_with<T: Telemetry + ?Sized>(
95 &self,
96 registry: &Registry,
97 addr: SocketAddr,
98 telemetry: &T,
99 ) -> Result<(), Box<dyn std::error::Error>> {
100 run_tcp_loop(&self.runtime, registry, addr, telemetry)
101 }
102}
103
104pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
105 RunnerRuntime::new()?.run_tcp(registry, addr)
106}
107
108pub fn run_tcp_with<T: Telemetry + ?Sized>(
109 registry: &Registry,
110 addr: SocketAddr,
111 telemetry: &T,
112) -> Result<(), Box<dyn std::error::Error>> {
113 RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
114}
115
116fn run_tcp_loop<T: Telemetry + ?Sized>(
117 runtime: &tokio::runtime::Runtime,
118 registry: &Registry,
119 addr: SocketAddr,
120 telemetry: &T,
121) -> Result<(), Box<dyn std::error::Error>> {
122 let registry = registry.clone();
123 let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
124 let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
125 Arc::new(Mutex::new(HashMap::new()));
126 let telemetry = telemetry.clone_box();
127 runtime.block_on(async move {
128 if !addr.ip().is_loopback() {
129 return Err(invalid_input(format!(
130 "runner tcp_socket must be loopback-only (got {addr})"
131 )));
132 }
133 let listener = TcpListener::bind(addr).await?;
134 loop {
135 let (stream, _) = listener.accept().await?;
136 let registry = registry.clone();
137 let telemetry = telemetry.clone();
138 let in_flight = in_flight.clone();
139 let job_index = job_index.clone();
140 tokio::spawn(async move {
141 if let Err(err) =
142 handle_connection(stream, ®istry, telemetry.as_ref(), in_flight, job_index)
143 .await
144 {
145 tracing::error!("runner connection error: {err}");
146 }
147 });
148 }
149 })
150}
151
152async fn handle_connection<S, T>(
153 stream: S,
154 registry: &Registry,
155 telemetry: &T,
156 in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
157 job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
158) -> Result<(), Box<dyn std::error::Error>>
159where
160 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
161 T: Telemetry + ?Sized,
162{
163 let (mut reader, mut writer) = tokio::io::split(stream);
164 let (response_tx, mut response_rx) =
165 mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
166 let writer_task = tokio::spawn(async move {
167 while let Some(outcome) = response_rx.recv().await {
168 let response = RunnerMessage::Response { payload: outcome };
169 if write_message(&mut writer, &response).await.is_err() {
170 break;
171 }
172 }
173 });
174 let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
175 Arc::new(Mutex::new(std::collections::HashSet::new()));
176
177 loop {
178 let message = match read_message(&mut reader).await? {
179 Some(message) => message,
180 None => break,
181 };
182 match message {
183 RunnerMessage::Request { payload } => {
184 if payload.protocol_version != PROTOCOL_VERSION {
185 let outcome = ExecutionOutcome::error(
186 payload.job_id.clone(),
187 payload.request_id.clone(),
188 "Unsupported protocol version",
189 );
190 let _ = response_tx.send(outcome).await;
191 continue;
192 }
193
194 let request_id = payload.request_id.clone();
195 let job_id = payload.job_id.clone();
196 {
197 let mut active = connection_requests.lock().await;
198 if active.len() >= RESPONSE_CHANNEL_CAPACITY {
199 let outcome = ExecutionOutcome::error(
200 payload.job_id.clone(),
201 payload.request_id.clone(),
202 "Runner busy: too many in-flight requests",
203 );
204 drop(active);
205 let send_result =
206 timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
207 match send_result {
208 Ok(Ok(())) => {}
209 Ok(Err(_)) => {
210 return Err("runner response channel closed".into());
211 }
212 Err(_) => {
213 return Err("runner response channel stalled".into());
214 }
215 }
216 continue;
217 }
218 active.insert(request_id.clone());
219 }
220 let response_tx = response_tx.clone();
221 let registry = registry.clone();
222 let telemetry = telemetry.clone_box();
223 let in_flight_for_task = in_flight.clone();
224 let job_index_for_task = job_index.clone();
225 let active_for_task = connection_requests.clone();
226 let request_id_for_task = request_id.clone();
227 let job_id_for_task = job_id.clone();
228 let response_tx_for_task = response_tx.clone();
229 let completed = Arc::new(AtomicBool::new(false));
230 let completed_for_task = completed.clone();
231
232 let handle = tokio::spawn(async move {
233 let outcome =
234 execute_with_deadline(payload, registry, telemetry.as_ref()).await;
235 completed_for_task.store(true, Ordering::SeqCst);
236 let send_result =
237 timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
238 match send_result {
239 Ok(Ok(())) => {}
240 Ok(Err(_)) => {
241 tracing::warn!("runner response channel closed; dropping outcome");
242 }
243 Err(_) => {
244 tracing::warn!("runner response channel stalled; dropping outcome");
245 }
246 }
247 {
248 let mut in_flight = in_flight_for_task.lock().await;
249 in_flight.remove(&request_id_for_task);
250 }
251 {
252 let mut job_index = job_index_for_task.lock().await;
253 if let Some(entries) = job_index.get_mut(&job_id_for_task) {
254 entries.remove(&request_id_for_task);
255 if entries.is_empty() {
256 job_index.remove(&job_id_for_task);
257 }
258 }
259 }
260 {
261 let mut active = active_for_task.lock().await;
262 active.remove(&request_id_for_task);
263 }
264 });
265
266 {
267 let mut in_flight = in_flight.lock().await;
268 in_flight.insert(
269 request_id.clone(),
270 InFlightTask {
271 job_id: job_id.clone(),
272 handle,
273 response_tx: response_tx.clone(),
274 connection_requests: connection_requests.clone(),
275 completed,
276 },
277 );
278 }
279 {
280 let mut job_index = job_index.lock().await;
281 job_index
282 .entry(job_id)
283 .or_insert_with(HashSet::new)
284 .insert(request_id);
285 }
286 }
287 RunnerMessage::Cancel { payload } => {
288 handle_cancel(payload, &in_flight, &job_index).await;
289 }
290 RunnerMessage::Response { .. } => {
291 let outcome = ExecutionOutcome {
292 job_id: Some("unknown".to_string()),
293 request_id: None,
294 status: rrq_protocol::OutcomeStatus::Error,
295 result: None,
296 error: Some(ExecutionError {
297 message: "unexpected response message".to_string(),
298 error_type: None,
299 code: None,
300 details: None,
301 }),
302 retry_after_seconds: None,
303 };
304 let _ = response_tx.send(outcome).await;
305 }
306 }
307 }
308
309 let request_ids = {
310 let mut active = connection_requests.lock().await;
311 active.drain().collect::<Vec<_>>()
312 };
313 for request_id in request_ids {
314 let task = {
315 let mut in_flight = in_flight.lock().await;
316 in_flight.remove(&request_id)
317 };
318 if let Some(task) = task {
319 task.handle.abort();
320 let mut job_index = job_index.lock().await;
321 if let Some(entries) = job_index.get_mut(&task.job_id) {
322 entries.remove(&request_id);
323 if entries.is_empty() {
324 job_index.remove(&task.job_id);
325 }
326 }
327 }
328 }
329 writer_task.abort();
330
331 Ok(())
332}
333
334struct InFlightTask {
335 job_id: String,
336 handle: tokio::task::JoinHandle<()>,
337 response_tx: mpsc::Sender<ExecutionOutcome>,
338 connection_requests: Arc<Mutex<HashSet<String>>>,
339 completed: Arc<AtomicBool>,
340}
341
342async fn handle_cancel(
343 payload: CancelRequest,
344 in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
345 job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
346) {
347 if payload.protocol_version != PROTOCOL_VERSION {
348 return;
349 }
350 let request_ids = if let Some(request_id) = payload.request_id.clone() {
351 vec![request_id]
352 } else {
353 let job_index = job_index.lock().await;
354 job_index
355 .get(&payload.job_id)
356 .map(|ids| ids.iter().cloned().collect())
357 .unwrap_or_else(Vec::new)
358 };
359 if request_ids.is_empty() {
360 return;
361 }
362
363 for request_id in request_ids {
364 let task = {
365 let mut in_flight = in_flight.lock().await;
366 if let Some(task) = in_flight.get(&request_id)
367 && task.completed.load(Ordering::SeqCst)
368 {
369 None
370 } else {
371 in_flight.remove(&request_id)
372 }
373 };
374 if let Some(task) = task {
375 task.handle.abort();
376 {
377 let mut active = task.connection_requests.lock().await;
378 active.remove(&request_id);
379 }
380 let outcome = ExecutionOutcome {
381 job_id: Some(payload.job_id.clone()),
382 request_id: Some(request_id.clone()),
383 status: OutcomeStatus::Error,
384 result: None,
385 error: Some(ExecutionError {
386 message: "Job cancelled".to_string(),
387 error_type: Some("cancelled".to_string()),
388 code: None,
389 details: None,
390 }),
391 retry_after_seconds: None,
392 };
393 let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
394 match send_result {
395 Ok(Ok(())) => {}
396 Ok(Err(_)) => {
397 tracing::warn!("runner response channel closed; dropping cancel outcome");
398 }
399 Err(_) => {
400 tracing::warn!("runner response channel stalled; dropping cancel outcome");
401 }
402 }
403 let mut job_index = job_index.lock().await;
404 if let Some(entries) = job_index.get_mut(&task.job_id) {
405 entries.remove(&request_id);
406 if entries.is_empty() {
407 job_index.remove(&task.job_id);
408 }
409 }
410 }
411 }
412}
413
414async fn execute_with_deadline<T: Telemetry + ?Sized>(
415 request: rrq_protocol::ExecutionRequest,
416 registry: Registry,
417 telemetry: &T,
418) -> ExecutionOutcome {
419 let job_id = request.job_id.clone();
420 let request_id = request.request_id.clone();
421 let deadline = request.context.deadline;
422 if let Some(deadline) = deadline {
423 let now: DateTime<Utc> = Utc::now();
424 if deadline <= now {
425 return ExecutionOutcome::timeout(
426 job_id.clone(),
427 request_id.clone(),
428 "Job deadline exceeded",
429 );
430 }
431 if let Ok(remaining) = (deadline - now).to_std() {
432 match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
433 Ok(outcome) => return outcome,
434 Err(_) => {
435 return ExecutionOutcome::timeout(
436 job_id.clone(),
437 request_id.clone(),
438 "Job execution timed out",
439 );
440 }
441 }
442 }
443 return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
444 }
445 registry.execute_with(request, telemetry).await
446}
447
448async fn read_message<R: AsyncRead + Unpin>(
449 stream: &mut R,
450) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
451 let mut header = [0u8; 4];
452 match stream.read_exact(&mut header).await {
453 Ok(_) => {}
454 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
455 Err(err) => return Err(err.into()),
456 }
457 let length = u32::from_be_bytes(header) as usize;
458 if length == 0 {
459 return Err("runner message payload cannot be empty".into());
460 }
461 if length > MAX_FRAME_LEN {
462 return Err("runner message payload too large".into());
463 }
464 let mut payload = vec![0u8; length];
465 stream.read_exact(&mut payload).await?;
466 let message = serde_json::from_slice(&payload)?;
467 Ok(Some(message))
468}
469
470async fn write_message<W: AsyncWrite + Unpin>(
471 stream: &mut W,
472 message: &RunnerMessage,
473) -> Result<(), Box<dyn std::error::Error>> {
474 let framed = encode_frame(message)?;
475 stream.write_all(&framed).await?;
476 stream.flush().await?;
477 Ok(())
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::registry::Registry;
484 use crate::telemetry::NoopTelemetry;
485 use chrono::Utc;
486 use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
487 use serde_json::json;
488 use tokio::net::{TcpListener, TcpStream};
489 use tokio::time::{Duration, timeout};
490
491 fn build_request(function_name: &str) -> ExecutionRequest {
492 ExecutionRequest {
493 protocol_version: PROTOCOL_VERSION.to_string(),
494 request_id: "req-1".to_string(),
495 job_id: "job-1".to_string(),
496 function_name: function_name.to_string(),
497 params: std::collections::HashMap::new(),
498 context: ExecutionContext {
499 job_id: "job-1".to_string(),
500 attempt: 1,
501 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
502 queue_name: "default".to_string(),
503 deadline: None,
504 trace_context: None,
505 worker_id: None,
506 },
507 }
508 }
509
510 async fn tcp_pair() -> (TcpStream, TcpStream) {
511 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
512 let addr = listener.local_addr().unwrap();
513 let client = TcpStream::connect(addr).await.unwrap();
514 let (server, _) = listener.accept().await.unwrap();
515 (client, server)
516 }
517
518 #[tokio::test]
519 async fn handle_connection_executes_request() {
520 let mut registry = Registry::new();
521 registry.register("echo", |request| async move {
522 ExecutionOutcome::success(
523 request.job_id.clone(),
524 request.request_id.clone(),
525 json!({"ok": true}),
526 )
527 });
528 let (client, server) = tcp_pair().await;
529 let in_flight = Arc::new(Mutex::new(HashMap::new()));
530 let job_index = Arc::new(Mutex::new(HashMap::new()));
531 let server_task = tokio::spawn(async move {
532 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
533 .await
534 .unwrap();
535 });
536 let mut client = client;
537 let request = build_request("echo");
538 let message = RunnerMessage::Request { payload: request };
539 write_message(&mut client, &message).await.unwrap();
540 let response = read_message(&mut client).await.unwrap().unwrap();
541 match response {
542 RunnerMessage::Response { payload } => {
543 assert_eq!(payload.status, OutcomeStatus::Success);
544 assert_eq!(payload.result, Some(json!({"ok": true})));
545 }
546 _ => panic!("expected response"),
547 }
548 drop(client);
549 let _ = server_task.await;
550 }
551
552 #[tokio::test]
553 async fn handle_connection_rejects_bad_protocol() {
554 let registry = Registry::new();
555 let (client, server) = tcp_pair().await;
556 let in_flight = Arc::new(Mutex::new(HashMap::new()));
557 let job_index = Arc::new(Mutex::new(HashMap::new()));
558 let server_task = tokio::spawn(async move {
559 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
560 .await
561 .unwrap();
562 });
563 let mut client = client;
564 let mut request = build_request("echo");
565 request.protocol_version = "0".to_string();
566 let message = RunnerMessage::Request { payload: request };
567 write_message(&mut client, &message).await.unwrap();
568 let response = read_message(&mut client).await.unwrap().unwrap();
569 match response {
570 RunnerMessage::Response { payload } => {
571 assert_eq!(payload.status, OutcomeStatus::Error);
572 }
573 _ => panic!("expected response"),
574 }
575 drop(client);
576 let _ = server_task.await;
577 }
578
579 #[tokio::test]
580 async fn handle_connection_cancels_inflight() {
581 let mut registry = Registry::new();
582 registry.register("sleep", |request| async move {
583 tokio::time::sleep(Duration::from_millis(200)).await;
584 ExecutionOutcome::success(
585 request.job_id.clone(),
586 request.request_id.clone(),
587 json!({"ok": true}),
588 )
589 });
590 let (client, server) = tcp_pair().await;
591 let in_flight = Arc::new(Mutex::new(HashMap::new()));
592 let job_index = Arc::new(Mutex::new(HashMap::new()));
593 let server_task = tokio::spawn(async move {
594 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
595 .await
596 .unwrap();
597 });
598 let mut client = client;
599 let request = ExecutionRequest {
600 protocol_version: PROTOCOL_VERSION.to_string(),
601 request_id: "req-cancel".to_string(),
602 job_id: "job-cancel".to_string(),
603 function_name: "sleep".to_string(),
604 params: std::collections::HashMap::new(),
605 context: ExecutionContext {
606 job_id: "job-cancel".to_string(),
607 attempt: 1,
608 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
609 queue_name: "default".to_string(),
610 deadline: None,
611 trace_context: None,
612 worker_id: None,
613 },
614 };
615 let message = RunnerMessage::Request {
616 payload: request.clone(),
617 };
618 write_message(&mut client, &message).await.unwrap();
619 let cancel = RunnerMessage::Cancel {
620 payload: CancelRequest {
621 protocol_version: PROTOCOL_VERSION.to_string(),
622 job_id: request.job_id.clone(),
623 request_id: Some(request.request_id.clone()),
624 hard_kill: false,
625 },
626 };
627 write_message(&mut client, &cancel).await.unwrap();
628 let response = read_message(&mut client).await.unwrap().unwrap();
629 match response {
630 RunnerMessage::Response { payload } => {
631 assert_eq!(payload.status, OutcomeStatus::Error);
632 let error_type = payload
633 .error
634 .as_ref()
635 .and_then(|error| error.error_type.as_deref());
636 assert_eq!(error_type, Some("cancelled"));
637 }
638 _ => panic!("expected response"),
639 }
640 drop(client);
641 let _ = server_task.await;
642 }
643
644 #[tokio::test]
645 async fn cancel_frees_connection_capacity() {
646 let mut registry = Registry::new();
647 let gate = Arc::new(tokio::sync::Semaphore::new(0));
648 let gate_for_handler = gate.clone();
649 registry.register("block", move |request| {
650 let gate = gate_for_handler.clone();
651 async move {
652 let _permit = gate.acquire().await.expect("semaphore closed");
653 ExecutionOutcome::success(
654 request.job_id.clone(),
655 request.request_id.clone(),
656 json!({"ok": true}),
657 )
658 }
659 });
660 let (client, server) = tcp_pair().await;
661 let in_flight = Arc::new(Mutex::new(HashMap::new()));
662 let job_index = Arc::new(Mutex::new(HashMap::new()));
663 let server_task = tokio::spawn(async move {
664 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
665 .await
666 .unwrap();
667 });
668 let mut client = client;
669 let job_id = "job-capacity".to_string();
670 for i in 0..RESPONSE_CHANNEL_CAPACITY {
671 let mut request = build_request("block");
672 request.request_id = format!("req-{i}");
673 request.job_id = job_id.clone();
674 request.context.job_id = job_id.clone();
675 write_message(&mut client, &RunnerMessage::Request { payload: request })
676 .await
677 .unwrap();
678 }
679
680 let cancel = RunnerMessage::Cancel {
681 payload: CancelRequest {
682 protocol_version: PROTOCOL_VERSION.to_string(),
683 job_id: job_id.clone(),
684 request_id: Some("req-0".to_string()),
685 hard_kill: false,
686 },
687 };
688 write_message(&mut client, &cancel).await.unwrap();
689 let response = timeout(Duration::from_secs(1), read_message(&mut client))
690 .await
691 .unwrap()
692 .unwrap()
693 .unwrap();
694 match response {
695 RunnerMessage::Response { payload } => {
696 assert_eq!(payload.status, OutcomeStatus::Error);
697 let error_type = payload
698 .error
699 .as_ref()
700 .and_then(|error| error.error_type.as_deref());
701 assert_eq!(error_type, Some("cancelled"));
702 }
703 _ => panic!("expected response"),
704 }
705
706 let mut extra_request = build_request("block");
707 extra_request.request_id = "req-extra".to_string();
708 extra_request.job_id = job_id.clone();
709 extra_request.context.job_id = job_id.clone();
710 write_message(
711 &mut client,
712 &RunnerMessage::Request {
713 payload: extra_request,
714 },
715 )
716 .await
717 .unwrap();
718
719 gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
720
721 let mut saw_extra = false;
722 for _ in 0..RESPONSE_CHANNEL_CAPACITY {
723 let response = timeout(Duration::from_secs(1), read_message(&mut client))
724 .await
725 .unwrap()
726 .unwrap()
727 .unwrap();
728 if let RunnerMessage::Response { payload } = response
729 && payload.request_id.as_deref() == Some("req-extra")
730 {
731 assert_eq!(payload.status, OutcomeStatus::Success);
732 saw_extra = true;
733 }
734 }
735 assert!(saw_extra, "extra request never completed");
736
737 drop(client);
738 let _ = server_task.await;
739 }
740
741 #[tokio::test]
742 async fn execute_with_deadline_times_out() {
743 let mut registry = Registry::new();
744 registry.register("echo", |request| async move {
745 ExecutionOutcome::success(
746 request.job_id.clone(),
747 request.request_id.clone(),
748 json!({"ok": true}),
749 )
750 });
751 let mut request = build_request("echo");
752 request.context.deadline = Some(
753 "2020-01-01T00:00:00Z"
754 .parse::<chrono::DateTime<Utc>>()
755 .unwrap(),
756 );
757 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
758 assert_eq!(outcome.status, OutcomeStatus::Timeout);
759 }
760
761 #[tokio::test]
762 async fn execute_with_deadline_succeeds_before_deadline() {
763 let mut registry = Registry::new();
764 registry.register("echo", |request| async move {
765 ExecutionOutcome::success(
766 request.job_id.clone(),
767 request.request_id.clone(),
768 json!({"ok": true}),
769 )
770 });
771 let mut request = build_request("echo");
772 request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
773 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
774 assert_eq!(outcome.status, OutcomeStatus::Success);
775 }
776
777 #[tokio::test]
778 async fn handle_connection_handles_unexpected_response_message() {
779 let registry = Registry::new();
780 let (client, server) = tcp_pair().await;
781 let in_flight = Arc::new(Mutex::new(HashMap::new()));
782 let job_index = Arc::new(Mutex::new(HashMap::new()));
783 let server_task = tokio::spawn(async move {
784 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
785 .await
786 .unwrap();
787 });
788 let mut client = client;
789 let response = RunnerMessage::Response {
790 payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
791 };
792 write_message(&mut client, &response).await.unwrap();
793 let reply = read_message(&mut client).await.unwrap().unwrap();
794 match reply {
795 RunnerMessage::Response { payload } => {
796 assert_eq!(payload.status, OutcomeStatus::Error);
797 assert!(
798 payload
799 .error
800 .as_ref()
801 .unwrap()
802 .message
803 .contains("unexpected response")
804 );
805 }
806 _ => panic!("expected response"),
807 }
808 drop(client);
809 let _ = server_task.await;
810 }
811
812 #[tokio::test]
813 async fn handle_connection_cancels_by_job_id() {
814 let mut registry = Registry::new();
815 registry.register("sleep", |request| async move {
816 tokio::time::sleep(Duration::from_millis(200)).await;
817 ExecutionOutcome::success(
818 request.job_id.clone(),
819 request.request_id.clone(),
820 json!({"ok": true}),
821 )
822 });
823 let (client, server) = tcp_pair().await;
824 let in_flight = Arc::new(Mutex::new(HashMap::new()));
825 let job_index = Arc::new(Mutex::new(HashMap::new()));
826 let server_task = tokio::spawn(async move {
827 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
828 .await
829 .unwrap();
830 });
831 let mut client = client;
832 let request = build_request("sleep");
833 let message = RunnerMessage::Request {
834 payload: request.clone(),
835 };
836 write_message(&mut client, &message).await.unwrap();
837 let cancel = RunnerMessage::Cancel {
838 payload: CancelRequest {
839 protocol_version: PROTOCOL_VERSION.to_string(),
840 job_id: request.job_id.clone(),
841 request_id: None,
842 hard_kill: false,
843 },
844 };
845 write_message(&mut client, &cancel).await.unwrap();
846 let response = read_message(&mut client).await.unwrap().unwrap();
847 match response {
848 RunnerMessage::Response { payload } => {
849 assert_eq!(payload.status, OutcomeStatus::Error);
850 let error_type = payload
851 .error
852 .as_ref()
853 .and_then(|error| error.error_type.as_deref());
854 assert_eq!(error_type, Some("cancelled"));
855 }
856 _ => panic!("expected response"),
857 }
858 drop(client);
859 let _ = server_task.await;
860 }
861
862 #[tokio::test]
863 async fn handle_cancel_by_job_id_cancels_all_requests() {
864 let mut registry = Registry::new();
865 registry.register("sleep", |request| async move {
866 tokio::time::sleep(Duration::from_millis(200)).await;
867 ExecutionOutcome::success(
868 request.job_id.clone(),
869 request.request_id.clone(),
870 json!({"ok": true}),
871 )
872 });
873 let (client, server) = tcp_pair().await;
874 let in_flight = Arc::new(Mutex::new(HashMap::new()));
875 let job_index = Arc::new(Mutex::new(HashMap::new()));
876 let server_task = tokio::spawn(async move {
877 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
878 .await
879 .unwrap();
880 });
881 let mut client = client;
882 let mut request1 = build_request("sleep");
883 request1.request_id = "req-1".to_string();
884 request1.job_id = "job-shared".to_string();
885 let mut request2 = build_request("sleep");
886 request2.request_id = "req-2".to_string();
887 request2.job_id = "job-shared".to_string();
888 write_message(&mut client, &RunnerMessage::Request { payload: request1 })
889 .await
890 .unwrap();
891 write_message(&mut client, &RunnerMessage::Request { payload: request2 })
892 .await
893 .unwrap();
894 let cancel = RunnerMessage::Cancel {
895 payload: CancelRequest {
896 protocol_version: PROTOCOL_VERSION.to_string(),
897 job_id: "job-shared".to_string(),
898 request_id: None,
899 hard_kill: false,
900 },
901 };
902 write_message(&mut client, &cancel).await.unwrap();
903
904 let mut cancelled = 0;
905 for _ in 0..2 {
906 let response = timeout(Duration::from_millis(200), read_message(&mut client))
907 .await
908 .unwrap()
909 .unwrap()
910 .unwrap();
911 match response {
912 RunnerMessage::Response { payload } => {
913 assert_eq!(payload.status, OutcomeStatus::Error);
914 let error_type = payload
915 .error
916 .as_ref()
917 .and_then(|error| error.error_type.as_deref());
918 assert_eq!(error_type, Some("cancelled"));
919 cancelled += 1;
920 }
921 _ => panic!("expected response"),
922 }
923 }
924 assert_eq!(cancelled, 2);
925 drop(client);
926 let _ = server_task.await;
927 }
928
929 #[tokio::test]
930 async fn connection_teardown_clears_tracking_maps() {
931 let mut registry = Registry::new();
932 registry.register("sleep", |request| async move {
933 tokio::time::sleep(Duration::from_millis(200)).await;
934 ExecutionOutcome::success(
935 request.job_id.clone(),
936 request.request_id.clone(),
937 json!({"ok": true}),
938 )
939 });
940 let (client, server) = tcp_pair().await;
941 let in_flight = Arc::new(Mutex::new(HashMap::new()));
942 let job_index = Arc::new(Mutex::new(HashMap::new()));
943 let in_flight_for_server = in_flight.clone();
944 let job_index_for_server = job_index.clone();
945 let server_task = tokio::spawn(async move {
946 handle_connection(
947 server,
948 ®istry,
949 &NoopTelemetry,
950 in_flight_for_server,
951 job_index_for_server,
952 )
953 .await
954 .unwrap();
955 });
956 let mut client = client;
957 let request = build_request("sleep");
958 let message = RunnerMessage::Request {
959 payload: request.clone(),
960 };
961 write_message(&mut client, &message).await.unwrap();
962
963 let mut inserted = false;
964 for _ in 0..20 {
965 let has_in_flight = {
966 let guard = in_flight.lock().await;
967 guard.contains_key(&request.request_id)
968 };
969 let has_job_index = {
970 let guard = job_index.lock().await;
971 guard.contains_key(&request.job_id)
972 };
973 if has_in_flight && has_job_index {
974 inserted = true;
975 break;
976 }
977 tokio::time::sleep(Duration::from_millis(10)).await;
978 }
979 assert!(inserted, "request never entered tracking maps");
980
981 drop(client);
982 let _ = server_task.await;
983
984 let in_flight = in_flight.lock().await;
985 let job_index = job_index.lock().await;
986 assert!(in_flight.is_empty());
987 assert!(job_index.is_empty());
988 }
989
990 #[tokio::test]
991 async fn handle_cancel_ignores_invalid_protocol() {
992 let in_flight = Arc::new(Mutex::new(HashMap::new()));
993 let job_index = Arc::new(Mutex::new(HashMap::new()));
994 let (tx, _rx) = mpsc::channel(1);
995 let handle = tokio::spawn(async {});
996 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
997 {
998 let mut guard = in_flight.lock().await;
999 guard.insert(
1000 "req-1".to_string(),
1001 InFlightTask {
1002 job_id: "job-1".to_string(),
1003 handle,
1004 response_tx: tx,
1005 connection_requests,
1006 completed: Arc::new(AtomicBool::new(false)),
1007 },
1008 );
1009 }
1010 let payload = CancelRequest {
1011 protocol_version: "0".to_string(),
1012 job_id: "job-1".to_string(),
1013 request_id: None,
1014 hard_kill: false,
1015 };
1016 handle_cancel(payload, &in_flight, &job_index).await;
1017 let guard = in_flight.lock().await;
1018 assert!(guard.contains_key("req-1"));
1019 guard.get("req-1").unwrap().handle.abort();
1020 }
1021
1022 #[tokio::test]
1023 async fn handle_cancel_skips_completed_requests() {
1024 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1025 let job_index = Arc::new(Mutex::new(HashMap::new()));
1026 let (tx, mut rx) = mpsc::channel(1);
1027 let handle = tokio::spawn(async {
1028 tokio::time::sleep(Duration::from_millis(50)).await;
1029 });
1030 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1031 {
1032 let mut guard = in_flight.lock().await;
1033 guard.insert(
1034 "req-1".to_string(),
1035 InFlightTask {
1036 job_id: "job-1".to_string(),
1037 handle,
1038 response_tx: tx,
1039 connection_requests,
1040 completed: Arc::new(AtomicBool::new(true)),
1041 },
1042 );
1043 }
1044 {
1045 let mut guard = job_index.lock().await;
1046 guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1047 }
1048 let payload = CancelRequest {
1049 protocol_version: PROTOCOL_VERSION.to_string(),
1050 job_id: "job-1".to_string(),
1051 request_id: Some("req-1".to_string()),
1052 hard_kill: false,
1053 };
1054 handle_cancel(payload, &in_flight, &job_index).await;
1055 assert!(in_flight.lock().await.contains_key("req-1"));
1056 assert!(job_index.lock().await.contains_key("job-1"));
1057 assert!(rx.try_recv().is_err());
1058 let task = in_flight.lock().await.remove("req-1").unwrap();
1059 task.handle.abort();
1060 }
1061
1062 #[tokio::test]
1063 async fn read_message_handles_empty_and_invalid_payloads() {
1064 let (mut client, mut server) = tokio::io::duplex(64);
1065 client.write_all(&0u32.to_be_bytes()).await.unwrap();
1067 let err = read_message(&mut server).await.unwrap_err();
1068 assert!(err.to_string().contains("payload cannot be empty"));
1069
1070 let (mut client, mut server) = tokio::io::duplex(64);
1072 let payload = b"not-json";
1073 let len = (payload.len() as u32).to_be_bytes();
1074 client.write_all(&len).await.unwrap();
1075 client.write_all(payload).await.unwrap();
1076 let err = read_message(&mut server).await.unwrap_err();
1077 assert!(err.to_string().contains("expected"));
1078
1079 let (mut client, mut server) = tokio::io::duplex(64);
1081 let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1082 client.write_all(&len).await.unwrap();
1083 let err = read_message(&mut server).await.unwrap_err();
1084 assert!(err.to_string().contains("payload too large"));
1085 }
1086
1087 #[tokio::test]
1088 async fn read_message_returns_none_on_eof() {
1089 let (client, mut server) = tokio::io::duplex(8);
1090 drop(client);
1091 let message = read_message(&mut server).await.unwrap();
1092 assert!(message.is_none());
1093 }
1094}