1use crate::parent_guard;
2use crate::registry::Registry;
3use crate::telemetry::{NoopTelemetry, Telemetry};
4use crate::types::{ExecutionError, ExecutionOutcome};
5use chrono::{DateTime, Utc};
6use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
7use std::collections::{HashMap, HashSet};
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::{
10 Arc,
11 atomic::{AtomicBool, Ordering},
12};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14use tokio::net::TcpListener;
15use tokio::sync::{Mutex, mpsc};
16use tokio::time::{Duration, timeout};
17const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
18const RESPONSE_CHANNEL_CAPACITY: usize = 64;
19const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
20
21fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
22 Box::new(std::io::Error::new(
23 std::io::ErrorKind::InvalidInput,
24 message.into(),
25 ))
26}
27
28pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
29 let raw = raw.trim();
30 if raw.is_empty() {
31 return Err(invalid_input("runner tcp_socket cannot be empty"));
32 }
33
34 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
35 let (host, port_str) = rest
36 .split_once("]:")
37 .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
38 (host, port_str)
39 } else {
40 let (host, port_str) = raw
41 .rsplit_once(':')
42 .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
43 if host.is_empty() {
44 return Err(invalid_input("runner tcp_socket host cannot be empty"));
45 }
46 (host, port_str)
47 };
48
49 let port: u16 = port_str
50 .parse()
51 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
52 if port == 0 {
53 return Err(invalid_input("runner tcp_socket port must be > 0"));
54 }
55
56 let ip = if host == "localhost" {
57 IpAddr::V4(Ipv4Addr::LOCALHOST)
58 } else {
59 let parsed: IpAddr = host
60 .parse()
61 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
62 if !parsed.is_loopback() {
63 return Err(invalid_input("runner tcp_socket host must be localhost"));
64 }
65 parsed
66 };
67
68 Ok(SocketAddr::new(ip, port))
69}
70
71pub struct RunnerRuntime {
72 runtime: tokio::runtime::Runtime,
73}
74
75impl RunnerRuntime {
76 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
77 parent_guard::install_parent_lifecycle_guard()?;
78 Ok(Self {
79 runtime: tokio::runtime::Runtime::new()?,
80 })
81 }
82
83 pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
84 self.runtime.enter()
85 }
86
87 pub fn run_tcp(
88 &self,
89 registry: &Registry,
90 addr: SocketAddr,
91 ) -> Result<(), Box<dyn std::error::Error>> {
92 let telemetry = NoopTelemetry;
93 self.run_tcp_with(registry, addr, &telemetry)
94 }
95
96 pub fn run_tcp_with<T: Telemetry + ?Sized>(
97 &self,
98 registry: &Registry,
99 addr: SocketAddr,
100 telemetry: &T,
101 ) -> Result<(), Box<dyn std::error::Error>> {
102 run_tcp_loop(&self.runtime, registry, addr, telemetry)
103 }
104}
105
106pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
107 RunnerRuntime::new()?.run_tcp(registry, addr)
108}
109
110pub fn run_tcp_with<T: Telemetry + ?Sized>(
111 registry: &Registry,
112 addr: SocketAddr,
113 telemetry: &T,
114) -> Result<(), Box<dyn std::error::Error>> {
115 RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
116}
117
118fn run_tcp_loop<T: Telemetry + ?Sized>(
119 runtime: &tokio::runtime::Runtime,
120 registry: &Registry,
121 addr: SocketAddr,
122 telemetry: &T,
123) -> Result<(), Box<dyn std::error::Error>> {
124 let registry = registry.clone();
125 let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
126 let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
127 Arc::new(Mutex::new(HashMap::new()));
128 let telemetry = telemetry.clone_box();
129 runtime.block_on(async move {
130 if !addr.ip().is_loopback() {
131 return Err(invalid_input(format!(
132 "runner tcp_socket must be loopback-only (got {addr})"
133 )));
134 }
135 let listener = TcpListener::bind(addr).await?;
136 loop {
137 let (stream, _) = listener.accept().await?;
138 let registry = registry.clone();
139 let telemetry = telemetry.clone();
140 let in_flight = in_flight.clone();
141 let job_index = job_index.clone();
142 tokio::spawn(async move {
143 if let Err(err) =
144 handle_connection(stream, ®istry, telemetry.as_ref(), in_flight, job_index)
145 .await
146 {
147 tracing::error!("runner connection error: {err}");
148 }
149 });
150 }
151 })
152}
153
154async fn handle_connection<S, T>(
155 stream: S,
156 registry: &Registry,
157 telemetry: &T,
158 in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
159 job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
160) -> Result<(), Box<dyn std::error::Error>>
161where
162 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
163 T: Telemetry + ?Sized,
164{
165 let (mut reader, mut writer) = tokio::io::split(stream);
166 let (response_tx, mut response_rx) =
167 mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
168 let writer_task = tokio::spawn(async move {
169 while let Some(outcome) = response_rx.recv().await {
170 let response = RunnerMessage::Response { payload: outcome };
171 if write_message(&mut writer, &response).await.is_err() {
172 break;
173 }
174 }
175 });
176 let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
177 Arc::new(Mutex::new(std::collections::HashSet::new()));
178
179 loop {
180 let message = match read_message(&mut reader).await? {
181 Some(message) => message,
182 None => break,
183 };
184 match message {
185 RunnerMessage::Request { payload } => {
186 if payload.protocol_version != PROTOCOL_VERSION {
187 let outcome = ExecutionOutcome::error(
188 payload.job_id.clone(),
189 payload.request_id.clone(),
190 "Unsupported protocol version",
191 );
192 let _ = response_tx.send(outcome).await;
193 continue;
194 }
195
196 let request_id = payload.request_id.clone();
197 let job_id = payload.job_id.clone();
198 {
199 let mut active = connection_requests.lock().await;
200 if active.len() >= RESPONSE_CHANNEL_CAPACITY {
201 let outcome = ExecutionOutcome::error(
202 payload.job_id.clone(),
203 payload.request_id.clone(),
204 "Runner busy: too many in-flight requests",
205 );
206 drop(active);
207 let send_result =
208 timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
209 match send_result {
210 Ok(Ok(())) => {}
211 Ok(Err(_)) => {
212 return Err("runner response channel closed".into());
213 }
214 Err(_) => {
215 return Err("runner response channel stalled".into());
216 }
217 }
218 continue;
219 }
220 active.insert(request_id.clone());
221 crate::telemetry::record_runner_channel_pressure(active.len());
222 }
223 let response_tx = response_tx.clone();
224 let registry = registry.clone();
225 let telemetry = telemetry.clone_box();
226 let in_flight_for_task = in_flight.clone();
227 let job_index_for_task = job_index.clone();
228 let active_for_task = connection_requests.clone();
229 let request_id_for_task = request_id.clone();
230 let job_id_for_task = job_id.clone();
231 let response_tx_for_task = response_tx.clone();
232 let completed = Arc::new(AtomicBool::new(false));
233 let completed_for_task = completed.clone();
234
235 let handle = tokio::spawn(async move {
236 let outcome =
237 execute_with_deadline(payload, registry, telemetry.as_ref()).await;
238 completed_for_task.store(true, Ordering::SeqCst);
239 let send_result =
240 timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
241 match send_result {
242 Ok(Ok(())) => {}
243 Ok(Err(_)) => {
244 tracing::warn!("runner response channel closed; dropping outcome");
245 }
246 Err(_) => {
247 tracing::warn!("runner response channel stalled; dropping outcome");
248 }
249 }
250 {
251 let mut in_flight = in_flight_for_task.lock().await;
252 if in_flight.remove(&request_id_for_task).is_some() {
253 crate::telemetry::record_runner_inflight_delta(-1);
254 }
255 }
256 {
257 let mut job_index = job_index_for_task.lock().await;
258 if let Some(entries) = job_index.get_mut(&job_id_for_task) {
259 entries.remove(&request_id_for_task);
260 if entries.is_empty() {
261 job_index.remove(&job_id_for_task);
262 }
263 }
264 }
265 {
266 let mut active = active_for_task.lock().await;
267 active.remove(&request_id_for_task);
268 crate::telemetry::record_runner_channel_pressure(active.len());
269 }
270 });
271
272 {
273 let mut in_flight = in_flight.lock().await;
274 in_flight.insert(
275 request_id.clone(),
276 InFlightTask {
277 job_id: job_id.clone(),
278 handle,
279 response_tx: response_tx.clone(),
280 connection_requests: connection_requests.clone(),
281 completed,
282 },
283 );
284 }
285 crate::telemetry::record_runner_inflight_delta(1);
286 {
287 let mut job_index = job_index.lock().await;
288 job_index
289 .entry(job_id)
290 .or_insert_with(HashSet::new)
291 .insert(request_id);
292 }
293 }
294 RunnerMessage::Cancel { payload } => {
295 handle_cancel(payload, &in_flight, &job_index).await;
296 }
297 RunnerMessage::Response { .. } => {
298 let outcome = ExecutionOutcome {
299 job_id: Some("unknown".to_string()),
300 request_id: None,
301 status: rrq_protocol::OutcomeStatus::Error,
302 result: None,
303 error: Some(ExecutionError {
304 message: "unexpected response message".to_string(),
305 error_type: None,
306 code: None,
307 details: None,
308 }),
309 retry_after_seconds: None,
310 };
311 let _ = response_tx.send(outcome).await;
312 }
313 }
314 }
315
316 let request_ids = {
317 let mut active = connection_requests.lock().await;
318 active.drain().collect::<Vec<_>>()
319 };
320 crate::telemetry::record_runner_channel_pressure(0);
321 for request_id in request_ids {
322 let task = {
323 let mut in_flight = in_flight.lock().await;
324 in_flight.remove(&request_id)
325 };
326 if let Some(task) = task {
327 task.handle.abort();
328 crate::telemetry::record_runner_inflight_delta(-1);
329 let mut job_index = job_index.lock().await;
330 if let Some(entries) = job_index.get_mut(&task.job_id) {
331 entries.remove(&request_id);
332 if entries.is_empty() {
333 job_index.remove(&task.job_id);
334 }
335 }
336 }
337 }
338 writer_task.abort();
339
340 Ok(())
341}
342
343struct InFlightTask {
344 job_id: String,
345 handle: tokio::task::JoinHandle<()>,
346 response_tx: mpsc::Sender<ExecutionOutcome>,
347 connection_requests: Arc<Mutex<HashSet<String>>>,
348 completed: Arc<AtomicBool>,
349}
350
351async fn handle_cancel(
352 payload: CancelRequest,
353 in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
354 job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
355) {
356 if payload.protocol_version != PROTOCOL_VERSION {
357 return;
358 }
359 let request_ids = if let Some(request_id) = payload.request_id.clone() {
360 vec![request_id]
361 } else {
362 let job_index = job_index.lock().await;
363 job_index
364 .get(&payload.job_id)
365 .map(|ids| ids.iter().cloned().collect())
366 .unwrap_or_else(Vec::new)
367 };
368 if request_ids.is_empty() {
369 return;
370 }
371 let scope = if payload.request_id.is_some() {
372 "request"
373 } else {
374 "job"
375 };
376 crate::telemetry::record_cancellation(scope);
377
378 for request_id in request_ids {
379 let task = {
380 let mut in_flight = in_flight.lock().await;
381 if let Some(task) = in_flight.get(&request_id)
382 && task.completed.load(Ordering::SeqCst)
383 {
384 None
385 } else {
386 in_flight.remove(&request_id)
387 }
388 };
389 if let Some(task) = task {
390 task.handle.abort();
391 crate::telemetry::record_runner_inflight_delta(-1);
392 {
393 let mut active = task.connection_requests.lock().await;
394 active.remove(&request_id);
395 crate::telemetry::record_runner_channel_pressure(active.len());
396 }
397 let outcome = ExecutionOutcome {
398 job_id: Some(payload.job_id.clone()),
399 request_id: Some(request_id.clone()),
400 status: OutcomeStatus::Error,
401 result: None,
402 error: Some(ExecutionError {
403 message: "Job cancelled".to_string(),
404 error_type: Some("cancelled".to_string()),
405 code: None,
406 details: None,
407 }),
408 retry_after_seconds: None,
409 };
410 let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
411 match send_result {
412 Ok(Ok(())) => {}
413 Ok(Err(_)) => {
414 tracing::warn!("runner response channel closed; dropping cancel outcome");
415 }
416 Err(_) => {
417 tracing::warn!("runner response channel stalled; dropping cancel outcome");
418 }
419 }
420 let mut job_index = job_index.lock().await;
421 if let Some(entries) = job_index.get_mut(&task.job_id) {
422 entries.remove(&request_id);
423 if entries.is_empty() {
424 job_index.remove(&task.job_id);
425 }
426 }
427 }
428 }
429}
430
431async fn execute_with_deadline<T: Telemetry + ?Sized>(
432 request: rrq_protocol::ExecutionRequest,
433 registry: Registry,
434 telemetry: &T,
435) -> ExecutionOutcome {
436 let job_id = request.job_id.clone();
437 let request_id = request.request_id.clone();
438 let deadline = request.context.deadline;
439 if let Some(deadline) = deadline {
440 let now: DateTime<Utc> = Utc::now();
441 if deadline <= now {
442 crate::telemetry::record_deadline_expired();
443 return ExecutionOutcome::timeout(
444 job_id.clone(),
445 request_id.clone(),
446 "Job deadline exceeded",
447 );
448 }
449 if let Ok(remaining) = (deadline - now).to_std() {
450 match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
451 Ok(outcome) => return outcome,
452 Err(_) => {
453 crate::telemetry::record_deadline_expired();
454 return ExecutionOutcome::timeout(
455 job_id.clone(),
456 request_id.clone(),
457 "Job execution timed out",
458 );
459 }
460 }
461 }
462 crate::telemetry::record_deadline_expired();
463 return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
464 }
465 registry.execute_with(request, telemetry).await
466}
467
468async fn read_message<R: AsyncRead + Unpin>(
469 stream: &mut R,
470) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
471 let mut header = [0u8; 4];
472 match stream.read_exact(&mut header).await {
473 Ok(_) => {}
474 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
475 Err(err) => return Err(err.into()),
476 }
477 let length = u32::from_be_bytes(header) as usize;
478 if length == 0 {
479 return Err("runner message payload cannot be empty".into());
480 }
481 if length > MAX_FRAME_LEN {
482 return Err("runner message payload too large".into());
483 }
484 let mut payload = vec![0u8; length];
485 stream.read_exact(&mut payload).await?;
486 let message = serde_json::from_slice(&payload)?;
487 Ok(Some(message))
488}
489
490async fn write_message<W: AsyncWrite + Unpin>(
491 stream: &mut W,
492 message: &RunnerMessage,
493) -> Result<(), Box<dyn std::error::Error>> {
494 let framed = encode_frame(message)?;
495 stream.write_all(&framed).await?;
496 stream.flush().await?;
497 Ok(())
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::registry::Registry;
504 use crate::telemetry::NoopTelemetry;
505 use chrono::Utc;
506 use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
507 use serde_json::json;
508 use tokio::net::{TcpListener, TcpStream};
509 use tokio::time::{Duration, timeout};
510
511 fn build_request(function_name: &str) -> ExecutionRequest {
512 ExecutionRequest {
513 protocol_version: PROTOCOL_VERSION.to_string(),
514 request_id: "req-1".to_string(),
515 job_id: "job-1".to_string(),
516 function_name: function_name.to_string(),
517 params: std::collections::HashMap::new(),
518 context: ExecutionContext {
519 job_id: "job-1".to_string(),
520 attempt: 1,
521 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
522 queue_name: "default".to_string(),
523 deadline: None,
524 trace_context: None,
525 correlation_context: None,
526 worker_id: None,
527 },
528 }
529 }
530
531 async fn tcp_pair() -> (TcpStream, TcpStream) {
532 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
533 let addr = listener.local_addr().unwrap();
534 let client = TcpStream::connect(addr).await.unwrap();
535 let (server, _) = listener.accept().await.unwrap();
536 (client, server)
537 }
538
539 #[tokio::test]
540 async fn handle_connection_executes_request() {
541 let mut registry = Registry::new();
542 registry.register("echo", |request| async move {
543 ExecutionOutcome::success(
544 request.job_id.clone(),
545 request.request_id.clone(),
546 json!({"ok": true}),
547 )
548 });
549 let (client, server) = tcp_pair().await;
550 let in_flight = Arc::new(Mutex::new(HashMap::new()));
551 let job_index = Arc::new(Mutex::new(HashMap::new()));
552 let server_task = tokio::spawn(async move {
553 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
554 .await
555 .unwrap();
556 });
557 let mut client = client;
558 let request = build_request("echo");
559 let message = RunnerMessage::Request { payload: request };
560 write_message(&mut client, &message).await.unwrap();
561 let response = read_message(&mut client).await.unwrap().unwrap();
562 match response {
563 RunnerMessage::Response { payload } => {
564 assert_eq!(payload.status, OutcomeStatus::Success);
565 assert_eq!(payload.result, Some(json!({"ok": true})));
566 }
567 _ => panic!("expected response"),
568 }
569 drop(client);
570 let _ = server_task.await;
571 }
572
573 #[tokio::test]
574 async fn handle_connection_rejects_bad_protocol() {
575 let registry = Registry::new();
576 let (client, server) = tcp_pair().await;
577 let in_flight = Arc::new(Mutex::new(HashMap::new()));
578 let job_index = Arc::new(Mutex::new(HashMap::new()));
579 let server_task = tokio::spawn(async move {
580 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
581 .await
582 .unwrap();
583 });
584 let mut client = client;
585 let mut request = build_request("echo");
586 request.protocol_version = "0".to_string();
587 let message = RunnerMessage::Request { payload: request };
588 write_message(&mut client, &message).await.unwrap();
589 let response = read_message(&mut client).await.unwrap().unwrap();
590 match response {
591 RunnerMessage::Response { payload } => {
592 assert_eq!(payload.status, OutcomeStatus::Error);
593 }
594 _ => panic!("expected response"),
595 }
596 drop(client);
597 let _ = server_task.await;
598 }
599
600 #[tokio::test]
601 async fn handle_connection_cancels_inflight() {
602 let mut registry = Registry::new();
603 registry.register("sleep", |request| async move {
604 tokio::time::sleep(Duration::from_millis(200)).await;
605 ExecutionOutcome::success(
606 request.job_id.clone(),
607 request.request_id.clone(),
608 json!({"ok": true}),
609 )
610 });
611 let (client, server) = tcp_pair().await;
612 let in_flight = Arc::new(Mutex::new(HashMap::new()));
613 let job_index = Arc::new(Mutex::new(HashMap::new()));
614 let server_task = tokio::spawn(async move {
615 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
616 .await
617 .unwrap();
618 });
619 let mut client = client;
620 let request = ExecutionRequest {
621 protocol_version: PROTOCOL_VERSION.to_string(),
622 request_id: "req-cancel".to_string(),
623 job_id: "job-cancel".to_string(),
624 function_name: "sleep".to_string(),
625 params: std::collections::HashMap::new(),
626 context: ExecutionContext {
627 job_id: "job-cancel".to_string(),
628 attempt: 1,
629 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
630 queue_name: "default".to_string(),
631 deadline: None,
632 trace_context: None,
633 correlation_context: None,
634 worker_id: None,
635 },
636 };
637 let message = RunnerMessage::Request {
638 payload: request.clone(),
639 };
640 write_message(&mut client, &message).await.unwrap();
641 let cancel = RunnerMessage::Cancel {
642 payload: CancelRequest {
643 protocol_version: PROTOCOL_VERSION.to_string(),
644 job_id: request.job_id.clone(),
645 request_id: Some(request.request_id.clone()),
646 hard_kill: false,
647 },
648 };
649 write_message(&mut client, &cancel).await.unwrap();
650 let response = read_message(&mut client).await.unwrap().unwrap();
651 match response {
652 RunnerMessage::Response { payload } => {
653 assert_eq!(payload.status, OutcomeStatus::Error);
654 let error_type = payload
655 .error
656 .as_ref()
657 .and_then(|error| error.error_type.as_deref());
658 assert_eq!(error_type, Some("cancelled"));
659 }
660 _ => panic!("expected response"),
661 }
662 drop(client);
663 let _ = server_task.await;
664 }
665
666 #[tokio::test]
667 async fn cancel_frees_connection_capacity() {
668 let mut registry = Registry::new();
669 let gate = Arc::new(tokio::sync::Semaphore::new(0));
670 let gate_for_handler = gate.clone();
671 registry.register("block", move |request| {
672 let gate = gate_for_handler.clone();
673 async move {
674 let _permit = gate.acquire().await.expect("semaphore closed");
675 ExecutionOutcome::success(
676 request.job_id.clone(),
677 request.request_id.clone(),
678 json!({"ok": true}),
679 )
680 }
681 });
682 let (client, server) = tcp_pair().await;
683 let in_flight = Arc::new(Mutex::new(HashMap::new()));
684 let job_index = Arc::new(Mutex::new(HashMap::new()));
685 let server_task = tokio::spawn(async move {
686 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
687 .await
688 .unwrap();
689 });
690 let mut client = client;
691 let job_id = "job-capacity".to_string();
692 for i in 0..RESPONSE_CHANNEL_CAPACITY {
693 let mut request = build_request("block");
694 request.request_id = format!("req-{i}");
695 request.job_id = job_id.clone();
696 request.context.job_id = job_id.clone();
697 write_message(&mut client, &RunnerMessage::Request { payload: request })
698 .await
699 .unwrap();
700 }
701
702 let cancel = RunnerMessage::Cancel {
703 payload: CancelRequest {
704 protocol_version: PROTOCOL_VERSION.to_string(),
705 job_id: job_id.clone(),
706 request_id: Some("req-0".to_string()),
707 hard_kill: false,
708 },
709 };
710 write_message(&mut client, &cancel).await.unwrap();
711 let response = timeout(Duration::from_secs(1), read_message(&mut client))
712 .await
713 .unwrap()
714 .unwrap()
715 .unwrap();
716 match response {
717 RunnerMessage::Response { payload } => {
718 assert_eq!(payload.status, OutcomeStatus::Error);
719 let error_type = payload
720 .error
721 .as_ref()
722 .and_then(|error| error.error_type.as_deref());
723 assert_eq!(error_type, Some("cancelled"));
724 }
725 _ => panic!("expected response"),
726 }
727
728 let mut extra_request = build_request("block");
729 extra_request.request_id = "req-extra".to_string();
730 extra_request.job_id = job_id.clone();
731 extra_request.context.job_id = job_id.clone();
732 write_message(
733 &mut client,
734 &RunnerMessage::Request {
735 payload: extra_request,
736 },
737 )
738 .await
739 .unwrap();
740
741 gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
742
743 let mut saw_extra = false;
744 for _ in 0..RESPONSE_CHANNEL_CAPACITY {
745 let response = timeout(Duration::from_secs(1), read_message(&mut client))
746 .await
747 .unwrap()
748 .unwrap()
749 .unwrap();
750 if let RunnerMessage::Response { payload } = response
751 && payload.request_id.as_deref() == Some("req-extra")
752 {
753 assert_eq!(payload.status, OutcomeStatus::Success);
754 saw_extra = true;
755 }
756 }
757 assert!(saw_extra, "extra request never completed");
758
759 drop(client);
760 let _ = server_task.await;
761 }
762
763 #[tokio::test]
764 async fn execute_with_deadline_times_out() {
765 let mut registry = Registry::new();
766 registry.register("echo", |request| async move {
767 ExecutionOutcome::success(
768 request.job_id.clone(),
769 request.request_id.clone(),
770 json!({"ok": true}),
771 )
772 });
773 let mut request = build_request("echo");
774 request.context.deadline = Some(
775 "2020-01-01T00:00:00Z"
776 .parse::<chrono::DateTime<Utc>>()
777 .unwrap(),
778 );
779 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
780 assert_eq!(outcome.status, OutcomeStatus::Timeout);
781 }
782
783 #[tokio::test]
784 async fn execute_with_deadline_succeeds_before_deadline() {
785 let mut registry = Registry::new();
786 registry.register("echo", |request| async move {
787 ExecutionOutcome::success(
788 request.job_id.clone(),
789 request.request_id.clone(),
790 json!({"ok": true}),
791 )
792 });
793 let mut request = build_request("echo");
794 request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
795 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
796 assert_eq!(outcome.status, OutcomeStatus::Success);
797 }
798
799 #[tokio::test]
800 async fn handle_connection_handles_unexpected_response_message() {
801 let registry = Registry::new();
802 let (client, server) = tcp_pair().await;
803 let in_flight = Arc::new(Mutex::new(HashMap::new()));
804 let job_index = Arc::new(Mutex::new(HashMap::new()));
805 let server_task = tokio::spawn(async move {
806 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
807 .await
808 .unwrap();
809 });
810 let mut client = client;
811 let response = RunnerMessage::Response {
812 payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
813 };
814 write_message(&mut client, &response).await.unwrap();
815 let reply = read_message(&mut client).await.unwrap().unwrap();
816 match reply {
817 RunnerMessage::Response { payload } => {
818 assert_eq!(payload.status, OutcomeStatus::Error);
819 assert!(
820 payload
821 .error
822 .as_ref()
823 .unwrap()
824 .message
825 .contains("unexpected response")
826 );
827 }
828 _ => panic!("expected response"),
829 }
830 drop(client);
831 let _ = server_task.await;
832 }
833
834 #[tokio::test]
835 async fn handle_connection_cancels_by_job_id() {
836 let mut registry = Registry::new();
837 registry.register("sleep", |request| async move {
838 tokio::time::sleep(Duration::from_millis(200)).await;
839 ExecutionOutcome::success(
840 request.job_id.clone(),
841 request.request_id.clone(),
842 json!({"ok": true}),
843 )
844 });
845 let (client, server) = tcp_pair().await;
846 let in_flight = Arc::new(Mutex::new(HashMap::new()));
847 let job_index = Arc::new(Mutex::new(HashMap::new()));
848 let server_task = tokio::spawn(async move {
849 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
850 .await
851 .unwrap();
852 });
853 let mut client = client;
854 let request = build_request("sleep");
855 let message = RunnerMessage::Request {
856 payload: request.clone(),
857 };
858 write_message(&mut client, &message).await.unwrap();
859 let cancel = RunnerMessage::Cancel {
860 payload: CancelRequest {
861 protocol_version: PROTOCOL_VERSION.to_string(),
862 job_id: request.job_id.clone(),
863 request_id: None,
864 hard_kill: false,
865 },
866 };
867 write_message(&mut client, &cancel).await.unwrap();
868 let response = read_message(&mut client).await.unwrap().unwrap();
869 match response {
870 RunnerMessage::Response { payload } => {
871 assert_eq!(payload.status, OutcomeStatus::Error);
872 let error_type = payload
873 .error
874 .as_ref()
875 .and_then(|error| error.error_type.as_deref());
876 assert_eq!(error_type, Some("cancelled"));
877 }
878 _ => panic!("expected response"),
879 }
880 drop(client);
881 let _ = server_task.await;
882 }
883
884 #[tokio::test]
885 async fn handle_cancel_by_job_id_cancels_all_requests() {
886 let mut registry = Registry::new();
887 registry.register("sleep", |request| async move {
888 tokio::time::sleep(Duration::from_millis(200)).await;
889 ExecutionOutcome::success(
890 request.job_id.clone(),
891 request.request_id.clone(),
892 json!({"ok": true}),
893 )
894 });
895 let (client, server) = tcp_pair().await;
896 let in_flight = Arc::new(Mutex::new(HashMap::new()));
897 let job_index = Arc::new(Mutex::new(HashMap::new()));
898 let server_task = tokio::spawn(async move {
899 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
900 .await
901 .unwrap();
902 });
903 let mut client = client;
904 let mut request1 = build_request("sleep");
905 request1.request_id = "req-1".to_string();
906 request1.job_id = "job-shared".to_string();
907 let mut request2 = build_request("sleep");
908 request2.request_id = "req-2".to_string();
909 request2.job_id = "job-shared".to_string();
910 write_message(&mut client, &RunnerMessage::Request { payload: request1 })
911 .await
912 .unwrap();
913 write_message(&mut client, &RunnerMessage::Request { payload: request2 })
914 .await
915 .unwrap();
916 let cancel = RunnerMessage::Cancel {
917 payload: CancelRequest {
918 protocol_version: PROTOCOL_VERSION.to_string(),
919 job_id: "job-shared".to_string(),
920 request_id: None,
921 hard_kill: false,
922 },
923 };
924 write_message(&mut client, &cancel).await.unwrap();
925
926 let mut cancelled = 0;
927 for _ in 0..2 {
928 let response = timeout(Duration::from_millis(200), read_message(&mut client))
929 .await
930 .unwrap()
931 .unwrap()
932 .unwrap();
933 match response {
934 RunnerMessage::Response { payload } => {
935 assert_eq!(payload.status, OutcomeStatus::Error);
936 let error_type = payload
937 .error
938 .as_ref()
939 .and_then(|error| error.error_type.as_deref());
940 assert_eq!(error_type, Some("cancelled"));
941 cancelled += 1;
942 }
943 _ => panic!("expected response"),
944 }
945 }
946 assert_eq!(cancelled, 2);
947 drop(client);
948 let _ = server_task.await;
949 }
950
951 #[tokio::test]
952 async fn connection_teardown_clears_tracking_maps() {
953 let mut registry = Registry::new();
954 registry.register("sleep", |request| async move {
955 tokio::time::sleep(Duration::from_millis(200)).await;
956 ExecutionOutcome::success(
957 request.job_id.clone(),
958 request.request_id.clone(),
959 json!({"ok": true}),
960 )
961 });
962 let (client, server) = tcp_pair().await;
963 let in_flight = Arc::new(Mutex::new(HashMap::new()));
964 let job_index = Arc::new(Mutex::new(HashMap::new()));
965 let in_flight_for_server = in_flight.clone();
966 let job_index_for_server = job_index.clone();
967 let server_task = tokio::spawn(async move {
968 handle_connection(
969 server,
970 ®istry,
971 &NoopTelemetry,
972 in_flight_for_server,
973 job_index_for_server,
974 )
975 .await
976 .unwrap();
977 });
978 let mut client = client;
979 let request = build_request("sleep");
980 let message = RunnerMessage::Request {
981 payload: request.clone(),
982 };
983 write_message(&mut client, &message).await.unwrap();
984
985 let mut inserted = false;
986 for _ in 0..20 {
987 let has_in_flight = {
988 let guard = in_flight.lock().await;
989 guard.contains_key(&request.request_id)
990 };
991 let has_job_index = {
992 let guard = job_index.lock().await;
993 guard.contains_key(&request.job_id)
994 };
995 if has_in_flight && has_job_index {
996 inserted = true;
997 break;
998 }
999 tokio::time::sleep(Duration::from_millis(10)).await;
1000 }
1001 assert!(inserted, "request never entered tracking maps");
1002
1003 drop(client);
1004 let _ = server_task.await;
1005
1006 let in_flight = in_flight.lock().await;
1007 let job_index = job_index.lock().await;
1008 assert!(in_flight.is_empty());
1009 assert!(job_index.is_empty());
1010 }
1011
1012 #[tokio::test]
1013 async fn handle_cancel_ignores_invalid_protocol() {
1014 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1015 let job_index = Arc::new(Mutex::new(HashMap::new()));
1016 let (tx, _rx) = mpsc::channel(1);
1017 let handle = tokio::spawn(async {});
1018 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1019 {
1020 let mut guard = in_flight.lock().await;
1021 guard.insert(
1022 "req-1".to_string(),
1023 InFlightTask {
1024 job_id: "job-1".to_string(),
1025 handle,
1026 response_tx: tx,
1027 connection_requests,
1028 completed: Arc::new(AtomicBool::new(false)),
1029 },
1030 );
1031 }
1032 let payload = CancelRequest {
1033 protocol_version: "0".to_string(),
1034 job_id: "job-1".to_string(),
1035 request_id: None,
1036 hard_kill: false,
1037 };
1038 handle_cancel(payload, &in_flight, &job_index).await;
1039 let guard = in_flight.lock().await;
1040 assert!(guard.contains_key("req-1"));
1041 guard.get("req-1").unwrap().handle.abort();
1042 }
1043
1044 #[tokio::test]
1045 async fn handle_cancel_skips_completed_requests() {
1046 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1047 let job_index = Arc::new(Mutex::new(HashMap::new()));
1048 let (tx, mut rx) = mpsc::channel(1);
1049 let handle = tokio::spawn(async {
1050 tokio::time::sleep(Duration::from_millis(50)).await;
1051 });
1052 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1053 {
1054 let mut guard = in_flight.lock().await;
1055 guard.insert(
1056 "req-1".to_string(),
1057 InFlightTask {
1058 job_id: "job-1".to_string(),
1059 handle,
1060 response_tx: tx,
1061 connection_requests,
1062 completed: Arc::new(AtomicBool::new(true)),
1063 },
1064 );
1065 }
1066 {
1067 let mut guard = job_index.lock().await;
1068 guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1069 }
1070 let payload = CancelRequest {
1071 protocol_version: PROTOCOL_VERSION.to_string(),
1072 job_id: "job-1".to_string(),
1073 request_id: Some("req-1".to_string()),
1074 hard_kill: false,
1075 };
1076 handle_cancel(payload, &in_flight, &job_index).await;
1077 assert!(in_flight.lock().await.contains_key("req-1"));
1078 assert!(job_index.lock().await.contains_key("job-1"));
1079 assert!(rx.try_recv().is_err());
1080 let task = in_flight.lock().await.remove("req-1").unwrap();
1081 task.handle.abort();
1082 }
1083
1084 #[tokio::test]
1085 async fn read_message_handles_empty_and_invalid_payloads() {
1086 let (mut client, mut server) = tokio::io::duplex(64);
1087 client.write_all(&0u32.to_be_bytes()).await.unwrap();
1089 let err = read_message(&mut server).await.unwrap_err();
1090 assert!(err.to_string().contains("payload cannot be empty"));
1091
1092 let (mut client, mut server) = tokio::io::duplex(64);
1094 let payload = b"not-json";
1095 let len = (payload.len() as u32).to_be_bytes();
1096 client.write_all(&len).await.unwrap();
1097 client.write_all(payload).await.unwrap();
1098 let err = read_message(&mut server).await.unwrap_err();
1099 assert!(err.to_string().contains("expected"));
1100
1101 let (mut client, mut server) = tokio::io::duplex(64);
1103 let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1104 client.write_all(&len).await.unwrap();
1105 let err = read_message(&mut server).await.unwrap_err();
1106 assert!(err.to_string().contains("payload too large"));
1107 }
1108
1109 #[tokio::test]
1110 async fn read_message_returns_none_on_eof() {
1111 let (client, mut server) = tokio::io::duplex(8);
1112 drop(client);
1113 let message = read_message(&mut server).await.unwrap();
1114 assert!(message.is_none());
1115 }
1116}