Skip to main content

rrq_runner/
runtime.rs

1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9    Arc,
10    atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
17const RESPONSE_CHANNEL_CAPACITY: usize = 64;
18const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
19
20fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
21    Box::new(std::io::Error::new(
22        std::io::ErrorKind::InvalidInput,
23        message.into(),
24    ))
25}
26
27pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
28    let raw = raw.trim();
29    if raw.is_empty() {
30        return Err(invalid_input("runner tcp_socket cannot be empty"));
31    }
32
33    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
34        let (host, port_str) = rest
35            .split_once("]:")
36            .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
37        (host, port_str)
38    } else {
39        let (host, port_str) = raw
40            .rsplit_once(':')
41            .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
42        if host.is_empty() {
43            return Err(invalid_input("runner tcp_socket host cannot be empty"));
44        }
45        (host, port_str)
46    };
47
48    let port: u16 = port_str
49        .parse()
50        .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
51    if port == 0 {
52        return Err(invalid_input("runner tcp_socket port must be > 0"));
53    }
54
55    let ip = if host == "localhost" {
56        IpAddr::V4(Ipv4Addr::LOCALHOST)
57    } else {
58        let parsed: IpAddr = host
59            .parse()
60            .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
61        if !parsed.is_loopback() {
62            return Err(invalid_input("runner tcp_socket host must be localhost"));
63        }
64        parsed
65    };
66
67    Ok(SocketAddr::new(ip, port))
68}
69
70pub struct RunnerRuntime {
71    runtime: tokio::runtime::Runtime,
72}
73
74impl RunnerRuntime {
75    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
76        Ok(Self {
77            runtime: tokio::runtime::Runtime::new()?,
78        })
79    }
80
81    pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
82        self.runtime.enter()
83    }
84
85    pub fn run_tcp(
86        &self,
87        registry: &Registry,
88        addr: SocketAddr,
89    ) -> Result<(), Box<dyn std::error::Error>> {
90        let telemetry = NoopTelemetry;
91        self.run_tcp_with(registry, addr, &telemetry)
92    }
93
94    pub fn run_tcp_with<T: Telemetry + ?Sized>(
95        &self,
96        registry: &Registry,
97        addr: SocketAddr,
98        telemetry: &T,
99    ) -> Result<(), Box<dyn std::error::Error>> {
100        run_tcp_loop(&self.runtime, registry, addr, telemetry)
101    }
102}
103
104pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
105    RunnerRuntime::new()?.run_tcp(registry, addr)
106}
107
108pub fn run_tcp_with<T: Telemetry + ?Sized>(
109    registry: &Registry,
110    addr: SocketAddr,
111    telemetry: &T,
112) -> Result<(), Box<dyn std::error::Error>> {
113    RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
114}
115
116fn run_tcp_loop<T: Telemetry + ?Sized>(
117    runtime: &tokio::runtime::Runtime,
118    registry: &Registry,
119    addr: SocketAddr,
120    telemetry: &T,
121) -> Result<(), Box<dyn std::error::Error>> {
122    let registry = registry.clone();
123    let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
124    let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
125        Arc::new(Mutex::new(HashMap::new()));
126    let telemetry = telemetry.clone_box();
127    runtime.block_on(async move {
128        if !addr.ip().is_loopback() {
129            return Err(invalid_input(format!(
130                "runner tcp_socket must be loopback-only (got {addr})"
131            )));
132        }
133        let listener = TcpListener::bind(addr).await?;
134        loop {
135            let (stream, _) = listener.accept().await?;
136            let registry = registry.clone();
137            let telemetry = telemetry.clone();
138            let in_flight = in_flight.clone();
139            let job_index = job_index.clone();
140            tokio::spawn(async move {
141                if let Err(err) =
142                    handle_connection(stream, &registry, telemetry.as_ref(), in_flight, job_index)
143                        .await
144                {
145                    tracing::error!("runner connection error: {err}");
146                }
147            });
148        }
149    })
150}
151
152async fn handle_connection<S, T>(
153    stream: S,
154    registry: &Registry,
155    telemetry: &T,
156    in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
157    job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
158) -> Result<(), Box<dyn std::error::Error>>
159where
160    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
161    T: Telemetry + ?Sized,
162{
163    let (mut reader, mut writer) = tokio::io::split(stream);
164    let (response_tx, mut response_rx) =
165        mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
166    let writer_task = tokio::spawn(async move {
167        while let Some(outcome) = response_rx.recv().await {
168            let response = RunnerMessage::Response { payload: outcome };
169            if write_message(&mut writer, &response).await.is_err() {
170                break;
171            }
172        }
173    });
174    let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
175        Arc::new(Mutex::new(std::collections::HashSet::new()));
176
177    loop {
178        let message = match read_message(&mut reader).await? {
179            Some(message) => message,
180            None => break,
181        };
182        match message {
183            RunnerMessage::Request { payload } => {
184                if payload.protocol_version != PROTOCOL_VERSION {
185                    let outcome = ExecutionOutcome::error(
186                        payload.job_id.clone(),
187                        payload.request_id.clone(),
188                        "Unsupported protocol version",
189                    );
190                    let _ = response_tx.send(outcome).await;
191                    continue;
192                }
193
194                let request_id = payload.request_id.clone();
195                let job_id = payload.job_id.clone();
196                {
197                    let mut active = connection_requests.lock().await;
198                    if active.len() >= RESPONSE_CHANNEL_CAPACITY {
199                        let outcome = ExecutionOutcome::error(
200                            payload.job_id.clone(),
201                            payload.request_id.clone(),
202                            "Runner busy: too many in-flight requests",
203                        );
204                        drop(active);
205                        let send_result =
206                            timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
207                        match send_result {
208                            Ok(Ok(())) => {}
209                            Ok(Err(_)) => {
210                                return Err("runner response channel closed".into());
211                            }
212                            Err(_) => {
213                                return Err("runner response channel stalled".into());
214                            }
215                        }
216                        continue;
217                    }
218                    active.insert(request_id.clone());
219                    crate::telemetry::record_runner_channel_pressure(active.len());
220                }
221                let response_tx = response_tx.clone();
222                let registry = registry.clone();
223                let telemetry = telemetry.clone_box();
224                let in_flight_for_task = in_flight.clone();
225                let job_index_for_task = job_index.clone();
226                let active_for_task = connection_requests.clone();
227                let request_id_for_task = request_id.clone();
228                let job_id_for_task = job_id.clone();
229                let response_tx_for_task = response_tx.clone();
230                let completed = Arc::new(AtomicBool::new(false));
231                let completed_for_task = completed.clone();
232
233                let handle = tokio::spawn(async move {
234                    let outcome =
235                        execute_with_deadline(payload, registry, telemetry.as_ref()).await;
236                    completed_for_task.store(true, Ordering::SeqCst);
237                    let send_result =
238                        timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
239                    match send_result {
240                        Ok(Ok(())) => {}
241                        Ok(Err(_)) => {
242                            tracing::warn!("runner response channel closed; dropping outcome");
243                        }
244                        Err(_) => {
245                            tracing::warn!("runner response channel stalled; dropping outcome");
246                        }
247                    }
248                    {
249                        let mut in_flight = in_flight_for_task.lock().await;
250                        if in_flight.remove(&request_id_for_task).is_some() {
251                            crate::telemetry::record_runner_inflight_delta(-1);
252                        }
253                    }
254                    {
255                        let mut job_index = job_index_for_task.lock().await;
256                        if let Some(entries) = job_index.get_mut(&job_id_for_task) {
257                            entries.remove(&request_id_for_task);
258                            if entries.is_empty() {
259                                job_index.remove(&job_id_for_task);
260                            }
261                        }
262                    }
263                    {
264                        let mut active = active_for_task.lock().await;
265                        active.remove(&request_id_for_task);
266                        crate::telemetry::record_runner_channel_pressure(active.len());
267                    }
268                });
269
270                {
271                    let mut in_flight = in_flight.lock().await;
272                    in_flight.insert(
273                        request_id.clone(),
274                        InFlightTask {
275                            job_id: job_id.clone(),
276                            handle,
277                            response_tx: response_tx.clone(),
278                            connection_requests: connection_requests.clone(),
279                            completed,
280                        },
281                    );
282                }
283                crate::telemetry::record_runner_inflight_delta(1);
284                {
285                    let mut job_index = job_index.lock().await;
286                    job_index
287                        .entry(job_id)
288                        .or_insert_with(HashSet::new)
289                        .insert(request_id);
290                }
291            }
292            RunnerMessage::Cancel { payload } => {
293                handle_cancel(payload, &in_flight, &job_index).await;
294            }
295            RunnerMessage::Response { .. } => {
296                let outcome = ExecutionOutcome {
297                    job_id: Some("unknown".to_string()),
298                    request_id: None,
299                    status: rrq_protocol::OutcomeStatus::Error,
300                    result: None,
301                    error: Some(ExecutionError {
302                        message: "unexpected response message".to_string(),
303                        error_type: None,
304                        code: None,
305                        details: None,
306                    }),
307                    retry_after_seconds: None,
308                };
309                let _ = response_tx.send(outcome).await;
310            }
311        }
312    }
313
314    let request_ids = {
315        let mut active = connection_requests.lock().await;
316        active.drain().collect::<Vec<_>>()
317    };
318    crate::telemetry::record_runner_channel_pressure(0);
319    for request_id in request_ids {
320        let task = {
321            let mut in_flight = in_flight.lock().await;
322            in_flight.remove(&request_id)
323        };
324        if let Some(task) = task {
325            task.handle.abort();
326            crate::telemetry::record_runner_inflight_delta(-1);
327            let mut job_index = job_index.lock().await;
328            if let Some(entries) = job_index.get_mut(&task.job_id) {
329                entries.remove(&request_id);
330                if entries.is_empty() {
331                    job_index.remove(&task.job_id);
332                }
333            }
334        }
335    }
336    writer_task.abort();
337
338    Ok(())
339}
340
341struct InFlightTask {
342    job_id: String,
343    handle: tokio::task::JoinHandle<()>,
344    response_tx: mpsc::Sender<ExecutionOutcome>,
345    connection_requests: Arc<Mutex<HashSet<String>>>,
346    completed: Arc<AtomicBool>,
347}
348
349async fn handle_cancel(
350    payload: CancelRequest,
351    in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
352    job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
353) {
354    if payload.protocol_version != PROTOCOL_VERSION {
355        return;
356    }
357    let request_ids = if let Some(request_id) = payload.request_id.clone() {
358        vec![request_id]
359    } else {
360        let job_index = job_index.lock().await;
361        job_index
362            .get(&payload.job_id)
363            .map(|ids| ids.iter().cloned().collect())
364            .unwrap_or_else(Vec::new)
365    };
366    if request_ids.is_empty() {
367        return;
368    }
369    let scope = if payload.request_id.is_some() {
370        "request"
371    } else {
372        "job"
373    };
374    crate::telemetry::record_cancellation(scope);
375
376    for request_id in request_ids {
377        let task = {
378            let mut in_flight = in_flight.lock().await;
379            if let Some(task) = in_flight.get(&request_id)
380                && task.completed.load(Ordering::SeqCst)
381            {
382                None
383            } else {
384                in_flight.remove(&request_id)
385            }
386        };
387        if let Some(task) = task {
388            task.handle.abort();
389            crate::telemetry::record_runner_inflight_delta(-1);
390            {
391                let mut active = task.connection_requests.lock().await;
392                active.remove(&request_id);
393                crate::telemetry::record_runner_channel_pressure(active.len());
394            }
395            let outcome = ExecutionOutcome {
396                job_id: Some(payload.job_id.clone()),
397                request_id: Some(request_id.clone()),
398                status: OutcomeStatus::Error,
399                result: None,
400                error: Some(ExecutionError {
401                    message: "Job cancelled".to_string(),
402                    error_type: Some("cancelled".to_string()),
403                    code: None,
404                    details: None,
405                }),
406                retry_after_seconds: None,
407            };
408            let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
409            match send_result {
410                Ok(Ok(())) => {}
411                Ok(Err(_)) => {
412                    tracing::warn!("runner response channel closed; dropping cancel outcome");
413                }
414                Err(_) => {
415                    tracing::warn!("runner response channel stalled; dropping cancel outcome");
416                }
417            }
418            let mut job_index = job_index.lock().await;
419            if let Some(entries) = job_index.get_mut(&task.job_id) {
420                entries.remove(&request_id);
421                if entries.is_empty() {
422                    job_index.remove(&task.job_id);
423                }
424            }
425        }
426    }
427}
428
429async fn execute_with_deadline<T: Telemetry + ?Sized>(
430    request: rrq_protocol::ExecutionRequest,
431    registry: Registry,
432    telemetry: &T,
433) -> ExecutionOutcome {
434    let job_id = request.job_id.clone();
435    let request_id = request.request_id.clone();
436    let deadline = request.context.deadline;
437    if let Some(deadline) = deadline {
438        let now: DateTime<Utc> = Utc::now();
439        if deadline <= now {
440            crate::telemetry::record_deadline_expired();
441            return ExecutionOutcome::timeout(
442                job_id.clone(),
443                request_id.clone(),
444                "Job deadline exceeded",
445            );
446        }
447        if let Ok(remaining) = (deadline - now).to_std() {
448            match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
449                Ok(outcome) => return outcome,
450                Err(_) => {
451                    crate::telemetry::record_deadline_expired();
452                    return ExecutionOutcome::timeout(
453                        job_id.clone(),
454                        request_id.clone(),
455                        "Job execution timed out",
456                    );
457                }
458            }
459        }
460        crate::telemetry::record_deadline_expired();
461        return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
462    }
463    registry.execute_with(request, telemetry).await
464}
465
466async fn read_message<R: AsyncRead + Unpin>(
467    stream: &mut R,
468) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
469    let mut header = [0u8; 4];
470    match stream.read_exact(&mut header).await {
471        Ok(_) => {}
472        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
473        Err(err) => return Err(err.into()),
474    }
475    let length = u32::from_be_bytes(header) as usize;
476    if length == 0 {
477        return Err("runner message payload cannot be empty".into());
478    }
479    if length > MAX_FRAME_LEN {
480        return Err("runner message payload too large".into());
481    }
482    let mut payload = vec![0u8; length];
483    stream.read_exact(&mut payload).await?;
484    let message = serde_json::from_slice(&payload)?;
485    Ok(Some(message))
486}
487
488async fn write_message<W: AsyncWrite + Unpin>(
489    stream: &mut W,
490    message: &RunnerMessage,
491) -> Result<(), Box<dyn std::error::Error>> {
492    let framed = encode_frame(message)?;
493    stream.write_all(&framed).await?;
494    stream.flush().await?;
495    Ok(())
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use crate::registry::Registry;
502    use crate::telemetry::NoopTelemetry;
503    use chrono::Utc;
504    use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
505    use serde_json::json;
506    use tokio::net::{TcpListener, TcpStream};
507    use tokio::time::{Duration, timeout};
508
509    fn build_request(function_name: &str) -> ExecutionRequest {
510        ExecutionRequest {
511            protocol_version: PROTOCOL_VERSION.to_string(),
512            request_id: "req-1".to_string(),
513            job_id: "job-1".to_string(),
514            function_name: function_name.to_string(),
515            params: std::collections::HashMap::new(),
516            context: ExecutionContext {
517                job_id: "job-1".to_string(),
518                attempt: 1,
519                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
520                queue_name: "default".to_string(),
521                deadline: None,
522                trace_context: None,
523                worker_id: None,
524            },
525        }
526    }
527
528    async fn tcp_pair() -> (TcpStream, TcpStream) {
529        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
530        let addr = listener.local_addr().unwrap();
531        let client = TcpStream::connect(addr).await.unwrap();
532        let (server, _) = listener.accept().await.unwrap();
533        (client, server)
534    }
535
536    #[tokio::test]
537    async fn handle_connection_executes_request() {
538        let mut registry = Registry::new();
539        registry.register("echo", |request| async move {
540            ExecutionOutcome::success(
541                request.job_id.clone(),
542                request.request_id.clone(),
543                json!({"ok": true}),
544            )
545        });
546        let (client, server) = tcp_pair().await;
547        let in_flight = Arc::new(Mutex::new(HashMap::new()));
548        let job_index = Arc::new(Mutex::new(HashMap::new()));
549        let server_task = tokio::spawn(async move {
550            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
551                .await
552                .unwrap();
553        });
554        let mut client = client;
555        let request = build_request("echo");
556        let message = RunnerMessage::Request { payload: request };
557        write_message(&mut client, &message).await.unwrap();
558        let response = read_message(&mut client).await.unwrap().unwrap();
559        match response {
560            RunnerMessage::Response { payload } => {
561                assert_eq!(payload.status, OutcomeStatus::Success);
562                assert_eq!(payload.result, Some(json!({"ok": true})));
563            }
564            _ => panic!("expected response"),
565        }
566        drop(client);
567        let _ = server_task.await;
568    }
569
570    #[tokio::test]
571    async fn handle_connection_rejects_bad_protocol() {
572        let registry = Registry::new();
573        let (client, server) = tcp_pair().await;
574        let in_flight = Arc::new(Mutex::new(HashMap::new()));
575        let job_index = Arc::new(Mutex::new(HashMap::new()));
576        let server_task = tokio::spawn(async move {
577            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
578                .await
579                .unwrap();
580        });
581        let mut client = client;
582        let mut request = build_request("echo");
583        request.protocol_version = "0".to_string();
584        let message = RunnerMessage::Request { payload: request };
585        write_message(&mut client, &message).await.unwrap();
586        let response = read_message(&mut client).await.unwrap().unwrap();
587        match response {
588            RunnerMessage::Response { payload } => {
589                assert_eq!(payload.status, OutcomeStatus::Error);
590            }
591            _ => panic!("expected response"),
592        }
593        drop(client);
594        let _ = server_task.await;
595    }
596
597    #[tokio::test]
598    async fn handle_connection_cancels_inflight() {
599        let mut registry = Registry::new();
600        registry.register("sleep", |request| async move {
601            tokio::time::sleep(Duration::from_millis(200)).await;
602            ExecutionOutcome::success(
603                request.job_id.clone(),
604                request.request_id.clone(),
605                json!({"ok": true}),
606            )
607        });
608        let (client, server) = tcp_pair().await;
609        let in_flight = Arc::new(Mutex::new(HashMap::new()));
610        let job_index = Arc::new(Mutex::new(HashMap::new()));
611        let server_task = tokio::spawn(async move {
612            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
613                .await
614                .unwrap();
615        });
616        let mut client = client;
617        let request = ExecutionRequest {
618            protocol_version: PROTOCOL_VERSION.to_string(),
619            request_id: "req-cancel".to_string(),
620            job_id: "job-cancel".to_string(),
621            function_name: "sleep".to_string(),
622            params: std::collections::HashMap::new(),
623            context: ExecutionContext {
624                job_id: "job-cancel".to_string(),
625                attempt: 1,
626                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
627                queue_name: "default".to_string(),
628                deadline: None,
629                trace_context: None,
630                worker_id: None,
631            },
632        };
633        let message = RunnerMessage::Request {
634            payload: request.clone(),
635        };
636        write_message(&mut client, &message).await.unwrap();
637        let cancel = RunnerMessage::Cancel {
638            payload: CancelRequest {
639                protocol_version: PROTOCOL_VERSION.to_string(),
640                job_id: request.job_id.clone(),
641                request_id: Some(request.request_id.clone()),
642                hard_kill: false,
643            },
644        };
645        write_message(&mut client, &cancel).await.unwrap();
646        let response = read_message(&mut client).await.unwrap().unwrap();
647        match response {
648            RunnerMessage::Response { payload } => {
649                assert_eq!(payload.status, OutcomeStatus::Error);
650                let error_type = payload
651                    .error
652                    .as_ref()
653                    .and_then(|error| error.error_type.as_deref());
654                assert_eq!(error_type, Some("cancelled"));
655            }
656            _ => panic!("expected response"),
657        }
658        drop(client);
659        let _ = server_task.await;
660    }
661
662    #[tokio::test]
663    async fn cancel_frees_connection_capacity() {
664        let mut registry = Registry::new();
665        let gate = Arc::new(tokio::sync::Semaphore::new(0));
666        let gate_for_handler = gate.clone();
667        registry.register("block", move |request| {
668            let gate = gate_for_handler.clone();
669            async move {
670                let _permit = gate.acquire().await.expect("semaphore closed");
671                ExecutionOutcome::success(
672                    request.job_id.clone(),
673                    request.request_id.clone(),
674                    json!({"ok": true}),
675                )
676            }
677        });
678        let (client, server) = tcp_pair().await;
679        let in_flight = Arc::new(Mutex::new(HashMap::new()));
680        let job_index = Arc::new(Mutex::new(HashMap::new()));
681        let server_task = tokio::spawn(async move {
682            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
683                .await
684                .unwrap();
685        });
686        let mut client = client;
687        let job_id = "job-capacity".to_string();
688        for i in 0..RESPONSE_CHANNEL_CAPACITY {
689            let mut request = build_request("block");
690            request.request_id = format!("req-{i}");
691            request.job_id = job_id.clone();
692            request.context.job_id = job_id.clone();
693            write_message(&mut client, &RunnerMessage::Request { payload: request })
694                .await
695                .unwrap();
696        }
697
698        let cancel = RunnerMessage::Cancel {
699            payload: CancelRequest {
700                protocol_version: PROTOCOL_VERSION.to_string(),
701                job_id: job_id.clone(),
702                request_id: Some("req-0".to_string()),
703                hard_kill: false,
704            },
705        };
706        write_message(&mut client, &cancel).await.unwrap();
707        let response = timeout(Duration::from_secs(1), read_message(&mut client))
708            .await
709            .unwrap()
710            .unwrap()
711            .unwrap();
712        match response {
713            RunnerMessage::Response { payload } => {
714                assert_eq!(payload.status, OutcomeStatus::Error);
715                let error_type = payload
716                    .error
717                    .as_ref()
718                    .and_then(|error| error.error_type.as_deref());
719                assert_eq!(error_type, Some("cancelled"));
720            }
721            _ => panic!("expected response"),
722        }
723
724        let mut extra_request = build_request("block");
725        extra_request.request_id = "req-extra".to_string();
726        extra_request.job_id = job_id.clone();
727        extra_request.context.job_id = job_id.clone();
728        write_message(
729            &mut client,
730            &RunnerMessage::Request {
731                payload: extra_request,
732            },
733        )
734        .await
735        .unwrap();
736
737        gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
738
739        let mut saw_extra = false;
740        for _ in 0..RESPONSE_CHANNEL_CAPACITY {
741            let response = timeout(Duration::from_secs(1), read_message(&mut client))
742                .await
743                .unwrap()
744                .unwrap()
745                .unwrap();
746            if let RunnerMessage::Response { payload } = response
747                && payload.request_id.as_deref() == Some("req-extra")
748            {
749                assert_eq!(payload.status, OutcomeStatus::Success);
750                saw_extra = true;
751            }
752        }
753        assert!(saw_extra, "extra request never completed");
754
755        drop(client);
756        let _ = server_task.await;
757    }
758
759    #[tokio::test]
760    async fn execute_with_deadline_times_out() {
761        let mut registry = Registry::new();
762        registry.register("echo", |request| async move {
763            ExecutionOutcome::success(
764                request.job_id.clone(),
765                request.request_id.clone(),
766                json!({"ok": true}),
767            )
768        });
769        let mut request = build_request("echo");
770        request.context.deadline = Some(
771            "2020-01-01T00:00:00Z"
772                .parse::<chrono::DateTime<Utc>>()
773                .unwrap(),
774        );
775        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
776        assert_eq!(outcome.status, OutcomeStatus::Timeout);
777    }
778
779    #[tokio::test]
780    async fn execute_with_deadline_succeeds_before_deadline() {
781        let mut registry = Registry::new();
782        registry.register("echo", |request| async move {
783            ExecutionOutcome::success(
784                request.job_id.clone(),
785                request.request_id.clone(),
786                json!({"ok": true}),
787            )
788        });
789        let mut request = build_request("echo");
790        request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
791        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
792        assert_eq!(outcome.status, OutcomeStatus::Success);
793    }
794
795    #[tokio::test]
796    async fn handle_connection_handles_unexpected_response_message() {
797        let registry = Registry::new();
798        let (client, server) = tcp_pair().await;
799        let in_flight = Arc::new(Mutex::new(HashMap::new()));
800        let job_index = Arc::new(Mutex::new(HashMap::new()));
801        let server_task = tokio::spawn(async move {
802            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
803                .await
804                .unwrap();
805        });
806        let mut client = client;
807        let response = RunnerMessage::Response {
808            payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
809        };
810        write_message(&mut client, &response).await.unwrap();
811        let reply = read_message(&mut client).await.unwrap().unwrap();
812        match reply {
813            RunnerMessage::Response { payload } => {
814                assert_eq!(payload.status, OutcomeStatus::Error);
815                assert!(
816                    payload
817                        .error
818                        .as_ref()
819                        .unwrap()
820                        .message
821                        .contains("unexpected response")
822                );
823            }
824            _ => panic!("expected response"),
825        }
826        drop(client);
827        let _ = server_task.await;
828    }
829
830    #[tokio::test]
831    async fn handle_connection_cancels_by_job_id() {
832        let mut registry = Registry::new();
833        registry.register("sleep", |request| async move {
834            tokio::time::sleep(Duration::from_millis(200)).await;
835            ExecutionOutcome::success(
836                request.job_id.clone(),
837                request.request_id.clone(),
838                json!({"ok": true}),
839            )
840        });
841        let (client, server) = tcp_pair().await;
842        let in_flight = Arc::new(Mutex::new(HashMap::new()));
843        let job_index = Arc::new(Mutex::new(HashMap::new()));
844        let server_task = tokio::spawn(async move {
845            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
846                .await
847                .unwrap();
848        });
849        let mut client = client;
850        let request = build_request("sleep");
851        let message = RunnerMessage::Request {
852            payload: request.clone(),
853        };
854        write_message(&mut client, &message).await.unwrap();
855        let cancel = RunnerMessage::Cancel {
856            payload: CancelRequest {
857                protocol_version: PROTOCOL_VERSION.to_string(),
858                job_id: request.job_id.clone(),
859                request_id: None,
860                hard_kill: false,
861            },
862        };
863        write_message(&mut client, &cancel).await.unwrap();
864        let response = read_message(&mut client).await.unwrap().unwrap();
865        match response {
866            RunnerMessage::Response { payload } => {
867                assert_eq!(payload.status, OutcomeStatus::Error);
868                let error_type = payload
869                    .error
870                    .as_ref()
871                    .and_then(|error| error.error_type.as_deref());
872                assert_eq!(error_type, Some("cancelled"));
873            }
874            _ => panic!("expected response"),
875        }
876        drop(client);
877        let _ = server_task.await;
878    }
879
880    #[tokio::test]
881    async fn handle_cancel_by_job_id_cancels_all_requests() {
882        let mut registry = Registry::new();
883        registry.register("sleep", |request| async move {
884            tokio::time::sleep(Duration::from_millis(200)).await;
885            ExecutionOutcome::success(
886                request.job_id.clone(),
887                request.request_id.clone(),
888                json!({"ok": true}),
889            )
890        });
891        let (client, server) = tcp_pair().await;
892        let in_flight = Arc::new(Mutex::new(HashMap::new()));
893        let job_index = Arc::new(Mutex::new(HashMap::new()));
894        let server_task = tokio::spawn(async move {
895            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
896                .await
897                .unwrap();
898        });
899        let mut client = client;
900        let mut request1 = build_request("sleep");
901        request1.request_id = "req-1".to_string();
902        request1.job_id = "job-shared".to_string();
903        let mut request2 = build_request("sleep");
904        request2.request_id = "req-2".to_string();
905        request2.job_id = "job-shared".to_string();
906        write_message(&mut client, &RunnerMessage::Request { payload: request1 })
907            .await
908            .unwrap();
909        write_message(&mut client, &RunnerMessage::Request { payload: request2 })
910            .await
911            .unwrap();
912        let cancel = RunnerMessage::Cancel {
913            payload: CancelRequest {
914                protocol_version: PROTOCOL_VERSION.to_string(),
915                job_id: "job-shared".to_string(),
916                request_id: None,
917                hard_kill: false,
918            },
919        };
920        write_message(&mut client, &cancel).await.unwrap();
921
922        let mut cancelled = 0;
923        for _ in 0..2 {
924            let response = timeout(Duration::from_millis(200), read_message(&mut client))
925                .await
926                .unwrap()
927                .unwrap()
928                .unwrap();
929            match response {
930                RunnerMessage::Response { payload } => {
931                    assert_eq!(payload.status, OutcomeStatus::Error);
932                    let error_type = payload
933                        .error
934                        .as_ref()
935                        .and_then(|error| error.error_type.as_deref());
936                    assert_eq!(error_type, Some("cancelled"));
937                    cancelled += 1;
938                }
939                _ => panic!("expected response"),
940            }
941        }
942        assert_eq!(cancelled, 2);
943        drop(client);
944        let _ = server_task.await;
945    }
946
947    #[tokio::test]
948    async fn connection_teardown_clears_tracking_maps() {
949        let mut registry = Registry::new();
950        registry.register("sleep", |request| async move {
951            tokio::time::sleep(Duration::from_millis(200)).await;
952            ExecutionOutcome::success(
953                request.job_id.clone(),
954                request.request_id.clone(),
955                json!({"ok": true}),
956            )
957        });
958        let (client, server) = tcp_pair().await;
959        let in_flight = Arc::new(Mutex::new(HashMap::new()));
960        let job_index = Arc::new(Mutex::new(HashMap::new()));
961        let in_flight_for_server = in_flight.clone();
962        let job_index_for_server = job_index.clone();
963        let server_task = tokio::spawn(async move {
964            handle_connection(
965                server,
966                &registry,
967                &NoopTelemetry,
968                in_flight_for_server,
969                job_index_for_server,
970            )
971            .await
972            .unwrap();
973        });
974        let mut client = client;
975        let request = build_request("sleep");
976        let message = RunnerMessage::Request {
977            payload: request.clone(),
978        };
979        write_message(&mut client, &message).await.unwrap();
980
981        let mut inserted = false;
982        for _ in 0..20 {
983            let has_in_flight = {
984                let guard = in_flight.lock().await;
985                guard.contains_key(&request.request_id)
986            };
987            let has_job_index = {
988                let guard = job_index.lock().await;
989                guard.contains_key(&request.job_id)
990            };
991            if has_in_flight && has_job_index {
992                inserted = true;
993                break;
994            }
995            tokio::time::sleep(Duration::from_millis(10)).await;
996        }
997        assert!(inserted, "request never entered tracking maps");
998
999        drop(client);
1000        let _ = server_task.await;
1001
1002        let in_flight = in_flight.lock().await;
1003        let job_index = job_index.lock().await;
1004        assert!(in_flight.is_empty());
1005        assert!(job_index.is_empty());
1006    }
1007
1008    #[tokio::test]
1009    async fn handle_cancel_ignores_invalid_protocol() {
1010        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1011        let job_index = Arc::new(Mutex::new(HashMap::new()));
1012        let (tx, _rx) = mpsc::channel(1);
1013        let handle = tokio::spawn(async {});
1014        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1015        {
1016            let mut guard = in_flight.lock().await;
1017            guard.insert(
1018                "req-1".to_string(),
1019                InFlightTask {
1020                    job_id: "job-1".to_string(),
1021                    handle,
1022                    response_tx: tx,
1023                    connection_requests,
1024                    completed: Arc::new(AtomicBool::new(false)),
1025                },
1026            );
1027        }
1028        let payload = CancelRequest {
1029            protocol_version: "0".to_string(),
1030            job_id: "job-1".to_string(),
1031            request_id: None,
1032            hard_kill: false,
1033        };
1034        handle_cancel(payload, &in_flight, &job_index).await;
1035        let guard = in_flight.lock().await;
1036        assert!(guard.contains_key("req-1"));
1037        guard.get("req-1").unwrap().handle.abort();
1038    }
1039
1040    #[tokio::test]
1041    async fn handle_cancel_skips_completed_requests() {
1042        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1043        let job_index = Arc::new(Mutex::new(HashMap::new()));
1044        let (tx, mut rx) = mpsc::channel(1);
1045        let handle = tokio::spawn(async {
1046            tokio::time::sleep(Duration::from_millis(50)).await;
1047        });
1048        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1049        {
1050            let mut guard = in_flight.lock().await;
1051            guard.insert(
1052                "req-1".to_string(),
1053                InFlightTask {
1054                    job_id: "job-1".to_string(),
1055                    handle,
1056                    response_tx: tx,
1057                    connection_requests,
1058                    completed: Arc::new(AtomicBool::new(true)),
1059                },
1060            );
1061        }
1062        {
1063            let mut guard = job_index.lock().await;
1064            guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1065        }
1066        let payload = CancelRequest {
1067            protocol_version: PROTOCOL_VERSION.to_string(),
1068            job_id: "job-1".to_string(),
1069            request_id: Some("req-1".to_string()),
1070            hard_kill: false,
1071        };
1072        handle_cancel(payload, &in_flight, &job_index).await;
1073        assert!(in_flight.lock().await.contains_key("req-1"));
1074        assert!(job_index.lock().await.contains_key("job-1"));
1075        assert!(rx.try_recv().is_err());
1076        let task = in_flight.lock().await.remove("req-1").unwrap();
1077        task.handle.abort();
1078    }
1079
1080    #[tokio::test]
1081    async fn read_message_handles_empty_and_invalid_payloads() {
1082        let (mut client, mut server) = tokio::io::duplex(64);
1083        // length = 0
1084        client.write_all(&0u32.to_be_bytes()).await.unwrap();
1085        let err = read_message(&mut server).await.unwrap_err();
1086        assert!(err.to_string().contains("payload cannot be empty"));
1087
1088        // invalid json
1089        let (mut client, mut server) = tokio::io::duplex(64);
1090        let payload = b"not-json";
1091        let len = (payload.len() as u32).to_be_bytes();
1092        client.write_all(&len).await.unwrap();
1093        client.write_all(payload).await.unwrap();
1094        let err = read_message(&mut server).await.unwrap_err();
1095        assert!(err.to_string().contains("expected"));
1096
1097        // oversized payload
1098        let (mut client, mut server) = tokio::io::duplex(64);
1099        let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1100        client.write_all(&len).await.unwrap();
1101        let err = read_message(&mut server).await.unwrap_err();
1102        assert!(err.to_string().contains("payload too large"));
1103    }
1104
1105    #[tokio::test]
1106    async fn read_message_returns_none_on_eof() {
1107        let (client, mut server) = tokio::io::duplex(8);
1108        drop(client);
1109        let message = read_message(&mut server).await.unwrap();
1110        assert!(message.is_none());
1111    }
1112}