1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9 Arc,
10 atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
17const RESPONSE_CHANNEL_CAPACITY: usize = 64;
18const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
19
20fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
21 Box::new(std::io::Error::new(
22 std::io::ErrorKind::InvalidInput,
23 message.into(),
24 ))
25}
26
27pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
28 let raw = raw.trim();
29 if raw.is_empty() {
30 return Err(invalid_input("runner tcp_socket cannot be empty"));
31 }
32
33 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
34 let (host, port_str) = rest
35 .split_once("]:")
36 .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
37 (host, port_str)
38 } else {
39 let (host, port_str) = raw
40 .rsplit_once(':')
41 .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
42 if host.is_empty() {
43 return Err(invalid_input("runner tcp_socket host cannot be empty"));
44 }
45 (host, port_str)
46 };
47
48 let port: u16 = port_str
49 .parse()
50 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
51 if port == 0 {
52 return Err(invalid_input("runner tcp_socket port must be > 0"));
53 }
54
55 let ip = if host == "localhost" {
56 IpAddr::V4(Ipv4Addr::LOCALHOST)
57 } else {
58 let parsed: IpAddr = host
59 .parse()
60 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
61 if !parsed.is_loopback() {
62 return Err(invalid_input("runner tcp_socket host must be localhost"));
63 }
64 parsed
65 };
66
67 Ok(SocketAddr::new(ip, port))
68}
69
70pub struct RunnerRuntime {
71 runtime: tokio::runtime::Runtime,
72}
73
74impl RunnerRuntime {
75 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
76 Ok(Self {
77 runtime: tokio::runtime::Runtime::new()?,
78 })
79 }
80
81 pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
82 self.runtime.enter()
83 }
84
85 pub fn run_tcp(
86 &self,
87 registry: &Registry,
88 addr: SocketAddr,
89 ) -> Result<(), Box<dyn std::error::Error>> {
90 let telemetry = NoopTelemetry;
91 self.run_tcp_with(registry, addr, &telemetry)
92 }
93
94 pub fn run_tcp_with<T: Telemetry + ?Sized>(
95 &self,
96 registry: &Registry,
97 addr: SocketAddr,
98 telemetry: &T,
99 ) -> Result<(), Box<dyn std::error::Error>> {
100 run_tcp_loop(&self.runtime, registry, addr, telemetry)
101 }
102}
103
104pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
105 RunnerRuntime::new()?.run_tcp(registry, addr)
106}
107
108pub fn run_tcp_with<T: Telemetry + ?Sized>(
109 registry: &Registry,
110 addr: SocketAddr,
111 telemetry: &T,
112) -> Result<(), Box<dyn std::error::Error>> {
113 RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
114}
115
116fn run_tcp_loop<T: Telemetry + ?Sized>(
117 runtime: &tokio::runtime::Runtime,
118 registry: &Registry,
119 addr: SocketAddr,
120 telemetry: &T,
121) -> Result<(), Box<dyn std::error::Error>> {
122 let registry = registry.clone();
123 let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
124 let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
125 Arc::new(Mutex::new(HashMap::new()));
126 let telemetry = telemetry.clone_box();
127 runtime.block_on(async move {
128 if !addr.ip().is_loopback() {
129 return Err(invalid_input(format!(
130 "runner tcp_socket must be loopback-only (got {addr})"
131 )));
132 }
133 let listener = TcpListener::bind(addr).await?;
134 loop {
135 let (stream, _) = listener.accept().await?;
136 let registry = registry.clone();
137 let telemetry = telemetry.clone();
138 let in_flight = in_flight.clone();
139 let job_index = job_index.clone();
140 tokio::spawn(async move {
141 if let Err(err) =
142 handle_connection(stream, ®istry, telemetry.as_ref(), in_flight, job_index)
143 .await
144 {
145 tracing::error!("runner connection error: {err}");
146 }
147 });
148 }
149 })
150}
151
152async fn handle_connection<S, T>(
153 stream: S,
154 registry: &Registry,
155 telemetry: &T,
156 in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
157 job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
158) -> Result<(), Box<dyn std::error::Error>>
159where
160 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
161 T: Telemetry + ?Sized,
162{
163 let (mut reader, mut writer) = tokio::io::split(stream);
164 let (response_tx, mut response_rx) =
165 mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
166 let writer_task = tokio::spawn(async move {
167 while let Some(outcome) = response_rx.recv().await {
168 let response = RunnerMessage::Response { payload: outcome };
169 if write_message(&mut writer, &response).await.is_err() {
170 break;
171 }
172 }
173 });
174 let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
175 Arc::new(Mutex::new(std::collections::HashSet::new()));
176
177 loop {
178 let message = match read_message(&mut reader).await? {
179 Some(message) => message,
180 None => break,
181 };
182 match message {
183 RunnerMessage::Request { payload } => {
184 if payload.protocol_version != PROTOCOL_VERSION {
185 let outcome = ExecutionOutcome::error(
186 payload.job_id.clone(),
187 payload.request_id.clone(),
188 "Unsupported protocol version",
189 );
190 let _ = response_tx.send(outcome).await;
191 continue;
192 }
193
194 let request_id = payload.request_id.clone();
195 let job_id = payload.job_id.clone();
196 {
197 let mut active = connection_requests.lock().await;
198 if active.len() >= RESPONSE_CHANNEL_CAPACITY {
199 let outcome = ExecutionOutcome::error(
200 payload.job_id.clone(),
201 payload.request_id.clone(),
202 "Runner busy: too many in-flight requests",
203 );
204 drop(active);
205 let send_result =
206 timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
207 match send_result {
208 Ok(Ok(())) => {}
209 Ok(Err(_)) => {
210 return Err("runner response channel closed".into());
211 }
212 Err(_) => {
213 return Err("runner response channel stalled".into());
214 }
215 }
216 continue;
217 }
218 active.insert(request_id.clone());
219 crate::telemetry::record_runner_channel_pressure(active.len());
220 }
221 let response_tx = response_tx.clone();
222 let registry = registry.clone();
223 let telemetry = telemetry.clone_box();
224 let in_flight_for_task = in_flight.clone();
225 let job_index_for_task = job_index.clone();
226 let active_for_task = connection_requests.clone();
227 let request_id_for_task = request_id.clone();
228 let job_id_for_task = job_id.clone();
229 let response_tx_for_task = response_tx.clone();
230 let completed = Arc::new(AtomicBool::new(false));
231 let completed_for_task = completed.clone();
232
233 let handle = tokio::spawn(async move {
234 let outcome =
235 execute_with_deadline(payload, registry, telemetry.as_ref()).await;
236 completed_for_task.store(true, Ordering::SeqCst);
237 let send_result =
238 timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
239 match send_result {
240 Ok(Ok(())) => {}
241 Ok(Err(_)) => {
242 tracing::warn!("runner response channel closed; dropping outcome");
243 }
244 Err(_) => {
245 tracing::warn!("runner response channel stalled; dropping outcome");
246 }
247 }
248 {
249 let mut in_flight = in_flight_for_task.lock().await;
250 if in_flight.remove(&request_id_for_task).is_some() {
251 crate::telemetry::record_runner_inflight_delta(-1);
252 }
253 }
254 {
255 let mut job_index = job_index_for_task.lock().await;
256 if let Some(entries) = job_index.get_mut(&job_id_for_task) {
257 entries.remove(&request_id_for_task);
258 if entries.is_empty() {
259 job_index.remove(&job_id_for_task);
260 }
261 }
262 }
263 {
264 let mut active = active_for_task.lock().await;
265 active.remove(&request_id_for_task);
266 crate::telemetry::record_runner_channel_pressure(active.len());
267 }
268 });
269
270 {
271 let mut in_flight = in_flight.lock().await;
272 in_flight.insert(
273 request_id.clone(),
274 InFlightTask {
275 job_id: job_id.clone(),
276 handle,
277 response_tx: response_tx.clone(),
278 connection_requests: connection_requests.clone(),
279 completed,
280 },
281 );
282 }
283 crate::telemetry::record_runner_inflight_delta(1);
284 {
285 let mut job_index = job_index.lock().await;
286 job_index
287 .entry(job_id)
288 .or_insert_with(HashSet::new)
289 .insert(request_id);
290 }
291 }
292 RunnerMessage::Cancel { payload } => {
293 handle_cancel(payload, &in_flight, &job_index).await;
294 }
295 RunnerMessage::Response { .. } => {
296 let outcome = ExecutionOutcome {
297 job_id: Some("unknown".to_string()),
298 request_id: None,
299 status: rrq_protocol::OutcomeStatus::Error,
300 result: None,
301 error: Some(ExecutionError {
302 message: "unexpected response message".to_string(),
303 error_type: None,
304 code: None,
305 details: None,
306 }),
307 retry_after_seconds: None,
308 };
309 let _ = response_tx.send(outcome).await;
310 }
311 }
312 }
313
314 let request_ids = {
315 let mut active = connection_requests.lock().await;
316 active.drain().collect::<Vec<_>>()
317 };
318 crate::telemetry::record_runner_channel_pressure(0);
319 for request_id in request_ids {
320 let task = {
321 let mut in_flight = in_flight.lock().await;
322 in_flight.remove(&request_id)
323 };
324 if let Some(task) = task {
325 task.handle.abort();
326 crate::telemetry::record_runner_inflight_delta(-1);
327 let mut job_index = job_index.lock().await;
328 if let Some(entries) = job_index.get_mut(&task.job_id) {
329 entries.remove(&request_id);
330 if entries.is_empty() {
331 job_index.remove(&task.job_id);
332 }
333 }
334 }
335 }
336 writer_task.abort();
337
338 Ok(())
339}
340
341struct InFlightTask {
342 job_id: String,
343 handle: tokio::task::JoinHandle<()>,
344 response_tx: mpsc::Sender<ExecutionOutcome>,
345 connection_requests: Arc<Mutex<HashSet<String>>>,
346 completed: Arc<AtomicBool>,
347}
348
349async fn handle_cancel(
350 payload: CancelRequest,
351 in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
352 job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
353) {
354 if payload.protocol_version != PROTOCOL_VERSION {
355 return;
356 }
357 let request_ids = if let Some(request_id) = payload.request_id.clone() {
358 vec![request_id]
359 } else {
360 let job_index = job_index.lock().await;
361 job_index
362 .get(&payload.job_id)
363 .map(|ids| ids.iter().cloned().collect())
364 .unwrap_or_else(Vec::new)
365 };
366 if request_ids.is_empty() {
367 return;
368 }
369 let scope = if payload.request_id.is_some() {
370 "request"
371 } else {
372 "job"
373 };
374 crate::telemetry::record_cancellation(scope);
375
376 for request_id in request_ids {
377 let task = {
378 let mut in_flight = in_flight.lock().await;
379 if let Some(task) = in_flight.get(&request_id)
380 && task.completed.load(Ordering::SeqCst)
381 {
382 None
383 } else {
384 in_flight.remove(&request_id)
385 }
386 };
387 if let Some(task) = task {
388 task.handle.abort();
389 crate::telemetry::record_runner_inflight_delta(-1);
390 {
391 let mut active = task.connection_requests.lock().await;
392 active.remove(&request_id);
393 crate::telemetry::record_runner_channel_pressure(active.len());
394 }
395 let outcome = ExecutionOutcome {
396 job_id: Some(payload.job_id.clone()),
397 request_id: Some(request_id.clone()),
398 status: OutcomeStatus::Error,
399 result: None,
400 error: Some(ExecutionError {
401 message: "Job cancelled".to_string(),
402 error_type: Some("cancelled".to_string()),
403 code: None,
404 details: None,
405 }),
406 retry_after_seconds: None,
407 };
408 let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
409 match send_result {
410 Ok(Ok(())) => {}
411 Ok(Err(_)) => {
412 tracing::warn!("runner response channel closed; dropping cancel outcome");
413 }
414 Err(_) => {
415 tracing::warn!("runner response channel stalled; dropping cancel outcome");
416 }
417 }
418 let mut job_index = job_index.lock().await;
419 if let Some(entries) = job_index.get_mut(&task.job_id) {
420 entries.remove(&request_id);
421 if entries.is_empty() {
422 job_index.remove(&task.job_id);
423 }
424 }
425 }
426 }
427}
428
429async fn execute_with_deadline<T: Telemetry + ?Sized>(
430 request: rrq_protocol::ExecutionRequest,
431 registry: Registry,
432 telemetry: &T,
433) -> ExecutionOutcome {
434 let job_id = request.job_id.clone();
435 let request_id = request.request_id.clone();
436 let deadline = request.context.deadline;
437 if let Some(deadline) = deadline {
438 let now: DateTime<Utc> = Utc::now();
439 if deadline <= now {
440 crate::telemetry::record_deadline_expired();
441 return ExecutionOutcome::timeout(
442 job_id.clone(),
443 request_id.clone(),
444 "Job deadline exceeded",
445 );
446 }
447 if let Ok(remaining) = (deadline - now).to_std() {
448 match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
449 Ok(outcome) => return outcome,
450 Err(_) => {
451 crate::telemetry::record_deadline_expired();
452 return ExecutionOutcome::timeout(
453 job_id.clone(),
454 request_id.clone(),
455 "Job execution timed out",
456 );
457 }
458 }
459 }
460 crate::telemetry::record_deadline_expired();
461 return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
462 }
463 registry.execute_with(request, telemetry).await
464}
465
466async fn read_message<R: AsyncRead + Unpin>(
467 stream: &mut R,
468) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
469 let mut header = [0u8; 4];
470 match stream.read_exact(&mut header).await {
471 Ok(_) => {}
472 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
473 Err(err) => return Err(err.into()),
474 }
475 let length = u32::from_be_bytes(header) as usize;
476 if length == 0 {
477 return Err("runner message payload cannot be empty".into());
478 }
479 if length > MAX_FRAME_LEN {
480 return Err("runner message payload too large".into());
481 }
482 let mut payload = vec![0u8; length];
483 stream.read_exact(&mut payload).await?;
484 let message = serde_json::from_slice(&payload)?;
485 Ok(Some(message))
486}
487
488async fn write_message<W: AsyncWrite + Unpin>(
489 stream: &mut W,
490 message: &RunnerMessage,
491) -> Result<(), Box<dyn std::error::Error>> {
492 let framed = encode_frame(message)?;
493 stream.write_all(&framed).await?;
494 stream.flush().await?;
495 Ok(())
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::registry::Registry;
502 use crate::telemetry::NoopTelemetry;
503 use chrono::Utc;
504 use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
505 use serde_json::json;
506 use tokio::net::{TcpListener, TcpStream};
507 use tokio::time::{Duration, timeout};
508
509 fn build_request(function_name: &str) -> ExecutionRequest {
510 ExecutionRequest {
511 protocol_version: PROTOCOL_VERSION.to_string(),
512 request_id: "req-1".to_string(),
513 job_id: "job-1".to_string(),
514 function_name: function_name.to_string(),
515 params: std::collections::HashMap::new(),
516 context: ExecutionContext {
517 job_id: "job-1".to_string(),
518 attempt: 1,
519 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
520 queue_name: "default".to_string(),
521 deadline: None,
522 trace_context: None,
523 worker_id: None,
524 },
525 }
526 }
527
528 async fn tcp_pair() -> (TcpStream, TcpStream) {
529 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
530 let addr = listener.local_addr().unwrap();
531 let client = TcpStream::connect(addr).await.unwrap();
532 let (server, _) = listener.accept().await.unwrap();
533 (client, server)
534 }
535
536 #[tokio::test]
537 async fn handle_connection_executes_request() {
538 let mut registry = Registry::new();
539 registry.register("echo", |request| async move {
540 ExecutionOutcome::success(
541 request.job_id.clone(),
542 request.request_id.clone(),
543 json!({"ok": true}),
544 )
545 });
546 let (client, server) = tcp_pair().await;
547 let in_flight = Arc::new(Mutex::new(HashMap::new()));
548 let job_index = Arc::new(Mutex::new(HashMap::new()));
549 let server_task = tokio::spawn(async move {
550 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
551 .await
552 .unwrap();
553 });
554 let mut client = client;
555 let request = build_request("echo");
556 let message = RunnerMessage::Request { payload: request };
557 write_message(&mut client, &message).await.unwrap();
558 let response = read_message(&mut client).await.unwrap().unwrap();
559 match response {
560 RunnerMessage::Response { payload } => {
561 assert_eq!(payload.status, OutcomeStatus::Success);
562 assert_eq!(payload.result, Some(json!({"ok": true})));
563 }
564 _ => panic!("expected response"),
565 }
566 drop(client);
567 let _ = server_task.await;
568 }
569
570 #[tokio::test]
571 async fn handle_connection_rejects_bad_protocol() {
572 let registry = Registry::new();
573 let (client, server) = tcp_pair().await;
574 let in_flight = Arc::new(Mutex::new(HashMap::new()));
575 let job_index = Arc::new(Mutex::new(HashMap::new()));
576 let server_task = tokio::spawn(async move {
577 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
578 .await
579 .unwrap();
580 });
581 let mut client = client;
582 let mut request = build_request("echo");
583 request.protocol_version = "0".to_string();
584 let message = RunnerMessage::Request { payload: request };
585 write_message(&mut client, &message).await.unwrap();
586 let response = read_message(&mut client).await.unwrap().unwrap();
587 match response {
588 RunnerMessage::Response { payload } => {
589 assert_eq!(payload.status, OutcomeStatus::Error);
590 }
591 _ => panic!("expected response"),
592 }
593 drop(client);
594 let _ = server_task.await;
595 }
596
597 #[tokio::test]
598 async fn handle_connection_cancels_inflight() {
599 let mut registry = Registry::new();
600 registry.register("sleep", |request| async move {
601 tokio::time::sleep(Duration::from_millis(200)).await;
602 ExecutionOutcome::success(
603 request.job_id.clone(),
604 request.request_id.clone(),
605 json!({"ok": true}),
606 )
607 });
608 let (client, server) = tcp_pair().await;
609 let in_flight = Arc::new(Mutex::new(HashMap::new()));
610 let job_index = Arc::new(Mutex::new(HashMap::new()));
611 let server_task = tokio::spawn(async move {
612 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
613 .await
614 .unwrap();
615 });
616 let mut client = client;
617 let request = ExecutionRequest {
618 protocol_version: PROTOCOL_VERSION.to_string(),
619 request_id: "req-cancel".to_string(),
620 job_id: "job-cancel".to_string(),
621 function_name: "sleep".to_string(),
622 params: std::collections::HashMap::new(),
623 context: ExecutionContext {
624 job_id: "job-cancel".to_string(),
625 attempt: 1,
626 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
627 queue_name: "default".to_string(),
628 deadline: None,
629 trace_context: None,
630 worker_id: None,
631 },
632 };
633 let message = RunnerMessage::Request {
634 payload: request.clone(),
635 };
636 write_message(&mut client, &message).await.unwrap();
637 let cancel = RunnerMessage::Cancel {
638 payload: CancelRequest {
639 protocol_version: PROTOCOL_VERSION.to_string(),
640 job_id: request.job_id.clone(),
641 request_id: Some(request.request_id.clone()),
642 hard_kill: false,
643 },
644 };
645 write_message(&mut client, &cancel).await.unwrap();
646 let response = read_message(&mut client).await.unwrap().unwrap();
647 match response {
648 RunnerMessage::Response { payload } => {
649 assert_eq!(payload.status, OutcomeStatus::Error);
650 let error_type = payload
651 .error
652 .as_ref()
653 .and_then(|error| error.error_type.as_deref());
654 assert_eq!(error_type, Some("cancelled"));
655 }
656 _ => panic!("expected response"),
657 }
658 drop(client);
659 let _ = server_task.await;
660 }
661
662 #[tokio::test]
663 async fn cancel_frees_connection_capacity() {
664 let mut registry = Registry::new();
665 let gate = Arc::new(tokio::sync::Semaphore::new(0));
666 let gate_for_handler = gate.clone();
667 registry.register("block", move |request| {
668 let gate = gate_for_handler.clone();
669 async move {
670 let _permit = gate.acquire().await.expect("semaphore closed");
671 ExecutionOutcome::success(
672 request.job_id.clone(),
673 request.request_id.clone(),
674 json!({"ok": true}),
675 )
676 }
677 });
678 let (client, server) = tcp_pair().await;
679 let in_flight = Arc::new(Mutex::new(HashMap::new()));
680 let job_index = Arc::new(Mutex::new(HashMap::new()));
681 let server_task = tokio::spawn(async move {
682 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
683 .await
684 .unwrap();
685 });
686 let mut client = client;
687 let job_id = "job-capacity".to_string();
688 for i in 0..RESPONSE_CHANNEL_CAPACITY {
689 let mut request = build_request("block");
690 request.request_id = format!("req-{i}");
691 request.job_id = job_id.clone();
692 request.context.job_id = job_id.clone();
693 write_message(&mut client, &RunnerMessage::Request { payload: request })
694 .await
695 .unwrap();
696 }
697
698 let cancel = RunnerMessage::Cancel {
699 payload: CancelRequest {
700 protocol_version: PROTOCOL_VERSION.to_string(),
701 job_id: job_id.clone(),
702 request_id: Some("req-0".to_string()),
703 hard_kill: false,
704 },
705 };
706 write_message(&mut client, &cancel).await.unwrap();
707 let response = timeout(Duration::from_secs(1), read_message(&mut client))
708 .await
709 .unwrap()
710 .unwrap()
711 .unwrap();
712 match response {
713 RunnerMessage::Response { payload } => {
714 assert_eq!(payload.status, OutcomeStatus::Error);
715 let error_type = payload
716 .error
717 .as_ref()
718 .and_then(|error| error.error_type.as_deref());
719 assert_eq!(error_type, Some("cancelled"));
720 }
721 _ => panic!("expected response"),
722 }
723
724 let mut extra_request = build_request("block");
725 extra_request.request_id = "req-extra".to_string();
726 extra_request.job_id = job_id.clone();
727 extra_request.context.job_id = job_id.clone();
728 write_message(
729 &mut client,
730 &RunnerMessage::Request {
731 payload: extra_request,
732 },
733 )
734 .await
735 .unwrap();
736
737 gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
738
739 let mut saw_extra = false;
740 for _ in 0..RESPONSE_CHANNEL_CAPACITY {
741 let response = timeout(Duration::from_secs(1), read_message(&mut client))
742 .await
743 .unwrap()
744 .unwrap()
745 .unwrap();
746 if let RunnerMessage::Response { payload } = response
747 && payload.request_id.as_deref() == Some("req-extra")
748 {
749 assert_eq!(payload.status, OutcomeStatus::Success);
750 saw_extra = true;
751 }
752 }
753 assert!(saw_extra, "extra request never completed");
754
755 drop(client);
756 let _ = server_task.await;
757 }
758
759 #[tokio::test]
760 async fn execute_with_deadline_times_out() {
761 let mut registry = Registry::new();
762 registry.register("echo", |request| async move {
763 ExecutionOutcome::success(
764 request.job_id.clone(),
765 request.request_id.clone(),
766 json!({"ok": true}),
767 )
768 });
769 let mut request = build_request("echo");
770 request.context.deadline = Some(
771 "2020-01-01T00:00:00Z"
772 .parse::<chrono::DateTime<Utc>>()
773 .unwrap(),
774 );
775 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
776 assert_eq!(outcome.status, OutcomeStatus::Timeout);
777 }
778
779 #[tokio::test]
780 async fn execute_with_deadline_succeeds_before_deadline() {
781 let mut registry = Registry::new();
782 registry.register("echo", |request| async move {
783 ExecutionOutcome::success(
784 request.job_id.clone(),
785 request.request_id.clone(),
786 json!({"ok": true}),
787 )
788 });
789 let mut request = build_request("echo");
790 request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
791 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
792 assert_eq!(outcome.status, OutcomeStatus::Success);
793 }
794
795 #[tokio::test]
796 async fn handle_connection_handles_unexpected_response_message() {
797 let registry = Registry::new();
798 let (client, server) = tcp_pair().await;
799 let in_flight = Arc::new(Mutex::new(HashMap::new()));
800 let job_index = Arc::new(Mutex::new(HashMap::new()));
801 let server_task = tokio::spawn(async move {
802 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
803 .await
804 .unwrap();
805 });
806 let mut client = client;
807 let response = RunnerMessage::Response {
808 payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
809 };
810 write_message(&mut client, &response).await.unwrap();
811 let reply = read_message(&mut client).await.unwrap().unwrap();
812 match reply {
813 RunnerMessage::Response { payload } => {
814 assert_eq!(payload.status, OutcomeStatus::Error);
815 assert!(
816 payload
817 .error
818 .as_ref()
819 .unwrap()
820 .message
821 .contains("unexpected response")
822 );
823 }
824 _ => panic!("expected response"),
825 }
826 drop(client);
827 let _ = server_task.await;
828 }
829
830 #[tokio::test]
831 async fn handle_connection_cancels_by_job_id() {
832 let mut registry = Registry::new();
833 registry.register("sleep", |request| async move {
834 tokio::time::sleep(Duration::from_millis(200)).await;
835 ExecutionOutcome::success(
836 request.job_id.clone(),
837 request.request_id.clone(),
838 json!({"ok": true}),
839 )
840 });
841 let (client, server) = tcp_pair().await;
842 let in_flight = Arc::new(Mutex::new(HashMap::new()));
843 let job_index = Arc::new(Mutex::new(HashMap::new()));
844 let server_task = tokio::spawn(async move {
845 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
846 .await
847 .unwrap();
848 });
849 let mut client = client;
850 let request = build_request("sleep");
851 let message = RunnerMessage::Request {
852 payload: request.clone(),
853 };
854 write_message(&mut client, &message).await.unwrap();
855 let cancel = RunnerMessage::Cancel {
856 payload: CancelRequest {
857 protocol_version: PROTOCOL_VERSION.to_string(),
858 job_id: request.job_id.clone(),
859 request_id: None,
860 hard_kill: false,
861 },
862 };
863 write_message(&mut client, &cancel).await.unwrap();
864 let response = read_message(&mut client).await.unwrap().unwrap();
865 match response {
866 RunnerMessage::Response { payload } => {
867 assert_eq!(payload.status, OutcomeStatus::Error);
868 let error_type = payload
869 .error
870 .as_ref()
871 .and_then(|error| error.error_type.as_deref());
872 assert_eq!(error_type, Some("cancelled"));
873 }
874 _ => panic!("expected response"),
875 }
876 drop(client);
877 let _ = server_task.await;
878 }
879
880 #[tokio::test]
881 async fn handle_cancel_by_job_id_cancels_all_requests() {
882 let mut registry = Registry::new();
883 registry.register("sleep", |request| async move {
884 tokio::time::sleep(Duration::from_millis(200)).await;
885 ExecutionOutcome::success(
886 request.job_id.clone(),
887 request.request_id.clone(),
888 json!({"ok": true}),
889 )
890 });
891 let (client, server) = tcp_pair().await;
892 let in_flight = Arc::new(Mutex::new(HashMap::new()));
893 let job_index = Arc::new(Mutex::new(HashMap::new()));
894 let server_task = tokio::spawn(async move {
895 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
896 .await
897 .unwrap();
898 });
899 let mut client = client;
900 let mut request1 = build_request("sleep");
901 request1.request_id = "req-1".to_string();
902 request1.job_id = "job-shared".to_string();
903 let mut request2 = build_request("sleep");
904 request2.request_id = "req-2".to_string();
905 request2.job_id = "job-shared".to_string();
906 write_message(&mut client, &RunnerMessage::Request { payload: request1 })
907 .await
908 .unwrap();
909 write_message(&mut client, &RunnerMessage::Request { payload: request2 })
910 .await
911 .unwrap();
912 let cancel = RunnerMessage::Cancel {
913 payload: CancelRequest {
914 protocol_version: PROTOCOL_VERSION.to_string(),
915 job_id: "job-shared".to_string(),
916 request_id: None,
917 hard_kill: false,
918 },
919 };
920 write_message(&mut client, &cancel).await.unwrap();
921
922 let mut cancelled = 0;
923 for _ in 0..2 {
924 let response = timeout(Duration::from_millis(200), read_message(&mut client))
925 .await
926 .unwrap()
927 .unwrap()
928 .unwrap();
929 match response {
930 RunnerMessage::Response { payload } => {
931 assert_eq!(payload.status, OutcomeStatus::Error);
932 let error_type = payload
933 .error
934 .as_ref()
935 .and_then(|error| error.error_type.as_deref());
936 assert_eq!(error_type, Some("cancelled"));
937 cancelled += 1;
938 }
939 _ => panic!("expected response"),
940 }
941 }
942 assert_eq!(cancelled, 2);
943 drop(client);
944 let _ = server_task.await;
945 }
946
947 #[tokio::test]
948 async fn connection_teardown_clears_tracking_maps() {
949 let mut registry = Registry::new();
950 registry.register("sleep", |request| async move {
951 tokio::time::sleep(Duration::from_millis(200)).await;
952 ExecutionOutcome::success(
953 request.job_id.clone(),
954 request.request_id.clone(),
955 json!({"ok": true}),
956 )
957 });
958 let (client, server) = tcp_pair().await;
959 let in_flight = Arc::new(Mutex::new(HashMap::new()));
960 let job_index = Arc::new(Mutex::new(HashMap::new()));
961 let in_flight_for_server = in_flight.clone();
962 let job_index_for_server = job_index.clone();
963 let server_task = tokio::spawn(async move {
964 handle_connection(
965 server,
966 ®istry,
967 &NoopTelemetry,
968 in_flight_for_server,
969 job_index_for_server,
970 )
971 .await
972 .unwrap();
973 });
974 let mut client = client;
975 let request = build_request("sleep");
976 let message = RunnerMessage::Request {
977 payload: request.clone(),
978 };
979 write_message(&mut client, &message).await.unwrap();
980
981 let mut inserted = false;
982 for _ in 0..20 {
983 let has_in_flight = {
984 let guard = in_flight.lock().await;
985 guard.contains_key(&request.request_id)
986 };
987 let has_job_index = {
988 let guard = job_index.lock().await;
989 guard.contains_key(&request.job_id)
990 };
991 if has_in_flight && has_job_index {
992 inserted = true;
993 break;
994 }
995 tokio::time::sleep(Duration::from_millis(10)).await;
996 }
997 assert!(inserted, "request never entered tracking maps");
998
999 drop(client);
1000 let _ = server_task.await;
1001
1002 let in_flight = in_flight.lock().await;
1003 let job_index = job_index.lock().await;
1004 assert!(in_flight.is_empty());
1005 assert!(job_index.is_empty());
1006 }
1007
1008 #[tokio::test]
1009 async fn handle_cancel_ignores_invalid_protocol() {
1010 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1011 let job_index = Arc::new(Mutex::new(HashMap::new()));
1012 let (tx, _rx) = mpsc::channel(1);
1013 let handle = tokio::spawn(async {});
1014 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1015 {
1016 let mut guard = in_flight.lock().await;
1017 guard.insert(
1018 "req-1".to_string(),
1019 InFlightTask {
1020 job_id: "job-1".to_string(),
1021 handle,
1022 response_tx: tx,
1023 connection_requests,
1024 completed: Arc::new(AtomicBool::new(false)),
1025 },
1026 );
1027 }
1028 let payload = CancelRequest {
1029 protocol_version: "0".to_string(),
1030 job_id: "job-1".to_string(),
1031 request_id: None,
1032 hard_kill: false,
1033 };
1034 handle_cancel(payload, &in_flight, &job_index).await;
1035 let guard = in_flight.lock().await;
1036 assert!(guard.contains_key("req-1"));
1037 guard.get("req-1").unwrap().handle.abort();
1038 }
1039
1040 #[tokio::test]
1041 async fn handle_cancel_skips_completed_requests() {
1042 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1043 let job_index = Arc::new(Mutex::new(HashMap::new()));
1044 let (tx, mut rx) = mpsc::channel(1);
1045 let handle = tokio::spawn(async {
1046 tokio::time::sleep(Duration::from_millis(50)).await;
1047 });
1048 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1049 {
1050 let mut guard = in_flight.lock().await;
1051 guard.insert(
1052 "req-1".to_string(),
1053 InFlightTask {
1054 job_id: "job-1".to_string(),
1055 handle,
1056 response_tx: tx,
1057 connection_requests,
1058 completed: Arc::new(AtomicBool::new(true)),
1059 },
1060 );
1061 }
1062 {
1063 let mut guard = job_index.lock().await;
1064 guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1065 }
1066 let payload = CancelRequest {
1067 protocol_version: PROTOCOL_VERSION.to_string(),
1068 job_id: "job-1".to_string(),
1069 request_id: Some("req-1".to_string()),
1070 hard_kill: false,
1071 };
1072 handle_cancel(payload, &in_flight, &job_index).await;
1073 assert!(in_flight.lock().await.contains_key("req-1"));
1074 assert!(job_index.lock().await.contains_key("job-1"));
1075 assert!(rx.try_recv().is_err());
1076 let task = in_flight.lock().await.remove("req-1").unwrap();
1077 task.handle.abort();
1078 }
1079
1080 #[tokio::test]
1081 async fn read_message_handles_empty_and_invalid_payloads() {
1082 let (mut client, mut server) = tokio::io::duplex(64);
1083 client.write_all(&0u32.to_be_bytes()).await.unwrap();
1085 let err = read_message(&mut server).await.unwrap_err();
1086 assert!(err.to_string().contains("payload cannot be empty"));
1087
1088 let (mut client, mut server) = tokio::io::duplex(64);
1090 let payload = b"not-json";
1091 let len = (payload.len() as u32).to_be_bytes();
1092 client.write_all(&len).await.unwrap();
1093 client.write_all(payload).await.unwrap();
1094 let err = read_message(&mut server).await.unwrap_err();
1095 assert!(err.to_string().contains("expected"));
1096
1097 let (mut client, mut server) = tokio::io::duplex(64);
1099 let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1100 client.write_all(&len).await.unwrap();
1101 let err = read_message(&mut server).await.unwrap_err();
1102 assert!(err.to_string().contains("payload too large"));
1103 }
1104
1105 #[tokio::test]
1106 async fn read_message_returns_none_on_eof() {
1107 let (client, mut server) = tokio::io::duplex(8);
1108 drop(client);
1109 let message = read_message(&mut server).await.unwrap();
1110 assert!(message.is_none());
1111 }
1112}