Skip to main content

rrq_runner/
runtime.rs

1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9    Arc,
10    atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16
17pub const ENV_RUNNER_TCP_SOCKET: &str = "RRQ_RUNNER_TCP_SOCKET";
18const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
19const RESPONSE_CHANNEL_CAPACITY: usize = 64;
20const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
21
22fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
23    Box::new(std::io::Error::new(
24        std::io::ErrorKind::InvalidInput,
25        message.into(),
26    ))
27}
28
29pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
30    let raw = raw.trim();
31    if raw.is_empty() {
32        return Err(invalid_input("runner tcp_socket cannot be empty"));
33    }
34
35    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
36        let (host, port_str) = rest
37            .split_once("]:")
38            .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
39        (host, port_str)
40    } else {
41        let (host, port_str) = raw
42            .rsplit_once(':')
43            .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
44        if host.is_empty() {
45            return Err(invalid_input("runner tcp_socket host cannot be empty"));
46        }
47        (host, port_str)
48    };
49
50    let port: u16 = port_str
51        .parse()
52        .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
53    if port == 0 {
54        return Err(invalid_input("runner tcp_socket port must be > 0"));
55    }
56
57    let ip = if host == "localhost" {
58        IpAddr::V4(Ipv4Addr::LOCALHOST)
59    } else {
60        let parsed: IpAddr = host
61            .parse()
62            .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
63        if !parsed.is_loopback() {
64            return Err(invalid_input("runner tcp_socket host must be localhost"));
65        }
66        parsed
67    };
68
69    Ok(SocketAddr::new(ip, port))
70}
71
72pub struct RunnerRuntime {
73    runtime: tokio::runtime::Runtime,
74}
75
76impl RunnerRuntime {
77    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
78        Ok(Self {
79            runtime: tokio::runtime::Runtime::new()?,
80        })
81    }
82
83    pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
84        self.runtime.enter()
85    }
86
87    pub fn run_tcp(
88        &self,
89        registry: &Registry,
90        addr: SocketAddr,
91    ) -> Result<(), Box<dyn std::error::Error>> {
92        let telemetry = NoopTelemetry;
93        self.run_tcp_with(registry, addr, &telemetry)
94    }
95
96    pub fn run_tcp_with<T: Telemetry + ?Sized>(
97        &self,
98        registry: &Registry,
99        addr: SocketAddr,
100        telemetry: &T,
101    ) -> Result<(), Box<dyn std::error::Error>> {
102        run_tcp_loop(&self.runtime, registry, addr, telemetry)
103    }
104}
105
106pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
107    RunnerRuntime::new()?.run_tcp(registry, addr)
108}
109
110pub fn run_tcp_with<T: Telemetry + ?Sized>(
111    registry: &Registry,
112    addr: SocketAddr,
113    telemetry: &T,
114) -> Result<(), Box<dyn std::error::Error>> {
115    RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
116}
117
118fn run_tcp_loop<T: Telemetry + ?Sized>(
119    runtime: &tokio::runtime::Runtime,
120    registry: &Registry,
121    addr: SocketAddr,
122    telemetry: &T,
123) -> Result<(), Box<dyn std::error::Error>> {
124    let registry = registry.clone();
125    let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
126    let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
127        Arc::new(Mutex::new(HashMap::new()));
128    let telemetry = telemetry.clone_box();
129    runtime.block_on(async move {
130        if !addr.ip().is_loopback() {
131            return Err(invalid_input(format!(
132                "runner tcp_socket must be loopback-only (got {addr})"
133            )));
134        }
135        let listener = TcpListener::bind(addr).await?;
136        loop {
137            let (stream, _) = listener.accept().await?;
138            let registry = registry.clone();
139            let telemetry = telemetry.clone();
140            let in_flight = in_flight.clone();
141            let job_index = job_index.clone();
142            tokio::spawn(async move {
143                if let Err(err) =
144                    handle_connection(stream, &registry, telemetry.as_ref(), in_flight, job_index)
145                        .await
146                {
147                    tracing::error!("runner connection error: {err}");
148                }
149            });
150        }
151    })
152}
153
154async fn handle_connection<S, T>(
155    stream: S,
156    registry: &Registry,
157    telemetry: &T,
158    in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
159    job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
160) -> Result<(), Box<dyn std::error::Error>>
161where
162    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
163    T: Telemetry + ?Sized,
164{
165    let (mut reader, mut writer) = tokio::io::split(stream);
166    let (response_tx, mut response_rx) =
167        mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
168    let writer_task = tokio::spawn(async move {
169        while let Some(outcome) = response_rx.recv().await {
170            let response = RunnerMessage::Response { payload: outcome };
171            if write_message(&mut writer, &response).await.is_err() {
172                break;
173            }
174        }
175    });
176    let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
177        Arc::new(Mutex::new(std::collections::HashSet::new()));
178
179    loop {
180        let message = match read_message(&mut reader).await? {
181            Some(message) => message,
182            None => break,
183        };
184        match message {
185            RunnerMessage::Request { payload } => {
186                if payload.protocol_version != PROTOCOL_VERSION {
187                    let outcome = ExecutionOutcome::error(
188                        payload.job_id.clone(),
189                        payload.request_id.clone(),
190                        "Unsupported protocol version",
191                    );
192                    let _ = response_tx.send(outcome).await;
193                    continue;
194                }
195
196                let request_id = payload.request_id.clone();
197                let job_id = payload.job_id.clone();
198                {
199                    let mut active = connection_requests.lock().await;
200                    if active.len() >= RESPONSE_CHANNEL_CAPACITY {
201                        let outcome = ExecutionOutcome::error(
202                            payload.job_id.clone(),
203                            payload.request_id.clone(),
204                            "Runner busy: too many in-flight requests",
205                        );
206                        drop(active);
207                        let _ = response_tx.try_send(outcome);
208                        continue;
209                    }
210                    active.insert(request_id.clone());
211                }
212                let response_tx = response_tx.clone();
213                let registry = registry.clone();
214                let telemetry = telemetry.clone_box();
215                let in_flight_for_task = in_flight.clone();
216                let job_index_for_task = job_index.clone();
217                let active_for_task = connection_requests.clone();
218                let request_id_for_task = request_id.clone();
219                let job_id_for_task = job_id.clone();
220                let response_tx_for_task = response_tx.clone();
221                let completed = Arc::new(AtomicBool::new(false));
222                let completed_for_task = completed.clone();
223
224                let handle = tokio::spawn(async move {
225                    let outcome =
226                        execute_with_deadline(payload, registry, telemetry.as_ref()).await;
227                    completed_for_task.store(true, Ordering::SeqCst);
228                    let send_result =
229                        timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
230                    match send_result {
231                        Ok(Ok(())) => {}
232                        Ok(Err(_)) => {
233                            tracing::warn!("runner response channel closed; dropping outcome");
234                        }
235                        Err(_) => {
236                            tracing::warn!("runner response channel stalled; dropping outcome");
237                        }
238                    }
239                    {
240                        let mut in_flight = in_flight_for_task.lock().await;
241                        in_flight.remove(&request_id_for_task);
242                    }
243                    {
244                        let mut job_index = job_index_for_task.lock().await;
245                        if let Some(entries) = job_index.get_mut(&job_id_for_task) {
246                            entries.remove(&request_id_for_task);
247                            if entries.is_empty() {
248                                job_index.remove(&job_id_for_task);
249                            }
250                        }
251                    }
252                    {
253                        let mut active = active_for_task.lock().await;
254                        active.remove(&request_id_for_task);
255                    }
256                });
257
258                {
259                    let mut in_flight = in_flight.lock().await;
260                    in_flight.insert(
261                        request_id.clone(),
262                        InFlightTask {
263                            job_id: job_id.clone(),
264                            handle,
265                            response_tx: response_tx.clone(),
266                            connection_requests: connection_requests.clone(),
267                            completed,
268                        },
269                    );
270                }
271                {
272                    let mut job_index = job_index.lock().await;
273                    job_index
274                        .entry(job_id)
275                        .or_insert_with(HashSet::new)
276                        .insert(request_id);
277                }
278            }
279            RunnerMessage::Cancel { payload } => {
280                handle_cancel(payload, &in_flight, &job_index).await;
281            }
282            RunnerMessage::Response { .. } => {
283                let outcome = ExecutionOutcome {
284                    job_id: Some("unknown".to_string()),
285                    request_id: None,
286                    status: rrq_protocol::OutcomeStatus::Error,
287                    result: None,
288                    error: Some(ExecutionError {
289                        message: "unexpected response message".to_string(),
290                        error_type: None,
291                        code: None,
292                        details: None,
293                    }),
294                    retry_after_seconds: None,
295                };
296                let _ = response_tx.send(outcome).await;
297            }
298        }
299    }
300
301    let request_ids = {
302        let mut active = connection_requests.lock().await;
303        active.drain().collect::<Vec<_>>()
304    };
305    for request_id in request_ids {
306        let task = {
307            let mut in_flight = in_flight.lock().await;
308            in_flight.remove(&request_id)
309        };
310        if let Some(task) = task {
311            task.handle.abort();
312            let mut job_index = job_index.lock().await;
313            if let Some(entries) = job_index.get_mut(&task.job_id) {
314                entries.remove(&request_id);
315                if entries.is_empty() {
316                    job_index.remove(&task.job_id);
317                }
318            }
319        }
320    }
321    writer_task.abort();
322
323    Ok(())
324}
325
326struct InFlightTask {
327    job_id: String,
328    handle: tokio::task::JoinHandle<()>,
329    response_tx: mpsc::Sender<ExecutionOutcome>,
330    connection_requests: Arc<Mutex<HashSet<String>>>,
331    completed: Arc<AtomicBool>,
332}
333
334async fn handle_cancel(
335    payload: CancelRequest,
336    in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
337    job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
338) {
339    if payload.protocol_version != PROTOCOL_VERSION {
340        return;
341    }
342    let request_ids = if let Some(request_id) = payload.request_id.clone() {
343        vec![request_id]
344    } else {
345        let job_index = job_index.lock().await;
346        job_index
347            .get(&payload.job_id)
348            .map(|ids| ids.iter().cloned().collect())
349            .unwrap_or_else(Vec::new)
350    };
351    if request_ids.is_empty() {
352        return;
353    }
354
355    for request_id in request_ids {
356        let task = {
357            let mut in_flight = in_flight.lock().await;
358            if let Some(task) = in_flight.get(&request_id)
359                && task.completed.load(Ordering::SeqCst)
360            {
361                None
362            } else {
363                in_flight.remove(&request_id)
364            }
365        };
366        if let Some(task) = task {
367            task.handle.abort();
368            {
369                let mut active = task.connection_requests.lock().await;
370                active.remove(&request_id);
371            }
372            let outcome = ExecutionOutcome {
373                job_id: Some(payload.job_id.clone()),
374                request_id: Some(request_id.clone()),
375                status: OutcomeStatus::Error,
376                result: None,
377                error: Some(ExecutionError {
378                    message: "Job cancelled".to_string(),
379                    error_type: Some("cancelled".to_string()),
380                    code: None,
381                    details: None,
382                }),
383                retry_after_seconds: None,
384            };
385            let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
386            match send_result {
387                Ok(Ok(())) => {}
388                Ok(Err(_)) => {
389                    tracing::warn!("runner response channel closed; dropping cancel outcome");
390                }
391                Err(_) => {
392                    tracing::warn!("runner response channel stalled; dropping cancel outcome");
393                }
394            }
395            let mut job_index = job_index.lock().await;
396            if let Some(entries) = job_index.get_mut(&task.job_id) {
397                entries.remove(&request_id);
398                if entries.is_empty() {
399                    job_index.remove(&task.job_id);
400                }
401            }
402        }
403    }
404}
405
406async fn execute_with_deadline<T: Telemetry + ?Sized>(
407    request: rrq_protocol::ExecutionRequest,
408    registry: Registry,
409    telemetry: &T,
410) -> ExecutionOutcome {
411    let job_id = request.job_id.clone();
412    let request_id = request.request_id.clone();
413    let deadline = request.context.deadline;
414    if let Some(deadline) = deadline {
415        let now: DateTime<Utc> = Utc::now();
416        if deadline <= now {
417            return ExecutionOutcome::timeout(
418                job_id.clone(),
419                request_id.clone(),
420                "Job deadline exceeded",
421            );
422        }
423        if let Ok(remaining) = (deadline - now).to_std() {
424            match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
425                Ok(outcome) => return outcome,
426                Err(_) => {
427                    return ExecutionOutcome::timeout(
428                        job_id.clone(),
429                        request_id.clone(),
430                        "Job execution timed out",
431                    );
432                }
433            }
434        }
435        return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
436    }
437    registry.execute_with(request, telemetry).await
438}
439
440async fn read_message<R: AsyncRead + Unpin>(
441    stream: &mut R,
442) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
443    let mut header = [0u8; 4];
444    match stream.read_exact(&mut header).await {
445        Ok(_) => {}
446        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
447        Err(err) => return Err(err.into()),
448    }
449    let length = u32::from_be_bytes(header) as usize;
450    if length == 0 {
451        return Err("runner message payload cannot be empty".into());
452    }
453    if length > MAX_FRAME_LEN {
454        return Err("runner message payload too large".into());
455    }
456    let mut payload = vec![0u8; length];
457    stream.read_exact(&mut payload).await?;
458    let message = serde_json::from_slice(&payload)?;
459    Ok(Some(message))
460}
461
462async fn write_message<W: AsyncWrite + Unpin>(
463    stream: &mut W,
464    message: &RunnerMessage,
465) -> Result<(), Box<dyn std::error::Error>> {
466    let framed = encode_frame(message)?;
467    stream.write_all(&framed).await?;
468    stream.flush().await?;
469    Ok(())
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475    use crate::registry::Registry;
476    use crate::telemetry::NoopTelemetry;
477    use chrono::Utc;
478    use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
479    use serde_json::json;
480    use tokio::net::{TcpListener, TcpStream};
481    use tokio::time::{Duration, timeout};
482
483    fn build_request(function_name: &str) -> ExecutionRequest {
484        ExecutionRequest {
485            protocol_version: PROTOCOL_VERSION.to_string(),
486            request_id: "req-1".to_string(),
487            job_id: "job-1".to_string(),
488            function_name: function_name.to_string(),
489            args: vec![],
490            kwargs: std::collections::HashMap::new(),
491            context: ExecutionContext {
492                job_id: "job-1".to_string(),
493                attempt: 1,
494                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
495                queue_name: "default".to_string(),
496                deadline: None,
497                trace_context: None,
498                worker_id: None,
499            },
500        }
501    }
502
503    async fn tcp_pair() -> (TcpStream, TcpStream) {
504        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
505        let addr = listener.local_addr().unwrap();
506        let client = TcpStream::connect(addr).await.unwrap();
507        let (server, _) = listener.accept().await.unwrap();
508        (client, server)
509    }
510
511    #[tokio::test]
512    async fn handle_connection_executes_request() {
513        let mut registry = Registry::new();
514        registry.register("echo", |request| async move {
515            ExecutionOutcome::success(
516                request.job_id.clone(),
517                request.request_id.clone(),
518                json!({"ok": true}),
519            )
520        });
521        let (client, server) = tcp_pair().await;
522        let in_flight = Arc::new(Mutex::new(HashMap::new()));
523        let job_index = Arc::new(Mutex::new(HashMap::new()));
524        let server_task = tokio::spawn(async move {
525            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
526                .await
527                .unwrap();
528        });
529        let mut client = client;
530        let request = build_request("echo");
531        let message = RunnerMessage::Request { payload: request };
532        write_message(&mut client, &message).await.unwrap();
533        let response = read_message(&mut client).await.unwrap().unwrap();
534        match response {
535            RunnerMessage::Response { payload } => {
536                assert_eq!(payload.status, OutcomeStatus::Success);
537                assert_eq!(payload.result, Some(json!({"ok": true})));
538            }
539            _ => panic!("expected response"),
540        }
541        drop(client);
542        let _ = server_task.await;
543    }
544
545    #[tokio::test]
546    async fn handle_connection_rejects_bad_protocol() {
547        let registry = Registry::new();
548        let (client, server) = tcp_pair().await;
549        let in_flight = Arc::new(Mutex::new(HashMap::new()));
550        let job_index = Arc::new(Mutex::new(HashMap::new()));
551        let server_task = tokio::spawn(async move {
552            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
553                .await
554                .unwrap();
555        });
556        let mut client = client;
557        let mut request = build_request("echo");
558        request.protocol_version = "0".to_string();
559        let message = RunnerMessage::Request { payload: request };
560        write_message(&mut client, &message).await.unwrap();
561        let response = read_message(&mut client).await.unwrap().unwrap();
562        match response {
563            RunnerMessage::Response { payload } => {
564                assert_eq!(payload.status, OutcomeStatus::Error);
565            }
566            _ => panic!("expected response"),
567        }
568        drop(client);
569        let _ = server_task.await;
570    }
571
572    #[tokio::test]
573    async fn handle_connection_cancels_inflight() {
574        let mut registry = Registry::new();
575        registry.register("sleep", |request| async move {
576            tokio::time::sleep(Duration::from_millis(200)).await;
577            ExecutionOutcome::success(
578                request.job_id.clone(),
579                request.request_id.clone(),
580                json!({"ok": true}),
581            )
582        });
583        let (client, server) = tcp_pair().await;
584        let in_flight = Arc::new(Mutex::new(HashMap::new()));
585        let job_index = Arc::new(Mutex::new(HashMap::new()));
586        let server_task = tokio::spawn(async move {
587            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
588                .await
589                .unwrap();
590        });
591        let mut client = client;
592        let request = ExecutionRequest {
593            protocol_version: PROTOCOL_VERSION.to_string(),
594            request_id: "req-cancel".to_string(),
595            job_id: "job-cancel".to_string(),
596            function_name: "sleep".to_string(),
597            args: vec![],
598            kwargs: std::collections::HashMap::new(),
599            context: ExecutionContext {
600                job_id: "job-cancel".to_string(),
601                attempt: 1,
602                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
603                queue_name: "default".to_string(),
604                deadline: None,
605                trace_context: None,
606                worker_id: None,
607            },
608        };
609        let message = RunnerMessage::Request {
610            payload: request.clone(),
611        };
612        write_message(&mut client, &message).await.unwrap();
613        let cancel = RunnerMessage::Cancel {
614            payload: CancelRequest {
615                protocol_version: PROTOCOL_VERSION.to_string(),
616                job_id: request.job_id.clone(),
617                request_id: Some(request.request_id.clone()),
618                hard_kill: false,
619            },
620        };
621        write_message(&mut client, &cancel).await.unwrap();
622        let response = read_message(&mut client).await.unwrap().unwrap();
623        match response {
624            RunnerMessage::Response { payload } => {
625                assert_eq!(payload.status, OutcomeStatus::Error);
626                let error_type = payload
627                    .error
628                    .as_ref()
629                    .and_then(|error| error.error_type.as_deref());
630                assert_eq!(error_type, Some("cancelled"));
631            }
632            _ => panic!("expected response"),
633        }
634        drop(client);
635        let _ = server_task.await;
636    }
637
638    #[tokio::test]
639    async fn cancel_frees_connection_capacity() {
640        let mut registry = Registry::new();
641        let gate = Arc::new(tokio::sync::Semaphore::new(0));
642        let gate_for_handler = gate.clone();
643        registry.register("block", move |request| {
644            let gate = gate_for_handler.clone();
645            async move {
646                let _permit = gate.acquire().await.expect("semaphore closed");
647                ExecutionOutcome::success(
648                    request.job_id.clone(),
649                    request.request_id.clone(),
650                    json!({"ok": true}),
651                )
652            }
653        });
654        let (client, server) = tcp_pair().await;
655        let in_flight = Arc::new(Mutex::new(HashMap::new()));
656        let job_index = Arc::new(Mutex::new(HashMap::new()));
657        let server_task = tokio::spawn(async move {
658            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
659                .await
660                .unwrap();
661        });
662        let mut client = client;
663        let job_id = "job-capacity".to_string();
664        for i in 0..RESPONSE_CHANNEL_CAPACITY {
665            let mut request = build_request("block");
666            request.request_id = format!("req-{i}");
667            request.job_id = job_id.clone();
668            request.context.job_id = job_id.clone();
669            write_message(&mut client, &RunnerMessage::Request { payload: request })
670                .await
671                .unwrap();
672        }
673
674        let cancel = RunnerMessage::Cancel {
675            payload: CancelRequest {
676                protocol_version: PROTOCOL_VERSION.to_string(),
677                job_id: job_id.clone(),
678                request_id: Some("req-0".to_string()),
679                hard_kill: false,
680            },
681        };
682        write_message(&mut client, &cancel).await.unwrap();
683        let response = timeout(Duration::from_secs(1), read_message(&mut client))
684            .await
685            .unwrap()
686            .unwrap()
687            .unwrap();
688        match response {
689            RunnerMessage::Response { payload } => {
690                assert_eq!(payload.status, OutcomeStatus::Error);
691                let error_type = payload
692                    .error
693                    .as_ref()
694                    .and_then(|error| error.error_type.as_deref());
695                assert_eq!(error_type, Some("cancelled"));
696            }
697            _ => panic!("expected response"),
698        }
699
700        let mut extra_request = build_request("block");
701        extra_request.request_id = "req-extra".to_string();
702        extra_request.job_id = job_id.clone();
703        extra_request.context.job_id = job_id.clone();
704        write_message(
705            &mut client,
706            &RunnerMessage::Request {
707                payload: extra_request,
708            },
709        )
710        .await
711        .unwrap();
712
713        gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
714
715        let mut saw_extra = false;
716        for _ in 0..RESPONSE_CHANNEL_CAPACITY {
717            let response = timeout(Duration::from_secs(1), read_message(&mut client))
718                .await
719                .unwrap()
720                .unwrap()
721                .unwrap();
722            if let RunnerMessage::Response { payload } = response
723                && payload.request_id.as_deref() == Some("req-extra")
724            {
725                assert_eq!(payload.status, OutcomeStatus::Success);
726                saw_extra = true;
727            }
728        }
729        assert!(saw_extra, "extra request never completed");
730
731        drop(client);
732        let _ = server_task.await;
733    }
734
735    #[tokio::test]
736    async fn execute_with_deadline_times_out() {
737        let mut registry = Registry::new();
738        registry.register("echo", |request| async move {
739            ExecutionOutcome::success(
740                request.job_id.clone(),
741                request.request_id.clone(),
742                json!({"ok": true}),
743            )
744        });
745        let mut request = build_request("echo");
746        request.context.deadline = Some(
747            "2020-01-01T00:00:00Z"
748                .parse::<chrono::DateTime<Utc>>()
749                .unwrap(),
750        );
751        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
752        assert_eq!(outcome.status, OutcomeStatus::Timeout);
753    }
754
755    #[tokio::test]
756    async fn execute_with_deadline_succeeds_before_deadline() {
757        let mut registry = Registry::new();
758        registry.register("echo", |request| async move {
759            ExecutionOutcome::success(
760                request.job_id.clone(),
761                request.request_id.clone(),
762                json!({"ok": true}),
763            )
764        });
765        let mut request = build_request("echo");
766        request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
767        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
768        assert_eq!(outcome.status, OutcomeStatus::Success);
769    }
770
771    #[tokio::test]
772    async fn handle_connection_handles_unexpected_response_message() {
773        let registry = Registry::new();
774        let (client, server) = tcp_pair().await;
775        let in_flight = Arc::new(Mutex::new(HashMap::new()));
776        let job_index = Arc::new(Mutex::new(HashMap::new()));
777        let server_task = tokio::spawn(async move {
778            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
779                .await
780                .unwrap();
781        });
782        let mut client = client;
783        let response = RunnerMessage::Response {
784            payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
785        };
786        write_message(&mut client, &response).await.unwrap();
787        let reply = read_message(&mut client).await.unwrap().unwrap();
788        match reply {
789            RunnerMessage::Response { payload } => {
790                assert_eq!(payload.status, OutcomeStatus::Error);
791                assert!(
792                    payload
793                        .error
794                        .as_ref()
795                        .unwrap()
796                        .message
797                        .contains("unexpected response")
798                );
799            }
800            _ => panic!("expected response"),
801        }
802        drop(client);
803        let _ = server_task.await;
804    }
805
806    #[tokio::test]
807    async fn handle_connection_cancels_by_job_id() {
808        let mut registry = Registry::new();
809        registry.register("sleep", |request| async move {
810            tokio::time::sleep(Duration::from_millis(200)).await;
811            ExecutionOutcome::success(
812                request.job_id.clone(),
813                request.request_id.clone(),
814                json!({"ok": true}),
815            )
816        });
817        let (client, server) = tcp_pair().await;
818        let in_flight = Arc::new(Mutex::new(HashMap::new()));
819        let job_index = Arc::new(Mutex::new(HashMap::new()));
820        let server_task = tokio::spawn(async move {
821            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
822                .await
823                .unwrap();
824        });
825        let mut client = client;
826        let request = build_request("sleep");
827        let message = RunnerMessage::Request {
828            payload: request.clone(),
829        };
830        write_message(&mut client, &message).await.unwrap();
831        let cancel = RunnerMessage::Cancel {
832            payload: CancelRequest {
833                protocol_version: PROTOCOL_VERSION.to_string(),
834                job_id: request.job_id.clone(),
835                request_id: None,
836                hard_kill: false,
837            },
838        };
839        write_message(&mut client, &cancel).await.unwrap();
840        let response = read_message(&mut client).await.unwrap().unwrap();
841        match response {
842            RunnerMessage::Response { payload } => {
843                assert_eq!(payload.status, OutcomeStatus::Error);
844                let error_type = payload
845                    .error
846                    .as_ref()
847                    .and_then(|error| error.error_type.as_deref());
848                assert_eq!(error_type, Some("cancelled"));
849            }
850            _ => panic!("expected response"),
851        }
852        drop(client);
853        let _ = server_task.await;
854    }
855
856    #[tokio::test]
857    async fn handle_cancel_by_job_id_cancels_all_requests() {
858        let mut registry = Registry::new();
859        registry.register("sleep", |request| async move {
860            tokio::time::sleep(Duration::from_millis(200)).await;
861            ExecutionOutcome::success(
862                request.job_id.clone(),
863                request.request_id.clone(),
864                json!({"ok": true}),
865            )
866        });
867        let (client, server) = tcp_pair().await;
868        let in_flight = Arc::new(Mutex::new(HashMap::new()));
869        let job_index = Arc::new(Mutex::new(HashMap::new()));
870        let server_task = tokio::spawn(async move {
871            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
872                .await
873                .unwrap();
874        });
875        let mut client = client;
876        let mut request1 = build_request("sleep");
877        request1.request_id = "req-1".to_string();
878        request1.job_id = "job-shared".to_string();
879        let mut request2 = build_request("sleep");
880        request2.request_id = "req-2".to_string();
881        request2.job_id = "job-shared".to_string();
882        write_message(&mut client, &RunnerMessage::Request { payload: request1 })
883            .await
884            .unwrap();
885        write_message(&mut client, &RunnerMessage::Request { payload: request2 })
886            .await
887            .unwrap();
888        let cancel = RunnerMessage::Cancel {
889            payload: CancelRequest {
890                protocol_version: PROTOCOL_VERSION.to_string(),
891                job_id: "job-shared".to_string(),
892                request_id: None,
893                hard_kill: false,
894            },
895        };
896        write_message(&mut client, &cancel).await.unwrap();
897
898        let mut cancelled = 0;
899        for _ in 0..2 {
900            let response = timeout(Duration::from_millis(200), read_message(&mut client))
901                .await
902                .unwrap()
903                .unwrap()
904                .unwrap();
905            match response {
906                RunnerMessage::Response { payload } => {
907                    assert_eq!(payload.status, OutcomeStatus::Error);
908                    let error_type = payload
909                        .error
910                        .as_ref()
911                        .and_then(|error| error.error_type.as_deref());
912                    assert_eq!(error_type, Some("cancelled"));
913                    cancelled += 1;
914                }
915                _ => panic!("expected response"),
916            }
917        }
918        assert_eq!(cancelled, 2);
919        drop(client);
920        let _ = server_task.await;
921    }
922
923    #[tokio::test]
924    async fn connection_teardown_clears_tracking_maps() {
925        let mut registry = Registry::new();
926        registry.register("sleep", |request| async move {
927            tokio::time::sleep(Duration::from_millis(200)).await;
928            ExecutionOutcome::success(
929                request.job_id.clone(),
930                request.request_id.clone(),
931                json!({"ok": true}),
932            )
933        });
934        let (client, server) = tcp_pair().await;
935        let in_flight = Arc::new(Mutex::new(HashMap::new()));
936        let job_index = Arc::new(Mutex::new(HashMap::new()));
937        let in_flight_for_server = in_flight.clone();
938        let job_index_for_server = job_index.clone();
939        let server_task = tokio::spawn(async move {
940            handle_connection(
941                server,
942                &registry,
943                &NoopTelemetry,
944                in_flight_for_server,
945                job_index_for_server,
946            )
947            .await
948            .unwrap();
949        });
950        let mut client = client;
951        let request = build_request("sleep");
952        let message = RunnerMessage::Request {
953            payload: request.clone(),
954        };
955        write_message(&mut client, &message).await.unwrap();
956
957        let mut inserted = false;
958        for _ in 0..20 {
959            let has_in_flight = {
960                let guard = in_flight.lock().await;
961                guard.contains_key(&request.request_id)
962            };
963            let has_job_index = {
964                let guard = job_index.lock().await;
965                guard.contains_key(&request.job_id)
966            };
967            if has_in_flight && has_job_index {
968                inserted = true;
969                break;
970            }
971            tokio::time::sleep(Duration::from_millis(10)).await;
972        }
973        assert!(inserted, "request never entered tracking maps");
974
975        drop(client);
976        let _ = server_task.await;
977
978        let in_flight = in_flight.lock().await;
979        let job_index = job_index.lock().await;
980        assert!(in_flight.is_empty());
981        assert!(job_index.is_empty());
982    }
983
984    #[tokio::test]
985    async fn handle_cancel_ignores_invalid_protocol() {
986        let in_flight = Arc::new(Mutex::new(HashMap::new()));
987        let job_index = Arc::new(Mutex::new(HashMap::new()));
988        let (tx, _rx) = mpsc::channel(1);
989        let handle = tokio::spawn(async {});
990        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
991        {
992            let mut guard = in_flight.lock().await;
993            guard.insert(
994                "req-1".to_string(),
995                InFlightTask {
996                    job_id: "job-1".to_string(),
997                    handle,
998                    response_tx: tx,
999                    connection_requests,
1000                    completed: Arc::new(AtomicBool::new(false)),
1001                },
1002            );
1003        }
1004        let payload = CancelRequest {
1005            protocol_version: "0".to_string(),
1006            job_id: "job-1".to_string(),
1007            request_id: None,
1008            hard_kill: false,
1009        };
1010        handle_cancel(payload, &in_flight, &job_index).await;
1011        let guard = in_flight.lock().await;
1012        assert!(guard.contains_key("req-1"));
1013        guard.get("req-1").unwrap().handle.abort();
1014    }
1015
1016    #[tokio::test]
1017    async fn handle_cancel_skips_completed_requests() {
1018        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1019        let job_index = Arc::new(Mutex::new(HashMap::new()));
1020        let (tx, mut rx) = mpsc::channel(1);
1021        let handle = tokio::spawn(async {
1022            tokio::time::sleep(Duration::from_millis(50)).await;
1023        });
1024        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1025        {
1026            let mut guard = in_flight.lock().await;
1027            guard.insert(
1028                "req-1".to_string(),
1029                InFlightTask {
1030                    job_id: "job-1".to_string(),
1031                    handle,
1032                    response_tx: tx,
1033                    connection_requests,
1034                    completed: Arc::new(AtomicBool::new(true)),
1035                },
1036            );
1037        }
1038        {
1039            let mut guard = job_index.lock().await;
1040            guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1041        }
1042        let payload = CancelRequest {
1043            protocol_version: PROTOCOL_VERSION.to_string(),
1044            job_id: "job-1".to_string(),
1045            request_id: Some("req-1".to_string()),
1046            hard_kill: false,
1047        };
1048        handle_cancel(payload, &in_flight, &job_index).await;
1049        assert!(in_flight.lock().await.contains_key("req-1"));
1050        assert!(job_index.lock().await.contains_key("job-1"));
1051        assert!(rx.try_recv().is_err());
1052        let task = in_flight.lock().await.remove("req-1").unwrap();
1053        task.handle.abort();
1054    }
1055
1056    #[tokio::test]
1057    async fn read_message_handles_empty_and_invalid_payloads() {
1058        let (mut client, mut server) = tokio::io::duplex(64);
1059        // length = 0
1060        client.write_all(&0u32.to_be_bytes()).await.unwrap();
1061        let err = read_message(&mut server).await.unwrap_err();
1062        assert!(err.to_string().contains("payload cannot be empty"));
1063
1064        // invalid json
1065        let (mut client, mut server) = tokio::io::duplex(64);
1066        let payload = b"not-json";
1067        let len = (payload.len() as u32).to_be_bytes();
1068        client.write_all(&len).await.unwrap();
1069        client.write_all(payload).await.unwrap();
1070        let err = read_message(&mut server).await.unwrap_err();
1071        assert!(err.to_string().contains("expected"));
1072
1073        // oversized payload
1074        let (mut client, mut server) = tokio::io::duplex(64);
1075        let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1076        client.write_all(&len).await.unwrap();
1077        let err = read_message(&mut server).await.unwrap_err();
1078        assert!(err.to_string().contains("payload too large"));
1079    }
1080
1081    #[tokio::test]
1082    async fn read_message_returns_none_on_eof() {
1083        let (client, mut server) = tokio::io::duplex(8);
1084        drop(client);
1085        let message = read_message(&mut server).await.unwrap();
1086        assert!(message.is_none());
1087    }
1088}