Skip to main content

rrq_runner/
runtime.rs

1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9    Arc,
10    atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16
17pub const ENV_RUNNER_TCP_SOCKET: &str = "RRQ_RUNNER_TCP_SOCKET";
18const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
19const RESPONSE_CHANNEL_CAPACITY: usize = 64;
20const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
21
22fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
23    Box::new(std::io::Error::new(
24        std::io::ErrorKind::InvalidInput,
25        message.into(),
26    ))
27}
28
29pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
30    let raw = raw.trim();
31    if raw.is_empty() {
32        return Err(invalid_input("runner tcp_socket cannot be empty"));
33    }
34
35    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
36        let (host, port_str) = rest
37            .split_once("]:")
38            .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
39        (host, port_str)
40    } else {
41        let (host, port_str) = raw
42            .rsplit_once(':')
43            .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
44        if host.is_empty() {
45            return Err(invalid_input("runner tcp_socket host cannot be empty"));
46        }
47        (host, port_str)
48    };
49
50    let port: u16 = port_str
51        .parse()
52        .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
53    if port == 0 {
54        return Err(invalid_input("runner tcp_socket port must be > 0"));
55    }
56
57    let ip = if host == "localhost" {
58        IpAddr::V4(Ipv4Addr::LOCALHOST)
59    } else {
60        let parsed: IpAddr = host
61            .parse()
62            .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
63        if !parsed.is_loopback() {
64            return Err(invalid_input("runner tcp_socket host must be localhost"));
65        }
66        parsed
67    };
68
69    Ok(SocketAddr::new(ip, port))
70}
71
72pub struct RunnerRuntime {
73    runtime: tokio::runtime::Runtime,
74}
75
76impl RunnerRuntime {
77    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
78        Ok(Self {
79            runtime: tokio::runtime::Runtime::new()?,
80        })
81    }
82
83    pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
84        self.runtime.enter()
85    }
86
87    pub fn run_tcp(
88        &self,
89        registry: &Registry,
90        addr: SocketAddr,
91    ) -> Result<(), Box<dyn std::error::Error>> {
92        let telemetry = NoopTelemetry;
93        self.run_tcp_with(registry, addr, &telemetry)
94    }
95
96    pub fn run_tcp_with<T: Telemetry + ?Sized>(
97        &self,
98        registry: &Registry,
99        addr: SocketAddr,
100        telemetry: &T,
101    ) -> Result<(), Box<dyn std::error::Error>> {
102        run_tcp_loop(&self.runtime, registry, addr, telemetry)
103    }
104}
105
106pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
107    RunnerRuntime::new()?.run_tcp(registry, addr)
108}
109
110pub fn run_tcp_with<T: Telemetry + ?Sized>(
111    registry: &Registry,
112    addr: SocketAddr,
113    telemetry: &T,
114) -> Result<(), Box<dyn std::error::Error>> {
115    RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
116}
117
118fn run_tcp_loop<T: Telemetry + ?Sized>(
119    runtime: &tokio::runtime::Runtime,
120    registry: &Registry,
121    addr: SocketAddr,
122    telemetry: &T,
123) -> Result<(), Box<dyn std::error::Error>> {
124    let registry = registry.clone();
125    let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
126    let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
127        Arc::new(Mutex::new(HashMap::new()));
128    let telemetry = telemetry.clone_box();
129    runtime.block_on(async move {
130        if !addr.ip().is_loopback() {
131            return Err(invalid_input(format!(
132                "runner tcp_socket must be loopback-only (got {addr})"
133            )));
134        }
135        let listener = TcpListener::bind(addr).await?;
136        loop {
137            let (stream, _) = listener.accept().await?;
138            let registry = registry.clone();
139            let telemetry = telemetry.clone();
140            let in_flight = in_flight.clone();
141            let job_index = job_index.clone();
142            tokio::spawn(async move {
143                if let Err(err) =
144                    handle_connection(stream, &registry, telemetry.as_ref(), in_flight, job_index)
145                        .await
146                {
147                    tracing::error!("runner connection error: {err}");
148                }
149            });
150        }
151    })
152}
153
154async fn handle_connection<S, T>(
155    stream: S,
156    registry: &Registry,
157    telemetry: &T,
158    in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
159    job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
160) -> Result<(), Box<dyn std::error::Error>>
161where
162    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
163    T: Telemetry + ?Sized,
164{
165    let (mut reader, mut writer) = tokio::io::split(stream);
166    let (response_tx, mut response_rx) =
167        mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
168    let writer_task = tokio::spawn(async move {
169        while let Some(outcome) = response_rx.recv().await {
170            let response = RunnerMessage::Response { payload: outcome };
171            if write_message(&mut writer, &response).await.is_err() {
172                break;
173            }
174        }
175    });
176    let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
177        Arc::new(Mutex::new(std::collections::HashSet::new()));
178
179    loop {
180        let message = match read_message(&mut reader).await? {
181            Some(message) => message,
182            None => break,
183        };
184        match message {
185            RunnerMessage::Request { payload } => {
186                if payload.protocol_version != PROTOCOL_VERSION {
187                    let outcome = ExecutionOutcome::error(
188                        payload.job_id.clone(),
189                        payload.request_id.clone(),
190                        "Unsupported protocol version",
191                    );
192                    let _ = response_tx.send(outcome).await;
193                    continue;
194                }
195
196                let request_id = payload.request_id.clone();
197                let job_id = payload.job_id.clone();
198                {
199                    let mut active = connection_requests.lock().await;
200                    if active.len() >= RESPONSE_CHANNEL_CAPACITY {
201                        let outcome = ExecutionOutcome::error(
202                            payload.job_id.clone(),
203                            payload.request_id.clone(),
204                            "Runner busy: too many in-flight requests",
205                        );
206                        drop(active);
207                        let send_result =
208                            timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
209                        match send_result {
210                            Ok(Ok(())) => {}
211                            Ok(Err(_)) => {
212                                return Err("runner response channel closed".into());
213                            }
214                            Err(_) => {
215                                return Err("runner response channel stalled".into());
216                            }
217                        }
218                        continue;
219                    }
220                    active.insert(request_id.clone());
221                }
222                let response_tx = response_tx.clone();
223                let registry = registry.clone();
224                let telemetry = telemetry.clone_box();
225                let in_flight_for_task = in_flight.clone();
226                let job_index_for_task = job_index.clone();
227                let active_for_task = connection_requests.clone();
228                let request_id_for_task = request_id.clone();
229                let job_id_for_task = job_id.clone();
230                let response_tx_for_task = response_tx.clone();
231                let completed = Arc::new(AtomicBool::new(false));
232                let completed_for_task = completed.clone();
233
234                let handle = tokio::spawn(async move {
235                    let outcome =
236                        execute_with_deadline(payload, registry, telemetry.as_ref()).await;
237                    completed_for_task.store(true, Ordering::SeqCst);
238                    let send_result =
239                        timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
240                    match send_result {
241                        Ok(Ok(())) => {}
242                        Ok(Err(_)) => {
243                            tracing::warn!("runner response channel closed; dropping outcome");
244                        }
245                        Err(_) => {
246                            tracing::warn!("runner response channel stalled; dropping outcome");
247                        }
248                    }
249                    {
250                        let mut in_flight = in_flight_for_task.lock().await;
251                        in_flight.remove(&request_id_for_task);
252                    }
253                    {
254                        let mut job_index = job_index_for_task.lock().await;
255                        if let Some(entries) = job_index.get_mut(&job_id_for_task) {
256                            entries.remove(&request_id_for_task);
257                            if entries.is_empty() {
258                                job_index.remove(&job_id_for_task);
259                            }
260                        }
261                    }
262                    {
263                        let mut active = active_for_task.lock().await;
264                        active.remove(&request_id_for_task);
265                    }
266                });
267
268                {
269                    let mut in_flight = in_flight.lock().await;
270                    in_flight.insert(
271                        request_id.clone(),
272                        InFlightTask {
273                            job_id: job_id.clone(),
274                            handle,
275                            response_tx: response_tx.clone(),
276                            connection_requests: connection_requests.clone(),
277                            completed,
278                        },
279                    );
280                }
281                {
282                    let mut job_index = job_index.lock().await;
283                    job_index
284                        .entry(job_id)
285                        .or_insert_with(HashSet::new)
286                        .insert(request_id);
287                }
288            }
289            RunnerMessage::Cancel { payload } => {
290                handle_cancel(payload, &in_flight, &job_index).await;
291            }
292            RunnerMessage::Response { .. } => {
293                let outcome = ExecutionOutcome {
294                    job_id: Some("unknown".to_string()),
295                    request_id: None,
296                    status: rrq_protocol::OutcomeStatus::Error,
297                    result: None,
298                    error: Some(ExecutionError {
299                        message: "unexpected response message".to_string(),
300                        error_type: None,
301                        code: None,
302                        details: None,
303                    }),
304                    retry_after_seconds: None,
305                };
306                let _ = response_tx.send(outcome).await;
307            }
308        }
309    }
310
311    let request_ids = {
312        let mut active = connection_requests.lock().await;
313        active.drain().collect::<Vec<_>>()
314    };
315    for request_id in request_ids {
316        let task = {
317            let mut in_flight = in_flight.lock().await;
318            in_flight.remove(&request_id)
319        };
320        if let Some(task) = task {
321            task.handle.abort();
322            let mut job_index = job_index.lock().await;
323            if let Some(entries) = job_index.get_mut(&task.job_id) {
324                entries.remove(&request_id);
325                if entries.is_empty() {
326                    job_index.remove(&task.job_id);
327                }
328            }
329        }
330    }
331    writer_task.abort();
332
333    Ok(())
334}
335
336struct InFlightTask {
337    job_id: String,
338    handle: tokio::task::JoinHandle<()>,
339    response_tx: mpsc::Sender<ExecutionOutcome>,
340    connection_requests: Arc<Mutex<HashSet<String>>>,
341    completed: Arc<AtomicBool>,
342}
343
344async fn handle_cancel(
345    payload: CancelRequest,
346    in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
347    job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
348) {
349    if payload.protocol_version != PROTOCOL_VERSION {
350        return;
351    }
352    let request_ids = if let Some(request_id) = payload.request_id.clone() {
353        vec![request_id]
354    } else {
355        let job_index = job_index.lock().await;
356        job_index
357            .get(&payload.job_id)
358            .map(|ids| ids.iter().cloned().collect())
359            .unwrap_or_else(Vec::new)
360    };
361    if request_ids.is_empty() {
362        return;
363    }
364
365    for request_id in request_ids {
366        let task = {
367            let mut in_flight = in_flight.lock().await;
368            if let Some(task) = in_flight.get(&request_id)
369                && task.completed.load(Ordering::SeqCst)
370            {
371                None
372            } else {
373                in_flight.remove(&request_id)
374            }
375        };
376        if let Some(task) = task {
377            task.handle.abort();
378            {
379                let mut active = task.connection_requests.lock().await;
380                active.remove(&request_id);
381            }
382            let outcome = ExecutionOutcome {
383                job_id: Some(payload.job_id.clone()),
384                request_id: Some(request_id.clone()),
385                status: OutcomeStatus::Error,
386                result: None,
387                error: Some(ExecutionError {
388                    message: "Job cancelled".to_string(),
389                    error_type: Some("cancelled".to_string()),
390                    code: None,
391                    details: None,
392                }),
393                retry_after_seconds: None,
394            };
395            let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
396            match send_result {
397                Ok(Ok(())) => {}
398                Ok(Err(_)) => {
399                    tracing::warn!("runner response channel closed; dropping cancel outcome");
400                }
401                Err(_) => {
402                    tracing::warn!("runner response channel stalled; dropping cancel outcome");
403                }
404            }
405            let mut job_index = job_index.lock().await;
406            if let Some(entries) = job_index.get_mut(&task.job_id) {
407                entries.remove(&request_id);
408                if entries.is_empty() {
409                    job_index.remove(&task.job_id);
410                }
411            }
412        }
413    }
414}
415
416async fn execute_with_deadline<T: Telemetry + ?Sized>(
417    request: rrq_protocol::ExecutionRequest,
418    registry: Registry,
419    telemetry: &T,
420) -> ExecutionOutcome {
421    let job_id = request.job_id.clone();
422    let request_id = request.request_id.clone();
423    let deadline = request.context.deadline;
424    if let Some(deadline) = deadline {
425        let now: DateTime<Utc> = Utc::now();
426        if deadline <= now {
427            return ExecutionOutcome::timeout(
428                job_id.clone(),
429                request_id.clone(),
430                "Job deadline exceeded",
431            );
432        }
433        if let Ok(remaining) = (deadline - now).to_std() {
434            match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
435                Ok(outcome) => return outcome,
436                Err(_) => {
437                    return ExecutionOutcome::timeout(
438                        job_id.clone(),
439                        request_id.clone(),
440                        "Job execution timed out",
441                    );
442                }
443            }
444        }
445        return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
446    }
447    registry.execute_with(request, telemetry).await
448}
449
450async fn read_message<R: AsyncRead + Unpin>(
451    stream: &mut R,
452) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
453    let mut header = [0u8; 4];
454    match stream.read_exact(&mut header).await {
455        Ok(_) => {}
456        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
457        Err(err) => return Err(err.into()),
458    }
459    let length = u32::from_be_bytes(header) as usize;
460    if length == 0 {
461        return Err("runner message payload cannot be empty".into());
462    }
463    if length > MAX_FRAME_LEN {
464        return Err("runner message payload too large".into());
465    }
466    let mut payload = vec![0u8; length];
467    stream.read_exact(&mut payload).await?;
468    let message = serde_json::from_slice(&payload)?;
469    Ok(Some(message))
470}
471
472async fn write_message<W: AsyncWrite + Unpin>(
473    stream: &mut W,
474    message: &RunnerMessage,
475) -> Result<(), Box<dyn std::error::Error>> {
476    let framed = encode_frame(message)?;
477    stream.write_all(&framed).await?;
478    stream.flush().await?;
479    Ok(())
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::registry::Registry;
486    use crate::telemetry::NoopTelemetry;
487    use chrono::Utc;
488    use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
489    use serde_json::json;
490    use tokio::net::{TcpListener, TcpStream};
491    use tokio::time::{Duration, timeout};
492
493    fn build_request(function_name: &str) -> ExecutionRequest {
494        ExecutionRequest {
495            protocol_version: PROTOCOL_VERSION.to_string(),
496            request_id: "req-1".to_string(),
497            job_id: "job-1".to_string(),
498            function_name: function_name.to_string(),
499            params: std::collections::HashMap::new(),
500            context: ExecutionContext {
501                job_id: "job-1".to_string(),
502                attempt: 1,
503                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
504                queue_name: "default".to_string(),
505                deadline: None,
506                trace_context: None,
507                worker_id: None,
508            },
509        }
510    }
511
512    async fn tcp_pair() -> (TcpStream, TcpStream) {
513        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
514        let addr = listener.local_addr().unwrap();
515        let client = TcpStream::connect(addr).await.unwrap();
516        let (server, _) = listener.accept().await.unwrap();
517        (client, server)
518    }
519
520    #[tokio::test]
521    async fn handle_connection_executes_request() {
522        let mut registry = Registry::new();
523        registry.register("echo", |request| async move {
524            ExecutionOutcome::success(
525                request.job_id.clone(),
526                request.request_id.clone(),
527                json!({"ok": true}),
528            )
529        });
530        let (client, server) = tcp_pair().await;
531        let in_flight = Arc::new(Mutex::new(HashMap::new()));
532        let job_index = Arc::new(Mutex::new(HashMap::new()));
533        let server_task = tokio::spawn(async move {
534            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
535                .await
536                .unwrap();
537        });
538        let mut client = client;
539        let request = build_request("echo");
540        let message = RunnerMessage::Request { payload: request };
541        write_message(&mut client, &message).await.unwrap();
542        let response = read_message(&mut client).await.unwrap().unwrap();
543        match response {
544            RunnerMessage::Response { payload } => {
545                assert_eq!(payload.status, OutcomeStatus::Success);
546                assert_eq!(payload.result, Some(json!({"ok": true})));
547            }
548            _ => panic!("expected response"),
549        }
550        drop(client);
551        let _ = server_task.await;
552    }
553
554    #[tokio::test]
555    async fn handle_connection_rejects_bad_protocol() {
556        let registry = Registry::new();
557        let (client, server) = tcp_pair().await;
558        let in_flight = Arc::new(Mutex::new(HashMap::new()));
559        let job_index = Arc::new(Mutex::new(HashMap::new()));
560        let server_task = tokio::spawn(async move {
561            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
562                .await
563                .unwrap();
564        });
565        let mut client = client;
566        let mut request = build_request("echo");
567        request.protocol_version = "0".to_string();
568        let message = RunnerMessage::Request { payload: request };
569        write_message(&mut client, &message).await.unwrap();
570        let response = read_message(&mut client).await.unwrap().unwrap();
571        match response {
572            RunnerMessage::Response { payload } => {
573                assert_eq!(payload.status, OutcomeStatus::Error);
574            }
575            _ => panic!("expected response"),
576        }
577        drop(client);
578        let _ = server_task.await;
579    }
580
581    #[tokio::test]
582    async fn handle_connection_cancels_inflight() {
583        let mut registry = Registry::new();
584        registry.register("sleep", |request| async move {
585            tokio::time::sleep(Duration::from_millis(200)).await;
586            ExecutionOutcome::success(
587                request.job_id.clone(),
588                request.request_id.clone(),
589                json!({"ok": true}),
590            )
591        });
592        let (client, server) = tcp_pair().await;
593        let in_flight = Arc::new(Mutex::new(HashMap::new()));
594        let job_index = Arc::new(Mutex::new(HashMap::new()));
595        let server_task = tokio::spawn(async move {
596            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
597                .await
598                .unwrap();
599        });
600        let mut client = client;
601        let request = ExecutionRequest {
602            protocol_version: PROTOCOL_VERSION.to_string(),
603            request_id: "req-cancel".to_string(),
604            job_id: "job-cancel".to_string(),
605            function_name: "sleep".to_string(),
606            params: std::collections::HashMap::new(),
607            context: ExecutionContext {
608                job_id: "job-cancel".to_string(),
609                attempt: 1,
610                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
611                queue_name: "default".to_string(),
612                deadline: None,
613                trace_context: None,
614                worker_id: None,
615            },
616        };
617        let message = RunnerMessage::Request {
618            payload: request.clone(),
619        };
620        write_message(&mut client, &message).await.unwrap();
621        let cancel = RunnerMessage::Cancel {
622            payload: CancelRequest {
623                protocol_version: PROTOCOL_VERSION.to_string(),
624                job_id: request.job_id.clone(),
625                request_id: Some(request.request_id.clone()),
626                hard_kill: false,
627            },
628        };
629        write_message(&mut client, &cancel).await.unwrap();
630        let response = read_message(&mut client).await.unwrap().unwrap();
631        match response {
632            RunnerMessage::Response { payload } => {
633                assert_eq!(payload.status, OutcomeStatus::Error);
634                let error_type = payload
635                    .error
636                    .as_ref()
637                    .and_then(|error| error.error_type.as_deref());
638                assert_eq!(error_type, Some("cancelled"));
639            }
640            _ => panic!("expected response"),
641        }
642        drop(client);
643        let _ = server_task.await;
644    }
645
646    #[tokio::test]
647    async fn cancel_frees_connection_capacity() {
648        let mut registry = Registry::new();
649        let gate = Arc::new(tokio::sync::Semaphore::new(0));
650        let gate_for_handler = gate.clone();
651        registry.register("block", move |request| {
652            let gate = gate_for_handler.clone();
653            async move {
654                let _permit = gate.acquire().await.expect("semaphore closed");
655                ExecutionOutcome::success(
656                    request.job_id.clone(),
657                    request.request_id.clone(),
658                    json!({"ok": true}),
659                )
660            }
661        });
662        let (client, server) = tcp_pair().await;
663        let in_flight = Arc::new(Mutex::new(HashMap::new()));
664        let job_index = Arc::new(Mutex::new(HashMap::new()));
665        let server_task = tokio::spawn(async move {
666            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
667                .await
668                .unwrap();
669        });
670        let mut client = client;
671        let job_id = "job-capacity".to_string();
672        for i in 0..RESPONSE_CHANNEL_CAPACITY {
673            let mut request = build_request("block");
674            request.request_id = format!("req-{i}");
675            request.job_id = job_id.clone();
676            request.context.job_id = job_id.clone();
677            write_message(&mut client, &RunnerMessage::Request { payload: request })
678                .await
679                .unwrap();
680        }
681
682        let cancel = RunnerMessage::Cancel {
683            payload: CancelRequest {
684                protocol_version: PROTOCOL_VERSION.to_string(),
685                job_id: job_id.clone(),
686                request_id: Some("req-0".to_string()),
687                hard_kill: false,
688            },
689        };
690        write_message(&mut client, &cancel).await.unwrap();
691        let response = timeout(Duration::from_secs(1), read_message(&mut client))
692            .await
693            .unwrap()
694            .unwrap()
695            .unwrap();
696        match response {
697            RunnerMessage::Response { payload } => {
698                assert_eq!(payload.status, OutcomeStatus::Error);
699                let error_type = payload
700                    .error
701                    .as_ref()
702                    .and_then(|error| error.error_type.as_deref());
703                assert_eq!(error_type, Some("cancelled"));
704            }
705            _ => panic!("expected response"),
706        }
707
708        let mut extra_request = build_request("block");
709        extra_request.request_id = "req-extra".to_string();
710        extra_request.job_id = job_id.clone();
711        extra_request.context.job_id = job_id.clone();
712        write_message(
713            &mut client,
714            &RunnerMessage::Request {
715                payload: extra_request,
716            },
717        )
718        .await
719        .unwrap();
720
721        gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
722
723        let mut saw_extra = false;
724        for _ in 0..RESPONSE_CHANNEL_CAPACITY {
725            let response = timeout(Duration::from_secs(1), read_message(&mut client))
726                .await
727                .unwrap()
728                .unwrap()
729                .unwrap();
730            if let RunnerMessage::Response { payload } = response
731                && payload.request_id.as_deref() == Some("req-extra")
732            {
733                assert_eq!(payload.status, OutcomeStatus::Success);
734                saw_extra = true;
735            }
736        }
737        assert!(saw_extra, "extra request never completed");
738
739        drop(client);
740        let _ = server_task.await;
741    }
742
743    #[tokio::test]
744    async fn execute_with_deadline_times_out() {
745        let mut registry = Registry::new();
746        registry.register("echo", |request| async move {
747            ExecutionOutcome::success(
748                request.job_id.clone(),
749                request.request_id.clone(),
750                json!({"ok": true}),
751            )
752        });
753        let mut request = build_request("echo");
754        request.context.deadline = Some(
755            "2020-01-01T00:00:00Z"
756                .parse::<chrono::DateTime<Utc>>()
757                .unwrap(),
758        );
759        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
760        assert_eq!(outcome.status, OutcomeStatus::Timeout);
761    }
762
763    #[tokio::test]
764    async fn execute_with_deadline_succeeds_before_deadline() {
765        let mut registry = Registry::new();
766        registry.register("echo", |request| async move {
767            ExecutionOutcome::success(
768                request.job_id.clone(),
769                request.request_id.clone(),
770                json!({"ok": true}),
771            )
772        });
773        let mut request = build_request("echo");
774        request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
775        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
776        assert_eq!(outcome.status, OutcomeStatus::Success);
777    }
778
779    #[tokio::test]
780    async fn handle_connection_handles_unexpected_response_message() {
781        let registry = Registry::new();
782        let (client, server) = tcp_pair().await;
783        let in_flight = Arc::new(Mutex::new(HashMap::new()));
784        let job_index = Arc::new(Mutex::new(HashMap::new()));
785        let server_task = tokio::spawn(async move {
786            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
787                .await
788                .unwrap();
789        });
790        let mut client = client;
791        let response = RunnerMessage::Response {
792            payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
793        };
794        write_message(&mut client, &response).await.unwrap();
795        let reply = read_message(&mut client).await.unwrap().unwrap();
796        match reply {
797            RunnerMessage::Response { payload } => {
798                assert_eq!(payload.status, OutcomeStatus::Error);
799                assert!(
800                    payload
801                        .error
802                        .as_ref()
803                        .unwrap()
804                        .message
805                        .contains("unexpected response")
806                );
807            }
808            _ => panic!("expected response"),
809        }
810        drop(client);
811        let _ = server_task.await;
812    }
813
814    #[tokio::test]
815    async fn handle_connection_cancels_by_job_id() {
816        let mut registry = Registry::new();
817        registry.register("sleep", |request| async move {
818            tokio::time::sleep(Duration::from_millis(200)).await;
819            ExecutionOutcome::success(
820                request.job_id.clone(),
821                request.request_id.clone(),
822                json!({"ok": true}),
823            )
824        });
825        let (client, server) = tcp_pair().await;
826        let in_flight = Arc::new(Mutex::new(HashMap::new()));
827        let job_index = Arc::new(Mutex::new(HashMap::new()));
828        let server_task = tokio::spawn(async move {
829            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
830                .await
831                .unwrap();
832        });
833        let mut client = client;
834        let request = build_request("sleep");
835        let message = RunnerMessage::Request {
836            payload: request.clone(),
837        };
838        write_message(&mut client, &message).await.unwrap();
839        let cancel = RunnerMessage::Cancel {
840            payload: CancelRequest {
841                protocol_version: PROTOCOL_VERSION.to_string(),
842                job_id: request.job_id.clone(),
843                request_id: None,
844                hard_kill: false,
845            },
846        };
847        write_message(&mut client, &cancel).await.unwrap();
848        let response = read_message(&mut client).await.unwrap().unwrap();
849        match response {
850            RunnerMessage::Response { payload } => {
851                assert_eq!(payload.status, OutcomeStatus::Error);
852                let error_type = payload
853                    .error
854                    .as_ref()
855                    .and_then(|error| error.error_type.as_deref());
856                assert_eq!(error_type, Some("cancelled"));
857            }
858            _ => panic!("expected response"),
859        }
860        drop(client);
861        let _ = server_task.await;
862    }
863
864    #[tokio::test]
865    async fn handle_cancel_by_job_id_cancels_all_requests() {
866        let mut registry = Registry::new();
867        registry.register("sleep", |request| async move {
868            tokio::time::sleep(Duration::from_millis(200)).await;
869            ExecutionOutcome::success(
870                request.job_id.clone(),
871                request.request_id.clone(),
872                json!({"ok": true}),
873            )
874        });
875        let (client, server) = tcp_pair().await;
876        let in_flight = Arc::new(Mutex::new(HashMap::new()));
877        let job_index = Arc::new(Mutex::new(HashMap::new()));
878        let server_task = tokio::spawn(async move {
879            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
880                .await
881                .unwrap();
882        });
883        let mut client = client;
884        let mut request1 = build_request("sleep");
885        request1.request_id = "req-1".to_string();
886        request1.job_id = "job-shared".to_string();
887        let mut request2 = build_request("sleep");
888        request2.request_id = "req-2".to_string();
889        request2.job_id = "job-shared".to_string();
890        write_message(&mut client, &RunnerMessage::Request { payload: request1 })
891            .await
892            .unwrap();
893        write_message(&mut client, &RunnerMessage::Request { payload: request2 })
894            .await
895            .unwrap();
896        let cancel = RunnerMessage::Cancel {
897            payload: CancelRequest {
898                protocol_version: PROTOCOL_VERSION.to_string(),
899                job_id: "job-shared".to_string(),
900                request_id: None,
901                hard_kill: false,
902            },
903        };
904        write_message(&mut client, &cancel).await.unwrap();
905
906        let mut cancelled = 0;
907        for _ in 0..2 {
908            let response = timeout(Duration::from_millis(200), read_message(&mut client))
909                .await
910                .unwrap()
911                .unwrap()
912                .unwrap();
913            match response {
914                RunnerMessage::Response { payload } => {
915                    assert_eq!(payload.status, OutcomeStatus::Error);
916                    let error_type = payload
917                        .error
918                        .as_ref()
919                        .and_then(|error| error.error_type.as_deref());
920                    assert_eq!(error_type, Some("cancelled"));
921                    cancelled += 1;
922                }
923                _ => panic!("expected response"),
924            }
925        }
926        assert_eq!(cancelled, 2);
927        drop(client);
928        let _ = server_task.await;
929    }
930
931    #[tokio::test]
932    async fn connection_teardown_clears_tracking_maps() {
933        let mut registry = Registry::new();
934        registry.register("sleep", |request| async move {
935            tokio::time::sleep(Duration::from_millis(200)).await;
936            ExecutionOutcome::success(
937                request.job_id.clone(),
938                request.request_id.clone(),
939                json!({"ok": true}),
940            )
941        });
942        let (client, server) = tcp_pair().await;
943        let in_flight = Arc::new(Mutex::new(HashMap::new()));
944        let job_index = Arc::new(Mutex::new(HashMap::new()));
945        let in_flight_for_server = in_flight.clone();
946        let job_index_for_server = job_index.clone();
947        let server_task = tokio::spawn(async move {
948            handle_connection(
949                server,
950                &registry,
951                &NoopTelemetry,
952                in_flight_for_server,
953                job_index_for_server,
954            )
955            .await
956            .unwrap();
957        });
958        let mut client = client;
959        let request = build_request("sleep");
960        let message = RunnerMessage::Request {
961            payload: request.clone(),
962        };
963        write_message(&mut client, &message).await.unwrap();
964
965        let mut inserted = false;
966        for _ in 0..20 {
967            let has_in_flight = {
968                let guard = in_flight.lock().await;
969                guard.contains_key(&request.request_id)
970            };
971            let has_job_index = {
972                let guard = job_index.lock().await;
973                guard.contains_key(&request.job_id)
974            };
975            if has_in_flight && has_job_index {
976                inserted = true;
977                break;
978            }
979            tokio::time::sleep(Duration::from_millis(10)).await;
980        }
981        assert!(inserted, "request never entered tracking maps");
982
983        drop(client);
984        let _ = server_task.await;
985
986        let in_flight = in_flight.lock().await;
987        let job_index = job_index.lock().await;
988        assert!(in_flight.is_empty());
989        assert!(job_index.is_empty());
990    }
991
992    #[tokio::test]
993    async fn handle_cancel_ignores_invalid_protocol() {
994        let in_flight = Arc::new(Mutex::new(HashMap::new()));
995        let job_index = Arc::new(Mutex::new(HashMap::new()));
996        let (tx, _rx) = mpsc::channel(1);
997        let handle = tokio::spawn(async {});
998        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
999        {
1000            let mut guard = in_flight.lock().await;
1001            guard.insert(
1002                "req-1".to_string(),
1003                InFlightTask {
1004                    job_id: "job-1".to_string(),
1005                    handle,
1006                    response_tx: tx,
1007                    connection_requests,
1008                    completed: Arc::new(AtomicBool::new(false)),
1009                },
1010            );
1011        }
1012        let payload = CancelRequest {
1013            protocol_version: "0".to_string(),
1014            job_id: "job-1".to_string(),
1015            request_id: None,
1016            hard_kill: false,
1017        };
1018        handle_cancel(payload, &in_flight, &job_index).await;
1019        let guard = in_flight.lock().await;
1020        assert!(guard.contains_key("req-1"));
1021        guard.get("req-1").unwrap().handle.abort();
1022    }
1023
1024    #[tokio::test]
1025    async fn handle_cancel_skips_completed_requests() {
1026        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1027        let job_index = Arc::new(Mutex::new(HashMap::new()));
1028        let (tx, mut rx) = mpsc::channel(1);
1029        let handle = tokio::spawn(async {
1030            tokio::time::sleep(Duration::from_millis(50)).await;
1031        });
1032        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1033        {
1034            let mut guard = in_flight.lock().await;
1035            guard.insert(
1036                "req-1".to_string(),
1037                InFlightTask {
1038                    job_id: "job-1".to_string(),
1039                    handle,
1040                    response_tx: tx,
1041                    connection_requests,
1042                    completed: Arc::new(AtomicBool::new(true)),
1043                },
1044            );
1045        }
1046        {
1047            let mut guard = job_index.lock().await;
1048            guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1049        }
1050        let payload = CancelRequest {
1051            protocol_version: PROTOCOL_VERSION.to_string(),
1052            job_id: "job-1".to_string(),
1053            request_id: Some("req-1".to_string()),
1054            hard_kill: false,
1055        };
1056        handle_cancel(payload, &in_flight, &job_index).await;
1057        assert!(in_flight.lock().await.contains_key("req-1"));
1058        assert!(job_index.lock().await.contains_key("job-1"));
1059        assert!(rx.try_recv().is_err());
1060        let task = in_flight.lock().await.remove("req-1").unwrap();
1061        task.handle.abort();
1062    }
1063
1064    #[tokio::test]
1065    async fn read_message_handles_empty_and_invalid_payloads() {
1066        let (mut client, mut server) = tokio::io::duplex(64);
1067        // length = 0
1068        client.write_all(&0u32.to_be_bytes()).await.unwrap();
1069        let err = read_message(&mut server).await.unwrap_err();
1070        assert!(err.to_string().contains("payload cannot be empty"));
1071
1072        // invalid json
1073        let (mut client, mut server) = tokio::io::duplex(64);
1074        let payload = b"not-json";
1075        let len = (payload.len() as u32).to_be_bytes();
1076        client.write_all(&len).await.unwrap();
1077        client.write_all(payload).await.unwrap();
1078        let err = read_message(&mut server).await.unwrap_err();
1079        assert!(err.to_string().contains("expected"));
1080
1081        // oversized payload
1082        let (mut client, mut server) = tokio::io::duplex(64);
1083        let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1084        client.write_all(&len).await.unwrap();
1085        let err = read_message(&mut server).await.unwrap_err();
1086        assert!(err.to_string().contains("payload too large"));
1087    }
1088
1089    #[tokio::test]
1090    async fn read_message_returns_none_on_eof() {
1091        let (client, mut server) = tokio::io::duplex(8);
1092        drop(client);
1093        let message = read_message(&mut server).await.unwrap();
1094        assert!(message.is_none());
1095    }
1096}