1use once_cell::sync::OnceCell;
4use runmat_builtins::{IntValue, StructValue, Value};
5use runmat_macros::runtime_builtin;
6
7use super::tcpserver::{default_user_data, server_handle, TcpServerState, HANDLE_ID_FIELD};
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13use thiserror::Error;
14
15use runmat_time::Instant;
16use std::collections::HashMap;
17use std::io::{self, ErrorKind};
18use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
19use std::sync::{Arc, Mutex};
20use std::time::Duration;
21
22const MESSAGE_ID_INVALID_SERVER: &str = "RunMat:accept:InvalidTcpServer";
23const MESSAGE_ID_TIMEOUT: &str = "RunMat:accept:Timeout";
24const MESSAGE_ID_INVALID_NAME_VALUE: &str = "RunMat:accept:InvalidNameValue";
25const MESSAGE_ID_INTERNAL: &str = "RunMat:accept:InternalError";
26const MESSAGE_ID_ACCEPT_FAILED: &str = "RunMat:accept:AcceptFailed";
27
28pub(crate) const CLIENT_HANDLE_FIELD: &str = "__tcpclient_id";
29
30type SharedTcpClient = Arc<Mutex<TcpClientState>>;
31
32#[derive(Debug)]
33#[allow(dead_code)]
34pub(crate) struct TcpClientState {
35 pub(crate) id: u64,
36 pub(crate) server_id: u64,
37 pub(crate) stream: TcpStream,
38 pub(crate) peer_addr: SocketAddr,
39 pub(crate) timeout: f64,
40 pub(crate) byte_order: String,
41 pub(crate) connected: bool,
42 pub(crate) readline_buffer: Vec<u8>,
43}
44
45#[derive(Default)]
46struct TcpClientRegistry {
47 next_id: u64,
48 clients: HashMap<u64, SharedTcpClient>,
49}
50
51static TCP_CLIENT_REGISTRY: OnceCell<Mutex<TcpClientRegistry>> = OnceCell::new();
52
53#[cfg(test)]
54static TCP_CLIENT_TEST_GUARD: OnceCell<Mutex<()>> = OnceCell::new();
55
56fn client_registry() -> &'static Mutex<TcpClientRegistry> {
57 TCP_CLIENT_REGISTRY.get_or_init(|| Mutex::new(TcpClientRegistry::default()))
58}
59
60#[cfg(test)]
61pub(crate) fn test_guard() -> std::sync::MutexGuard<'static, ()> {
62 TCP_CLIENT_TEST_GUARD
63 .get_or_init(|| Mutex::new(()))
64 .lock()
65 .unwrap_or_else(|poison| poison.into_inner())
66}
67
68pub(crate) fn insert_client(
69 stream: TcpStream,
70 server_id: u64,
71 peer_addr: SocketAddr,
72 timeout: f64,
73 byte_order: String,
74) -> u64 {
75 let mut guard = client_registry()
76 .lock()
77 .unwrap_or_else(|poison| poison.into_inner());
78 guard.next_id = guard.next_id.wrapping_add(1);
79 let id = guard.next_id;
80 let state = TcpClientState {
81 id,
82 server_id,
83 stream,
84 peer_addr,
85 timeout,
86 byte_order,
87 connected: true,
88 readline_buffer: Vec::new(),
89 };
90 let shared = Arc::new(Mutex::new(state));
91 guard.clients.insert(id, shared);
92 id
93}
94
95#[allow(dead_code)]
96pub(crate) fn client_handle(id: u64) -> Option<SharedTcpClient> {
97 client_registry()
98 .lock()
99 .unwrap_or_else(|poison| poison.into_inner())
100 .clients
101 .get(&id)
102 .cloned()
103}
104
105pub(crate) fn close_client(id: u64) -> bool {
106 let entry = {
107 let mut guard = client_registry()
108 .lock()
109 .unwrap_or_else(|poison| poison.into_inner());
110 guard.clients.remove(&id)
111 };
112
113 if let Some(client) = entry {
114 close_client_state(&client);
115 true
116 } else {
117 false
118 }
119}
120
121pub(crate) fn close_clients_for_server(server_id: u64) -> usize {
122 let mut guard = client_registry()
123 .lock()
124 .unwrap_or_else(|poison| poison.into_inner());
125
126 let mut to_close: Vec<(u64, SharedTcpClient)> = Vec::new();
127 for (id, client) in guard.clients.iter() {
128 if let Ok(state) = client.lock() {
129 if state.server_id == server_id {
130 to_close.push((*id, client.clone()));
131 }
132 }
133 }
134
135 for (id, _) in &to_close {
136 guard.clients.remove(id);
137 }
138 drop(guard);
139
140 for (_, client) in &to_close {
141 close_client_state(client);
142 }
143
144 to_close.len()
145}
146
147pub(crate) fn close_all_clients() -> usize {
148 let entries = {
149 let mut guard = client_registry()
150 .lock()
151 .unwrap_or_else(|poison| poison.into_inner());
152 guard.clients.drain().collect::<Vec<_>>()
153 };
154
155 for (_, client) in &entries {
156 close_client_state(client);
157 }
158
159 entries.len()
160}
161
162fn close_client_state(client: &SharedTcpClient) {
163 if let Ok(mut state) = client.lock() {
164 if state.connected {
165 let _ = state.stream.shutdown(Shutdown::Both);
166 state.connected = false;
167 }
168 }
169}
170
171#[cfg(test)]
172pub(super) fn remove_client_for_test(id: u64) {
173 if let Some(entry) = client_registry()
174 .lock()
175 .unwrap_or_else(|poison| poison.into_inner())
176 .clients
177 .remove(&id)
178 {
179 drop(entry);
180 }
181}
182
183#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::net::accept")]
184pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
185 name: "accept",
186 op_kind: GpuOpKind::Custom("network"),
187 supported_precisions: &[],
188 broadcast: BroadcastSemantics::None,
189 provider_hooks: &[],
190 constant_strategy: ConstantStrategy::InlineLiteral,
191 residency: ResidencyPolicy::GatherImmediately,
192 nan_mode: ReductionNaN::Include,
193 two_pass_threshold: None,
194 workgroup_size: None,
195 accepts_nan_mode: false,
196 notes: "Host-only networking builtin; GPU inputs are gathered to CPU before accepting clients.",
197};
198
199#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::net::accept")]
200pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
201 name: "accept",
202 shape: ShapeRequirements::Any,
203 constant_strategy: ConstantStrategy::InlineLiteral,
204 elementwise: None,
205 reduction: None,
206 emits_nan: false,
207 notes: "Networking builtin executed eagerly on the CPU.",
208};
209
210#[runtime_builtin(
211 name = "accept",
212 category = "io/net",
213 summary = "Accept a pending client connection on a TCP server.",
214 keywords = "accept,tcpserver,tcpclient",
215 type_resolver(crate::builtins::io::type_resolvers::accept_type),
216 builtin_path = "crate::builtins::io::net::accept"
217)]
218pub(crate) async fn accept_builtin(server: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
219 let server = gather_if_needed_async(&server).await?;
220 let server_id = extract_server_id(&server)?;
221
222 let options = parse_accept_options(rest).await?;
223
224 let shared_server = server_handle(server_id).ok_or_else(|| {
225 accept_error(
226 MESSAGE_ID_INVALID_SERVER,
227 "accept: tcpserver handle is no longer valid",
228 )
229 })?;
230
231 let server_guard = shared_server
232 .lock()
233 .map_err(|_| accept_error(MESSAGE_ID_INTERNAL, "accept: server lock poisoned"))?;
234
235 let timeout = options.timeout.unwrap_or(server_guard.timeout);
236 validate_timeout(timeout)?;
237
238 match accept_with_timeout(&server_guard.listener, timeout) {
239 Ok((stream, peer_addr)) => {
240 if let Err(err) = configure_stream(&stream, timeout) {
241 drop(server_guard);
242 return Err(accept_error(
243 MESSAGE_ID_INTERNAL,
244 format!("accept: failed to configure stream timeouts ({err})"),
245 ));
246 }
247 let byte_order = server_guard.byte_order.clone();
248 let client_id = insert_client(
249 stream,
250 server_guard.id,
251 peer_addr,
252 timeout,
253 byte_order.clone(),
254 );
255 let client_value =
256 build_tcpclient_value(client_id, &server_guard, peer_addr, timeout, byte_order);
257 drop(server_guard);
258 Ok(client_value)
259 }
260 Err(err) => {
261 drop(server_guard);
262 let message = match err.kind() {
263 ErrorKind::WouldBlock => accept_error(
264 MESSAGE_ID_TIMEOUT,
265 format!(
266 "accept: timed out waiting for a client connection after {:.3} seconds",
267 timeout
268 ),
269 ),
270 _ => accept_error(
271 MESSAGE_ID_ACCEPT_FAILED,
272 format!("accept: failed to accept client ({err})"),
273 ),
274 };
275 Err(message)
276 }
277 }
278}
279
280fn extract_server_id(value: &Value) -> BuiltinResult<u64> {
281 match value {
282 Value::Struct(struct_value) => {
283 let id_value = struct_value.fields.get(HANDLE_ID_FIELD).ok_or_else(|| {
284 accept_error(
285 MESSAGE_ID_INVALID_SERVER,
286 "accept: tcpserver struct missing internal identifier",
287 )
288 })?;
289 let id = match id_value {
290 Value::Int(IntValue::U64(id)) => *id,
291 Value::Int(iv) => iv.to_i64() as u64,
292 other => {
293 return Err(accept_error(
294 MESSAGE_ID_INVALID_SERVER,
295 format!("accept: expected numeric tcpserver identifier, got {other:?}"),
296 ));
297 }
298 };
299 Ok(id)
300 }
301 _ => Err(accept_error(
302 MESSAGE_ID_INVALID_SERVER,
303 "accept: first argument must be the struct returned by tcpserver",
304 )),
305 }
306}
307
308#[derive(Default)]
309struct AcceptOptions {
310 timeout: Option<f64>,
311}
312
313async fn parse_accept_options(rest: Vec<Value>) -> BuiltinResult<AcceptOptions> {
314 if rest.is_empty() {
315 return Ok(AcceptOptions::default());
316 }
317 if !rest.len().is_multiple_of(2) {
318 return Err(accept_error(
319 MESSAGE_ID_INVALID_NAME_VALUE,
320 "accept: name-value arguments must appear in pairs",
321 ));
322 }
323
324 let mut options = AcceptOptions::default();
325 let mut iter = rest.into_iter();
326 while let Some(name_raw) = iter.next() {
327 let value_raw = iter
328 .next()
329 .expect("paired iteration guarantees value exists");
330 let name_value = gather_if_needed_async(&name_raw).await?;
331 let name = match name_value {
332 Value::String(ref s) => s.clone(),
333 Value::CharArray(ref ca) if ca.rows == 1 => ca.data.iter().collect(),
334 Value::StringArray(ref sa) if sa.data.len() == 1 => sa.data[0].clone(),
335 other => {
336 return Err(accept_error(
337 MESSAGE_ID_INVALID_NAME_VALUE,
338 format!("accept: invalid option name ({other:?})"),
339 ));
340 }
341 };
342 let lower = name.to_ascii_lowercase();
343 match lower.as_str() {
344 "timeout" => {
345 let gathered = gather_if_needed_async(&value_raw).await?;
346 let timeout = parse_timeout_value(&gathered).map_err(|msg| {
347 accept_error(
348 MESSAGE_ID_INVALID_NAME_VALUE,
349 format!("accept: invalid Timeout value: {msg}"),
350 )
351 })?;
352 options.timeout = Some(timeout);
353 }
354 _ => {
355 return Err(accept_error(
356 MESSAGE_ID_INVALID_NAME_VALUE,
357 format!("accept: unsupported option '{name}'"),
358 ));
359 }
360 }
361 }
362 Ok(options)
363}
364
365#[derive(Debug, Error)]
366pub(crate) enum TimeoutParseError {
367 #[error("Timeout must be a scalar")]
368 NonScalar,
369 #[error("Timeout must be numeric")]
370 NonNumeric,
371 #[error("Timeout must be finite or Inf")]
372 NonFinite,
373 #[error("Timeout must be non-negative")]
374 Negative,
375}
376
377pub(crate) fn parse_timeout_value(value: &Value) -> Result<f64, TimeoutParseError> {
378 let timeout = match value {
379 Value::Num(n) => *n,
380 Value::Int(i) => i.to_f64(),
381 Value::Tensor(t) if t.data.len() == 1 => t.data[0],
382 Value::Tensor(_) => {
383 return Err(TimeoutParseError::NonScalar);
384 }
385 _ => return Err(TimeoutParseError::NonNumeric),
386 };
387 if !timeout.is_finite() && !timeout.is_infinite() {
388 return Err(TimeoutParseError::NonFinite);
389 }
390 if timeout.is_sign_negative() {
391 return Err(TimeoutParseError::Negative);
392 }
393 Ok(timeout)
394}
395
396fn validate_timeout(timeout: f64) -> BuiltinResult<()> {
397 if timeout.is_nan() {
398 return Err(accept_error(
399 MESSAGE_ID_INVALID_NAME_VALUE,
400 "accept: Timeout must not be NaN",
401 ));
402 }
403 if timeout.is_sign_negative() {
404 return Err(accept_error(
405 MESSAGE_ID_INVALID_NAME_VALUE,
406 "accept: Timeout must be non-negative",
407 ));
408 }
409 Ok(())
410}
411
412fn accept_with_timeout(
413 listener: &TcpListener,
414 timeout: f64,
415) -> io::Result<(TcpStream, SocketAddr)> {
416 if timeout.is_infinite() {
417 return listener.accept();
418 }
419 listener.set_nonblocking(true)?;
420 let start = Instant::now();
421 let deadline = Duration::from_secs_f64(timeout);
422 loop {
423 match listener.accept() {
424 Ok((stream, addr)) => {
425 let _ = listener.set_nonblocking(false);
426 return Ok((stream, addr));
427 }
428 Err(err) if err.kind() == ErrorKind::WouldBlock => {
429 if start.elapsed() >= deadline {
430 let _ = listener.set_nonblocking(false);
431 return Err(io::Error::new(ErrorKind::WouldBlock, "accept timeout"));
432 }
433 std::thread::sleep(Duration::from_millis(10));
434 }
435 Err(err) => {
436 let _ = listener.set_nonblocking(false);
437 return Err(err);
438 }
439 }
440 }
441}
442
443pub(crate) fn configure_stream(stream: &TcpStream, timeout: f64) -> io::Result<()> {
444 let opt = if timeout.is_infinite() || timeout == 0.0 {
445 None
446 } else {
447 Some(Duration::from_secs_f64(timeout))
448 };
449 stream.set_read_timeout(opt)?;
450 stream.set_write_timeout(opt)?;
451 Ok(())
452}
453
454fn build_tcpclient_value(
455 client_id: u64,
456 server_state: &TcpServerState,
457 peer_addr: SocketAddr,
458 timeout: f64,
459 byte_order: String,
460) -> Value {
461 let mut st = StructValue::new();
462 st.fields
463 .insert("Type".to_string(), Value::String("tcpclient".to_string()));
464 st.fields.insert(
465 "Address".to_string(),
466 Value::String(peer_addr.ip().to_string()),
467 );
468 st.fields.insert(
469 "Port".to_string(),
470 Value::Int(IntValue::U16(peer_addr.port())),
471 );
472 st.fields.insert(
473 "ServerAddress".to_string(),
474 Value::String(server_state.local_addr.ip().to_string()),
475 );
476 st.fields.insert(
477 "ServerPort".to_string(),
478 Value::Int(IntValue::U16(server_state.local_addr.port())),
479 );
480 st.fields.insert("Connected".to_string(), Value::Bool(true));
481 st.fields
482 .insert("Status".to_string(), Value::String("connected".to_string()));
483 st.fields.insert(
484 "NumBytesAvailable".to_string(),
485 Value::Int(IntValue::I32(0)),
486 );
487 st.fields
488 .insert("BytesAvailableFcn".to_string(), default_user_data());
489 st.fields.insert(
490 "BytesAvailableFcnMode".to_string(),
491 Value::String("byte".to_string()),
492 );
493 st.fields.insert(
494 "BytesAvailableFcnCount".to_string(),
495 Value::Int(IntValue::I32(1)),
496 );
497 st.fields
498 .insert("ByteOrder".to_string(), Value::String(byte_order));
499 st.fields.insert(
500 "Timeout".to_string(),
501 Value::Num(if timeout.is_infinite() {
502 f64::INFINITY
503 } else {
504 timeout
505 }),
506 );
507 st.fields
508 .insert("UserData".to_string(), default_user_data());
509 st.fields.insert(
510 CLIENT_HANDLE_FIELD.to_string(),
511 Value::Int(IntValue::U64(client_id)),
512 );
513 st.fields.insert(
514 HANDLE_ID_FIELD.to_string(),
515 Value::Int(IntValue::U64(server_state.id)),
516 );
517 Value::Struct(st)
518}
519
520fn accept_error(message_id: &'static str, message: impl Into<String>) -> RuntimeError {
521 build_runtime_error(message)
522 .with_identifier(message_id)
523 .with_builtin("accept")
524 .build()
525}
526
527#[cfg(test)]
528pub(crate) mod tests {
529 use super::super::tcpserver::{
530 remove_server_for_test, tcpserver_builtin, HANDLE_ID_FIELD as SERVER_FIELD,
531 };
532 use super::*;
533 use runmat_builtins::Value;
534 use std::net::TcpStream;
535 use std::thread;
536 use std::time::Duration;
537
538 fn struct_field<'a>(value: &'a Value, name: &str) -> &'a Value {
539 match value {
540 Value::Struct(st) => st
541 .fields
542 .get(name)
543 .unwrap_or_else(|| panic!("missing field {name}")),
544 _ => panic!("expected struct"),
545 }
546 }
547
548 fn client_id(value: &Value) -> u64 {
549 match struct_field(value, CLIENT_HANDLE_FIELD) {
550 Value::Int(IntValue::U64(id)) => *id,
551 Value::Int(iv) => iv.to_i64() as u64,
552 other => panic!("expected id int, got {other:?}"),
553 }
554 }
555
556 fn assert_error_identifier(err: RuntimeError, expected: &str) {
557 assert_eq!(err.identifier(), Some(expected));
558 }
559
560 fn run_accept(server: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
561 futures::executor::block_on(accept_builtin(server, rest))
562 }
563
564 fn run_tcpserver(address: Value, port: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
565 futures::executor::block_on(tcpserver_builtin(address, port, rest))
566 }
567
568 fn server_id(value: &Value) -> u64 {
569 match struct_field(value, SERVER_FIELD) {
570 Value::Int(IntValue::U64(id)) => *id,
571 Value::Int(iv) => iv.to_i64() as u64,
572 other => panic!("expected server id int, got {other:?}"),
573 }
574 }
575
576 fn net_guard() -> std::sync::MutexGuard<'static, ()> {
577 super::test_guard()
578 }
579
580 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
581 #[test]
582 fn accept_rejects_non_struct() {
583 let _guard = net_guard();
584 let err = run_accept(Value::Num(1.0), Vec::new()).unwrap_err();
585 assert_error_identifier(err, MESSAGE_ID_INVALID_SERVER);
586 }
587
588 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
589 #[test]
590 fn accept_establishes_client_connection() {
591 let _guard = net_guard();
592 let server_value = run_tcpserver(
593 Value::from("127.0.0.1"),
594 Value::Int(IntValue::I32(0)),
595 Vec::new(),
596 )
597 .expect("tcpserver");
598 let port = match struct_field(&server_value, "ServerPort") {
599 Value::Int(iv) => iv.to_i64() as u16,
600 other => panic!("expected ServerPort int, got {other:?}"),
601 };
602
603 let handle = thread::spawn(move || {
604 thread::sleep(Duration::from_millis(50));
605 TcpStream::connect(("127.0.0.1", port)).expect("connect")
606 });
607
608 let client = run_accept(server_value.clone(), Vec::new()).expect("accept");
609 let stream = handle.join().expect("client thread");
610 drop(stream);
611
612 match struct_field(&client, "Connected") {
613 Value::Bool(flag) => assert!(*flag),
614 other => panic!("expected Connected bool, got {other:?}"),
615 }
616 match struct_field(&client, "Address") {
617 Value::String(addr) => assert_eq!(addr, "127.0.0.1"),
618 other => panic!("expected Address string, got {other:?}"),
619 }
620 match struct_field(&client, "Timeout") {
621 Value::Num(n) => assert_eq!(*n, 10.0),
622 other => panic!("expected Timeout numeric, got {other:?}"),
623 }
624
625 let cid = client_id(&client);
626 remove_client_for_test(cid);
627 remove_server_for_test(server_id(&server_value));
628 }
629
630 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
631 #[test]
632 fn accept_times_out_when_no_client_connects() {
633 let _guard = net_guard();
634 let server_value = run_tcpserver(
635 Value::from("127.0.0.1"),
636 Value::Int(IntValue::I32(0)),
637 Vec::new(),
638 )
639 .expect("tcpserver");
640 let err = run_accept(
641 server_value.clone(),
642 vec![Value::from("Timeout"), Value::Num(0.05)],
643 )
644 .unwrap_err();
645 assert_error_identifier(err, MESSAGE_ID_TIMEOUT);
646 remove_server_for_test(server_id(&server_value));
647 }
648
649 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
650 #[test]
651 fn accept_rejects_invalid_timeout_name_value() {
652 let _guard = net_guard();
653 let server_value = run_tcpserver(
654 Value::from("127.0.0.1"),
655 Value::Int(IntValue::I32(0)),
656 Vec::new(),
657 )
658 .expect("tcpserver");
659 let err = run_accept(
660 server_value.clone(),
661 vec![Value::from("Timeout"), Value::Num(-1.0)],
662 )
663 .unwrap_err();
664 assert_error_identifier(err, MESSAGE_ID_INVALID_NAME_VALUE);
665 remove_server_for_test(server_id(&server_value));
666 }
667
668 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
669 #[test]
670 fn accept_respects_per_call_timeout_override() {
671 let _guard = net_guard();
672 let server_value = run_tcpserver(
673 Value::from("127.0.0.1"),
674 Value::Int(IntValue::I32(0)),
675 Vec::new(),
676 )
677 .expect("tcpserver");
678 let port = match struct_field(&server_value, "ServerPort") {
679 Value::Int(iv) => iv.to_i64() as u16,
680 other => panic!("expected ServerPort int, got {other:?}"),
681 };
682
683 let handle = thread::spawn(move || {
684 thread::sleep(Duration::from_millis(20));
685 TcpStream::connect(("127.0.0.1", port)).expect("connect")
686 });
687
688 let client = run_accept(
689 server_value.clone(),
690 vec![Value::from("Timeout"), Value::Num(1.0)],
691 )
692 .expect("accept");
693 handle.join().expect("join");
694 let timeout_val = match struct_field(&client, "Timeout") {
695 Value::Num(n) => *n,
696 other => panic!("expected Timeout numeric, got {other:?}"),
697 };
698 assert_eq!(timeout_val, 1.0);
699 let cid = client_id(&client);
700 remove_client_for_test(cid);
701 remove_server_for_test(server_id(&server_value));
702 }
703}