Skip to main content

rrq_runner/
runtime.rs

1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9    Arc,
10    atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
17const RESPONSE_CHANNEL_CAPACITY: usize = 64;
18const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
19
20fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
21    Box::new(std::io::Error::new(
22        std::io::ErrorKind::InvalidInput,
23        message.into(),
24    ))
25}
26
27pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
28    let raw = raw.trim();
29    if raw.is_empty() {
30        return Err(invalid_input("runner tcp_socket cannot be empty"));
31    }
32
33    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
34        let (host, port_str) = rest
35            .split_once("]:")
36            .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
37        (host, port_str)
38    } else {
39        let (host, port_str) = raw
40            .rsplit_once(':')
41            .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
42        if host.is_empty() {
43            return Err(invalid_input("runner tcp_socket host cannot be empty"));
44        }
45        (host, port_str)
46    };
47
48    let port: u16 = port_str
49        .parse()
50        .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
51    if port == 0 {
52        return Err(invalid_input("runner tcp_socket port must be > 0"));
53    }
54
55    let ip = if host == "localhost" {
56        IpAddr::V4(Ipv4Addr::LOCALHOST)
57    } else {
58        let parsed: IpAddr = host
59            .parse()
60            .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
61        if !parsed.is_loopback() {
62            return Err(invalid_input("runner tcp_socket host must be localhost"));
63        }
64        parsed
65    };
66
67    Ok(SocketAddr::new(ip, port))
68}
69
70pub struct RunnerRuntime {
71    runtime: tokio::runtime::Runtime,
72}
73
74impl RunnerRuntime {
75    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
76        Ok(Self {
77            runtime: tokio::runtime::Runtime::new()?,
78        })
79    }
80
81    pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
82        self.runtime.enter()
83    }
84
85    pub fn run_tcp(
86        &self,
87        registry: &Registry,
88        addr: SocketAddr,
89    ) -> Result<(), Box<dyn std::error::Error>> {
90        let telemetry = NoopTelemetry;
91        self.run_tcp_with(registry, addr, &telemetry)
92    }
93
94    pub fn run_tcp_with<T: Telemetry + ?Sized>(
95        &self,
96        registry: &Registry,
97        addr: SocketAddr,
98        telemetry: &T,
99    ) -> Result<(), Box<dyn std::error::Error>> {
100        run_tcp_loop(&self.runtime, registry, addr, telemetry)
101    }
102}
103
104pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
105    RunnerRuntime::new()?.run_tcp(registry, addr)
106}
107
108pub fn run_tcp_with<T: Telemetry + ?Sized>(
109    registry: &Registry,
110    addr: SocketAddr,
111    telemetry: &T,
112) -> Result<(), Box<dyn std::error::Error>> {
113    RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
114}
115
116fn run_tcp_loop<T: Telemetry + ?Sized>(
117    runtime: &tokio::runtime::Runtime,
118    registry: &Registry,
119    addr: SocketAddr,
120    telemetry: &T,
121) -> Result<(), Box<dyn std::error::Error>> {
122    let registry = registry.clone();
123    let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
124    let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
125        Arc::new(Mutex::new(HashMap::new()));
126    let telemetry = telemetry.clone_box();
127    runtime.block_on(async move {
128        if !addr.ip().is_loopback() {
129            return Err(invalid_input(format!(
130                "runner tcp_socket must be loopback-only (got {addr})"
131            )));
132        }
133        let listener = TcpListener::bind(addr).await?;
134        loop {
135            let (stream, _) = listener.accept().await?;
136            let registry = registry.clone();
137            let telemetry = telemetry.clone();
138            let in_flight = in_flight.clone();
139            let job_index = job_index.clone();
140            tokio::spawn(async move {
141                if let Err(err) =
142                    handle_connection(stream, &registry, telemetry.as_ref(), in_flight, job_index)
143                        .await
144                {
145                    tracing::error!("runner connection error: {err}");
146                }
147            });
148        }
149    })
150}
151
152async fn handle_connection<S, T>(
153    stream: S,
154    registry: &Registry,
155    telemetry: &T,
156    in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
157    job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
158) -> Result<(), Box<dyn std::error::Error>>
159where
160    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
161    T: Telemetry + ?Sized,
162{
163    let (mut reader, mut writer) = tokio::io::split(stream);
164    let (response_tx, mut response_rx) =
165        mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
166    let writer_task = tokio::spawn(async move {
167        while let Some(outcome) = response_rx.recv().await {
168            let response = RunnerMessage::Response { payload: outcome };
169            if write_message(&mut writer, &response).await.is_err() {
170                break;
171            }
172        }
173    });
174    let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
175        Arc::new(Mutex::new(std::collections::HashSet::new()));
176
177    loop {
178        let message = match read_message(&mut reader).await? {
179            Some(message) => message,
180            None => break,
181        };
182        match message {
183            RunnerMessage::Request { payload } => {
184                if payload.protocol_version != PROTOCOL_VERSION {
185                    let outcome = ExecutionOutcome::error(
186                        payload.job_id.clone(),
187                        payload.request_id.clone(),
188                        "Unsupported protocol version",
189                    );
190                    let _ = response_tx.send(outcome).await;
191                    continue;
192                }
193
194                let request_id = payload.request_id.clone();
195                let job_id = payload.job_id.clone();
196                {
197                    let mut active = connection_requests.lock().await;
198                    if active.len() >= RESPONSE_CHANNEL_CAPACITY {
199                        let outcome = ExecutionOutcome::error(
200                            payload.job_id.clone(),
201                            payload.request_id.clone(),
202                            "Runner busy: too many in-flight requests",
203                        );
204                        drop(active);
205                        let send_result =
206                            timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
207                        match send_result {
208                            Ok(Ok(())) => {}
209                            Ok(Err(_)) => {
210                                return Err("runner response channel closed".into());
211                            }
212                            Err(_) => {
213                                return Err("runner response channel stalled".into());
214                            }
215                        }
216                        continue;
217                    }
218                    active.insert(request_id.clone());
219                    crate::telemetry::record_runner_channel_pressure(active.len());
220                }
221                let response_tx = response_tx.clone();
222                let registry = registry.clone();
223                let telemetry = telemetry.clone_box();
224                let in_flight_for_task = in_flight.clone();
225                let job_index_for_task = job_index.clone();
226                let active_for_task = connection_requests.clone();
227                let request_id_for_task = request_id.clone();
228                let job_id_for_task = job_id.clone();
229                let response_tx_for_task = response_tx.clone();
230                let completed = Arc::new(AtomicBool::new(false));
231                let completed_for_task = completed.clone();
232
233                let handle = tokio::spawn(async move {
234                    let outcome =
235                        execute_with_deadline(payload, registry, telemetry.as_ref()).await;
236                    completed_for_task.store(true, Ordering::SeqCst);
237                    let send_result =
238                        timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
239                    match send_result {
240                        Ok(Ok(())) => {}
241                        Ok(Err(_)) => {
242                            tracing::warn!("runner response channel closed; dropping outcome");
243                        }
244                        Err(_) => {
245                            tracing::warn!("runner response channel stalled; dropping outcome");
246                        }
247                    }
248                    {
249                        let mut in_flight = in_flight_for_task.lock().await;
250                        if in_flight.remove(&request_id_for_task).is_some() {
251                            crate::telemetry::record_runner_inflight_delta(-1);
252                        }
253                    }
254                    {
255                        let mut job_index = job_index_for_task.lock().await;
256                        if let Some(entries) = job_index.get_mut(&job_id_for_task) {
257                            entries.remove(&request_id_for_task);
258                            if entries.is_empty() {
259                                job_index.remove(&job_id_for_task);
260                            }
261                        }
262                    }
263                    {
264                        let mut active = active_for_task.lock().await;
265                        active.remove(&request_id_for_task);
266                        crate::telemetry::record_runner_channel_pressure(active.len());
267                    }
268                });
269
270                {
271                    let mut in_flight = in_flight.lock().await;
272                    in_flight.insert(
273                        request_id.clone(),
274                        InFlightTask {
275                            job_id: job_id.clone(),
276                            handle,
277                            response_tx: response_tx.clone(),
278                            connection_requests: connection_requests.clone(),
279                            completed,
280                        },
281                    );
282                }
283                crate::telemetry::record_runner_inflight_delta(1);
284                {
285                    let mut job_index = job_index.lock().await;
286                    job_index
287                        .entry(job_id)
288                        .or_insert_with(HashSet::new)
289                        .insert(request_id);
290                }
291            }
292            RunnerMessage::Cancel { payload } => {
293                handle_cancel(payload, &in_flight, &job_index).await;
294            }
295            RunnerMessage::Response { .. } => {
296                let outcome = ExecutionOutcome {
297                    job_id: Some("unknown".to_string()),
298                    request_id: None,
299                    status: rrq_protocol::OutcomeStatus::Error,
300                    result: None,
301                    error: Some(ExecutionError {
302                        message: "unexpected response message".to_string(),
303                        error_type: None,
304                        code: None,
305                        details: None,
306                    }),
307                    retry_after_seconds: None,
308                };
309                let _ = response_tx.send(outcome).await;
310            }
311        }
312    }
313
314    let request_ids = {
315        let mut active = connection_requests.lock().await;
316        active.drain().collect::<Vec<_>>()
317    };
318    crate::telemetry::record_runner_channel_pressure(0);
319    for request_id in request_ids {
320        let task = {
321            let mut in_flight = in_flight.lock().await;
322            in_flight.remove(&request_id)
323        };
324        if let Some(task) = task {
325            task.handle.abort();
326            crate::telemetry::record_runner_inflight_delta(-1);
327            let mut job_index = job_index.lock().await;
328            if let Some(entries) = job_index.get_mut(&task.job_id) {
329                entries.remove(&request_id);
330                if entries.is_empty() {
331                    job_index.remove(&task.job_id);
332                }
333            }
334        }
335    }
336    writer_task.abort();
337
338    Ok(())
339}
340
341struct InFlightTask {
342    job_id: String,
343    handle: tokio::task::JoinHandle<()>,
344    response_tx: mpsc::Sender<ExecutionOutcome>,
345    connection_requests: Arc<Mutex<HashSet<String>>>,
346    completed: Arc<AtomicBool>,
347}
348
349async fn handle_cancel(
350    payload: CancelRequest,
351    in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
352    job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
353) {
354    if payload.protocol_version != PROTOCOL_VERSION {
355        return;
356    }
357    let request_ids = if let Some(request_id) = payload.request_id.clone() {
358        vec![request_id]
359    } else {
360        let job_index = job_index.lock().await;
361        job_index
362            .get(&payload.job_id)
363            .map(|ids| ids.iter().cloned().collect())
364            .unwrap_or_else(Vec::new)
365    };
366    if request_ids.is_empty() {
367        return;
368    }
369    let scope = if payload.request_id.is_some() {
370        "request"
371    } else {
372        "job"
373    };
374    crate::telemetry::record_cancellation(scope);
375
376    for request_id in request_ids {
377        let task = {
378            let mut in_flight = in_flight.lock().await;
379            if let Some(task) = in_flight.get(&request_id)
380                && task.completed.load(Ordering::SeqCst)
381            {
382                None
383            } else {
384                in_flight.remove(&request_id)
385            }
386        };
387        if let Some(task) = task {
388            task.handle.abort();
389            crate::telemetry::record_runner_inflight_delta(-1);
390            {
391                let mut active = task.connection_requests.lock().await;
392                active.remove(&request_id);
393                crate::telemetry::record_runner_channel_pressure(active.len());
394            }
395            let outcome = ExecutionOutcome {
396                job_id: Some(payload.job_id.clone()),
397                request_id: Some(request_id.clone()),
398                status: OutcomeStatus::Error,
399                result: None,
400                error: Some(ExecutionError {
401                    message: "Job cancelled".to_string(),
402                    error_type: Some("cancelled".to_string()),
403                    code: None,
404                    details: None,
405                }),
406                retry_after_seconds: None,
407            };
408            let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
409            match send_result {
410                Ok(Ok(())) => {}
411                Ok(Err(_)) => {
412                    tracing::warn!("runner response channel closed; dropping cancel outcome");
413                }
414                Err(_) => {
415                    tracing::warn!("runner response channel stalled; dropping cancel outcome");
416                }
417            }
418            let mut job_index = job_index.lock().await;
419            if let Some(entries) = job_index.get_mut(&task.job_id) {
420                entries.remove(&request_id);
421                if entries.is_empty() {
422                    job_index.remove(&task.job_id);
423                }
424            }
425        }
426    }
427}
428
429async fn execute_with_deadline<T: Telemetry + ?Sized>(
430    request: rrq_protocol::ExecutionRequest,
431    registry: Registry,
432    telemetry: &T,
433) -> ExecutionOutcome {
434    let job_id = request.job_id.clone();
435    let request_id = request.request_id.clone();
436    let deadline = request.context.deadline;
437    if let Some(deadline) = deadline {
438        let now: DateTime<Utc> = Utc::now();
439        if deadline <= now {
440            crate::telemetry::record_deadline_expired();
441            return ExecutionOutcome::timeout(
442                job_id.clone(),
443                request_id.clone(),
444                "Job deadline exceeded",
445            );
446        }
447        if let Ok(remaining) = (deadline - now).to_std() {
448            match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
449                Ok(outcome) => return outcome,
450                Err(_) => {
451                    crate::telemetry::record_deadline_expired();
452                    return ExecutionOutcome::timeout(
453                        job_id.clone(),
454                        request_id.clone(),
455                        "Job execution timed out",
456                    );
457                }
458            }
459        }
460        crate::telemetry::record_deadline_expired();
461        return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
462    }
463    registry.execute_with(request, telemetry).await
464}
465
466async fn read_message<R: AsyncRead + Unpin>(
467    stream: &mut R,
468) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
469    let mut header = [0u8; 4];
470    match stream.read_exact(&mut header).await {
471        Ok(_) => {}
472        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
473        Err(err) => return Err(err.into()),
474    }
475    let length = u32::from_be_bytes(header) as usize;
476    if length == 0 {
477        return Err("runner message payload cannot be empty".into());
478    }
479    if length > MAX_FRAME_LEN {
480        return Err("runner message payload too large".into());
481    }
482    let mut payload = vec![0u8; length];
483    stream.read_exact(&mut payload).await?;
484    let message = serde_json::from_slice(&payload)?;
485    Ok(Some(message))
486}
487
488async fn write_message<W: AsyncWrite + Unpin>(
489    stream: &mut W,
490    message: &RunnerMessage,
491) -> Result<(), Box<dyn std::error::Error>> {
492    let framed = encode_frame(message)?;
493    stream.write_all(&framed).await?;
494    stream.flush().await?;
495    Ok(())
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use crate::registry::Registry;
502    use crate::telemetry::NoopTelemetry;
503    use chrono::Utc;
504    use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
505    use serde_json::json;
506    use tokio::net::{TcpListener, TcpStream};
507    use tokio::time::{Duration, timeout};
508
509    fn build_request(function_name: &str) -> ExecutionRequest {
510        ExecutionRequest {
511            protocol_version: PROTOCOL_VERSION.to_string(),
512            request_id: "req-1".to_string(),
513            job_id: "job-1".to_string(),
514            function_name: function_name.to_string(),
515            params: std::collections::HashMap::new(),
516            context: ExecutionContext {
517                job_id: "job-1".to_string(),
518                attempt: 1,
519                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
520                queue_name: "default".to_string(),
521                deadline: None,
522                trace_context: None,
523                correlation_context: None,
524                worker_id: None,
525            },
526        }
527    }
528
529    async fn tcp_pair() -> (TcpStream, TcpStream) {
530        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
531        let addr = listener.local_addr().unwrap();
532        let client = TcpStream::connect(addr).await.unwrap();
533        let (server, _) = listener.accept().await.unwrap();
534        (client, server)
535    }
536
537    #[tokio::test]
538    async fn handle_connection_executes_request() {
539        let mut registry = Registry::new();
540        registry.register("echo", |request| async move {
541            ExecutionOutcome::success(
542                request.job_id.clone(),
543                request.request_id.clone(),
544                json!({"ok": true}),
545            )
546        });
547        let (client, server) = tcp_pair().await;
548        let in_flight = Arc::new(Mutex::new(HashMap::new()));
549        let job_index = Arc::new(Mutex::new(HashMap::new()));
550        let server_task = tokio::spawn(async move {
551            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
552                .await
553                .unwrap();
554        });
555        let mut client = client;
556        let request = build_request("echo");
557        let message = RunnerMessage::Request { payload: request };
558        write_message(&mut client, &message).await.unwrap();
559        let response = read_message(&mut client).await.unwrap().unwrap();
560        match response {
561            RunnerMessage::Response { payload } => {
562                assert_eq!(payload.status, OutcomeStatus::Success);
563                assert_eq!(payload.result, Some(json!({"ok": true})));
564            }
565            _ => panic!("expected response"),
566        }
567        drop(client);
568        let _ = server_task.await;
569    }
570
571    #[tokio::test]
572    async fn handle_connection_rejects_bad_protocol() {
573        let registry = Registry::new();
574        let (client, server) = tcp_pair().await;
575        let in_flight = Arc::new(Mutex::new(HashMap::new()));
576        let job_index = Arc::new(Mutex::new(HashMap::new()));
577        let server_task = tokio::spawn(async move {
578            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
579                .await
580                .unwrap();
581        });
582        let mut client = client;
583        let mut request = build_request("echo");
584        request.protocol_version = "0".to_string();
585        let message = RunnerMessage::Request { payload: request };
586        write_message(&mut client, &message).await.unwrap();
587        let response = read_message(&mut client).await.unwrap().unwrap();
588        match response {
589            RunnerMessage::Response { payload } => {
590                assert_eq!(payload.status, OutcomeStatus::Error);
591            }
592            _ => panic!("expected response"),
593        }
594        drop(client);
595        let _ = server_task.await;
596    }
597
598    #[tokio::test]
599    async fn handle_connection_cancels_inflight() {
600        let mut registry = Registry::new();
601        registry.register("sleep", |request| async move {
602            tokio::time::sleep(Duration::from_millis(200)).await;
603            ExecutionOutcome::success(
604                request.job_id.clone(),
605                request.request_id.clone(),
606                json!({"ok": true}),
607            )
608        });
609        let (client, server) = tcp_pair().await;
610        let in_flight = Arc::new(Mutex::new(HashMap::new()));
611        let job_index = Arc::new(Mutex::new(HashMap::new()));
612        let server_task = tokio::spawn(async move {
613            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
614                .await
615                .unwrap();
616        });
617        let mut client = client;
618        let request = ExecutionRequest {
619            protocol_version: PROTOCOL_VERSION.to_string(),
620            request_id: "req-cancel".to_string(),
621            job_id: "job-cancel".to_string(),
622            function_name: "sleep".to_string(),
623            params: std::collections::HashMap::new(),
624            context: ExecutionContext {
625                job_id: "job-cancel".to_string(),
626                attempt: 1,
627                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
628                queue_name: "default".to_string(),
629                deadline: None,
630                trace_context: None,
631                correlation_context: None,
632                worker_id: None,
633            },
634        };
635        let message = RunnerMessage::Request {
636            payload: request.clone(),
637        };
638        write_message(&mut client, &message).await.unwrap();
639        let cancel = RunnerMessage::Cancel {
640            payload: CancelRequest {
641                protocol_version: PROTOCOL_VERSION.to_string(),
642                job_id: request.job_id.clone(),
643                request_id: Some(request.request_id.clone()),
644                hard_kill: false,
645            },
646        };
647        write_message(&mut client, &cancel).await.unwrap();
648        let response = read_message(&mut client).await.unwrap().unwrap();
649        match response {
650            RunnerMessage::Response { payload } => {
651                assert_eq!(payload.status, OutcomeStatus::Error);
652                let error_type = payload
653                    .error
654                    .as_ref()
655                    .and_then(|error| error.error_type.as_deref());
656                assert_eq!(error_type, Some("cancelled"));
657            }
658            _ => panic!("expected response"),
659        }
660        drop(client);
661        let _ = server_task.await;
662    }
663
664    #[tokio::test]
665    async fn cancel_frees_connection_capacity() {
666        let mut registry = Registry::new();
667        let gate = Arc::new(tokio::sync::Semaphore::new(0));
668        let gate_for_handler = gate.clone();
669        registry.register("block", move |request| {
670            let gate = gate_for_handler.clone();
671            async move {
672                let _permit = gate.acquire().await.expect("semaphore closed");
673                ExecutionOutcome::success(
674                    request.job_id.clone(),
675                    request.request_id.clone(),
676                    json!({"ok": true}),
677                )
678            }
679        });
680        let (client, server) = tcp_pair().await;
681        let in_flight = Arc::new(Mutex::new(HashMap::new()));
682        let job_index = Arc::new(Mutex::new(HashMap::new()));
683        let server_task = tokio::spawn(async move {
684            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
685                .await
686                .unwrap();
687        });
688        let mut client = client;
689        let job_id = "job-capacity".to_string();
690        for i in 0..RESPONSE_CHANNEL_CAPACITY {
691            let mut request = build_request("block");
692            request.request_id = format!("req-{i}");
693            request.job_id = job_id.clone();
694            request.context.job_id = job_id.clone();
695            write_message(&mut client, &RunnerMessage::Request { payload: request })
696                .await
697                .unwrap();
698        }
699
700        let cancel = RunnerMessage::Cancel {
701            payload: CancelRequest {
702                protocol_version: PROTOCOL_VERSION.to_string(),
703                job_id: job_id.clone(),
704                request_id: Some("req-0".to_string()),
705                hard_kill: false,
706            },
707        };
708        write_message(&mut client, &cancel).await.unwrap();
709        let response = timeout(Duration::from_secs(1), read_message(&mut client))
710            .await
711            .unwrap()
712            .unwrap()
713            .unwrap();
714        match response {
715            RunnerMessage::Response { payload } => {
716                assert_eq!(payload.status, OutcomeStatus::Error);
717                let error_type = payload
718                    .error
719                    .as_ref()
720                    .and_then(|error| error.error_type.as_deref());
721                assert_eq!(error_type, Some("cancelled"));
722            }
723            _ => panic!("expected response"),
724        }
725
726        let mut extra_request = build_request("block");
727        extra_request.request_id = "req-extra".to_string();
728        extra_request.job_id = job_id.clone();
729        extra_request.context.job_id = job_id.clone();
730        write_message(
731            &mut client,
732            &RunnerMessage::Request {
733                payload: extra_request,
734            },
735        )
736        .await
737        .unwrap();
738
739        gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
740
741        let mut saw_extra = false;
742        for _ in 0..RESPONSE_CHANNEL_CAPACITY {
743            let response = timeout(Duration::from_secs(1), read_message(&mut client))
744                .await
745                .unwrap()
746                .unwrap()
747                .unwrap();
748            if let RunnerMessage::Response { payload } = response
749                && payload.request_id.as_deref() == Some("req-extra")
750            {
751                assert_eq!(payload.status, OutcomeStatus::Success);
752                saw_extra = true;
753            }
754        }
755        assert!(saw_extra, "extra request never completed");
756
757        drop(client);
758        let _ = server_task.await;
759    }
760
761    #[tokio::test]
762    async fn execute_with_deadline_times_out() {
763        let mut registry = Registry::new();
764        registry.register("echo", |request| async move {
765            ExecutionOutcome::success(
766                request.job_id.clone(),
767                request.request_id.clone(),
768                json!({"ok": true}),
769            )
770        });
771        let mut request = build_request("echo");
772        request.context.deadline = Some(
773            "2020-01-01T00:00:00Z"
774                .parse::<chrono::DateTime<Utc>>()
775                .unwrap(),
776        );
777        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
778        assert_eq!(outcome.status, OutcomeStatus::Timeout);
779    }
780
781    #[tokio::test]
782    async fn execute_with_deadline_succeeds_before_deadline() {
783        let mut registry = Registry::new();
784        registry.register("echo", |request| async move {
785            ExecutionOutcome::success(
786                request.job_id.clone(),
787                request.request_id.clone(),
788                json!({"ok": true}),
789            )
790        });
791        let mut request = build_request("echo");
792        request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
793        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
794        assert_eq!(outcome.status, OutcomeStatus::Success);
795    }
796
797    #[tokio::test]
798    async fn handle_connection_handles_unexpected_response_message() {
799        let registry = Registry::new();
800        let (client, server) = tcp_pair().await;
801        let in_flight = Arc::new(Mutex::new(HashMap::new()));
802        let job_index = Arc::new(Mutex::new(HashMap::new()));
803        let server_task = tokio::spawn(async move {
804            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
805                .await
806                .unwrap();
807        });
808        let mut client = client;
809        let response = RunnerMessage::Response {
810            payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
811        };
812        write_message(&mut client, &response).await.unwrap();
813        let reply = read_message(&mut client).await.unwrap().unwrap();
814        match reply {
815            RunnerMessage::Response { payload } => {
816                assert_eq!(payload.status, OutcomeStatus::Error);
817                assert!(
818                    payload
819                        .error
820                        .as_ref()
821                        .unwrap()
822                        .message
823                        .contains("unexpected response")
824                );
825            }
826            _ => panic!("expected response"),
827        }
828        drop(client);
829        let _ = server_task.await;
830    }
831
832    #[tokio::test]
833    async fn handle_connection_cancels_by_job_id() {
834        let mut registry = Registry::new();
835        registry.register("sleep", |request| async move {
836            tokio::time::sleep(Duration::from_millis(200)).await;
837            ExecutionOutcome::success(
838                request.job_id.clone(),
839                request.request_id.clone(),
840                json!({"ok": true}),
841            )
842        });
843        let (client, server) = tcp_pair().await;
844        let in_flight = Arc::new(Mutex::new(HashMap::new()));
845        let job_index = Arc::new(Mutex::new(HashMap::new()));
846        let server_task = tokio::spawn(async move {
847            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
848                .await
849                .unwrap();
850        });
851        let mut client = client;
852        let request = build_request("sleep");
853        let message = RunnerMessage::Request {
854            payload: request.clone(),
855        };
856        write_message(&mut client, &message).await.unwrap();
857        let cancel = RunnerMessage::Cancel {
858            payload: CancelRequest {
859                protocol_version: PROTOCOL_VERSION.to_string(),
860                job_id: request.job_id.clone(),
861                request_id: None,
862                hard_kill: false,
863            },
864        };
865        write_message(&mut client, &cancel).await.unwrap();
866        let response = read_message(&mut client).await.unwrap().unwrap();
867        match response {
868            RunnerMessage::Response { payload } => {
869                assert_eq!(payload.status, OutcomeStatus::Error);
870                let error_type = payload
871                    .error
872                    .as_ref()
873                    .and_then(|error| error.error_type.as_deref());
874                assert_eq!(error_type, Some("cancelled"));
875            }
876            _ => panic!("expected response"),
877        }
878        drop(client);
879        let _ = server_task.await;
880    }
881
882    #[tokio::test]
883    async fn handle_cancel_by_job_id_cancels_all_requests() {
884        let mut registry = Registry::new();
885        registry.register("sleep", |request| async move {
886            tokio::time::sleep(Duration::from_millis(200)).await;
887            ExecutionOutcome::success(
888                request.job_id.clone(),
889                request.request_id.clone(),
890                json!({"ok": true}),
891            )
892        });
893        let (client, server) = tcp_pair().await;
894        let in_flight = Arc::new(Mutex::new(HashMap::new()));
895        let job_index = Arc::new(Mutex::new(HashMap::new()));
896        let server_task = tokio::spawn(async move {
897            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
898                .await
899                .unwrap();
900        });
901        let mut client = client;
902        let mut request1 = build_request("sleep");
903        request1.request_id = "req-1".to_string();
904        request1.job_id = "job-shared".to_string();
905        let mut request2 = build_request("sleep");
906        request2.request_id = "req-2".to_string();
907        request2.job_id = "job-shared".to_string();
908        write_message(&mut client, &RunnerMessage::Request { payload: request1 })
909            .await
910            .unwrap();
911        write_message(&mut client, &RunnerMessage::Request { payload: request2 })
912            .await
913            .unwrap();
914        let cancel = RunnerMessage::Cancel {
915            payload: CancelRequest {
916                protocol_version: PROTOCOL_VERSION.to_string(),
917                job_id: "job-shared".to_string(),
918                request_id: None,
919                hard_kill: false,
920            },
921        };
922        write_message(&mut client, &cancel).await.unwrap();
923
924        let mut cancelled = 0;
925        for _ in 0..2 {
926            let response = timeout(Duration::from_millis(200), read_message(&mut client))
927                .await
928                .unwrap()
929                .unwrap()
930                .unwrap();
931            match response {
932                RunnerMessage::Response { payload } => {
933                    assert_eq!(payload.status, OutcomeStatus::Error);
934                    let error_type = payload
935                        .error
936                        .as_ref()
937                        .and_then(|error| error.error_type.as_deref());
938                    assert_eq!(error_type, Some("cancelled"));
939                    cancelled += 1;
940                }
941                _ => panic!("expected response"),
942            }
943        }
944        assert_eq!(cancelled, 2);
945        drop(client);
946        let _ = server_task.await;
947    }
948
949    #[tokio::test]
950    async fn connection_teardown_clears_tracking_maps() {
951        let mut registry = Registry::new();
952        registry.register("sleep", |request| async move {
953            tokio::time::sleep(Duration::from_millis(200)).await;
954            ExecutionOutcome::success(
955                request.job_id.clone(),
956                request.request_id.clone(),
957                json!({"ok": true}),
958            )
959        });
960        let (client, server) = tcp_pair().await;
961        let in_flight = Arc::new(Mutex::new(HashMap::new()));
962        let job_index = Arc::new(Mutex::new(HashMap::new()));
963        let in_flight_for_server = in_flight.clone();
964        let job_index_for_server = job_index.clone();
965        let server_task = tokio::spawn(async move {
966            handle_connection(
967                server,
968                &registry,
969                &NoopTelemetry,
970                in_flight_for_server,
971                job_index_for_server,
972            )
973            .await
974            .unwrap();
975        });
976        let mut client = client;
977        let request = build_request("sleep");
978        let message = RunnerMessage::Request {
979            payload: request.clone(),
980        };
981        write_message(&mut client, &message).await.unwrap();
982
983        let mut inserted = false;
984        for _ in 0..20 {
985            let has_in_flight = {
986                let guard = in_flight.lock().await;
987                guard.contains_key(&request.request_id)
988            };
989            let has_job_index = {
990                let guard = job_index.lock().await;
991                guard.contains_key(&request.job_id)
992            };
993            if has_in_flight && has_job_index {
994                inserted = true;
995                break;
996            }
997            tokio::time::sleep(Duration::from_millis(10)).await;
998        }
999        assert!(inserted, "request never entered tracking maps");
1000
1001        drop(client);
1002        let _ = server_task.await;
1003
1004        let in_flight = in_flight.lock().await;
1005        let job_index = job_index.lock().await;
1006        assert!(in_flight.is_empty());
1007        assert!(job_index.is_empty());
1008    }
1009
1010    #[tokio::test]
1011    async fn handle_cancel_ignores_invalid_protocol() {
1012        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1013        let job_index = Arc::new(Mutex::new(HashMap::new()));
1014        let (tx, _rx) = mpsc::channel(1);
1015        let handle = tokio::spawn(async {});
1016        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1017        {
1018            let mut guard = in_flight.lock().await;
1019            guard.insert(
1020                "req-1".to_string(),
1021                InFlightTask {
1022                    job_id: "job-1".to_string(),
1023                    handle,
1024                    response_tx: tx,
1025                    connection_requests,
1026                    completed: Arc::new(AtomicBool::new(false)),
1027                },
1028            );
1029        }
1030        let payload = CancelRequest {
1031            protocol_version: "0".to_string(),
1032            job_id: "job-1".to_string(),
1033            request_id: None,
1034            hard_kill: false,
1035        };
1036        handle_cancel(payload, &in_flight, &job_index).await;
1037        let guard = in_flight.lock().await;
1038        assert!(guard.contains_key("req-1"));
1039        guard.get("req-1").unwrap().handle.abort();
1040    }
1041
1042    #[tokio::test]
1043    async fn handle_cancel_skips_completed_requests() {
1044        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1045        let job_index = Arc::new(Mutex::new(HashMap::new()));
1046        let (tx, mut rx) = mpsc::channel(1);
1047        let handle = tokio::spawn(async {
1048            tokio::time::sleep(Duration::from_millis(50)).await;
1049        });
1050        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1051        {
1052            let mut guard = in_flight.lock().await;
1053            guard.insert(
1054                "req-1".to_string(),
1055                InFlightTask {
1056                    job_id: "job-1".to_string(),
1057                    handle,
1058                    response_tx: tx,
1059                    connection_requests,
1060                    completed: Arc::new(AtomicBool::new(true)),
1061                },
1062            );
1063        }
1064        {
1065            let mut guard = job_index.lock().await;
1066            guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1067        }
1068        let payload = CancelRequest {
1069            protocol_version: PROTOCOL_VERSION.to_string(),
1070            job_id: "job-1".to_string(),
1071            request_id: Some("req-1".to_string()),
1072            hard_kill: false,
1073        };
1074        handle_cancel(payload, &in_flight, &job_index).await;
1075        assert!(in_flight.lock().await.contains_key("req-1"));
1076        assert!(job_index.lock().await.contains_key("job-1"));
1077        assert!(rx.try_recv().is_err());
1078        let task = in_flight.lock().await.remove("req-1").unwrap();
1079        task.handle.abort();
1080    }
1081
1082    #[tokio::test]
1083    async fn read_message_handles_empty_and_invalid_payloads() {
1084        let (mut client, mut server) = tokio::io::duplex(64);
1085        // length = 0
1086        client.write_all(&0u32.to_be_bytes()).await.unwrap();
1087        let err = read_message(&mut server).await.unwrap_err();
1088        assert!(err.to_string().contains("payload cannot be empty"));
1089
1090        // invalid json
1091        let (mut client, mut server) = tokio::io::duplex(64);
1092        let payload = b"not-json";
1093        let len = (payload.len() as u32).to_be_bytes();
1094        client.write_all(&len).await.unwrap();
1095        client.write_all(payload).await.unwrap();
1096        let err = read_message(&mut server).await.unwrap_err();
1097        assert!(err.to_string().contains("expected"));
1098
1099        // oversized payload
1100        let (mut client, mut server) = tokio::io::duplex(64);
1101        let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1102        client.write_all(&len).await.unwrap();
1103        let err = read_message(&mut server).await.unwrap_err();
1104        assert!(err.to_string().contains("payload too large"));
1105    }
1106
1107    #[tokio::test]
1108    async fn read_message_returns_none_on_eof() {
1109        let (client, mut server) = tokio::io::duplex(8);
1110        drop(client);
1111        let message = read_message(&mut server).await.unwrap();
1112        assert!(message.is_none());
1113    }
1114}