Skip to main content

rrq_runner/
runtime.rs

1use crate::parent_guard;
2use crate::registry::Registry;
3use crate::telemetry::{NoopTelemetry, Telemetry};
4use crate::types::{ExecutionError, ExecutionOutcome};
5use chrono::{DateTime, Utc};
6use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
7use std::collections::{HashMap, HashSet};
8use std::net::{IpAddr, Ipv4Addr, SocketAddr};
9use std::sync::{
10    Arc,
11    atomic::{AtomicBool, Ordering},
12};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14use tokio::net::TcpListener;
15use tokio::sync::{Mutex, mpsc};
16use tokio::time::{Duration, timeout};
17const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
18const RESPONSE_CHANNEL_CAPACITY: usize = 64;
19const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
20
21fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
22    Box::new(std::io::Error::new(
23        std::io::ErrorKind::InvalidInput,
24        message.into(),
25    ))
26}
27
28pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
29    let raw = raw.trim();
30    if raw.is_empty() {
31        return Err(invalid_input("runner tcp_socket cannot be empty"));
32    }
33
34    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
35        let (host, port_str) = rest
36            .split_once("]:")
37            .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
38        (host, port_str)
39    } else {
40        let (host, port_str) = raw
41            .rsplit_once(':')
42            .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
43        if host.is_empty() {
44            return Err(invalid_input("runner tcp_socket host cannot be empty"));
45        }
46        (host, port_str)
47    };
48
49    let port: u16 = port_str
50        .parse()
51        .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
52    if port == 0 {
53        return Err(invalid_input("runner tcp_socket port must be > 0"));
54    }
55
56    let ip = if host == "localhost" {
57        IpAddr::V4(Ipv4Addr::LOCALHOST)
58    } else {
59        let parsed: IpAddr = host
60            .parse()
61            .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
62        if !parsed.is_loopback() {
63            return Err(invalid_input("runner tcp_socket host must be localhost"));
64        }
65        parsed
66    };
67
68    Ok(SocketAddr::new(ip, port))
69}
70
71pub struct RunnerRuntime {
72    runtime: tokio::runtime::Runtime,
73}
74
75impl RunnerRuntime {
76    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
77        parent_guard::install_parent_lifecycle_guard()?;
78        Ok(Self {
79            runtime: tokio::runtime::Runtime::new()?,
80        })
81    }
82
83    pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
84        self.runtime.enter()
85    }
86
87    pub fn run_tcp(
88        &self,
89        registry: &Registry,
90        addr: SocketAddr,
91    ) -> Result<(), Box<dyn std::error::Error>> {
92        let telemetry = NoopTelemetry;
93        self.run_tcp_with(registry, addr, &telemetry)
94    }
95
96    pub fn run_tcp_with<T: Telemetry + ?Sized>(
97        &self,
98        registry: &Registry,
99        addr: SocketAddr,
100        telemetry: &T,
101    ) -> Result<(), Box<dyn std::error::Error>> {
102        run_tcp_loop(&self.runtime, registry, addr, telemetry)
103    }
104}
105
106pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
107    RunnerRuntime::new()?.run_tcp(registry, addr)
108}
109
110pub fn run_tcp_with<T: Telemetry + ?Sized>(
111    registry: &Registry,
112    addr: SocketAddr,
113    telemetry: &T,
114) -> Result<(), Box<dyn std::error::Error>> {
115    RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
116}
117
118fn run_tcp_loop<T: Telemetry + ?Sized>(
119    runtime: &tokio::runtime::Runtime,
120    registry: &Registry,
121    addr: SocketAddr,
122    telemetry: &T,
123) -> Result<(), Box<dyn std::error::Error>> {
124    let registry = registry.clone();
125    let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
126    let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
127        Arc::new(Mutex::new(HashMap::new()));
128    let telemetry = telemetry.clone_box();
129    runtime.block_on(async move {
130        if !addr.ip().is_loopback() {
131            return Err(invalid_input(format!(
132                "runner tcp_socket must be loopback-only (got {addr})"
133            )));
134        }
135        let listener = TcpListener::bind(addr).await?;
136        loop {
137            let (stream, _) = listener.accept().await?;
138            let registry = registry.clone();
139            let telemetry = telemetry.clone();
140            let in_flight = in_flight.clone();
141            let job_index = job_index.clone();
142            tokio::spawn(async move {
143                if let Err(err) =
144                    handle_connection(stream, &registry, telemetry.as_ref(), in_flight, job_index)
145                        .await
146                {
147                    tracing::error!("runner connection error: {err}");
148                }
149            });
150        }
151    })
152}
153
154async fn handle_connection<S, T>(
155    stream: S,
156    registry: &Registry,
157    telemetry: &T,
158    in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
159    job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
160) -> Result<(), Box<dyn std::error::Error>>
161where
162    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
163    T: Telemetry + ?Sized,
164{
165    let (mut reader, mut writer) = tokio::io::split(stream);
166    let (response_tx, mut response_rx) =
167        mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
168    let writer_task = tokio::spawn(async move {
169        while let Some(outcome) = response_rx.recv().await {
170            let response = RunnerMessage::Response { payload: outcome };
171            if write_message(&mut writer, &response).await.is_err() {
172                break;
173            }
174        }
175    });
176    let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
177        Arc::new(Mutex::new(std::collections::HashSet::new()));
178
179    loop {
180        let message = match read_message(&mut reader).await? {
181            Some(message) => message,
182            None => break,
183        };
184        match message {
185            RunnerMessage::Request { payload } => {
186                if payload.protocol_version != PROTOCOL_VERSION {
187                    let outcome = ExecutionOutcome::error(
188                        payload.job_id.clone(),
189                        payload.request_id.clone(),
190                        "Unsupported protocol version",
191                    );
192                    let _ = response_tx.send(outcome).await;
193                    continue;
194                }
195
196                let request_id = payload.request_id.clone();
197                let job_id = payload.job_id.clone();
198                {
199                    let mut active = connection_requests.lock().await;
200                    if active.len() >= RESPONSE_CHANNEL_CAPACITY {
201                        let outcome = ExecutionOutcome::error(
202                            payload.job_id.clone(),
203                            payload.request_id.clone(),
204                            "Runner busy: too many in-flight requests",
205                        );
206                        drop(active);
207                        let send_result =
208                            timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
209                        match send_result {
210                            Ok(Ok(())) => {}
211                            Ok(Err(_)) => {
212                                return Err("runner response channel closed".into());
213                            }
214                            Err(_) => {
215                                return Err("runner response channel stalled".into());
216                            }
217                        }
218                        continue;
219                    }
220                    active.insert(request_id.clone());
221                    crate::telemetry::record_runner_channel_pressure(active.len());
222                }
223                let response_tx = response_tx.clone();
224                let registry = registry.clone();
225                let telemetry = telemetry.clone_box();
226                let in_flight_for_task = in_flight.clone();
227                let job_index_for_task = job_index.clone();
228                let active_for_task = connection_requests.clone();
229                let request_id_for_task = request_id.clone();
230                let job_id_for_task = job_id.clone();
231                let response_tx_for_task = response_tx.clone();
232                let completed = Arc::new(AtomicBool::new(false));
233                let completed_for_task = completed.clone();
234
235                let handle = tokio::spawn(async move {
236                    let outcome =
237                        execute_with_deadline(payload, registry, telemetry.as_ref()).await;
238                    completed_for_task.store(true, Ordering::SeqCst);
239                    let send_result =
240                        timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
241                    match send_result {
242                        Ok(Ok(())) => {}
243                        Ok(Err(_)) => {
244                            tracing::warn!("runner response channel closed; dropping outcome");
245                        }
246                        Err(_) => {
247                            tracing::warn!("runner response channel stalled; dropping outcome");
248                        }
249                    }
250                    {
251                        let mut in_flight = in_flight_for_task.lock().await;
252                        if in_flight.remove(&request_id_for_task).is_some() {
253                            crate::telemetry::record_runner_inflight_delta(-1);
254                        }
255                    }
256                    {
257                        let mut job_index = job_index_for_task.lock().await;
258                        if let Some(entries) = job_index.get_mut(&job_id_for_task) {
259                            entries.remove(&request_id_for_task);
260                            if entries.is_empty() {
261                                job_index.remove(&job_id_for_task);
262                            }
263                        }
264                    }
265                    {
266                        let mut active = active_for_task.lock().await;
267                        active.remove(&request_id_for_task);
268                        crate::telemetry::record_runner_channel_pressure(active.len());
269                    }
270                });
271
272                {
273                    let mut in_flight = in_flight.lock().await;
274                    in_flight.insert(
275                        request_id.clone(),
276                        InFlightTask {
277                            job_id: job_id.clone(),
278                            handle,
279                            response_tx: response_tx.clone(),
280                            connection_requests: connection_requests.clone(),
281                            completed,
282                        },
283                    );
284                }
285                crate::telemetry::record_runner_inflight_delta(1);
286                {
287                    let mut job_index = job_index.lock().await;
288                    job_index
289                        .entry(job_id)
290                        .or_insert_with(HashSet::new)
291                        .insert(request_id);
292                }
293            }
294            RunnerMessage::Cancel { payload } => {
295                handle_cancel(payload, &in_flight, &job_index).await;
296            }
297            RunnerMessage::Response { .. } => {
298                let outcome = ExecutionOutcome {
299                    job_id: Some("unknown".to_string()),
300                    request_id: None,
301                    status: rrq_protocol::OutcomeStatus::Error,
302                    result: None,
303                    error: Some(ExecutionError {
304                        message: "unexpected response message".to_string(),
305                        error_type: None,
306                        code: None,
307                        details: None,
308                    }),
309                    retry_after_seconds: None,
310                };
311                let _ = response_tx.send(outcome).await;
312            }
313        }
314    }
315
316    let request_ids = {
317        let mut active = connection_requests.lock().await;
318        active.drain().collect::<Vec<_>>()
319    };
320    crate::telemetry::record_runner_channel_pressure(0);
321    for request_id in request_ids {
322        let task = {
323            let mut in_flight = in_flight.lock().await;
324            in_flight.remove(&request_id)
325        };
326        if let Some(task) = task {
327            task.handle.abort();
328            crate::telemetry::record_runner_inflight_delta(-1);
329            let mut job_index = job_index.lock().await;
330            if let Some(entries) = job_index.get_mut(&task.job_id) {
331                entries.remove(&request_id);
332                if entries.is_empty() {
333                    job_index.remove(&task.job_id);
334                }
335            }
336        }
337    }
338    writer_task.abort();
339
340    Ok(())
341}
342
343struct InFlightTask {
344    job_id: String,
345    handle: tokio::task::JoinHandle<()>,
346    response_tx: mpsc::Sender<ExecutionOutcome>,
347    connection_requests: Arc<Mutex<HashSet<String>>>,
348    completed: Arc<AtomicBool>,
349}
350
351async fn handle_cancel(
352    payload: CancelRequest,
353    in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
354    job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
355) {
356    if payload.protocol_version != PROTOCOL_VERSION {
357        return;
358    }
359    let request_ids = if let Some(request_id) = payload.request_id.clone() {
360        vec![request_id]
361    } else {
362        let job_index = job_index.lock().await;
363        job_index
364            .get(&payload.job_id)
365            .map(|ids| ids.iter().cloned().collect())
366            .unwrap_or_else(Vec::new)
367    };
368    if request_ids.is_empty() {
369        return;
370    }
371    let scope = if payload.request_id.is_some() {
372        "request"
373    } else {
374        "job"
375    };
376    crate::telemetry::record_cancellation(scope);
377
378    for request_id in request_ids {
379        let task = {
380            let mut in_flight = in_flight.lock().await;
381            if let Some(task) = in_flight.get(&request_id)
382                && task.completed.load(Ordering::SeqCst)
383            {
384                None
385            } else {
386                in_flight.remove(&request_id)
387            }
388        };
389        if let Some(task) = task {
390            task.handle.abort();
391            crate::telemetry::record_runner_inflight_delta(-1);
392            {
393                let mut active = task.connection_requests.lock().await;
394                active.remove(&request_id);
395                crate::telemetry::record_runner_channel_pressure(active.len());
396            }
397            let outcome = ExecutionOutcome {
398                job_id: Some(payload.job_id.clone()),
399                request_id: Some(request_id.clone()),
400                status: OutcomeStatus::Error,
401                result: None,
402                error: Some(ExecutionError {
403                    message: "Job cancelled".to_string(),
404                    error_type: Some("cancelled".to_string()),
405                    code: None,
406                    details: None,
407                }),
408                retry_after_seconds: None,
409            };
410            let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
411            match send_result {
412                Ok(Ok(())) => {}
413                Ok(Err(_)) => {
414                    tracing::warn!("runner response channel closed; dropping cancel outcome");
415                }
416                Err(_) => {
417                    tracing::warn!("runner response channel stalled; dropping cancel outcome");
418                }
419            }
420            let mut job_index = job_index.lock().await;
421            if let Some(entries) = job_index.get_mut(&task.job_id) {
422                entries.remove(&request_id);
423                if entries.is_empty() {
424                    job_index.remove(&task.job_id);
425                }
426            }
427        }
428    }
429}
430
431async fn execute_with_deadline<T: Telemetry + ?Sized>(
432    request: rrq_protocol::ExecutionRequest,
433    registry: Registry,
434    telemetry: &T,
435) -> ExecutionOutcome {
436    let job_id = request.job_id.clone();
437    let request_id = request.request_id.clone();
438    let deadline = request.context.deadline;
439    if let Some(deadline) = deadline {
440        let now: DateTime<Utc> = Utc::now();
441        if deadline <= now {
442            crate::telemetry::record_deadline_expired();
443            return ExecutionOutcome::timeout(
444                job_id.clone(),
445                request_id.clone(),
446                "Job deadline exceeded",
447            );
448        }
449        if let Ok(remaining) = (deadline - now).to_std() {
450            match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
451                Ok(outcome) => return outcome,
452                Err(_) => {
453                    crate::telemetry::record_deadline_expired();
454                    return ExecutionOutcome::timeout(
455                        job_id.clone(),
456                        request_id.clone(),
457                        "Job execution timed out",
458                    );
459                }
460            }
461        }
462        crate::telemetry::record_deadline_expired();
463        return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
464    }
465    registry.execute_with(request, telemetry).await
466}
467
468async fn read_message<R: AsyncRead + Unpin>(
469    stream: &mut R,
470) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
471    let mut header = [0u8; 4];
472    match stream.read_exact(&mut header).await {
473        Ok(_) => {}
474        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
475        Err(err) => return Err(err.into()),
476    }
477    let length = u32::from_be_bytes(header) as usize;
478    if length == 0 {
479        return Err("runner message payload cannot be empty".into());
480    }
481    if length > MAX_FRAME_LEN {
482        return Err("runner message payload too large".into());
483    }
484    let mut payload = vec![0u8; length];
485    stream.read_exact(&mut payload).await?;
486    let message = serde_json::from_slice(&payload)?;
487    Ok(Some(message))
488}
489
490async fn write_message<W: AsyncWrite + Unpin>(
491    stream: &mut W,
492    message: &RunnerMessage,
493) -> Result<(), Box<dyn std::error::Error>> {
494    let framed = encode_frame(message)?;
495    stream.write_all(&framed).await?;
496    stream.flush().await?;
497    Ok(())
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::registry::Registry;
504    use crate::telemetry::NoopTelemetry;
505    use chrono::Utc;
506    use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
507    use serde_json::json;
508    use tokio::net::{TcpListener, TcpStream};
509    use tokio::time::{Duration, timeout};
510
511    fn build_request(function_name: &str) -> ExecutionRequest {
512        ExecutionRequest {
513            protocol_version: PROTOCOL_VERSION.to_string(),
514            request_id: "req-1".to_string(),
515            job_id: "job-1".to_string(),
516            function_name: function_name.to_string(),
517            params: std::collections::HashMap::new(),
518            context: ExecutionContext {
519                job_id: "job-1".to_string(),
520                attempt: 1,
521                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
522                queue_name: "default".to_string(),
523                deadline: None,
524                trace_context: None,
525                correlation_context: None,
526                worker_id: None,
527            },
528        }
529    }
530
531    async fn tcp_pair() -> (TcpStream, TcpStream) {
532        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
533        let addr = listener.local_addr().unwrap();
534        let client = TcpStream::connect(addr).await.unwrap();
535        let (server, _) = listener.accept().await.unwrap();
536        (client, server)
537    }
538
539    #[tokio::test]
540    async fn handle_connection_executes_request() {
541        let mut registry = Registry::new();
542        registry.register("echo", |request| async move {
543            ExecutionOutcome::success(
544                request.job_id.clone(),
545                request.request_id.clone(),
546                json!({"ok": true}),
547            )
548        });
549        let (client, server) = tcp_pair().await;
550        let in_flight = Arc::new(Mutex::new(HashMap::new()));
551        let job_index = Arc::new(Mutex::new(HashMap::new()));
552        let server_task = tokio::spawn(async move {
553            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
554                .await
555                .unwrap();
556        });
557        let mut client = client;
558        let request = build_request("echo");
559        let message = RunnerMessage::Request { payload: request };
560        write_message(&mut client, &message).await.unwrap();
561        let response = read_message(&mut client).await.unwrap().unwrap();
562        match response {
563            RunnerMessage::Response { payload } => {
564                assert_eq!(payload.status, OutcomeStatus::Success);
565                assert_eq!(payload.result, Some(json!({"ok": true})));
566            }
567            _ => panic!("expected response"),
568        }
569        drop(client);
570        let _ = server_task.await;
571    }
572
573    #[tokio::test]
574    async fn handle_connection_rejects_bad_protocol() {
575        let registry = Registry::new();
576        let (client, server) = tcp_pair().await;
577        let in_flight = Arc::new(Mutex::new(HashMap::new()));
578        let job_index = Arc::new(Mutex::new(HashMap::new()));
579        let server_task = tokio::spawn(async move {
580            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
581                .await
582                .unwrap();
583        });
584        let mut client = client;
585        let mut request = build_request("echo");
586        request.protocol_version = "0".to_string();
587        let message = RunnerMessage::Request { payload: request };
588        write_message(&mut client, &message).await.unwrap();
589        let response = read_message(&mut client).await.unwrap().unwrap();
590        match response {
591            RunnerMessage::Response { payload } => {
592                assert_eq!(payload.status, OutcomeStatus::Error);
593            }
594            _ => panic!("expected response"),
595        }
596        drop(client);
597        let _ = server_task.await;
598    }
599
600    #[tokio::test]
601    async fn handle_connection_cancels_inflight() {
602        let mut registry = Registry::new();
603        registry.register("sleep", |request| async move {
604            tokio::time::sleep(Duration::from_millis(200)).await;
605            ExecutionOutcome::success(
606                request.job_id.clone(),
607                request.request_id.clone(),
608                json!({"ok": true}),
609            )
610        });
611        let (client, server) = tcp_pair().await;
612        let in_flight = Arc::new(Mutex::new(HashMap::new()));
613        let job_index = Arc::new(Mutex::new(HashMap::new()));
614        let server_task = tokio::spawn(async move {
615            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
616                .await
617                .unwrap();
618        });
619        let mut client = client;
620        let request = ExecutionRequest {
621            protocol_version: PROTOCOL_VERSION.to_string(),
622            request_id: "req-cancel".to_string(),
623            job_id: "job-cancel".to_string(),
624            function_name: "sleep".to_string(),
625            params: std::collections::HashMap::new(),
626            context: ExecutionContext {
627                job_id: "job-cancel".to_string(),
628                attempt: 1,
629                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
630                queue_name: "default".to_string(),
631                deadline: None,
632                trace_context: None,
633                correlation_context: None,
634                worker_id: None,
635            },
636        };
637        let message = RunnerMessage::Request {
638            payload: request.clone(),
639        };
640        write_message(&mut client, &message).await.unwrap();
641        let cancel = RunnerMessage::Cancel {
642            payload: CancelRequest {
643                protocol_version: PROTOCOL_VERSION.to_string(),
644                job_id: request.job_id.clone(),
645                request_id: Some(request.request_id.clone()),
646                hard_kill: false,
647            },
648        };
649        write_message(&mut client, &cancel).await.unwrap();
650        let response = read_message(&mut client).await.unwrap().unwrap();
651        match response {
652            RunnerMessage::Response { payload } => {
653                assert_eq!(payload.status, OutcomeStatus::Error);
654                let error_type = payload
655                    .error
656                    .as_ref()
657                    .and_then(|error| error.error_type.as_deref());
658                assert_eq!(error_type, Some("cancelled"));
659            }
660            _ => panic!("expected response"),
661        }
662        drop(client);
663        let _ = server_task.await;
664    }
665
666    #[tokio::test]
667    async fn cancel_frees_connection_capacity() {
668        let mut registry = Registry::new();
669        let gate = Arc::new(tokio::sync::Semaphore::new(0));
670        let gate_for_handler = gate.clone();
671        registry.register("block", move |request| {
672            let gate = gate_for_handler.clone();
673            async move {
674                let _permit = gate.acquire().await.expect("semaphore closed");
675                ExecutionOutcome::success(
676                    request.job_id.clone(),
677                    request.request_id.clone(),
678                    json!({"ok": true}),
679                )
680            }
681        });
682        let (client, server) = tcp_pair().await;
683        let in_flight = Arc::new(Mutex::new(HashMap::new()));
684        let job_index = Arc::new(Mutex::new(HashMap::new()));
685        let server_task = tokio::spawn(async move {
686            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
687                .await
688                .unwrap();
689        });
690        let mut client = client;
691        let job_id = "job-capacity".to_string();
692        for i in 0..RESPONSE_CHANNEL_CAPACITY {
693            let mut request = build_request("block");
694            request.request_id = format!("req-{i}");
695            request.job_id = job_id.clone();
696            request.context.job_id = job_id.clone();
697            write_message(&mut client, &RunnerMessage::Request { payload: request })
698                .await
699                .unwrap();
700        }
701
702        let cancel = RunnerMessage::Cancel {
703            payload: CancelRequest {
704                protocol_version: PROTOCOL_VERSION.to_string(),
705                job_id: job_id.clone(),
706                request_id: Some("req-0".to_string()),
707                hard_kill: false,
708            },
709        };
710        write_message(&mut client, &cancel).await.unwrap();
711        let response = timeout(Duration::from_secs(1), read_message(&mut client))
712            .await
713            .unwrap()
714            .unwrap()
715            .unwrap();
716        match response {
717            RunnerMessage::Response { payload } => {
718                assert_eq!(payload.status, OutcomeStatus::Error);
719                let error_type = payload
720                    .error
721                    .as_ref()
722                    .and_then(|error| error.error_type.as_deref());
723                assert_eq!(error_type, Some("cancelled"));
724            }
725            _ => panic!("expected response"),
726        }
727
728        let mut extra_request = build_request("block");
729        extra_request.request_id = "req-extra".to_string();
730        extra_request.job_id = job_id.clone();
731        extra_request.context.job_id = job_id.clone();
732        write_message(
733            &mut client,
734            &RunnerMessage::Request {
735                payload: extra_request,
736            },
737        )
738        .await
739        .unwrap();
740
741        gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
742
743        let mut saw_extra = false;
744        for _ in 0..RESPONSE_CHANNEL_CAPACITY {
745            let response = timeout(Duration::from_secs(1), read_message(&mut client))
746                .await
747                .unwrap()
748                .unwrap()
749                .unwrap();
750            if let RunnerMessage::Response { payload } = response
751                && payload.request_id.as_deref() == Some("req-extra")
752            {
753                assert_eq!(payload.status, OutcomeStatus::Success);
754                saw_extra = true;
755            }
756        }
757        assert!(saw_extra, "extra request never completed");
758
759        drop(client);
760        let _ = server_task.await;
761    }
762
763    #[tokio::test]
764    async fn execute_with_deadline_times_out() {
765        let mut registry = Registry::new();
766        registry.register("echo", |request| async move {
767            ExecutionOutcome::success(
768                request.job_id.clone(),
769                request.request_id.clone(),
770                json!({"ok": true}),
771            )
772        });
773        let mut request = build_request("echo");
774        request.context.deadline = Some(
775            "2020-01-01T00:00:00Z"
776                .parse::<chrono::DateTime<Utc>>()
777                .unwrap(),
778        );
779        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
780        assert_eq!(outcome.status, OutcomeStatus::Timeout);
781    }
782
783    #[tokio::test]
784    async fn execute_with_deadline_succeeds_before_deadline() {
785        let mut registry = Registry::new();
786        registry.register("echo", |request| async move {
787            ExecutionOutcome::success(
788                request.job_id.clone(),
789                request.request_id.clone(),
790                json!({"ok": true}),
791            )
792        });
793        let mut request = build_request("echo");
794        request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
795        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
796        assert_eq!(outcome.status, OutcomeStatus::Success);
797    }
798
799    #[tokio::test]
800    async fn handle_connection_handles_unexpected_response_message() {
801        let registry = Registry::new();
802        let (client, server) = tcp_pair().await;
803        let in_flight = Arc::new(Mutex::new(HashMap::new()));
804        let job_index = Arc::new(Mutex::new(HashMap::new()));
805        let server_task = tokio::spawn(async move {
806            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
807                .await
808                .unwrap();
809        });
810        let mut client = client;
811        let response = RunnerMessage::Response {
812            payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
813        };
814        write_message(&mut client, &response).await.unwrap();
815        let reply = read_message(&mut client).await.unwrap().unwrap();
816        match reply {
817            RunnerMessage::Response { payload } => {
818                assert_eq!(payload.status, OutcomeStatus::Error);
819                assert!(
820                    payload
821                        .error
822                        .as_ref()
823                        .unwrap()
824                        .message
825                        .contains("unexpected response")
826                );
827            }
828            _ => panic!("expected response"),
829        }
830        drop(client);
831        let _ = server_task.await;
832    }
833
834    #[tokio::test]
835    async fn handle_connection_cancels_by_job_id() {
836        let mut registry = Registry::new();
837        registry.register("sleep", |request| async move {
838            tokio::time::sleep(Duration::from_millis(200)).await;
839            ExecutionOutcome::success(
840                request.job_id.clone(),
841                request.request_id.clone(),
842                json!({"ok": true}),
843            )
844        });
845        let (client, server) = tcp_pair().await;
846        let in_flight = Arc::new(Mutex::new(HashMap::new()));
847        let job_index = Arc::new(Mutex::new(HashMap::new()));
848        let server_task = tokio::spawn(async move {
849            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
850                .await
851                .unwrap();
852        });
853        let mut client = client;
854        let request = build_request("sleep");
855        let message = RunnerMessage::Request {
856            payload: request.clone(),
857        };
858        write_message(&mut client, &message).await.unwrap();
859        let cancel = RunnerMessage::Cancel {
860            payload: CancelRequest {
861                protocol_version: PROTOCOL_VERSION.to_string(),
862                job_id: request.job_id.clone(),
863                request_id: None,
864                hard_kill: false,
865            },
866        };
867        write_message(&mut client, &cancel).await.unwrap();
868        let response = read_message(&mut client).await.unwrap().unwrap();
869        match response {
870            RunnerMessage::Response { payload } => {
871                assert_eq!(payload.status, OutcomeStatus::Error);
872                let error_type = payload
873                    .error
874                    .as_ref()
875                    .and_then(|error| error.error_type.as_deref());
876                assert_eq!(error_type, Some("cancelled"));
877            }
878            _ => panic!("expected response"),
879        }
880        drop(client);
881        let _ = server_task.await;
882    }
883
884    #[tokio::test]
885    async fn handle_cancel_by_job_id_cancels_all_requests() {
886        let mut registry = Registry::new();
887        registry.register("sleep", |request| async move {
888            tokio::time::sleep(Duration::from_millis(200)).await;
889            ExecutionOutcome::success(
890                request.job_id.clone(),
891                request.request_id.clone(),
892                json!({"ok": true}),
893            )
894        });
895        let (client, server) = tcp_pair().await;
896        let in_flight = Arc::new(Mutex::new(HashMap::new()));
897        let job_index = Arc::new(Mutex::new(HashMap::new()));
898        let server_task = tokio::spawn(async move {
899            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
900                .await
901                .unwrap();
902        });
903        let mut client = client;
904        let mut request1 = build_request("sleep");
905        request1.request_id = "req-1".to_string();
906        request1.job_id = "job-shared".to_string();
907        let mut request2 = build_request("sleep");
908        request2.request_id = "req-2".to_string();
909        request2.job_id = "job-shared".to_string();
910        write_message(&mut client, &RunnerMessage::Request { payload: request1 })
911            .await
912            .unwrap();
913        write_message(&mut client, &RunnerMessage::Request { payload: request2 })
914            .await
915            .unwrap();
916        let cancel = RunnerMessage::Cancel {
917            payload: CancelRequest {
918                protocol_version: PROTOCOL_VERSION.to_string(),
919                job_id: "job-shared".to_string(),
920                request_id: None,
921                hard_kill: false,
922            },
923        };
924        write_message(&mut client, &cancel).await.unwrap();
925
926        let mut cancelled = 0;
927        for _ in 0..2 {
928            let response = timeout(Duration::from_millis(200), read_message(&mut client))
929                .await
930                .unwrap()
931                .unwrap()
932                .unwrap();
933            match response {
934                RunnerMessage::Response { payload } => {
935                    assert_eq!(payload.status, OutcomeStatus::Error);
936                    let error_type = payload
937                        .error
938                        .as_ref()
939                        .and_then(|error| error.error_type.as_deref());
940                    assert_eq!(error_type, Some("cancelled"));
941                    cancelled += 1;
942                }
943                _ => panic!("expected response"),
944            }
945        }
946        assert_eq!(cancelled, 2);
947        drop(client);
948        let _ = server_task.await;
949    }
950
951    #[tokio::test]
952    async fn connection_teardown_clears_tracking_maps() {
953        let mut registry = Registry::new();
954        registry.register("sleep", |request| async move {
955            tokio::time::sleep(Duration::from_millis(200)).await;
956            ExecutionOutcome::success(
957                request.job_id.clone(),
958                request.request_id.clone(),
959                json!({"ok": true}),
960            )
961        });
962        let (client, server) = tcp_pair().await;
963        let in_flight = Arc::new(Mutex::new(HashMap::new()));
964        let job_index = Arc::new(Mutex::new(HashMap::new()));
965        let in_flight_for_server = in_flight.clone();
966        let job_index_for_server = job_index.clone();
967        let server_task = tokio::spawn(async move {
968            handle_connection(
969                server,
970                &registry,
971                &NoopTelemetry,
972                in_flight_for_server,
973                job_index_for_server,
974            )
975            .await
976            .unwrap();
977        });
978        let mut client = client;
979        let request = build_request("sleep");
980        let message = RunnerMessage::Request {
981            payload: request.clone(),
982        };
983        write_message(&mut client, &message).await.unwrap();
984
985        let mut inserted = false;
986        for _ in 0..20 {
987            let has_in_flight = {
988                let guard = in_flight.lock().await;
989                guard.contains_key(&request.request_id)
990            };
991            let has_job_index = {
992                let guard = job_index.lock().await;
993                guard.contains_key(&request.job_id)
994            };
995            if has_in_flight && has_job_index {
996                inserted = true;
997                break;
998            }
999            tokio::time::sleep(Duration::from_millis(10)).await;
1000        }
1001        assert!(inserted, "request never entered tracking maps");
1002
1003        drop(client);
1004        let _ = server_task.await;
1005
1006        let in_flight = in_flight.lock().await;
1007        let job_index = job_index.lock().await;
1008        assert!(in_flight.is_empty());
1009        assert!(job_index.is_empty());
1010    }
1011
1012    #[tokio::test]
1013    async fn handle_cancel_ignores_invalid_protocol() {
1014        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1015        let job_index = Arc::new(Mutex::new(HashMap::new()));
1016        let (tx, _rx) = mpsc::channel(1);
1017        let handle = tokio::spawn(async {});
1018        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1019        {
1020            let mut guard = in_flight.lock().await;
1021            guard.insert(
1022                "req-1".to_string(),
1023                InFlightTask {
1024                    job_id: "job-1".to_string(),
1025                    handle,
1026                    response_tx: tx,
1027                    connection_requests,
1028                    completed: Arc::new(AtomicBool::new(false)),
1029                },
1030            );
1031        }
1032        let payload = CancelRequest {
1033            protocol_version: "0".to_string(),
1034            job_id: "job-1".to_string(),
1035            request_id: None,
1036            hard_kill: false,
1037        };
1038        handle_cancel(payload, &in_flight, &job_index).await;
1039        let guard = in_flight.lock().await;
1040        assert!(guard.contains_key("req-1"));
1041        guard.get("req-1").unwrap().handle.abort();
1042    }
1043
1044    #[tokio::test]
1045    async fn handle_cancel_skips_completed_requests() {
1046        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1047        let job_index = Arc::new(Mutex::new(HashMap::new()));
1048        let (tx, mut rx) = mpsc::channel(1);
1049        let handle = tokio::spawn(async {
1050            tokio::time::sleep(Duration::from_millis(50)).await;
1051        });
1052        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1053        {
1054            let mut guard = in_flight.lock().await;
1055            guard.insert(
1056                "req-1".to_string(),
1057                InFlightTask {
1058                    job_id: "job-1".to_string(),
1059                    handle,
1060                    response_tx: tx,
1061                    connection_requests,
1062                    completed: Arc::new(AtomicBool::new(true)),
1063                },
1064            );
1065        }
1066        {
1067            let mut guard = job_index.lock().await;
1068            guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1069        }
1070        let payload = CancelRequest {
1071            protocol_version: PROTOCOL_VERSION.to_string(),
1072            job_id: "job-1".to_string(),
1073            request_id: Some("req-1".to_string()),
1074            hard_kill: false,
1075        };
1076        handle_cancel(payload, &in_flight, &job_index).await;
1077        assert!(in_flight.lock().await.contains_key("req-1"));
1078        assert!(job_index.lock().await.contains_key("job-1"));
1079        assert!(rx.try_recv().is_err());
1080        let task = in_flight.lock().await.remove("req-1").unwrap();
1081        task.handle.abort();
1082    }
1083
1084    #[tokio::test]
1085    async fn read_message_handles_empty_and_invalid_payloads() {
1086        let (mut client, mut server) = tokio::io::duplex(64);
1087        // length = 0
1088        client.write_all(&0u32.to_be_bytes()).await.unwrap();
1089        let err = read_message(&mut server).await.unwrap_err();
1090        assert!(err.to_string().contains("payload cannot be empty"));
1091
1092        // invalid json
1093        let (mut client, mut server) = tokio::io::duplex(64);
1094        let payload = b"not-json";
1095        let len = (payload.len() as u32).to_be_bytes();
1096        client.write_all(&len).await.unwrap();
1097        client.write_all(payload).await.unwrap();
1098        let err = read_message(&mut server).await.unwrap_err();
1099        assert!(err.to_string().contains("expected"));
1100
1101        // oversized payload
1102        let (mut client, mut server) = tokio::io::duplex(64);
1103        let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1104        client.write_all(&len).await.unwrap();
1105        let err = read_message(&mut server).await.unwrap_err();
1106        assert!(err.to_string().contains("payload too large"));
1107    }
1108
1109    #[tokio::test]
1110    async fn read_message_returns_none_on_eof() {
1111        let (client, mut server) = tokio::io::duplex(8);
1112        drop(client);
1113        let message = read_message(&mut server).await.unwrap();
1114        assert!(message.is_none());
1115    }
1116}