Skip to main content

runmat_runtime/builtins/io/net/
accept.rs

1//! MATLAB-compatible `accept` builtin for RunMat networking.
2
3use once_cell::sync::OnceCell;
4use runmat_builtins::{IntValue, StructValue, Value};
5use runmat_macros::runtime_builtin;
6
7use super::tcpserver::{default_user_data, server_handle, TcpServerState, HANDLE_ID_FIELD};
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13use thiserror::Error;
14
15use runmat_time::Instant;
16use std::collections::HashMap;
17use std::io::{self, ErrorKind};
18use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
19use std::sync::{Arc, Mutex};
20use std::time::Duration;
21
22const MESSAGE_ID_INVALID_SERVER: &str = "RunMat:accept:InvalidTcpServer";
23const MESSAGE_ID_TIMEOUT: &str = "RunMat:accept:Timeout";
24const MESSAGE_ID_INVALID_NAME_VALUE: &str = "RunMat:accept:InvalidNameValue";
25const MESSAGE_ID_INTERNAL: &str = "RunMat:accept:InternalError";
26const MESSAGE_ID_ACCEPT_FAILED: &str = "RunMat:accept:AcceptFailed";
27
28pub(crate) const CLIENT_HANDLE_FIELD: &str = "__tcpclient_id";
29
30type SharedTcpClient = Arc<Mutex<TcpClientState>>;
31
32#[derive(Debug)]
33#[allow(dead_code)]
34pub(crate) struct TcpClientState {
35    pub(crate) id: u64,
36    pub(crate) server_id: u64,
37    pub(crate) stream: TcpStream,
38    pub(crate) peer_addr: SocketAddr,
39    pub(crate) timeout: f64,
40    pub(crate) byte_order: String,
41    pub(crate) connected: bool,
42    pub(crate) readline_buffer: Vec<u8>,
43}
44
45#[derive(Default)]
46struct TcpClientRegistry {
47    next_id: u64,
48    clients: HashMap<u64, SharedTcpClient>,
49}
50
51static TCP_CLIENT_REGISTRY: OnceCell<Mutex<TcpClientRegistry>> = OnceCell::new();
52
53#[cfg(test)]
54static TCP_CLIENT_TEST_GUARD: OnceCell<Mutex<()>> = OnceCell::new();
55
56fn client_registry() -> &'static Mutex<TcpClientRegistry> {
57    TCP_CLIENT_REGISTRY.get_or_init(|| Mutex::new(TcpClientRegistry::default()))
58}
59
60#[cfg(test)]
61pub(crate) fn test_guard() -> std::sync::MutexGuard<'static, ()> {
62    TCP_CLIENT_TEST_GUARD
63        .get_or_init(|| Mutex::new(()))
64        .lock()
65        .unwrap_or_else(|poison| poison.into_inner())
66}
67
68pub(crate) fn insert_client(
69    stream: TcpStream,
70    server_id: u64,
71    peer_addr: SocketAddr,
72    timeout: f64,
73    byte_order: String,
74) -> u64 {
75    let mut guard = client_registry()
76        .lock()
77        .unwrap_or_else(|poison| poison.into_inner());
78    guard.next_id = guard.next_id.wrapping_add(1);
79    let id = guard.next_id;
80    let state = TcpClientState {
81        id,
82        server_id,
83        stream,
84        peer_addr,
85        timeout,
86        byte_order,
87        connected: true,
88        readline_buffer: Vec::new(),
89    };
90    let shared = Arc::new(Mutex::new(state));
91    guard.clients.insert(id, shared);
92    id
93}
94
95#[allow(dead_code)]
96pub(crate) fn client_handle(id: u64) -> Option<SharedTcpClient> {
97    client_registry()
98        .lock()
99        .unwrap_or_else(|poison| poison.into_inner())
100        .clients
101        .get(&id)
102        .cloned()
103}
104
105pub(crate) fn close_client(id: u64) -> bool {
106    let entry = {
107        let mut guard = client_registry()
108            .lock()
109            .unwrap_or_else(|poison| poison.into_inner());
110        guard.clients.remove(&id)
111    };
112
113    if let Some(client) = entry {
114        close_client_state(&client);
115        true
116    } else {
117        false
118    }
119}
120
121pub(crate) fn close_clients_for_server(server_id: u64) -> usize {
122    let mut guard = client_registry()
123        .lock()
124        .unwrap_or_else(|poison| poison.into_inner());
125
126    let mut to_close: Vec<(u64, SharedTcpClient)> = Vec::new();
127    for (id, client) in guard.clients.iter() {
128        if let Ok(state) = client.lock() {
129            if state.server_id == server_id {
130                to_close.push((*id, client.clone()));
131            }
132        }
133    }
134
135    for (id, _) in &to_close {
136        guard.clients.remove(id);
137    }
138    drop(guard);
139
140    for (_, client) in &to_close {
141        close_client_state(client);
142    }
143
144    to_close.len()
145}
146
147pub(crate) fn close_all_clients() -> usize {
148    let entries = {
149        let mut guard = client_registry()
150            .lock()
151            .unwrap_or_else(|poison| poison.into_inner());
152        guard.clients.drain().collect::<Vec<_>>()
153    };
154
155    for (_, client) in &entries {
156        close_client_state(client);
157    }
158
159    entries.len()
160}
161
162fn close_client_state(client: &SharedTcpClient) {
163    if let Ok(mut state) = client.lock() {
164        if state.connected {
165            let _ = state.stream.shutdown(Shutdown::Both);
166            state.connected = false;
167        }
168    }
169}
170
171#[cfg(test)]
172pub(super) fn remove_client_for_test(id: u64) {
173    if let Some(entry) = client_registry()
174        .lock()
175        .unwrap_or_else(|poison| poison.into_inner())
176        .clients
177        .remove(&id)
178    {
179        drop(entry);
180    }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::net::accept")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185    name: "accept",
186    op_kind: GpuOpKind::Custom("network"),
187    supported_precisions: &[],
188    broadcast: BroadcastSemantics::None,
189    provider_hooks: &[],
190    constant_strategy: ConstantStrategy::InlineLiteral,
191    residency: ResidencyPolicy::GatherImmediately,
192    nan_mode: ReductionNaN::Include,
193    two_pass_threshold: None,
194    workgroup_size: None,
195    accepts_nan_mode: false,
196    notes: "Host-only networking builtin; GPU inputs are gathered to CPU before accepting clients.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::net::accept")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201    name: "accept",
202    shape: ShapeRequirements::Any,
203    constant_strategy: ConstantStrategy::InlineLiteral,
204    elementwise: None,
205    reduction: None,
206    emits_nan: false,
207    notes: "Networking builtin executed eagerly on the CPU.",
208};
209
210#[runtime_builtin(
211    name = "accept",
212    category = "io/net",
213    summary = "Accept a pending client connection on a TCP server.",
214    keywords = "accept,tcpserver,tcpclient",
215    type_resolver(crate::builtins::io::type_resolvers::accept_type),
216    builtin_path = "crate::builtins::io::net::accept"
217)]
218pub(crate) async fn accept_builtin(server: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
219    let server = gather_if_needed_async(&server).await?;
220    let server_id = extract_server_id(&server)?;
221
222    let options = parse_accept_options(rest).await?;
223
224    let shared_server = server_handle(server_id).ok_or_else(|| {
225        accept_error(
226            MESSAGE_ID_INVALID_SERVER,
227            "accept: tcpserver handle is no longer valid",
228        )
229    })?;
230
231    let server_guard = shared_server
232        .lock()
233        .map_err(|_| accept_error(MESSAGE_ID_INTERNAL, "accept: server lock poisoned"))?;
234
235    let timeout = options.timeout.unwrap_or(server_guard.timeout);
236    validate_timeout(timeout)?;
237
238    match accept_with_timeout(&server_guard.listener, timeout) {
239        Ok((stream, peer_addr)) => {
240            if let Err(err) = configure_stream(&stream, timeout) {
241                drop(server_guard);
242                return Err(accept_error(
243                    MESSAGE_ID_INTERNAL,
244                    format!("accept: failed to configure stream timeouts ({err})"),
245                ));
246            }
247            let byte_order = server_guard.byte_order.clone();
248            let client_id = insert_client(
249                stream,
250                server_guard.id,
251                peer_addr,
252                timeout,
253                byte_order.clone(),
254            );
255            let client_value =
256                build_tcpclient_value(client_id, &server_guard, peer_addr, timeout, byte_order);
257            drop(server_guard);
258            Ok(client_value)
259        }
260        Err(err) => {
261            drop(server_guard);
262            let message = match err.kind() {
263                ErrorKind::WouldBlock => accept_error(
264                    MESSAGE_ID_TIMEOUT,
265                    format!(
266                        "accept: timed out waiting for a client connection after {:.3} seconds",
267                        timeout
268                    ),
269                ),
270                _ => accept_error(
271                    MESSAGE_ID_ACCEPT_FAILED,
272                    format!("accept: failed to accept client ({err})"),
273                ),
274            };
275            Err(message)
276        }
277    }
278}
279
280fn extract_server_id(value: &Value) -> BuiltinResult<u64> {
281    match value {
282        Value::Struct(struct_value) => {
283            let id_value = struct_value.fields.get(HANDLE_ID_FIELD).ok_or_else(|| {
284                accept_error(
285                    MESSAGE_ID_INVALID_SERVER,
286                    "accept: tcpserver struct missing internal identifier",
287                )
288            })?;
289            let id = match id_value {
290                Value::Int(IntValue::U64(id)) => *id,
291                Value::Int(iv) => iv.to_i64() as u64,
292                other => {
293                    return Err(accept_error(
294                        MESSAGE_ID_INVALID_SERVER,
295                        format!("accept: expected numeric tcpserver identifier, got {other:?}"),
296                    ));
297                }
298            };
299            Ok(id)
300        }
301        _ => Err(accept_error(
302            MESSAGE_ID_INVALID_SERVER,
303            "accept: first argument must be the struct returned by tcpserver",
304        )),
305    }
306}
307
308#[derive(Default)]
309struct AcceptOptions {
310    timeout: Option<f64>,
311}
312
313async fn parse_accept_options(rest: Vec<Value>) -> BuiltinResult<AcceptOptions> {
314    if rest.is_empty() {
315        return Ok(AcceptOptions::default());
316    }
317    if !rest.len().is_multiple_of(2) {
318        return Err(accept_error(
319            MESSAGE_ID_INVALID_NAME_VALUE,
320            "accept: name-value arguments must appear in pairs",
321        ));
322    }
323
324    let mut options = AcceptOptions::default();
325    let mut iter = rest.into_iter();
326    while let Some(name_raw) = iter.next() {
327        let value_raw = iter
328            .next()
329            .expect("paired iteration guarantees value exists");
330        let name_value = gather_if_needed_async(&name_raw).await?;
331        let name = match name_value {
332            Value::String(ref s) => s.clone(),
333            Value::CharArray(ref ca) if ca.rows == 1 => ca.data.iter().collect(),
334            Value::StringArray(ref sa) if sa.data.len() == 1 => sa.data[0].clone(),
335            other => {
336                return Err(accept_error(
337                    MESSAGE_ID_INVALID_NAME_VALUE,
338                    format!("accept: invalid option name ({other:?})"),
339                ));
340            }
341        };
342        let lower = name.to_ascii_lowercase();
343        match lower.as_str() {
344            "timeout" => {
345                let gathered = gather_if_needed_async(&value_raw).await?;
346                let timeout = parse_timeout_value(&gathered).map_err(|msg| {
347                    accept_error(
348                        MESSAGE_ID_INVALID_NAME_VALUE,
349                        format!("accept: invalid Timeout value: {msg}"),
350                    )
351                })?;
352                options.timeout = Some(timeout);
353            }
354            _ => {
355                return Err(accept_error(
356                    MESSAGE_ID_INVALID_NAME_VALUE,
357                    format!("accept: unsupported option '{name}'"),
358                ));
359            }
360        }
361    }
362    Ok(options)
363}
364
365#[derive(Debug, Error)]
366pub(crate) enum TimeoutParseError {
367    #[error("Timeout must be a scalar")]
368    NonScalar,
369    #[error("Timeout must be numeric")]
370    NonNumeric,
371    #[error("Timeout must be finite or Inf")]
372    NonFinite,
373    #[error("Timeout must be non-negative")]
374    Negative,
375}
376
377pub(crate) fn parse_timeout_value(value: &Value) -> Result<f64, TimeoutParseError> {
378    let timeout = match value {
379        Value::Num(n) => *n,
380        Value::Int(i) => i.to_f64(),
381        Value::Tensor(t) if t.data.len() == 1 => t.data[0],
382        Value::Tensor(_) => {
383            return Err(TimeoutParseError::NonScalar);
384        }
385        _ => return Err(TimeoutParseError::NonNumeric),
386    };
387    if !timeout.is_finite() && !timeout.is_infinite() {
388        return Err(TimeoutParseError::NonFinite);
389    }
390    if timeout.is_sign_negative() {
391        return Err(TimeoutParseError::Negative);
392    }
393    Ok(timeout)
394}
395
396fn validate_timeout(timeout: f64) -> BuiltinResult<()> {
397    if timeout.is_nan() {
398        return Err(accept_error(
399            MESSAGE_ID_INVALID_NAME_VALUE,
400            "accept: Timeout must not be NaN",
401        ));
402    }
403    if timeout.is_sign_negative() {
404        return Err(accept_error(
405            MESSAGE_ID_INVALID_NAME_VALUE,
406            "accept: Timeout must be non-negative",
407        ));
408    }
409    Ok(())
410}
411
412fn accept_with_timeout(
413    listener: &TcpListener,
414    timeout: f64,
415) -> io::Result<(TcpStream, SocketAddr)> {
416    if timeout.is_infinite() {
417        return listener.accept();
418    }
419    listener.set_nonblocking(true)?;
420    let start = Instant::now();
421    let deadline = Duration::from_secs_f64(timeout);
422    loop {
423        match listener.accept() {
424            Ok((stream, addr)) => {
425                let _ = listener.set_nonblocking(false);
426                return Ok((stream, addr));
427            }
428            Err(err) if err.kind() == ErrorKind::WouldBlock => {
429                if start.elapsed() >= deadline {
430                    let _ = listener.set_nonblocking(false);
431                    return Err(io::Error::new(ErrorKind::WouldBlock, "accept timeout"));
432                }
433                std::thread::sleep(Duration::from_millis(10));
434            }
435            Err(err) => {
436                let _ = listener.set_nonblocking(false);
437                return Err(err);
438            }
439        }
440    }
441}
442
443pub(crate) fn configure_stream(stream: &TcpStream, timeout: f64) -> io::Result<()> {
444    let opt = if timeout.is_infinite() || timeout == 0.0 {
445        None
446    } else {
447        Some(Duration::from_secs_f64(timeout))
448    };
449    stream.set_read_timeout(opt)?;
450    stream.set_write_timeout(opt)?;
451    Ok(())
452}
453
454fn build_tcpclient_value(
455    client_id: u64,
456    server_state: &TcpServerState,
457    peer_addr: SocketAddr,
458    timeout: f64,
459    byte_order: String,
460) -> Value {
461    let mut st = StructValue::new();
462    st.fields
463        .insert("Type".to_string(), Value::String("tcpclient".to_string()));
464    st.fields.insert(
465        "Address".to_string(),
466        Value::String(peer_addr.ip().to_string()),
467    );
468    st.fields.insert(
469        "Port".to_string(),
470        Value::Int(IntValue::U16(peer_addr.port())),
471    );
472    st.fields.insert(
473        "ServerAddress".to_string(),
474        Value::String(server_state.local_addr.ip().to_string()),
475    );
476    st.fields.insert(
477        "ServerPort".to_string(),
478        Value::Int(IntValue::U16(server_state.local_addr.port())),
479    );
480    st.fields.insert("Connected".to_string(), Value::Bool(true));
481    st.fields
482        .insert("Status".to_string(), Value::String("connected".to_string()));
483    st.fields.insert(
484        "NumBytesAvailable".to_string(),
485        Value::Int(IntValue::I32(0)),
486    );
487    st.fields
488        .insert("BytesAvailableFcn".to_string(), default_user_data());
489    st.fields.insert(
490        "BytesAvailableFcnMode".to_string(),
491        Value::String("byte".to_string()),
492    );
493    st.fields.insert(
494        "BytesAvailableFcnCount".to_string(),
495        Value::Int(IntValue::I32(1)),
496    );
497    st.fields
498        .insert("ByteOrder".to_string(), Value::String(byte_order));
499    st.fields.insert(
500        "Timeout".to_string(),
501        Value::Num(if timeout.is_infinite() {
502            f64::INFINITY
503        } else {
504            timeout
505        }),
506    );
507    st.fields
508        .insert("UserData".to_string(), default_user_data());
509    st.fields.insert(
510        CLIENT_HANDLE_FIELD.to_string(),
511        Value::Int(IntValue::U64(client_id)),
512    );
513    st.fields.insert(
514        HANDLE_ID_FIELD.to_string(),
515        Value::Int(IntValue::U64(server_state.id)),
516    );
517    Value::Struct(st)
518}
519
520fn accept_error(message_id: &'static str, message: impl Into<String>) -> RuntimeError {
521    build_runtime_error(message)
522        .with_identifier(message_id)
523        .with_builtin("accept")
524        .build()
525}
526
527#[cfg(test)]
528pub(crate) mod tests {
529    use super::super::tcpserver::{
530        remove_server_for_test, tcpserver_builtin, HANDLE_ID_FIELD as SERVER_FIELD,
531    };
532    use super::*;
533    use runmat_builtins::Value;
534    use std::net::TcpStream;
535    use std::thread;
536    use std::time::Duration;
537
538    fn struct_field<'a>(value: &'a Value, name: &str) -> &'a Value {
539        match value {
540            Value::Struct(st) => st
541                .fields
542                .get(name)
543                .unwrap_or_else(|| panic!("missing field {name}")),
544            _ => panic!("expected struct"),
545        }
546    }
547
548    fn client_id(value: &Value) -> u64 {
549        match struct_field(value, CLIENT_HANDLE_FIELD) {
550            Value::Int(IntValue::U64(id)) => *id,
551            Value::Int(iv) => iv.to_i64() as u64,
552            other => panic!("expected id int, got {other:?}"),
553        }
554    }
555
556    fn assert_error_identifier(err: RuntimeError, expected: &str) {
557        assert_eq!(err.identifier(), Some(expected));
558    }
559
560    fn run_accept(server: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
561        futures::executor::block_on(accept_builtin(server, rest))
562    }
563
564    fn run_tcpserver(address: Value, port: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
565        futures::executor::block_on(tcpserver_builtin(address, port, rest))
566    }
567
568    fn server_id(value: &Value) -> u64 {
569        match struct_field(value, SERVER_FIELD) {
570            Value::Int(IntValue::U64(id)) => *id,
571            Value::Int(iv) => iv.to_i64() as u64,
572            other => panic!("expected server id int, got {other:?}"),
573        }
574    }
575
576    fn net_guard() -> std::sync::MutexGuard<'static, ()> {
577        super::test_guard()
578    }
579
580    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581    #[test]
582    fn accept_rejects_non_struct() {
583        let _guard = net_guard();
584        let err = run_accept(Value::Num(1.0), Vec::new()).unwrap_err();
585        assert_error_identifier(err, MESSAGE_ID_INVALID_SERVER);
586    }
587
588    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
589    #[test]
590    fn accept_establishes_client_connection() {
591        let _guard = net_guard();
592        let server_value = run_tcpserver(
593            Value::from("127.0.0.1"),
594            Value::Int(IntValue::I32(0)),
595            Vec::new(),
596        )
597        .expect("tcpserver");
598        let port = match struct_field(&server_value, "ServerPort") {
599            Value::Int(iv) => iv.to_i64() as u16,
600            other => panic!("expected ServerPort int, got {other:?}"),
601        };
602
603        let handle = thread::spawn(move || {
604            thread::sleep(Duration::from_millis(50));
605            TcpStream::connect(("127.0.0.1", port)).expect("connect")
606        });
607
608        let client = run_accept(server_value.clone(), Vec::new()).expect("accept");
609        let stream = handle.join().expect("client thread");
610        drop(stream);
611
612        match struct_field(&client, "Connected") {
613            Value::Bool(flag) => assert!(*flag),
614            other => panic!("expected Connected bool, got {other:?}"),
615        }
616        match struct_field(&client, "Address") {
617            Value::String(addr) => assert_eq!(addr, "127.0.0.1"),
618            other => panic!("expected Address string, got {other:?}"),
619        }
620        match struct_field(&client, "Timeout") {
621            Value::Num(n) => assert_eq!(*n, 10.0),
622            other => panic!("expected Timeout numeric, got {other:?}"),
623        }
624
625        let cid = client_id(&client);
626        remove_client_for_test(cid);
627        remove_server_for_test(server_id(&server_value));
628    }
629
630    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
631    #[test]
632    fn accept_times_out_when_no_client_connects() {
633        let _guard = net_guard();
634        let server_value = run_tcpserver(
635            Value::from("127.0.0.1"),
636            Value::Int(IntValue::I32(0)),
637            Vec::new(),
638        )
639        .expect("tcpserver");
640        let err = run_accept(
641            server_value.clone(),
642            vec![Value::from("Timeout"), Value::Num(0.05)],
643        )
644        .unwrap_err();
645        assert_error_identifier(err, MESSAGE_ID_TIMEOUT);
646        remove_server_for_test(server_id(&server_value));
647    }
648
649    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650    #[test]
651    fn accept_rejects_invalid_timeout_name_value() {
652        let _guard = net_guard();
653        let server_value = run_tcpserver(
654            Value::from("127.0.0.1"),
655            Value::Int(IntValue::I32(0)),
656            Vec::new(),
657        )
658        .expect("tcpserver");
659        let err = run_accept(
660            server_value.clone(),
661            vec![Value::from("Timeout"), Value::Num(-1.0)],
662        )
663        .unwrap_err();
664        assert_error_identifier(err, MESSAGE_ID_INVALID_NAME_VALUE);
665        remove_server_for_test(server_id(&server_value));
666    }
667
668    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
669    #[test]
670    fn accept_respects_per_call_timeout_override() {
671        let _guard = net_guard();
672        let server_value = run_tcpserver(
673            Value::from("127.0.0.1"),
674            Value::Int(IntValue::I32(0)),
675            Vec::new(),
676        )
677        .expect("tcpserver");
678        let port = match struct_field(&server_value, "ServerPort") {
679            Value::Int(iv) => iv.to_i64() as u16,
680            other => panic!("expected ServerPort int, got {other:?}"),
681        };
682
683        let handle = thread::spawn(move || {
684            thread::sleep(Duration::from_millis(20));
685            TcpStream::connect(("127.0.0.1", port)).expect("connect")
686        });
687
688        let client = run_accept(
689            server_value.clone(),
690            vec![Value::from("Timeout"), Value::Num(1.0)],
691        )
692        .expect("accept");
693        handle.join().expect("join");
694        let timeout_val = match struct_field(&client, "Timeout") {
695            Value::Num(n) => *n,
696            other => panic!("expected Timeout numeric, got {other:?}"),
697        };
698        assert_eq!(timeout_val, 1.0);
699        let cid = client_id(&client);
700        remove_client_for_test(cid);
701        remove_server_for_test(server_id(&server_value));
702    }
703}