1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9 Arc,
10 atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
17const RESPONSE_CHANNEL_CAPACITY: usize = 64;
18const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
19
20fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
21 Box::new(std::io::Error::new(
22 std::io::ErrorKind::InvalidInput,
23 message.into(),
24 ))
25}
26
27pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
28 let raw = raw.trim();
29 if raw.is_empty() {
30 return Err(invalid_input("runner tcp_socket cannot be empty"));
31 }
32
33 let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
34 let (host, port_str) = rest
35 .split_once("]:")
36 .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
37 (host, port_str)
38 } else {
39 let (host, port_str) = raw
40 .rsplit_once(':')
41 .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
42 if host.is_empty() {
43 return Err(invalid_input("runner tcp_socket host cannot be empty"));
44 }
45 (host, port_str)
46 };
47
48 let port: u16 = port_str
49 .parse()
50 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
51 if port == 0 {
52 return Err(invalid_input("runner tcp_socket port must be > 0"));
53 }
54
55 let ip = if host == "localhost" {
56 IpAddr::V4(Ipv4Addr::LOCALHOST)
57 } else {
58 let parsed: IpAddr = host
59 .parse()
60 .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
61 if !parsed.is_loopback() {
62 return Err(invalid_input("runner tcp_socket host must be localhost"));
63 }
64 parsed
65 };
66
67 Ok(SocketAddr::new(ip, port))
68}
69
70pub struct RunnerRuntime {
71 runtime: tokio::runtime::Runtime,
72}
73
74impl RunnerRuntime {
75 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
76 Ok(Self {
77 runtime: tokio::runtime::Runtime::new()?,
78 })
79 }
80
81 pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
82 self.runtime.enter()
83 }
84
85 pub fn run_tcp(
86 &self,
87 registry: &Registry,
88 addr: SocketAddr,
89 ) -> Result<(), Box<dyn std::error::Error>> {
90 let telemetry = NoopTelemetry;
91 self.run_tcp_with(registry, addr, &telemetry)
92 }
93
94 pub fn run_tcp_with<T: Telemetry + ?Sized>(
95 &self,
96 registry: &Registry,
97 addr: SocketAddr,
98 telemetry: &T,
99 ) -> Result<(), Box<dyn std::error::Error>> {
100 run_tcp_loop(&self.runtime, registry, addr, telemetry)
101 }
102}
103
104pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
105 RunnerRuntime::new()?.run_tcp(registry, addr)
106}
107
108pub fn run_tcp_with<T: Telemetry + ?Sized>(
109 registry: &Registry,
110 addr: SocketAddr,
111 telemetry: &T,
112) -> Result<(), Box<dyn std::error::Error>> {
113 RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
114}
115
116fn run_tcp_loop<T: Telemetry + ?Sized>(
117 runtime: &tokio::runtime::Runtime,
118 registry: &Registry,
119 addr: SocketAddr,
120 telemetry: &T,
121) -> Result<(), Box<dyn std::error::Error>> {
122 let registry = registry.clone();
123 let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
124 let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
125 Arc::new(Mutex::new(HashMap::new()));
126 let telemetry = telemetry.clone_box();
127 runtime.block_on(async move {
128 if !addr.ip().is_loopback() {
129 return Err(invalid_input(format!(
130 "runner tcp_socket must be loopback-only (got {addr})"
131 )));
132 }
133 let listener = TcpListener::bind(addr).await?;
134 loop {
135 let (stream, _) = listener.accept().await?;
136 let registry = registry.clone();
137 let telemetry = telemetry.clone();
138 let in_flight = in_flight.clone();
139 let job_index = job_index.clone();
140 tokio::spawn(async move {
141 if let Err(err) =
142 handle_connection(stream, ®istry, telemetry.as_ref(), in_flight, job_index)
143 .await
144 {
145 tracing::error!("runner connection error: {err}");
146 }
147 });
148 }
149 })
150}
151
152async fn handle_connection<S, T>(
153 stream: S,
154 registry: &Registry,
155 telemetry: &T,
156 in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
157 job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
158) -> Result<(), Box<dyn std::error::Error>>
159where
160 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
161 T: Telemetry + ?Sized,
162{
163 let (mut reader, mut writer) = tokio::io::split(stream);
164 let (response_tx, mut response_rx) =
165 mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
166 let writer_task = tokio::spawn(async move {
167 while let Some(outcome) = response_rx.recv().await {
168 let response = RunnerMessage::Response { payload: outcome };
169 if write_message(&mut writer, &response).await.is_err() {
170 break;
171 }
172 }
173 });
174 let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
175 Arc::new(Mutex::new(std::collections::HashSet::new()));
176
177 loop {
178 let message = match read_message(&mut reader).await? {
179 Some(message) => message,
180 None => break,
181 };
182 match message {
183 RunnerMessage::Request { payload } => {
184 if payload.protocol_version != PROTOCOL_VERSION {
185 let outcome = ExecutionOutcome::error(
186 payload.job_id.clone(),
187 payload.request_id.clone(),
188 "Unsupported protocol version",
189 );
190 let _ = response_tx.send(outcome).await;
191 continue;
192 }
193
194 let request_id = payload.request_id.clone();
195 let job_id = payload.job_id.clone();
196 {
197 let mut active = connection_requests.lock().await;
198 if active.len() >= RESPONSE_CHANNEL_CAPACITY {
199 let outcome = ExecutionOutcome::error(
200 payload.job_id.clone(),
201 payload.request_id.clone(),
202 "Runner busy: too many in-flight requests",
203 );
204 drop(active);
205 let send_result =
206 timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
207 match send_result {
208 Ok(Ok(())) => {}
209 Ok(Err(_)) => {
210 return Err("runner response channel closed".into());
211 }
212 Err(_) => {
213 return Err("runner response channel stalled".into());
214 }
215 }
216 continue;
217 }
218 active.insert(request_id.clone());
219 crate::telemetry::record_runner_channel_pressure(active.len());
220 }
221 let response_tx = response_tx.clone();
222 let registry = registry.clone();
223 let telemetry = telemetry.clone_box();
224 let in_flight_for_task = in_flight.clone();
225 let job_index_for_task = job_index.clone();
226 let active_for_task = connection_requests.clone();
227 let request_id_for_task = request_id.clone();
228 let job_id_for_task = job_id.clone();
229 let response_tx_for_task = response_tx.clone();
230 let completed = Arc::new(AtomicBool::new(false));
231 let completed_for_task = completed.clone();
232
233 let handle = tokio::spawn(async move {
234 let outcome =
235 execute_with_deadline(payload, registry, telemetry.as_ref()).await;
236 completed_for_task.store(true, Ordering::SeqCst);
237 let send_result =
238 timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
239 match send_result {
240 Ok(Ok(())) => {}
241 Ok(Err(_)) => {
242 tracing::warn!("runner response channel closed; dropping outcome");
243 }
244 Err(_) => {
245 tracing::warn!("runner response channel stalled; dropping outcome");
246 }
247 }
248 {
249 let mut in_flight = in_flight_for_task.lock().await;
250 if in_flight.remove(&request_id_for_task).is_some() {
251 crate::telemetry::record_runner_inflight_delta(-1);
252 }
253 }
254 {
255 let mut job_index = job_index_for_task.lock().await;
256 if let Some(entries) = job_index.get_mut(&job_id_for_task) {
257 entries.remove(&request_id_for_task);
258 if entries.is_empty() {
259 job_index.remove(&job_id_for_task);
260 }
261 }
262 }
263 {
264 let mut active = active_for_task.lock().await;
265 active.remove(&request_id_for_task);
266 crate::telemetry::record_runner_channel_pressure(active.len());
267 }
268 });
269
270 {
271 let mut in_flight = in_flight.lock().await;
272 in_flight.insert(
273 request_id.clone(),
274 InFlightTask {
275 job_id: job_id.clone(),
276 handle,
277 response_tx: response_tx.clone(),
278 connection_requests: connection_requests.clone(),
279 completed,
280 },
281 );
282 }
283 crate::telemetry::record_runner_inflight_delta(1);
284 {
285 let mut job_index = job_index.lock().await;
286 job_index
287 .entry(job_id)
288 .or_insert_with(HashSet::new)
289 .insert(request_id);
290 }
291 }
292 RunnerMessage::Cancel { payload } => {
293 handle_cancel(payload, &in_flight, &job_index).await;
294 }
295 RunnerMessage::Response { .. } => {
296 let outcome = ExecutionOutcome {
297 job_id: Some("unknown".to_string()),
298 request_id: None,
299 status: rrq_protocol::OutcomeStatus::Error,
300 result: None,
301 error: Some(ExecutionError {
302 message: "unexpected response message".to_string(),
303 error_type: None,
304 code: None,
305 details: None,
306 }),
307 retry_after_seconds: None,
308 };
309 let _ = response_tx.send(outcome).await;
310 }
311 }
312 }
313
314 let request_ids = {
315 let mut active = connection_requests.lock().await;
316 active.drain().collect::<Vec<_>>()
317 };
318 crate::telemetry::record_runner_channel_pressure(0);
319 for request_id in request_ids {
320 let task = {
321 let mut in_flight = in_flight.lock().await;
322 in_flight.remove(&request_id)
323 };
324 if let Some(task) = task {
325 task.handle.abort();
326 crate::telemetry::record_runner_inflight_delta(-1);
327 let mut job_index = job_index.lock().await;
328 if let Some(entries) = job_index.get_mut(&task.job_id) {
329 entries.remove(&request_id);
330 if entries.is_empty() {
331 job_index.remove(&task.job_id);
332 }
333 }
334 }
335 }
336 writer_task.abort();
337
338 Ok(())
339}
340
341struct InFlightTask {
342 job_id: String,
343 handle: tokio::task::JoinHandle<()>,
344 response_tx: mpsc::Sender<ExecutionOutcome>,
345 connection_requests: Arc<Mutex<HashSet<String>>>,
346 completed: Arc<AtomicBool>,
347}
348
349async fn handle_cancel(
350 payload: CancelRequest,
351 in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
352 job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
353) {
354 if payload.protocol_version != PROTOCOL_VERSION {
355 return;
356 }
357 let request_ids = if let Some(request_id) = payload.request_id.clone() {
358 vec![request_id]
359 } else {
360 let job_index = job_index.lock().await;
361 job_index
362 .get(&payload.job_id)
363 .map(|ids| ids.iter().cloned().collect())
364 .unwrap_or_else(Vec::new)
365 };
366 if request_ids.is_empty() {
367 return;
368 }
369 let scope = if payload.request_id.is_some() {
370 "request"
371 } else {
372 "job"
373 };
374 crate::telemetry::record_cancellation(scope);
375
376 for request_id in request_ids {
377 let task = {
378 let mut in_flight = in_flight.lock().await;
379 if let Some(task) = in_flight.get(&request_id)
380 && task.completed.load(Ordering::SeqCst)
381 {
382 None
383 } else {
384 in_flight.remove(&request_id)
385 }
386 };
387 if let Some(task) = task {
388 task.handle.abort();
389 crate::telemetry::record_runner_inflight_delta(-1);
390 {
391 let mut active = task.connection_requests.lock().await;
392 active.remove(&request_id);
393 crate::telemetry::record_runner_channel_pressure(active.len());
394 }
395 let outcome = ExecutionOutcome {
396 job_id: Some(payload.job_id.clone()),
397 request_id: Some(request_id.clone()),
398 status: OutcomeStatus::Error,
399 result: None,
400 error: Some(ExecutionError {
401 message: "Job cancelled".to_string(),
402 error_type: Some("cancelled".to_string()),
403 code: None,
404 details: None,
405 }),
406 retry_after_seconds: None,
407 };
408 let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
409 match send_result {
410 Ok(Ok(())) => {}
411 Ok(Err(_)) => {
412 tracing::warn!("runner response channel closed; dropping cancel outcome");
413 }
414 Err(_) => {
415 tracing::warn!("runner response channel stalled; dropping cancel outcome");
416 }
417 }
418 let mut job_index = job_index.lock().await;
419 if let Some(entries) = job_index.get_mut(&task.job_id) {
420 entries.remove(&request_id);
421 if entries.is_empty() {
422 job_index.remove(&task.job_id);
423 }
424 }
425 }
426 }
427}
428
429async fn execute_with_deadline<T: Telemetry + ?Sized>(
430 request: rrq_protocol::ExecutionRequest,
431 registry: Registry,
432 telemetry: &T,
433) -> ExecutionOutcome {
434 let job_id = request.job_id.clone();
435 let request_id = request.request_id.clone();
436 let deadline = request.context.deadline;
437 if let Some(deadline) = deadline {
438 let now: DateTime<Utc> = Utc::now();
439 if deadline <= now {
440 crate::telemetry::record_deadline_expired();
441 return ExecutionOutcome::timeout(
442 job_id.clone(),
443 request_id.clone(),
444 "Job deadline exceeded",
445 );
446 }
447 if let Ok(remaining) = (deadline - now).to_std() {
448 match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
449 Ok(outcome) => return outcome,
450 Err(_) => {
451 crate::telemetry::record_deadline_expired();
452 return ExecutionOutcome::timeout(
453 job_id.clone(),
454 request_id.clone(),
455 "Job execution timed out",
456 );
457 }
458 }
459 }
460 crate::telemetry::record_deadline_expired();
461 return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
462 }
463 registry.execute_with(request, telemetry).await
464}
465
466async fn read_message<R: AsyncRead + Unpin>(
467 stream: &mut R,
468) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
469 let mut header = [0u8; 4];
470 match stream.read_exact(&mut header).await {
471 Ok(_) => {}
472 Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
473 Err(err) => return Err(err.into()),
474 }
475 let length = u32::from_be_bytes(header) as usize;
476 if length == 0 {
477 return Err("runner message payload cannot be empty".into());
478 }
479 if length > MAX_FRAME_LEN {
480 return Err("runner message payload too large".into());
481 }
482 let mut payload = vec![0u8; length];
483 stream.read_exact(&mut payload).await?;
484 let message = serde_json::from_slice(&payload)?;
485 Ok(Some(message))
486}
487
488async fn write_message<W: AsyncWrite + Unpin>(
489 stream: &mut W,
490 message: &RunnerMessage,
491) -> Result<(), Box<dyn std::error::Error>> {
492 let framed = encode_frame(message)?;
493 stream.write_all(&framed).await?;
494 stream.flush().await?;
495 Ok(())
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::registry::Registry;
502 use crate::telemetry::NoopTelemetry;
503 use chrono::Utc;
504 use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
505 use serde_json::json;
506 use tokio::net::{TcpListener, TcpStream};
507 use tokio::time::{Duration, timeout};
508
509 fn build_request(function_name: &str) -> ExecutionRequest {
510 ExecutionRequest {
511 protocol_version: PROTOCOL_VERSION.to_string(),
512 request_id: "req-1".to_string(),
513 job_id: "job-1".to_string(),
514 function_name: function_name.to_string(),
515 params: std::collections::HashMap::new(),
516 context: ExecutionContext {
517 job_id: "job-1".to_string(),
518 attempt: 1,
519 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
520 queue_name: "default".to_string(),
521 deadline: None,
522 trace_context: None,
523 correlation_context: None,
524 worker_id: None,
525 },
526 }
527 }
528
529 async fn tcp_pair() -> (TcpStream, TcpStream) {
530 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
531 let addr = listener.local_addr().unwrap();
532 let client = TcpStream::connect(addr).await.unwrap();
533 let (server, _) = listener.accept().await.unwrap();
534 (client, server)
535 }
536
537 #[tokio::test]
538 async fn handle_connection_executes_request() {
539 let mut registry = Registry::new();
540 registry.register("echo", |request| async move {
541 ExecutionOutcome::success(
542 request.job_id.clone(),
543 request.request_id.clone(),
544 json!({"ok": true}),
545 )
546 });
547 let (client, server) = tcp_pair().await;
548 let in_flight = Arc::new(Mutex::new(HashMap::new()));
549 let job_index = Arc::new(Mutex::new(HashMap::new()));
550 let server_task = tokio::spawn(async move {
551 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
552 .await
553 .unwrap();
554 });
555 let mut client = client;
556 let request = build_request("echo");
557 let message = RunnerMessage::Request { payload: request };
558 write_message(&mut client, &message).await.unwrap();
559 let response = read_message(&mut client).await.unwrap().unwrap();
560 match response {
561 RunnerMessage::Response { payload } => {
562 assert_eq!(payload.status, OutcomeStatus::Success);
563 assert_eq!(payload.result, Some(json!({"ok": true})));
564 }
565 _ => panic!("expected response"),
566 }
567 drop(client);
568 let _ = server_task.await;
569 }
570
571 #[tokio::test]
572 async fn handle_connection_rejects_bad_protocol() {
573 let registry = Registry::new();
574 let (client, server) = tcp_pair().await;
575 let in_flight = Arc::new(Mutex::new(HashMap::new()));
576 let job_index = Arc::new(Mutex::new(HashMap::new()));
577 let server_task = tokio::spawn(async move {
578 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
579 .await
580 .unwrap();
581 });
582 let mut client = client;
583 let mut request = build_request("echo");
584 request.protocol_version = "0".to_string();
585 let message = RunnerMessage::Request { payload: request };
586 write_message(&mut client, &message).await.unwrap();
587 let response = read_message(&mut client).await.unwrap().unwrap();
588 match response {
589 RunnerMessage::Response { payload } => {
590 assert_eq!(payload.status, OutcomeStatus::Error);
591 }
592 _ => panic!("expected response"),
593 }
594 drop(client);
595 let _ = server_task.await;
596 }
597
598 #[tokio::test]
599 async fn handle_connection_cancels_inflight() {
600 let mut registry = Registry::new();
601 registry.register("sleep", |request| async move {
602 tokio::time::sleep(Duration::from_millis(200)).await;
603 ExecutionOutcome::success(
604 request.job_id.clone(),
605 request.request_id.clone(),
606 json!({"ok": true}),
607 )
608 });
609 let (client, server) = tcp_pair().await;
610 let in_flight = Arc::new(Mutex::new(HashMap::new()));
611 let job_index = Arc::new(Mutex::new(HashMap::new()));
612 let server_task = tokio::spawn(async move {
613 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
614 .await
615 .unwrap();
616 });
617 let mut client = client;
618 let request = ExecutionRequest {
619 protocol_version: PROTOCOL_VERSION.to_string(),
620 request_id: "req-cancel".to_string(),
621 job_id: "job-cancel".to_string(),
622 function_name: "sleep".to_string(),
623 params: std::collections::HashMap::new(),
624 context: ExecutionContext {
625 job_id: "job-cancel".to_string(),
626 attempt: 1,
627 enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
628 queue_name: "default".to_string(),
629 deadline: None,
630 trace_context: None,
631 correlation_context: None,
632 worker_id: None,
633 },
634 };
635 let message = RunnerMessage::Request {
636 payload: request.clone(),
637 };
638 write_message(&mut client, &message).await.unwrap();
639 let cancel = RunnerMessage::Cancel {
640 payload: CancelRequest {
641 protocol_version: PROTOCOL_VERSION.to_string(),
642 job_id: request.job_id.clone(),
643 request_id: Some(request.request_id.clone()),
644 hard_kill: false,
645 },
646 };
647 write_message(&mut client, &cancel).await.unwrap();
648 let response = read_message(&mut client).await.unwrap().unwrap();
649 match response {
650 RunnerMessage::Response { payload } => {
651 assert_eq!(payload.status, OutcomeStatus::Error);
652 let error_type = payload
653 .error
654 .as_ref()
655 .and_then(|error| error.error_type.as_deref());
656 assert_eq!(error_type, Some("cancelled"));
657 }
658 _ => panic!("expected response"),
659 }
660 drop(client);
661 let _ = server_task.await;
662 }
663
664 #[tokio::test]
665 async fn cancel_frees_connection_capacity() {
666 let mut registry = Registry::new();
667 let gate = Arc::new(tokio::sync::Semaphore::new(0));
668 let gate_for_handler = gate.clone();
669 registry.register("block", move |request| {
670 let gate = gate_for_handler.clone();
671 async move {
672 let _permit = gate.acquire().await.expect("semaphore closed");
673 ExecutionOutcome::success(
674 request.job_id.clone(),
675 request.request_id.clone(),
676 json!({"ok": true}),
677 )
678 }
679 });
680 let (client, server) = tcp_pair().await;
681 let in_flight = Arc::new(Mutex::new(HashMap::new()));
682 let job_index = Arc::new(Mutex::new(HashMap::new()));
683 let server_task = tokio::spawn(async move {
684 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
685 .await
686 .unwrap();
687 });
688 let mut client = client;
689 let job_id = "job-capacity".to_string();
690 for i in 0..RESPONSE_CHANNEL_CAPACITY {
691 let mut request = build_request("block");
692 request.request_id = format!("req-{i}");
693 request.job_id = job_id.clone();
694 request.context.job_id = job_id.clone();
695 write_message(&mut client, &RunnerMessage::Request { payload: request })
696 .await
697 .unwrap();
698 }
699
700 let cancel = RunnerMessage::Cancel {
701 payload: CancelRequest {
702 protocol_version: PROTOCOL_VERSION.to_string(),
703 job_id: job_id.clone(),
704 request_id: Some("req-0".to_string()),
705 hard_kill: false,
706 },
707 };
708 write_message(&mut client, &cancel).await.unwrap();
709 let response = timeout(Duration::from_secs(1), read_message(&mut client))
710 .await
711 .unwrap()
712 .unwrap()
713 .unwrap();
714 match response {
715 RunnerMessage::Response { payload } => {
716 assert_eq!(payload.status, OutcomeStatus::Error);
717 let error_type = payload
718 .error
719 .as_ref()
720 .and_then(|error| error.error_type.as_deref());
721 assert_eq!(error_type, Some("cancelled"));
722 }
723 _ => panic!("expected response"),
724 }
725
726 let mut extra_request = build_request("block");
727 extra_request.request_id = "req-extra".to_string();
728 extra_request.job_id = job_id.clone();
729 extra_request.context.job_id = job_id.clone();
730 write_message(
731 &mut client,
732 &RunnerMessage::Request {
733 payload: extra_request,
734 },
735 )
736 .await
737 .unwrap();
738
739 gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
740
741 let mut saw_extra = false;
742 for _ in 0..RESPONSE_CHANNEL_CAPACITY {
743 let response = timeout(Duration::from_secs(1), read_message(&mut client))
744 .await
745 .unwrap()
746 .unwrap()
747 .unwrap();
748 if let RunnerMessage::Response { payload } = response
749 && payload.request_id.as_deref() == Some("req-extra")
750 {
751 assert_eq!(payload.status, OutcomeStatus::Success);
752 saw_extra = true;
753 }
754 }
755 assert!(saw_extra, "extra request never completed");
756
757 drop(client);
758 let _ = server_task.await;
759 }
760
761 #[tokio::test]
762 async fn execute_with_deadline_times_out() {
763 let mut registry = Registry::new();
764 registry.register("echo", |request| async move {
765 ExecutionOutcome::success(
766 request.job_id.clone(),
767 request.request_id.clone(),
768 json!({"ok": true}),
769 )
770 });
771 let mut request = build_request("echo");
772 request.context.deadline = Some(
773 "2020-01-01T00:00:00Z"
774 .parse::<chrono::DateTime<Utc>>()
775 .unwrap(),
776 );
777 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
778 assert_eq!(outcome.status, OutcomeStatus::Timeout);
779 }
780
781 #[tokio::test]
782 async fn execute_with_deadline_succeeds_before_deadline() {
783 let mut registry = Registry::new();
784 registry.register("echo", |request| async move {
785 ExecutionOutcome::success(
786 request.job_id.clone(),
787 request.request_id.clone(),
788 json!({"ok": true}),
789 )
790 });
791 let mut request = build_request("echo");
792 request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
793 let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
794 assert_eq!(outcome.status, OutcomeStatus::Success);
795 }
796
797 #[tokio::test]
798 async fn handle_connection_handles_unexpected_response_message() {
799 let registry = Registry::new();
800 let (client, server) = tcp_pair().await;
801 let in_flight = Arc::new(Mutex::new(HashMap::new()));
802 let job_index = Arc::new(Mutex::new(HashMap::new()));
803 let server_task = tokio::spawn(async move {
804 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
805 .await
806 .unwrap();
807 });
808 let mut client = client;
809 let response = RunnerMessage::Response {
810 payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
811 };
812 write_message(&mut client, &response).await.unwrap();
813 let reply = read_message(&mut client).await.unwrap().unwrap();
814 match reply {
815 RunnerMessage::Response { payload } => {
816 assert_eq!(payload.status, OutcomeStatus::Error);
817 assert!(
818 payload
819 .error
820 .as_ref()
821 .unwrap()
822 .message
823 .contains("unexpected response")
824 );
825 }
826 _ => panic!("expected response"),
827 }
828 drop(client);
829 let _ = server_task.await;
830 }
831
832 #[tokio::test]
833 async fn handle_connection_cancels_by_job_id() {
834 let mut registry = Registry::new();
835 registry.register("sleep", |request| async move {
836 tokio::time::sleep(Duration::from_millis(200)).await;
837 ExecutionOutcome::success(
838 request.job_id.clone(),
839 request.request_id.clone(),
840 json!({"ok": true}),
841 )
842 });
843 let (client, server) = tcp_pair().await;
844 let in_flight = Arc::new(Mutex::new(HashMap::new()));
845 let job_index = Arc::new(Mutex::new(HashMap::new()));
846 let server_task = tokio::spawn(async move {
847 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
848 .await
849 .unwrap();
850 });
851 let mut client = client;
852 let request = build_request("sleep");
853 let message = RunnerMessage::Request {
854 payload: request.clone(),
855 };
856 write_message(&mut client, &message).await.unwrap();
857 let cancel = RunnerMessage::Cancel {
858 payload: CancelRequest {
859 protocol_version: PROTOCOL_VERSION.to_string(),
860 job_id: request.job_id.clone(),
861 request_id: None,
862 hard_kill: false,
863 },
864 };
865 write_message(&mut client, &cancel).await.unwrap();
866 let response = read_message(&mut client).await.unwrap().unwrap();
867 match response {
868 RunnerMessage::Response { payload } => {
869 assert_eq!(payload.status, OutcomeStatus::Error);
870 let error_type = payload
871 .error
872 .as_ref()
873 .and_then(|error| error.error_type.as_deref());
874 assert_eq!(error_type, Some("cancelled"));
875 }
876 _ => panic!("expected response"),
877 }
878 drop(client);
879 let _ = server_task.await;
880 }
881
882 #[tokio::test]
883 async fn handle_cancel_by_job_id_cancels_all_requests() {
884 let mut registry = Registry::new();
885 registry.register("sleep", |request| async move {
886 tokio::time::sleep(Duration::from_millis(200)).await;
887 ExecutionOutcome::success(
888 request.job_id.clone(),
889 request.request_id.clone(),
890 json!({"ok": true}),
891 )
892 });
893 let (client, server) = tcp_pair().await;
894 let in_flight = Arc::new(Mutex::new(HashMap::new()));
895 let job_index = Arc::new(Mutex::new(HashMap::new()));
896 let server_task = tokio::spawn(async move {
897 handle_connection(server, ®istry, &NoopTelemetry, in_flight, job_index)
898 .await
899 .unwrap();
900 });
901 let mut client = client;
902 let mut request1 = build_request("sleep");
903 request1.request_id = "req-1".to_string();
904 request1.job_id = "job-shared".to_string();
905 let mut request2 = build_request("sleep");
906 request2.request_id = "req-2".to_string();
907 request2.job_id = "job-shared".to_string();
908 write_message(&mut client, &RunnerMessage::Request { payload: request1 })
909 .await
910 .unwrap();
911 write_message(&mut client, &RunnerMessage::Request { payload: request2 })
912 .await
913 .unwrap();
914 let cancel = RunnerMessage::Cancel {
915 payload: CancelRequest {
916 protocol_version: PROTOCOL_VERSION.to_string(),
917 job_id: "job-shared".to_string(),
918 request_id: None,
919 hard_kill: false,
920 },
921 };
922 write_message(&mut client, &cancel).await.unwrap();
923
924 let mut cancelled = 0;
925 for _ in 0..2 {
926 let response = timeout(Duration::from_millis(200), read_message(&mut client))
927 .await
928 .unwrap()
929 .unwrap()
930 .unwrap();
931 match response {
932 RunnerMessage::Response { payload } => {
933 assert_eq!(payload.status, OutcomeStatus::Error);
934 let error_type = payload
935 .error
936 .as_ref()
937 .and_then(|error| error.error_type.as_deref());
938 assert_eq!(error_type, Some("cancelled"));
939 cancelled += 1;
940 }
941 _ => panic!("expected response"),
942 }
943 }
944 assert_eq!(cancelled, 2);
945 drop(client);
946 let _ = server_task.await;
947 }
948
949 #[tokio::test]
950 async fn connection_teardown_clears_tracking_maps() {
951 let mut registry = Registry::new();
952 registry.register("sleep", |request| async move {
953 tokio::time::sleep(Duration::from_millis(200)).await;
954 ExecutionOutcome::success(
955 request.job_id.clone(),
956 request.request_id.clone(),
957 json!({"ok": true}),
958 )
959 });
960 let (client, server) = tcp_pair().await;
961 let in_flight = Arc::new(Mutex::new(HashMap::new()));
962 let job_index = Arc::new(Mutex::new(HashMap::new()));
963 let in_flight_for_server = in_flight.clone();
964 let job_index_for_server = job_index.clone();
965 let server_task = tokio::spawn(async move {
966 handle_connection(
967 server,
968 ®istry,
969 &NoopTelemetry,
970 in_flight_for_server,
971 job_index_for_server,
972 )
973 .await
974 .unwrap();
975 });
976 let mut client = client;
977 let request = build_request("sleep");
978 let message = RunnerMessage::Request {
979 payload: request.clone(),
980 };
981 write_message(&mut client, &message).await.unwrap();
982
983 let mut inserted = false;
984 for _ in 0..20 {
985 let has_in_flight = {
986 let guard = in_flight.lock().await;
987 guard.contains_key(&request.request_id)
988 };
989 let has_job_index = {
990 let guard = job_index.lock().await;
991 guard.contains_key(&request.job_id)
992 };
993 if has_in_flight && has_job_index {
994 inserted = true;
995 break;
996 }
997 tokio::time::sleep(Duration::from_millis(10)).await;
998 }
999 assert!(inserted, "request never entered tracking maps");
1000
1001 drop(client);
1002 let _ = server_task.await;
1003
1004 let in_flight = in_flight.lock().await;
1005 let job_index = job_index.lock().await;
1006 assert!(in_flight.is_empty());
1007 assert!(job_index.is_empty());
1008 }
1009
1010 #[tokio::test]
1011 async fn handle_cancel_ignores_invalid_protocol() {
1012 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1013 let job_index = Arc::new(Mutex::new(HashMap::new()));
1014 let (tx, _rx) = mpsc::channel(1);
1015 let handle = tokio::spawn(async {});
1016 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1017 {
1018 let mut guard = in_flight.lock().await;
1019 guard.insert(
1020 "req-1".to_string(),
1021 InFlightTask {
1022 job_id: "job-1".to_string(),
1023 handle,
1024 response_tx: tx,
1025 connection_requests,
1026 completed: Arc::new(AtomicBool::new(false)),
1027 },
1028 );
1029 }
1030 let payload = CancelRequest {
1031 protocol_version: "0".to_string(),
1032 job_id: "job-1".to_string(),
1033 request_id: None,
1034 hard_kill: false,
1035 };
1036 handle_cancel(payload, &in_flight, &job_index).await;
1037 let guard = in_flight.lock().await;
1038 assert!(guard.contains_key("req-1"));
1039 guard.get("req-1").unwrap().handle.abort();
1040 }
1041
1042 #[tokio::test]
1043 async fn handle_cancel_skips_completed_requests() {
1044 let in_flight = Arc::new(Mutex::new(HashMap::new()));
1045 let job_index = Arc::new(Mutex::new(HashMap::new()));
1046 let (tx, mut rx) = mpsc::channel(1);
1047 let handle = tokio::spawn(async {
1048 tokio::time::sleep(Duration::from_millis(50)).await;
1049 });
1050 let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1051 {
1052 let mut guard = in_flight.lock().await;
1053 guard.insert(
1054 "req-1".to_string(),
1055 InFlightTask {
1056 job_id: "job-1".to_string(),
1057 handle,
1058 response_tx: tx,
1059 connection_requests,
1060 completed: Arc::new(AtomicBool::new(true)),
1061 },
1062 );
1063 }
1064 {
1065 let mut guard = job_index.lock().await;
1066 guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1067 }
1068 let payload = CancelRequest {
1069 protocol_version: PROTOCOL_VERSION.to_string(),
1070 job_id: "job-1".to_string(),
1071 request_id: Some("req-1".to_string()),
1072 hard_kill: false,
1073 };
1074 handle_cancel(payload, &in_flight, &job_index).await;
1075 assert!(in_flight.lock().await.contains_key("req-1"));
1076 assert!(job_index.lock().await.contains_key("job-1"));
1077 assert!(rx.try_recv().is_err());
1078 let task = in_flight.lock().await.remove("req-1").unwrap();
1079 task.handle.abort();
1080 }
1081
1082 #[tokio::test]
1083 async fn read_message_handles_empty_and_invalid_payloads() {
1084 let (mut client, mut server) = tokio::io::duplex(64);
1085 client.write_all(&0u32.to_be_bytes()).await.unwrap();
1087 let err = read_message(&mut server).await.unwrap_err();
1088 assert!(err.to_string().contains("payload cannot be empty"));
1089
1090 let (mut client, mut server) = tokio::io::duplex(64);
1092 let payload = b"not-json";
1093 let len = (payload.len() as u32).to_be_bytes();
1094 client.write_all(&len).await.unwrap();
1095 client.write_all(payload).await.unwrap();
1096 let err = read_message(&mut server).await.unwrap_err();
1097 assert!(err.to_string().contains("expected"));
1098
1099 let (mut client, mut server) = tokio::io::duplex(64);
1101 let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1102 client.write_all(&len).await.unwrap();
1103 let err = read_message(&mut server).await.unwrap_err();
1104 assert!(err.to_string().contains("payload too large"));
1105 }
1106
1107 #[tokio::test]
1108 async fn read_message_returns_none_on_eof() {
1109 let (client, mut server) = tokio::io::duplex(8);
1110 drop(client);
1111 let message = read_message(&mut server).await.unwrap();
1112 assert!(message.is_none());
1113 }
1114}