Skip to main content

rrq_runner/
runtime.rs

1use crate::registry::Registry;
2use crate::telemetry::{NoopTelemetry, Telemetry};
3use crate::types::{ExecutionError, ExecutionOutcome};
4use chrono::{DateTime, Utc};
5use rrq_protocol::{CancelRequest, OutcomeStatus, PROTOCOL_VERSION, RunnerMessage, encode_frame};
6use std::collections::{HashMap, HashSet};
7use std::net::{IpAddr, Ipv4Addr, SocketAddr};
8use std::sync::{
9    Arc,
10    atomic::{AtomicBool, Ordering},
11};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::TcpListener;
14use tokio::sync::{Mutex, mpsc};
15use tokio::time::{Duration, timeout};
16const MAX_FRAME_LEN: usize = 16 * 1024 * 1024;
17const RESPONSE_CHANNEL_CAPACITY: usize = 64;
18const RESPONSE_SEND_TIMEOUT: Duration = Duration::from_secs(1);
19
20fn invalid_input(message: impl Into<String>) -> Box<dyn std::error::Error> {
21    Box::new(std::io::Error::new(
22        std::io::ErrorKind::InvalidInput,
23        message.into(),
24    ))
25}
26
27pub fn parse_tcp_socket(raw: &str) -> Result<SocketAddr, Box<dyn std::error::Error>> {
28    let raw = raw.trim();
29    if raw.is_empty() {
30        return Err(invalid_input("runner tcp_socket cannot be empty"));
31    }
32
33    let (host, port_str) = if let Some(rest) = raw.strip_prefix('[') {
34        let (host, port_str) = rest
35            .split_once("]:")
36            .ok_or_else(|| invalid_input("runner tcp_socket must be in [host]:port format"))?;
37        (host, port_str)
38    } else {
39        let (host, port_str) = raw
40            .rsplit_once(':')
41            .ok_or_else(|| invalid_input("runner tcp_socket must be in host:port format"))?;
42        if host.is_empty() {
43            return Err(invalid_input("runner tcp_socket host cannot be empty"));
44        }
45        (host, port_str)
46    };
47
48    let port: u16 = port_str
49        .parse()
50        .map_err(|_| invalid_input(format!("Invalid runner tcp_socket port: {port_str}")))?;
51    if port == 0 {
52        return Err(invalid_input("runner tcp_socket port must be > 0"));
53    }
54
55    let ip = if host == "localhost" {
56        IpAddr::V4(Ipv4Addr::LOCALHOST)
57    } else {
58        let parsed: IpAddr = host
59            .parse()
60            .map_err(|_| invalid_input(format!("Invalid runner tcp_socket host: {host}")))?;
61        if !parsed.is_loopback() {
62            return Err(invalid_input("runner tcp_socket host must be localhost"));
63        }
64        parsed
65    };
66
67    Ok(SocketAddr::new(ip, port))
68}
69
70pub struct RunnerRuntime {
71    runtime: tokio::runtime::Runtime,
72}
73
74impl RunnerRuntime {
75    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
76        Ok(Self {
77            runtime: tokio::runtime::Runtime::new()?,
78        })
79    }
80
81    pub fn enter(&self) -> tokio::runtime::EnterGuard<'_> {
82        self.runtime.enter()
83    }
84
85    pub fn run_tcp(
86        &self,
87        registry: &Registry,
88        addr: SocketAddr,
89    ) -> Result<(), Box<dyn std::error::Error>> {
90        let telemetry = NoopTelemetry;
91        self.run_tcp_with(registry, addr, &telemetry)
92    }
93
94    pub fn run_tcp_with<T: Telemetry + ?Sized>(
95        &self,
96        registry: &Registry,
97        addr: SocketAddr,
98        telemetry: &T,
99    ) -> Result<(), Box<dyn std::error::Error>> {
100        run_tcp_loop(&self.runtime, registry, addr, telemetry)
101    }
102}
103
104pub fn run_tcp(registry: &Registry, addr: SocketAddr) -> Result<(), Box<dyn std::error::Error>> {
105    RunnerRuntime::new()?.run_tcp(registry, addr)
106}
107
108pub fn run_tcp_with<T: Telemetry + ?Sized>(
109    registry: &Registry,
110    addr: SocketAddr,
111    telemetry: &T,
112) -> Result<(), Box<dyn std::error::Error>> {
113    RunnerRuntime::new()?.run_tcp_with(registry, addr, telemetry)
114}
115
116fn run_tcp_loop<T: Telemetry + ?Sized>(
117    runtime: &tokio::runtime::Runtime,
118    registry: &Registry,
119    addr: SocketAddr,
120    telemetry: &T,
121) -> Result<(), Box<dyn std::error::Error>> {
122    let registry = registry.clone();
123    let in_flight: Arc<Mutex<HashMap<String, InFlightTask>>> = Arc::new(Mutex::new(HashMap::new()));
124    let job_index: Arc<Mutex<HashMap<String, HashSet<String>>>> =
125        Arc::new(Mutex::new(HashMap::new()));
126    let telemetry = telemetry.clone_box();
127    runtime.block_on(async move {
128        if !addr.ip().is_loopback() {
129            return Err(invalid_input(format!(
130                "runner tcp_socket must be loopback-only (got {addr})"
131            )));
132        }
133        let listener = TcpListener::bind(addr).await?;
134        loop {
135            let (stream, _) = listener.accept().await?;
136            let registry = registry.clone();
137            let telemetry = telemetry.clone();
138            let in_flight = in_flight.clone();
139            let job_index = job_index.clone();
140            tokio::spawn(async move {
141                if let Err(err) =
142                    handle_connection(stream, &registry, telemetry.as_ref(), in_flight, job_index)
143                        .await
144                {
145                    tracing::error!("runner connection error: {err}");
146                }
147            });
148        }
149    })
150}
151
152async fn handle_connection<S, T>(
153    stream: S,
154    registry: &Registry,
155    telemetry: &T,
156    in_flight: Arc<Mutex<HashMap<String, InFlightTask>>>,
157    job_index: Arc<Mutex<HashMap<String, HashSet<String>>>>,
158) -> Result<(), Box<dyn std::error::Error>>
159where
160    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
161    T: Telemetry + ?Sized,
162{
163    let (mut reader, mut writer) = tokio::io::split(stream);
164    let (response_tx, mut response_rx) =
165        mpsc::channel::<ExecutionOutcome>(RESPONSE_CHANNEL_CAPACITY);
166    let writer_task = tokio::spawn(async move {
167        while let Some(outcome) = response_rx.recv().await {
168            let response = RunnerMessage::Response { payload: outcome };
169            if write_message(&mut writer, &response).await.is_err() {
170                break;
171            }
172        }
173    });
174    let connection_requests: Arc<Mutex<std::collections::HashSet<String>>> =
175        Arc::new(Mutex::new(std::collections::HashSet::new()));
176
177    loop {
178        let message = match read_message(&mut reader).await? {
179            Some(message) => message,
180            None => break,
181        };
182        match message {
183            RunnerMessage::Request { payload } => {
184                if payload.protocol_version != PROTOCOL_VERSION {
185                    let outcome = ExecutionOutcome::error(
186                        payload.job_id.clone(),
187                        payload.request_id.clone(),
188                        "Unsupported protocol version",
189                    );
190                    let _ = response_tx.send(outcome).await;
191                    continue;
192                }
193
194                let request_id = payload.request_id.clone();
195                let job_id = payload.job_id.clone();
196                {
197                    let mut active = connection_requests.lock().await;
198                    if active.len() >= RESPONSE_CHANNEL_CAPACITY {
199                        let outcome = ExecutionOutcome::error(
200                            payload.job_id.clone(),
201                            payload.request_id.clone(),
202                            "Runner busy: too many in-flight requests",
203                        );
204                        drop(active);
205                        let send_result =
206                            timeout(RESPONSE_SEND_TIMEOUT, response_tx.send(outcome)).await;
207                        match send_result {
208                            Ok(Ok(())) => {}
209                            Ok(Err(_)) => {
210                                return Err("runner response channel closed".into());
211                            }
212                            Err(_) => {
213                                return Err("runner response channel stalled".into());
214                            }
215                        }
216                        continue;
217                    }
218                    active.insert(request_id.clone());
219                }
220                let response_tx = response_tx.clone();
221                let registry = registry.clone();
222                let telemetry = telemetry.clone_box();
223                let in_flight_for_task = in_flight.clone();
224                let job_index_for_task = job_index.clone();
225                let active_for_task = connection_requests.clone();
226                let request_id_for_task = request_id.clone();
227                let job_id_for_task = job_id.clone();
228                let response_tx_for_task = response_tx.clone();
229                let completed = Arc::new(AtomicBool::new(false));
230                let completed_for_task = completed.clone();
231
232                let handle = tokio::spawn(async move {
233                    let outcome =
234                        execute_with_deadline(payload, registry, telemetry.as_ref()).await;
235                    completed_for_task.store(true, Ordering::SeqCst);
236                    let send_result =
237                        timeout(RESPONSE_SEND_TIMEOUT, response_tx_for_task.send(outcome)).await;
238                    match send_result {
239                        Ok(Ok(())) => {}
240                        Ok(Err(_)) => {
241                            tracing::warn!("runner response channel closed; dropping outcome");
242                        }
243                        Err(_) => {
244                            tracing::warn!("runner response channel stalled; dropping outcome");
245                        }
246                    }
247                    {
248                        let mut in_flight = in_flight_for_task.lock().await;
249                        in_flight.remove(&request_id_for_task);
250                    }
251                    {
252                        let mut job_index = job_index_for_task.lock().await;
253                        if let Some(entries) = job_index.get_mut(&job_id_for_task) {
254                            entries.remove(&request_id_for_task);
255                            if entries.is_empty() {
256                                job_index.remove(&job_id_for_task);
257                            }
258                        }
259                    }
260                    {
261                        let mut active = active_for_task.lock().await;
262                        active.remove(&request_id_for_task);
263                    }
264                });
265
266                {
267                    let mut in_flight = in_flight.lock().await;
268                    in_flight.insert(
269                        request_id.clone(),
270                        InFlightTask {
271                            job_id: job_id.clone(),
272                            handle,
273                            response_tx: response_tx.clone(),
274                            connection_requests: connection_requests.clone(),
275                            completed,
276                        },
277                    );
278                }
279                {
280                    let mut job_index = job_index.lock().await;
281                    job_index
282                        .entry(job_id)
283                        .or_insert_with(HashSet::new)
284                        .insert(request_id);
285                }
286            }
287            RunnerMessage::Cancel { payload } => {
288                handle_cancel(payload, &in_flight, &job_index).await;
289            }
290            RunnerMessage::Response { .. } => {
291                let outcome = ExecutionOutcome {
292                    job_id: Some("unknown".to_string()),
293                    request_id: None,
294                    status: rrq_protocol::OutcomeStatus::Error,
295                    result: None,
296                    error: Some(ExecutionError {
297                        message: "unexpected response message".to_string(),
298                        error_type: None,
299                        code: None,
300                        details: None,
301                    }),
302                    retry_after_seconds: None,
303                };
304                let _ = response_tx.send(outcome).await;
305            }
306        }
307    }
308
309    let request_ids = {
310        let mut active = connection_requests.lock().await;
311        active.drain().collect::<Vec<_>>()
312    };
313    for request_id in request_ids {
314        let task = {
315            let mut in_flight = in_flight.lock().await;
316            in_flight.remove(&request_id)
317        };
318        if let Some(task) = task {
319            task.handle.abort();
320            let mut job_index = job_index.lock().await;
321            if let Some(entries) = job_index.get_mut(&task.job_id) {
322                entries.remove(&request_id);
323                if entries.is_empty() {
324                    job_index.remove(&task.job_id);
325                }
326            }
327        }
328    }
329    writer_task.abort();
330
331    Ok(())
332}
333
334struct InFlightTask {
335    job_id: String,
336    handle: tokio::task::JoinHandle<()>,
337    response_tx: mpsc::Sender<ExecutionOutcome>,
338    connection_requests: Arc<Mutex<HashSet<String>>>,
339    completed: Arc<AtomicBool>,
340}
341
342async fn handle_cancel(
343    payload: CancelRequest,
344    in_flight: &Arc<Mutex<HashMap<String, InFlightTask>>>,
345    job_index: &Arc<Mutex<HashMap<String, HashSet<String>>>>,
346) {
347    if payload.protocol_version != PROTOCOL_VERSION {
348        return;
349    }
350    let request_ids = if let Some(request_id) = payload.request_id.clone() {
351        vec![request_id]
352    } else {
353        let job_index = job_index.lock().await;
354        job_index
355            .get(&payload.job_id)
356            .map(|ids| ids.iter().cloned().collect())
357            .unwrap_or_else(Vec::new)
358    };
359    if request_ids.is_empty() {
360        return;
361    }
362
363    for request_id in request_ids {
364        let task = {
365            let mut in_flight = in_flight.lock().await;
366            if let Some(task) = in_flight.get(&request_id)
367                && task.completed.load(Ordering::SeqCst)
368            {
369                None
370            } else {
371                in_flight.remove(&request_id)
372            }
373        };
374        if let Some(task) = task {
375            task.handle.abort();
376            {
377                let mut active = task.connection_requests.lock().await;
378                active.remove(&request_id);
379            }
380            let outcome = ExecutionOutcome {
381                job_id: Some(payload.job_id.clone()),
382                request_id: Some(request_id.clone()),
383                status: OutcomeStatus::Error,
384                result: None,
385                error: Some(ExecutionError {
386                    message: "Job cancelled".to_string(),
387                    error_type: Some("cancelled".to_string()),
388                    code: None,
389                    details: None,
390                }),
391                retry_after_seconds: None,
392            };
393            let send_result = timeout(RESPONSE_SEND_TIMEOUT, task.response_tx.send(outcome)).await;
394            match send_result {
395                Ok(Ok(())) => {}
396                Ok(Err(_)) => {
397                    tracing::warn!("runner response channel closed; dropping cancel outcome");
398                }
399                Err(_) => {
400                    tracing::warn!("runner response channel stalled; dropping cancel outcome");
401                }
402            }
403            let mut job_index = job_index.lock().await;
404            if let Some(entries) = job_index.get_mut(&task.job_id) {
405                entries.remove(&request_id);
406                if entries.is_empty() {
407                    job_index.remove(&task.job_id);
408                }
409            }
410        }
411    }
412}
413
414async fn execute_with_deadline<T: Telemetry + ?Sized>(
415    request: rrq_protocol::ExecutionRequest,
416    registry: Registry,
417    telemetry: &T,
418) -> ExecutionOutcome {
419    let job_id = request.job_id.clone();
420    let request_id = request.request_id.clone();
421    let deadline = request.context.deadline;
422    if let Some(deadline) = deadline {
423        let now: DateTime<Utc> = Utc::now();
424        if deadline <= now {
425            return ExecutionOutcome::timeout(
426                job_id.clone(),
427                request_id.clone(),
428                "Job deadline exceeded",
429            );
430        }
431        if let Ok(remaining) = (deadline - now).to_std() {
432            match tokio::time::timeout(remaining, registry.execute_with(request, telemetry)).await {
433                Ok(outcome) => return outcome,
434                Err(_) => {
435                    return ExecutionOutcome::timeout(
436                        job_id.clone(),
437                        request_id.clone(),
438                        "Job execution timed out",
439                    );
440                }
441            }
442        }
443        return ExecutionOutcome::timeout(job_id, request_id, "Job deadline exceeded");
444    }
445    registry.execute_with(request, telemetry).await
446}
447
448async fn read_message<R: AsyncRead + Unpin>(
449    stream: &mut R,
450) -> Result<Option<RunnerMessage>, Box<dyn std::error::Error>> {
451    let mut header = [0u8; 4];
452    match stream.read_exact(&mut header).await {
453        Ok(_) => {}
454        Err(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
455        Err(err) => return Err(err.into()),
456    }
457    let length = u32::from_be_bytes(header) as usize;
458    if length == 0 {
459        return Err("runner message payload cannot be empty".into());
460    }
461    if length > MAX_FRAME_LEN {
462        return Err("runner message payload too large".into());
463    }
464    let mut payload = vec![0u8; length];
465    stream.read_exact(&mut payload).await?;
466    let message = serde_json::from_slice(&payload)?;
467    Ok(Some(message))
468}
469
470async fn write_message<W: AsyncWrite + Unpin>(
471    stream: &mut W,
472    message: &RunnerMessage,
473) -> Result<(), Box<dyn std::error::Error>> {
474    let framed = encode_frame(message)?;
475    stream.write_all(&framed).await?;
476    stream.flush().await?;
477    Ok(())
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::registry::Registry;
484    use crate::telemetry::NoopTelemetry;
485    use chrono::Utc;
486    use rrq_protocol::{ExecutionContext, ExecutionRequest, OutcomeStatus};
487    use serde_json::json;
488    use tokio::net::{TcpListener, TcpStream};
489    use tokio::time::{Duration, timeout};
490
491    fn build_request(function_name: &str) -> ExecutionRequest {
492        ExecutionRequest {
493            protocol_version: PROTOCOL_VERSION.to_string(),
494            request_id: "req-1".to_string(),
495            job_id: "job-1".to_string(),
496            function_name: function_name.to_string(),
497            params: std::collections::HashMap::new(),
498            context: ExecutionContext {
499                job_id: "job-1".to_string(),
500                attempt: 1,
501                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
502                queue_name: "default".to_string(),
503                deadline: None,
504                trace_context: None,
505                worker_id: None,
506            },
507        }
508    }
509
510    async fn tcp_pair() -> (TcpStream, TcpStream) {
511        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
512        let addr = listener.local_addr().unwrap();
513        let client = TcpStream::connect(addr).await.unwrap();
514        let (server, _) = listener.accept().await.unwrap();
515        (client, server)
516    }
517
518    #[tokio::test]
519    async fn handle_connection_executes_request() {
520        let mut registry = Registry::new();
521        registry.register("echo", |request| async move {
522            ExecutionOutcome::success(
523                request.job_id.clone(),
524                request.request_id.clone(),
525                json!({"ok": true}),
526            )
527        });
528        let (client, server) = tcp_pair().await;
529        let in_flight = Arc::new(Mutex::new(HashMap::new()));
530        let job_index = Arc::new(Mutex::new(HashMap::new()));
531        let server_task = tokio::spawn(async move {
532            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
533                .await
534                .unwrap();
535        });
536        let mut client = client;
537        let request = build_request("echo");
538        let message = RunnerMessage::Request { payload: request };
539        write_message(&mut client, &message).await.unwrap();
540        let response = read_message(&mut client).await.unwrap().unwrap();
541        match response {
542            RunnerMessage::Response { payload } => {
543                assert_eq!(payload.status, OutcomeStatus::Success);
544                assert_eq!(payload.result, Some(json!({"ok": true})));
545            }
546            _ => panic!("expected response"),
547        }
548        drop(client);
549        let _ = server_task.await;
550    }
551
552    #[tokio::test]
553    async fn handle_connection_rejects_bad_protocol() {
554        let registry = Registry::new();
555        let (client, server) = tcp_pair().await;
556        let in_flight = Arc::new(Mutex::new(HashMap::new()));
557        let job_index = Arc::new(Mutex::new(HashMap::new()));
558        let server_task = tokio::spawn(async move {
559            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
560                .await
561                .unwrap();
562        });
563        let mut client = client;
564        let mut request = build_request("echo");
565        request.protocol_version = "0".to_string();
566        let message = RunnerMessage::Request { payload: request };
567        write_message(&mut client, &message).await.unwrap();
568        let response = read_message(&mut client).await.unwrap().unwrap();
569        match response {
570            RunnerMessage::Response { payload } => {
571                assert_eq!(payload.status, OutcomeStatus::Error);
572            }
573            _ => panic!("expected response"),
574        }
575        drop(client);
576        let _ = server_task.await;
577    }
578
579    #[tokio::test]
580    async fn handle_connection_cancels_inflight() {
581        let mut registry = Registry::new();
582        registry.register("sleep", |request| async move {
583            tokio::time::sleep(Duration::from_millis(200)).await;
584            ExecutionOutcome::success(
585                request.job_id.clone(),
586                request.request_id.clone(),
587                json!({"ok": true}),
588            )
589        });
590        let (client, server) = tcp_pair().await;
591        let in_flight = Arc::new(Mutex::new(HashMap::new()));
592        let job_index = Arc::new(Mutex::new(HashMap::new()));
593        let server_task = tokio::spawn(async move {
594            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
595                .await
596                .unwrap();
597        });
598        let mut client = client;
599        let request = ExecutionRequest {
600            protocol_version: PROTOCOL_VERSION.to_string(),
601            request_id: "req-cancel".to_string(),
602            job_id: "job-cancel".to_string(),
603            function_name: "sleep".to_string(),
604            params: std::collections::HashMap::new(),
605            context: ExecutionContext {
606                job_id: "job-cancel".to_string(),
607                attempt: 1,
608                enqueue_time: "2024-01-01T00:00:00Z".parse().unwrap(),
609                queue_name: "default".to_string(),
610                deadline: None,
611                trace_context: None,
612                worker_id: None,
613            },
614        };
615        let message = RunnerMessage::Request {
616            payload: request.clone(),
617        };
618        write_message(&mut client, &message).await.unwrap();
619        let cancel = RunnerMessage::Cancel {
620            payload: CancelRequest {
621                protocol_version: PROTOCOL_VERSION.to_string(),
622                job_id: request.job_id.clone(),
623                request_id: Some(request.request_id.clone()),
624                hard_kill: false,
625            },
626        };
627        write_message(&mut client, &cancel).await.unwrap();
628        let response = read_message(&mut client).await.unwrap().unwrap();
629        match response {
630            RunnerMessage::Response { payload } => {
631                assert_eq!(payload.status, OutcomeStatus::Error);
632                let error_type = payload
633                    .error
634                    .as_ref()
635                    .and_then(|error| error.error_type.as_deref());
636                assert_eq!(error_type, Some("cancelled"));
637            }
638            _ => panic!("expected response"),
639        }
640        drop(client);
641        let _ = server_task.await;
642    }
643
644    #[tokio::test]
645    async fn cancel_frees_connection_capacity() {
646        let mut registry = Registry::new();
647        let gate = Arc::new(tokio::sync::Semaphore::new(0));
648        let gate_for_handler = gate.clone();
649        registry.register("block", move |request| {
650            let gate = gate_for_handler.clone();
651            async move {
652                let _permit = gate.acquire().await.expect("semaphore closed");
653                ExecutionOutcome::success(
654                    request.job_id.clone(),
655                    request.request_id.clone(),
656                    json!({"ok": true}),
657                )
658            }
659        });
660        let (client, server) = tcp_pair().await;
661        let in_flight = Arc::new(Mutex::new(HashMap::new()));
662        let job_index = Arc::new(Mutex::new(HashMap::new()));
663        let server_task = tokio::spawn(async move {
664            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
665                .await
666                .unwrap();
667        });
668        let mut client = client;
669        let job_id = "job-capacity".to_string();
670        for i in 0..RESPONSE_CHANNEL_CAPACITY {
671            let mut request = build_request("block");
672            request.request_id = format!("req-{i}");
673            request.job_id = job_id.clone();
674            request.context.job_id = job_id.clone();
675            write_message(&mut client, &RunnerMessage::Request { payload: request })
676                .await
677                .unwrap();
678        }
679
680        let cancel = RunnerMessage::Cancel {
681            payload: CancelRequest {
682                protocol_version: PROTOCOL_VERSION.to_string(),
683                job_id: job_id.clone(),
684                request_id: Some("req-0".to_string()),
685                hard_kill: false,
686            },
687        };
688        write_message(&mut client, &cancel).await.unwrap();
689        let response = timeout(Duration::from_secs(1), read_message(&mut client))
690            .await
691            .unwrap()
692            .unwrap()
693            .unwrap();
694        match response {
695            RunnerMessage::Response { payload } => {
696                assert_eq!(payload.status, OutcomeStatus::Error);
697                let error_type = payload
698                    .error
699                    .as_ref()
700                    .and_then(|error| error.error_type.as_deref());
701                assert_eq!(error_type, Some("cancelled"));
702            }
703            _ => panic!("expected response"),
704        }
705
706        let mut extra_request = build_request("block");
707        extra_request.request_id = "req-extra".to_string();
708        extra_request.job_id = job_id.clone();
709        extra_request.context.job_id = job_id.clone();
710        write_message(
711            &mut client,
712            &RunnerMessage::Request {
713                payload: extra_request,
714            },
715        )
716        .await
717        .unwrap();
718
719        gate.add_permits(RESPONSE_CHANNEL_CAPACITY + 1);
720
721        let mut saw_extra = false;
722        for _ in 0..RESPONSE_CHANNEL_CAPACITY {
723            let response = timeout(Duration::from_secs(1), read_message(&mut client))
724                .await
725                .unwrap()
726                .unwrap()
727                .unwrap();
728            if let RunnerMessage::Response { payload } = response
729                && payload.request_id.as_deref() == Some("req-extra")
730            {
731                assert_eq!(payload.status, OutcomeStatus::Success);
732                saw_extra = true;
733            }
734        }
735        assert!(saw_extra, "extra request never completed");
736
737        drop(client);
738        let _ = server_task.await;
739    }
740
741    #[tokio::test]
742    async fn execute_with_deadline_times_out() {
743        let mut registry = Registry::new();
744        registry.register("echo", |request| async move {
745            ExecutionOutcome::success(
746                request.job_id.clone(),
747                request.request_id.clone(),
748                json!({"ok": true}),
749            )
750        });
751        let mut request = build_request("echo");
752        request.context.deadline = Some(
753            "2020-01-01T00:00:00Z"
754                .parse::<chrono::DateTime<Utc>>()
755                .unwrap(),
756        );
757        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
758        assert_eq!(outcome.status, OutcomeStatus::Timeout);
759    }
760
761    #[tokio::test]
762    async fn execute_with_deadline_succeeds_before_deadline() {
763        let mut registry = Registry::new();
764        registry.register("echo", |request| async move {
765            ExecutionOutcome::success(
766                request.job_id.clone(),
767                request.request_id.clone(),
768                json!({"ok": true}),
769            )
770        });
771        let mut request = build_request("echo");
772        request.context.deadline = Some(Utc::now() + chrono::Duration::seconds(5));
773        let outcome = execute_with_deadline(request, registry, &NoopTelemetry).await;
774        assert_eq!(outcome.status, OutcomeStatus::Success);
775    }
776
777    #[tokio::test]
778    async fn handle_connection_handles_unexpected_response_message() {
779        let registry = Registry::new();
780        let (client, server) = tcp_pair().await;
781        let in_flight = Arc::new(Mutex::new(HashMap::new()));
782        let job_index = Arc::new(Mutex::new(HashMap::new()));
783        let server_task = tokio::spawn(async move {
784            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
785                .await
786                .unwrap();
787        });
788        let mut client = client;
789        let response = RunnerMessage::Response {
790            payload: ExecutionOutcome::error("job-x", "req-x", "oops"),
791        };
792        write_message(&mut client, &response).await.unwrap();
793        let reply = read_message(&mut client).await.unwrap().unwrap();
794        match reply {
795            RunnerMessage::Response { payload } => {
796                assert_eq!(payload.status, OutcomeStatus::Error);
797                assert!(
798                    payload
799                        .error
800                        .as_ref()
801                        .unwrap()
802                        .message
803                        .contains("unexpected response")
804                );
805            }
806            _ => panic!("expected response"),
807        }
808        drop(client);
809        let _ = server_task.await;
810    }
811
812    #[tokio::test]
813    async fn handle_connection_cancels_by_job_id() {
814        let mut registry = Registry::new();
815        registry.register("sleep", |request| async move {
816            tokio::time::sleep(Duration::from_millis(200)).await;
817            ExecutionOutcome::success(
818                request.job_id.clone(),
819                request.request_id.clone(),
820                json!({"ok": true}),
821            )
822        });
823        let (client, server) = tcp_pair().await;
824        let in_flight = Arc::new(Mutex::new(HashMap::new()));
825        let job_index = Arc::new(Mutex::new(HashMap::new()));
826        let server_task = tokio::spawn(async move {
827            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
828                .await
829                .unwrap();
830        });
831        let mut client = client;
832        let request = build_request("sleep");
833        let message = RunnerMessage::Request {
834            payload: request.clone(),
835        };
836        write_message(&mut client, &message).await.unwrap();
837        let cancel = RunnerMessage::Cancel {
838            payload: CancelRequest {
839                protocol_version: PROTOCOL_VERSION.to_string(),
840                job_id: request.job_id.clone(),
841                request_id: None,
842                hard_kill: false,
843            },
844        };
845        write_message(&mut client, &cancel).await.unwrap();
846        let response = read_message(&mut client).await.unwrap().unwrap();
847        match response {
848            RunnerMessage::Response { payload } => {
849                assert_eq!(payload.status, OutcomeStatus::Error);
850                let error_type = payload
851                    .error
852                    .as_ref()
853                    .and_then(|error| error.error_type.as_deref());
854                assert_eq!(error_type, Some("cancelled"));
855            }
856            _ => panic!("expected response"),
857        }
858        drop(client);
859        let _ = server_task.await;
860    }
861
862    #[tokio::test]
863    async fn handle_cancel_by_job_id_cancels_all_requests() {
864        let mut registry = Registry::new();
865        registry.register("sleep", |request| async move {
866            tokio::time::sleep(Duration::from_millis(200)).await;
867            ExecutionOutcome::success(
868                request.job_id.clone(),
869                request.request_id.clone(),
870                json!({"ok": true}),
871            )
872        });
873        let (client, server) = tcp_pair().await;
874        let in_flight = Arc::new(Mutex::new(HashMap::new()));
875        let job_index = Arc::new(Mutex::new(HashMap::new()));
876        let server_task = tokio::spawn(async move {
877            handle_connection(server, &registry, &NoopTelemetry, in_flight, job_index)
878                .await
879                .unwrap();
880        });
881        let mut client = client;
882        let mut request1 = build_request("sleep");
883        request1.request_id = "req-1".to_string();
884        request1.job_id = "job-shared".to_string();
885        let mut request2 = build_request("sleep");
886        request2.request_id = "req-2".to_string();
887        request2.job_id = "job-shared".to_string();
888        write_message(&mut client, &RunnerMessage::Request { payload: request1 })
889            .await
890            .unwrap();
891        write_message(&mut client, &RunnerMessage::Request { payload: request2 })
892            .await
893            .unwrap();
894        let cancel = RunnerMessage::Cancel {
895            payload: CancelRequest {
896                protocol_version: PROTOCOL_VERSION.to_string(),
897                job_id: "job-shared".to_string(),
898                request_id: None,
899                hard_kill: false,
900            },
901        };
902        write_message(&mut client, &cancel).await.unwrap();
903
904        let mut cancelled = 0;
905        for _ in 0..2 {
906            let response = timeout(Duration::from_millis(200), read_message(&mut client))
907                .await
908                .unwrap()
909                .unwrap()
910                .unwrap();
911            match response {
912                RunnerMessage::Response { payload } => {
913                    assert_eq!(payload.status, OutcomeStatus::Error);
914                    let error_type = payload
915                        .error
916                        .as_ref()
917                        .and_then(|error| error.error_type.as_deref());
918                    assert_eq!(error_type, Some("cancelled"));
919                    cancelled += 1;
920                }
921                _ => panic!("expected response"),
922            }
923        }
924        assert_eq!(cancelled, 2);
925        drop(client);
926        let _ = server_task.await;
927    }
928
929    #[tokio::test]
930    async fn connection_teardown_clears_tracking_maps() {
931        let mut registry = Registry::new();
932        registry.register("sleep", |request| async move {
933            tokio::time::sleep(Duration::from_millis(200)).await;
934            ExecutionOutcome::success(
935                request.job_id.clone(),
936                request.request_id.clone(),
937                json!({"ok": true}),
938            )
939        });
940        let (client, server) = tcp_pair().await;
941        let in_flight = Arc::new(Mutex::new(HashMap::new()));
942        let job_index = Arc::new(Mutex::new(HashMap::new()));
943        let in_flight_for_server = in_flight.clone();
944        let job_index_for_server = job_index.clone();
945        let server_task = tokio::spawn(async move {
946            handle_connection(
947                server,
948                &registry,
949                &NoopTelemetry,
950                in_flight_for_server,
951                job_index_for_server,
952            )
953            .await
954            .unwrap();
955        });
956        let mut client = client;
957        let request = build_request("sleep");
958        let message = RunnerMessage::Request {
959            payload: request.clone(),
960        };
961        write_message(&mut client, &message).await.unwrap();
962
963        let mut inserted = false;
964        for _ in 0..20 {
965            let has_in_flight = {
966                let guard = in_flight.lock().await;
967                guard.contains_key(&request.request_id)
968            };
969            let has_job_index = {
970                let guard = job_index.lock().await;
971                guard.contains_key(&request.job_id)
972            };
973            if has_in_flight && has_job_index {
974                inserted = true;
975                break;
976            }
977            tokio::time::sleep(Duration::from_millis(10)).await;
978        }
979        assert!(inserted, "request never entered tracking maps");
980
981        drop(client);
982        let _ = server_task.await;
983
984        let in_flight = in_flight.lock().await;
985        let job_index = job_index.lock().await;
986        assert!(in_flight.is_empty());
987        assert!(job_index.is_empty());
988    }
989
990    #[tokio::test]
991    async fn handle_cancel_ignores_invalid_protocol() {
992        let in_flight = Arc::new(Mutex::new(HashMap::new()));
993        let job_index = Arc::new(Mutex::new(HashMap::new()));
994        let (tx, _rx) = mpsc::channel(1);
995        let handle = tokio::spawn(async {});
996        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
997        {
998            let mut guard = in_flight.lock().await;
999            guard.insert(
1000                "req-1".to_string(),
1001                InFlightTask {
1002                    job_id: "job-1".to_string(),
1003                    handle,
1004                    response_tx: tx,
1005                    connection_requests,
1006                    completed: Arc::new(AtomicBool::new(false)),
1007                },
1008            );
1009        }
1010        let payload = CancelRequest {
1011            protocol_version: "0".to_string(),
1012            job_id: "job-1".to_string(),
1013            request_id: None,
1014            hard_kill: false,
1015        };
1016        handle_cancel(payload, &in_flight, &job_index).await;
1017        let guard = in_flight.lock().await;
1018        assert!(guard.contains_key("req-1"));
1019        guard.get("req-1").unwrap().handle.abort();
1020    }
1021
1022    #[tokio::test]
1023    async fn handle_cancel_skips_completed_requests() {
1024        let in_flight = Arc::new(Mutex::new(HashMap::new()));
1025        let job_index = Arc::new(Mutex::new(HashMap::new()));
1026        let (tx, mut rx) = mpsc::channel(1);
1027        let handle = tokio::spawn(async {
1028            tokio::time::sleep(Duration::from_millis(50)).await;
1029        });
1030        let connection_requests = Arc::new(Mutex::new(HashSet::new()));
1031        {
1032            let mut guard = in_flight.lock().await;
1033            guard.insert(
1034                "req-1".to_string(),
1035                InFlightTask {
1036                    job_id: "job-1".to_string(),
1037                    handle,
1038                    response_tx: tx,
1039                    connection_requests,
1040                    completed: Arc::new(AtomicBool::new(true)),
1041                },
1042            );
1043        }
1044        {
1045            let mut guard = job_index.lock().await;
1046            guard.insert("job-1".to_string(), HashSet::from(["req-1".to_string()]));
1047        }
1048        let payload = CancelRequest {
1049            protocol_version: PROTOCOL_VERSION.to_string(),
1050            job_id: "job-1".to_string(),
1051            request_id: Some("req-1".to_string()),
1052            hard_kill: false,
1053        };
1054        handle_cancel(payload, &in_flight, &job_index).await;
1055        assert!(in_flight.lock().await.contains_key("req-1"));
1056        assert!(job_index.lock().await.contains_key("job-1"));
1057        assert!(rx.try_recv().is_err());
1058        let task = in_flight.lock().await.remove("req-1").unwrap();
1059        task.handle.abort();
1060    }
1061
1062    #[tokio::test]
1063    async fn read_message_handles_empty_and_invalid_payloads() {
1064        let (mut client, mut server) = tokio::io::duplex(64);
1065        // length = 0
1066        client.write_all(&0u32.to_be_bytes()).await.unwrap();
1067        let err = read_message(&mut server).await.unwrap_err();
1068        assert!(err.to_string().contains("payload cannot be empty"));
1069
1070        // invalid json
1071        let (mut client, mut server) = tokio::io::duplex(64);
1072        let payload = b"not-json";
1073        let len = (payload.len() as u32).to_be_bytes();
1074        client.write_all(&len).await.unwrap();
1075        client.write_all(payload).await.unwrap();
1076        let err = read_message(&mut server).await.unwrap_err();
1077        assert!(err.to_string().contains("expected"));
1078
1079        // oversized payload
1080        let (mut client, mut server) = tokio::io::duplex(64);
1081        let len = ((MAX_FRAME_LEN + 1) as u32).to_be_bytes();
1082        client.write_all(&len).await.unwrap();
1083        let err = read_message(&mut server).await.unwrap_err();
1084        assert!(err.to_string().contains("payload too large"));
1085    }
1086
1087    #[tokio::test]
1088    async fn read_message_returns_none_on_eof() {
1089        let (client, mut server) = tokio::io::duplex(8);
1090        drop(client);
1091        let message = read_message(&mut server).await.unwrap();
1092        assert!(message.is_none());
1093    }
1094}