1use runmat_builtins::{IntValue, StructValue, Value};
4use runmat_macros::runtime_builtin;
5use std::io::{self, Write};
6use std::net::TcpStream;
7
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::{gather_if_needed, register_builtin_fusion_spec, register_builtin_gpu_spec};
13
14use super::accept::{client_handle, configure_stream, CLIENT_HANDLE_FIELD};
15
16#[cfg(feature = "doc_export")]
17use crate::register_builtin_doc_text;
18
19const MESSAGE_ID_INVALID_CLIENT: &str = "MATLAB:write:InvalidTcpClient";
20const MESSAGE_ID_INVALID_DATA: &str = "MATLAB:write:InvalidData";
21const MESSAGE_ID_INVALID_DATATYPE: &str = "MATLAB:write:InvalidDataType";
22const MESSAGE_ID_NOT_CONNECTED: &str = "MATLAB:write:NotConnected";
23const MESSAGE_ID_TIMEOUT: &str = "MATLAB:write:Timeout";
24const MESSAGE_ID_CONNECTION_CLOSED: &str = "MATLAB:write:ConnectionClosed";
25const MESSAGE_ID_INTERNAL: &str = "MATLAB:write:InternalError";
26
27#[cfg(feature = "doc_export")]
28pub const DOC_MD: &str = r#"---
29title: "write"
30category: "io/net"
31keywords: ["write", "tcpclient", "networking", "socket", "binary data", "text"]
32summary: "Write numeric or text data to a remote host through a MATLAB-compatible tcpclient struct."
33references:
34 - https://www.mathworks.com/help/matlab/ref/tcpclient.write.html
35gpu_support:
36 elementwise: false
37 reduction: false
38 precisions: []
39 broadcasting: "none"
40 notes: "All TCP writes execute on the host CPU. GPU-resident arguments are gathered automatically before socket I/O."
41fusion:
42 elementwise: false
43 reduction: false
44 max_inputs: 3
45 constants: "inline"
46requires_feature: null
47tested:
48 unit: "builtins::io::net::write::tests"
49---
50
51# What does the `write` function do in MATLAB / RunMat?
52`write(t, data)` transmits binary or textual data over the TCP/IP client returned by `tcpclient` (or `accept`).
53The builtin mirrors MATLAB’s `write` behaviour so existing socket code continues working without modification.
54It honours the client’s configured `Timeout`, applies the `ByteOrder` property when encoding multi-byte values,
55and accepts the optional `datatype` argument used throughout MATLAB’s I/O APIs.
56
57## How does the `write` function behave in MATLAB / RunMat?
58- `write(t, data)` converts `data` to unsigned 8-bit integers (the MATLAB default) and sends the bytes to the peer.
59 The return value is the number of elements written when the caller requests an output argument.
60- `write(t, data, datatype)` encodes the payload using the supplied MATLAB datatype token.
61 Supported values mirror MATLAB: `"uint8"` (default), `"int8"`, `"uint16"`, `"int16"`, `"uint32"`, `"int32"`,
62 `"uint64"`, `"int64"`, `"single"`, `"double"`, `"char"`, and `"string"`. Numeric conversions saturate to the
63 destination range just like MATLAB cast operations. `"char"` treats values as single-byte character codes and
64 `"string"` encodes UTF-8 text.
65- The client’s `ByteOrder` property controls how multi-byte numeric values are serialised. `"little-endian"` is
66 the default, while `"big-endian"` matches the traditional network byte order.
67- When the socket cannot send the entire payload before the timeout expires, `write` raises `MATLAB:write:Timeout`.
68 If the peer closes the connection before or during the transfer the builtin raises `MATLAB:write:ConnectionClosed`
69 and marks the client as disconnected.
70- Inputs that originate on the GPU are gathered back to the host automatically before any bytes are written.
71
72## `write` Function GPU Execution Behaviour
73Networking occurs on the host CPU. If `data` or the tcpclient struct resides on the GPU, RunMat gathers the values
74to host memory before converting them to bytes. Acceleration providers are not involved and the resulting payload
75remains on the CPU. Providers that support residency tracking automatically mark any gathered tensors as released.
76
77## Examples of using the `write` function in MATLAB / RunMat
78
79### Sending an array of bytes to an echo service
80```matlab
81client = tcpclient("127.0.0.1", 50000);
82count = write(client, uint8(1:4));
83```
84Expected output when an output argument is requested:
85```matlab
86count =
87 4
88```
89
90### Writing doubles with explicit byte order
91```matlab
92client = tcpclient("localhost", 50001, "ByteOrder", "big-endian");
93values = [1.5 2.5 3.5];
94write(client, values, "double");
95```
96The remote peer receives 24 bytes representing the doubles in big-endian order.
97
98### Transmitting ASCII text
99```matlab
100client = tcpclient("127.0.0.1", 50002);
101write(client, "RunMat TCP", "char");
102```
103Expected payload (one byte per character):
104```
10552 117 110 77 97 116 32 84 67 80
106```
107
108### Sending UTF-8 encoded strings
109```matlab
110client = tcpclient("127.0.0.1", 50003);
111write(client, "Διακριτό", "string");
112```
113The builtin encodes the Unicode text as UTF-8 before sending it across the socket.
114
115### Handling connection closures
116```matlab
117client = tcpclient("example.com", 12345, "Timeout", 0.25);
118try
119 write(client, uint8([1 2 3 4]));
120catch err
121 disp(err.identifier)
122end
123```
124Expected output when the peer closes the connection abruptly:
125```matlab
126MATLAB:write:ConnectionClosed
127```
128
129## FAQ
130
131### How many output values does `write` return?
132When the caller requests an output, the builtin returns the number of elements written (after datatype conversion).
133This mirrors the behaviour of MATLAB’s numeric I/O routines. If no output is requested, the value is discarded.
134
135### Does `write` support complex numbers?
136No. The input must be real. Pass separate real and imaginary parts or convert to a byte representation manually.
137
138### How are values rounded when converting to integer datatypes?
139Floating-point inputs are rounded to the nearest integer and then saturated to the target range, matching MATLAB
140casts (for example `uint8(255.7)` becomes `256 → 255`, `int8(-128.2)` becomes `-128`).
141
142### What happens to GPU-resident tensors?
143They are gathered automatically before the write. Networking is a CPU-only subsystem, so the resulting data is sent
144from host memory and any temporary handles are released after the transfer.
145
146### Can I stream large payloads?
147Yes. `write` loops until the entire payload has been sent or an error occurs. Large payloads honour the client’s
148timeout and byte-order settings.
149
150## See also
151[tcpclient](./tcpclient), [accept](./accept), [read](./read), [readline](./readline)
152
153## Source & feedback
154- Implementation: [`crates/runmat-runtime/src/builtins/io/net/write.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/io/net/write.rs)
155- Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) if you encounter behavioural
156 differences from MATLAB.
157"#;
158
159pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
160 name: "write",
161 op_kind: GpuOpKind::Custom("network"),
162 supported_precisions: &[],
163 broadcast: BroadcastSemantics::None,
164 provider_hooks: &[],
165 constant_strategy: ConstantStrategy::InlineLiteral,
166 residency: ResidencyPolicy::GatherImmediately,
167 nan_mode: ReductionNaN::Include,
168 two_pass_threshold: None,
169 workgroup_size: None,
170 accepts_nan_mode: false,
171 notes: "Socket writes always execute on the host CPU; GPU providers are never consulted.",
172};
173
174register_builtin_gpu_spec!(GPU_SPEC);
175
176pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
177 name: "write",
178 shape: ShapeRequirements::Any,
179 constant_strategy: ConstantStrategy::InlineLiteral,
180 elementwise: None,
181 reduction: None,
182 emits_nan: false,
183 notes: "Networking builtin executed eagerly on the CPU.",
184};
185
186register_builtin_fusion_spec!(FUSION_SPEC);
187
188#[cfg(feature = "doc_export")]
189register_builtin_doc_text!("write", DOC_MD);
190
191#[runtime_builtin(
192 name = "write",
193 category = "io/net",
194 summary = "Write numeric or text data to a TCP/IP client.",
195 keywords = "write,tcpclient,networking"
196)]
197fn write_builtin(client: Value, data: Value, rest: Vec<Value>) -> Result<Value, String> {
198 let client = gather_if_needed(&client)
199 .map_err(|err| runtime_error(MESSAGE_ID_INVALID_CLIENT, format!("write: {err}")))?;
200 let data = gather_if_needed(&data)
201 .map_err(|err| runtime_error(MESSAGE_ID_INVALID_DATA, format!("write: {err}")))?;
202
203 let mut gathered_rest = Vec::with_capacity(rest.len());
204 for value in rest {
205 gathered_rest.push(
206 gather_if_needed(&value).map_err(|err| {
207 runtime_error(MESSAGE_ID_INVALID_DATATYPE, format!("write: {err}"))
208 })?,
209 );
210 }
211 let datatype = parse_arguments(&gathered_rest)?;
212
213 let client_struct = match &client {
214 Value::Struct(st) => st,
215 _ => {
216 return Err(runtime_error(
217 MESSAGE_ID_INVALID_CLIENT,
218 "write: expected tcpclient struct as first argument",
219 ))
220 }
221 };
222
223 let client_id = extract_client_id(client_struct)?;
224 let handle = client_handle(client_id).ok_or_else(|| {
225 runtime_error(
226 MESSAGE_ID_INVALID_CLIENT,
227 "write: tcpclient handle is no longer valid",
228 )
229 })?;
230
231 let (mut stream, timeout, byte_order) = {
232 let guard = handle.lock().unwrap_or_else(|poison| poison.into_inner());
233 if !guard.connected {
234 return Err(runtime_error(
235 MESSAGE_ID_NOT_CONNECTED,
236 "write: tcpclient is disconnected",
237 ));
238 }
239 let timeout = guard.timeout;
240 let byte_order = parse_byte_order(&guard.byte_order);
241 let stream = guard.stream.try_clone().map_err(|err| {
242 runtime_error(MESSAGE_ID_INTERNAL, format!("write: clone failed ({err})"))
243 })?;
244 (stream, timeout, byte_order)
245 };
246
247 if let Err(err) = configure_stream(&stream, timeout) {
248 return Err(runtime_error(
249 MESSAGE_ID_INTERNAL,
250 format!("write: unable to configure socket timeout ({err})"),
251 ));
252 }
253
254 let payload = prepare_payload(&data, datatype, byte_order)?;
255 if payload.bytes.is_empty() {
256 return Ok(Value::Num(0.0));
257 }
258
259 match write_bytes(&mut stream, &payload.bytes) {
260 Ok(_) => Ok(Value::Num(payload.elements as f64)),
261 Err(WriteError::Timeout) => Err(runtime_error(
262 MESSAGE_ID_TIMEOUT,
263 "write: timed out while sending data",
264 )),
265 Err(WriteError::ConnectionClosed) => {
266 if let Ok(mut guard) = handle.lock() {
267 guard.connected = false;
268 }
269 Err(runtime_error(
270 MESSAGE_ID_CONNECTION_CLOSED,
271 "write: connection closed before all data was sent",
272 ))
273 }
274 Err(WriteError::Io(err)) => Err(runtime_error(
275 MESSAGE_ID_INTERNAL,
276 format!("write: socket error ({err})"),
277 )),
278 }
279}
280
281#[derive(Clone, Copy)]
282enum DataType {
283 UInt8,
284 Int8,
285 UInt16,
286 Int16,
287 UInt32,
288 Int32,
289 UInt64,
290 Int64,
291 Single,
292 Double,
293 Char,
294 String,
295}
296
297impl DataType {
298 fn default() -> Self {
299 DataType::UInt8
300 }
301
302 fn element_size(self) -> usize {
303 match self {
304 DataType::UInt8 | DataType::Int8 | DataType::Char | DataType::String => 1,
305 DataType::UInt16 | DataType::Int16 => 2,
306 DataType::UInt32 | DataType::Int32 | DataType::Single => 4,
307 DataType::UInt64 | DataType::Int64 | DataType::Double => 8,
308 }
309 }
310}
311
312#[derive(Clone, Copy)]
313enum ByteOrder {
314 Little,
315 Big,
316}
317
318struct Payload {
319 bytes: Vec<u8>,
320 elements: usize,
321}
322
323fn parse_arguments(args: &[Value]) -> Result<DataType, String> {
324 match args.len() {
325 0 => Ok(DataType::default()),
326 1 => parse_datatype(&args[0]),
327 _ => Err(runtime_error(
328 MESSAGE_ID_INVALID_DATATYPE,
329 "write: expected at most one datatype argument",
330 )),
331 }
332}
333
334fn parse_datatype(value: &Value) -> Result<DataType, String> {
335 let text =
336 scalar_string(value).map_err(|err| runtime_error(MESSAGE_ID_INVALID_DATATYPE, err))?;
337 let lowered = text.trim().to_ascii_lowercase();
338 if lowered.is_empty() {
339 return Err(runtime_error(
340 MESSAGE_ID_INVALID_DATATYPE,
341 "write: datatype must not be empty",
342 ));
343 }
344 let dtype = match lowered.as_str() {
345 "uint8" => DataType::UInt8,
346 "int8" => DataType::Int8,
347 "uint16" => DataType::UInt16,
348 "int16" => DataType::Int16,
349 "uint32" => DataType::UInt32,
350 "int32" => DataType::Int32,
351 "uint64" => DataType::UInt64,
352 "int64" => DataType::Int64,
353 "single" => DataType::Single,
354 "double" => DataType::Double,
355 "char" => DataType::Char,
356 "string" => DataType::String,
357 _ => {
358 return Err(runtime_error(
359 MESSAGE_ID_INVALID_DATATYPE,
360 format!("write: unsupported datatype '{text}'"),
361 ))
362 }
363 };
364 Ok(dtype)
365}
366
367fn prepare_payload(data: &Value, datatype: DataType, order: ByteOrder) -> Result<Payload, String> {
368 match datatype {
369 DataType::Char => char_payload(data),
370 DataType::String => string_payload(data),
371 _ => numeric_payload(data, datatype, order),
372 }
373}
374
375fn numeric_payload(data: &Value, datatype: DataType, order: ByteOrder) -> Result<Payload, String> {
376 let values = flatten_numeric(data)?;
377 let mut bytes = Vec::with_capacity(values.len() * datatype.element_size());
378 for value in values.iter().copied() {
379 match datatype {
380 DataType::UInt8 => bytes.push(cast_to_u8(value)),
381 DataType::Int8 => bytes.push(cast_to_i8(value) as u8),
382 DataType::UInt16 => extend_u16(&mut bytes, cast_to_u16(value), order),
383 DataType::Int16 => extend_i16(&mut bytes, cast_to_i16(value), order),
384 DataType::UInt32 => extend_u32(&mut bytes, cast_to_u32(value), order),
385 DataType::Int32 => extend_i32(&mut bytes, cast_to_i32(value), order),
386 DataType::UInt64 => extend_u64(&mut bytes, cast_to_u64(value), order),
387 DataType::Int64 => extend_i64(&mut bytes, cast_to_i64(value), order),
388 DataType::Single => extend_f32(&mut bytes, cast_to_f32(value), order),
389 DataType::Double => extend_f64(&mut bytes, value, order),
390 DataType::Char | DataType::String => unreachable!(),
391 }
392 }
393 Ok(Payload {
394 bytes,
395 elements: values.len(),
396 })
397}
398
399fn char_payload(data: &Value) -> Result<Payload, String> {
400 let bytes = match data {
401 Value::CharArray(ca) => ca.data.iter().map(|&ch| (ch as u32 & 0xFF) as u8).collect(),
402 Value::String(text) => text.bytes().collect(),
403 Value::StringArray(sa) => {
404 if sa.data.len() != 1 {
405 return Err(runtime_error(
406 MESSAGE_ID_INVALID_DATA,
407 "write: string array input must be scalar when using 'char'",
408 ));
409 }
410 sa.data[0].as_bytes().to_vec()
411 }
412 Value::Tensor(t) => t.data.iter().map(|&v| cast_to_u8(v)).collect::<Vec<u8>>(),
413 Value::Num(n) => vec![cast_to_u8(*n)],
414 Value::Int(iv) => vec![cast_to_u8(iv.to_f64())],
415 Value::Bool(b) => vec![if *b { 1 } else { 0 }],
416 Value::LogicalArray(la) => la
417 .data
418 .iter()
419 .map(|&b| if b != 0 { 1 } else { 0 })
420 .collect(),
421 _ => {
422 return Err(runtime_error(
423 MESSAGE_ID_INVALID_DATA,
424 "write: unsupported input for 'char' datatype",
425 ))
426 }
427 };
428 Ok(Payload {
429 elements: bytes.len(),
430 bytes,
431 })
432}
433
434fn string_payload(data: &Value) -> Result<Payload, String> {
435 match data {
436 Value::String(text) => Ok(Payload {
437 elements: 1,
438 bytes: text.as_bytes().to_vec(),
439 }),
440 Value::CharArray(ca) => {
441 let string: String = ca.data.iter().collect();
442 Ok(Payload {
443 elements: 1,
444 bytes: string.into_bytes(),
445 })
446 }
447 Value::StringArray(sa) => {
448 if sa.data.is_empty() {
449 return Ok(Payload {
450 elements: 0,
451 bytes: Vec::new(),
452 });
453 }
454 if sa.data.len() != 1 {
455 return Err(runtime_error(
456 MESSAGE_ID_INVALID_DATA,
457 "write: string array input must be scalar when using 'string'",
458 ));
459 }
460 Ok(Payload {
461 elements: 1,
462 bytes: sa.data[0].as_bytes().to_vec(),
463 })
464 }
465 _ => Err(runtime_error(
466 MESSAGE_ID_INVALID_DATA,
467 "write: expected text input when using 'string' datatype",
468 )),
469 }
470}
471
472fn flatten_numeric(value: &Value) -> Result<Vec<f64>, String> {
473 match value {
474 Value::Tensor(t) => Ok(t.data.clone()),
475 Value::Num(n) => Ok(vec![*n]),
476 Value::Int(iv) => Ok(vec![iv.to_f64()]),
477 Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
478 Value::LogicalArray(la) => Ok(la
479 .data
480 .iter()
481 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
482 .collect()),
483 Value::CharArray(ca) => Ok(ca
484 .data
485 .iter()
486 .map(|&ch| (ch as u32 & 0xFF) as f64)
487 .collect()),
488 Value::String(text) => Ok(text.chars().map(|ch| (ch as u32) as f64).collect()),
489 Value::StringArray(sa) => {
490 if sa.data.len() != 1 {
491 return Err(runtime_error(
492 MESSAGE_ID_INVALID_DATA,
493 "write: string array input must be scalar",
494 ));
495 }
496 Ok(sa.data[0].chars().map(|ch| (ch as u32) as f64).collect())
497 }
498 Value::Complex(_, _) | Value::ComplexTensor(_) => Err(runtime_error(
499 MESSAGE_ID_INVALID_DATA,
500 "write: complex data is not supported",
501 )),
502 Value::Cell(_)
503 | Value::Struct(_)
504 | Value::Object(_)
505 | Value::HandleObject(_)
506 | Value::Listener(_)
507 | Value::FunctionHandle(_)
508 | Value::Closure(_)
509 | Value::ClassRef(_)
510 | Value::MException(_) => Err(runtime_error(
511 MESSAGE_ID_INVALID_DATA,
512 "write: unsupported input type",
513 )),
514 Value::GpuTensor(_) => Err(runtime_error(
515 MESSAGE_ID_INVALID_DATA,
516 "write: GPU tensor should have been gathered before encoding",
517 )),
518 }
519}
520
521fn cast_to_u8(value: f64) -> u8 {
522 let rounded = rounded_scalar(value);
523 if !rounded.is_finite() {
524 return if rounded.is_sign_negative() {
525 0
526 } else {
527 u8::MAX
528 };
529 }
530 if rounded < 0.0 {
531 0
532 } else if rounded > u8::MAX as f64 {
533 u8::MAX
534 } else {
535 rounded as u8
536 }
537}
538
539fn cast_to_i8(value: f64) -> i8 {
540 let rounded = rounded_scalar(value);
541 if !rounded.is_finite() {
542 return if rounded.is_sign_negative() {
543 i8::MIN
544 } else {
545 i8::MAX
546 };
547 }
548 if rounded < i8::MIN as f64 {
549 i8::MIN
550 } else if rounded > i8::MAX as f64 {
551 i8::MAX
552 } else {
553 rounded as i8
554 }
555}
556
557fn cast_to_u16(value: f64) -> u16 {
558 let rounded = rounded_scalar(value);
559 if !rounded.is_finite() {
560 return if rounded.is_sign_negative() {
561 0
562 } else {
563 u16::MAX
564 };
565 }
566 if rounded < 0.0 {
567 0
568 } else if rounded > u16::MAX as f64 {
569 u16::MAX
570 } else {
571 rounded as u16
572 }
573}
574
575fn cast_to_i16(value: f64) -> i16 {
576 let rounded = rounded_scalar(value);
577 if !rounded.is_finite() {
578 return if rounded.is_sign_negative() {
579 i16::MIN
580 } else {
581 i16::MAX
582 };
583 }
584 if rounded < i16::MIN as f64 {
585 i16::MIN
586 } else if rounded > i16::MAX as f64 {
587 i16::MAX
588 } else {
589 rounded as i16
590 }
591}
592
593fn cast_to_u32(value: f64) -> u32 {
594 let rounded = rounded_scalar(value);
595 if !rounded.is_finite() {
596 return if rounded.is_sign_negative() {
597 0
598 } else {
599 u32::MAX
600 };
601 }
602 if rounded < 0.0 {
603 0
604 } else if rounded > u32::MAX as f64 {
605 u32::MAX
606 } else {
607 rounded as u32
608 }
609}
610
611fn cast_to_i32(value: f64) -> i32 {
612 let rounded = rounded_scalar(value);
613 if !rounded.is_finite() {
614 return if rounded.is_sign_negative() {
615 i32::MIN
616 } else {
617 i32::MAX
618 };
619 }
620 if rounded < i32::MIN as f64 {
621 i32::MIN
622 } else if rounded > i32::MAX as f64 {
623 i32::MAX
624 } else {
625 rounded as i32
626 }
627}
628
629fn cast_to_u64(value: f64) -> u64 {
630 let rounded = rounded_scalar(value);
631 if !rounded.is_finite() {
632 return if rounded.is_sign_negative() {
633 0
634 } else {
635 u64::MAX
636 };
637 }
638 if rounded < 0.0 {
639 0
640 } else if rounded > u64::MAX as f64 {
641 u64::MAX
642 } else {
643 rounded as u64
644 }
645}
646
647fn cast_to_i64(value: f64) -> i64 {
648 let rounded = rounded_scalar(value);
649 if !rounded.is_finite() {
650 return if rounded.is_sign_negative() {
651 i64::MIN
652 } else {
653 i64::MAX
654 };
655 }
656 if rounded < i64::MIN as f64 {
657 i64::MIN
658 } else if rounded > i64::MAX as f64 {
659 i64::MAX
660 } else {
661 rounded as i64
662 }
663}
664
665fn cast_to_f32(value: f64) -> f32 {
666 value as f32
667}
668
669fn rounded_scalar(value: f64) -> f64 {
670 if value.is_nan() {
671 0.0
672 } else {
673 value.round()
674 }
675}
676
677fn extend_u16(buffer: &mut Vec<u8>, value: u16, order: ByteOrder) {
678 match order {
679 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
680 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
681 }
682}
683
684fn extend_i16(buffer: &mut Vec<u8>, value: i16, order: ByteOrder) {
685 match order {
686 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
687 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
688 }
689}
690
691fn extend_u32(buffer: &mut Vec<u8>, value: u32, order: ByteOrder) {
692 match order {
693 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
694 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
695 }
696}
697
698fn extend_i32(buffer: &mut Vec<u8>, value: i32, order: ByteOrder) {
699 match order {
700 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
701 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
702 }
703}
704
705fn extend_u64(buffer: &mut Vec<u8>, value: u64, order: ByteOrder) {
706 match order {
707 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
708 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
709 }
710}
711
712fn extend_i64(buffer: &mut Vec<u8>, value: i64, order: ByteOrder) {
713 match order {
714 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
715 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
716 }
717}
718
719fn extend_f32(buffer: &mut Vec<u8>, value: f32, order: ByteOrder) {
720 match order {
721 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
722 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
723 }
724}
725
726fn extend_f64(buffer: &mut Vec<u8>, value: f64, order: ByteOrder) {
727 match order {
728 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
729 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
730 }
731}
732
733fn parse_byte_order(text: &str) -> ByteOrder {
734 if text.eq_ignore_ascii_case("big-endian") || text.eq_ignore_ascii_case("big endian") {
735 ByteOrder::Big
736 } else {
737 ByteOrder::Little
738 }
739}
740
741fn scalar_string(value: &Value) -> Result<String, String> {
742 match value {
743 Value::String(s) => Ok(s.clone()),
744 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
745 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
746 _ => Err(
747 "write: datatype argument must be a string scalar or character row vector".to_string(),
748 ),
749 }
750}
751
752fn extract_client_id(struct_value: &StructValue) -> Result<u64, String> {
753 let id_value = struct_value
754 .fields
755 .get(CLIENT_HANDLE_FIELD)
756 .ok_or_else(|| {
757 runtime_error(
758 MESSAGE_ID_INVALID_CLIENT,
759 "write: tcpclient struct is missing internal handle",
760 )
761 })?;
762 match id_value {
763 Value::Int(IntValue::U64(id)) => Ok(*id),
764 Value::Int(iv) => Ok(iv.to_i64() as u64),
765 _ => Err(runtime_error(
766 MESSAGE_ID_INVALID_CLIENT,
767 "write: tcpclient struct has invalid handle field",
768 )),
769 }
770}
771
772fn runtime_error(message_id: &'static str, message: impl Into<String>) -> String {
773 format!("{message_id}: {}", message.into())
774}
775
776enum WriteError {
777 Timeout,
778 ConnectionClosed,
779 Io(io::Error),
780}
781
782fn write_bytes(stream: &mut TcpStream, bytes: &[u8]) -> Result<(), WriteError> {
783 let mut offset = 0usize;
784 while offset < bytes.len() {
785 match stream.write(&bytes[offset..]) {
786 Ok(0) => return Err(WriteError::ConnectionClosed),
787 Ok(n) => offset += n,
788 Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
789 Err(err) if is_timeout(&err) => return Err(WriteError::Timeout),
790 Err(err) if is_connection_closed_error(&err) => {
791 return Err(WriteError::ConnectionClosed)
792 }
793 Err(err) => return Err(WriteError::Io(err)),
794 }
795 }
796 Ok(())
797}
798
799fn is_timeout(err: &io::Error) -> bool {
800 matches!(
801 err.kind(),
802 io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
803 )
804}
805
806fn is_connection_closed_error(err: &io::Error) -> bool {
807 matches!(
808 err.kind(),
809 io::ErrorKind::BrokenPipe
810 | io::ErrorKind::ConnectionReset
811 | io::ErrorKind::ConnectionAborted
812 | io::ErrorKind::NotConnected
813 | io::ErrorKind::UnexpectedEof
814 )
815}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820 #[cfg(feature = "doc_export")]
821 use crate::builtins::common::test_support;
822 use crate::builtins::io::net::accept::{
823 configure_stream, insert_client, remove_client_for_test,
824 };
825 use runmat_builtins::{CharArray, IntValue, StructValue, Tensor};
826 use std::io::Read;
827 use std::net::{TcpListener, TcpStream};
828 use std::sync::{Arc, Barrier};
829 use std::thread;
830
831 fn make_client(stream: TcpStream, timeout: f64, byte_order: &str) -> Value {
832 let peer_addr = stream.peer_addr().expect("peer addr");
833 configure_stream(&stream, timeout).expect("configure stream");
834 let client_id = insert_client(stream, 0, peer_addr, timeout, byte_order.to_string());
835 let mut st = StructValue::new();
836 st.fields.insert(
837 CLIENT_HANDLE_FIELD.to_string(),
838 Value::Int(IntValue::U64(client_id)),
839 );
840 Value::Struct(st)
841 }
842
843 fn client_id(client: &Value) -> u64 {
844 match client {
845 Value::Struct(st) => match st.fields.get(CLIENT_HANDLE_FIELD) {
846 Some(Value::Int(IntValue::U64(id))) => *id,
847 Some(Value::Int(iv)) => iv.to_i64() as u64,
848 other => panic!("unexpected id field {other:?}"),
849 },
850 other => panic!("expected struct, got {other:?}"),
851 }
852 }
853
854 #[test]
855 fn write_default_uint8_sends_bytes() {
856 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
857 let port = listener.local_addr().unwrap().port();
858 let handle = thread::spawn(move || {
859 let (mut stream, _) = listener.accept().expect("accept");
860 let mut received = Vec::new();
861 stream.read_to_end(&mut received).unwrap_or_default();
862 received
863 });
864
865 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
866 let client = make_client(stream, 1.0, "little-endian");
867 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
868 let result =
869 write_builtin(client.clone(), Value::Tensor(tensor), Vec::new()).expect("write");
870 match result {
871 Value::Num(count) => assert_eq!(count, 4.0),
872 other => panic!("expected numeric result, got {other:?}"),
873 }
874 remove_client_for_test(client_id(&client));
875 let received = handle.join().expect("join");
876 assert_eq!(received, vec![1, 2, 3, 4]);
877 }
878
879 #[test]
880 fn write_double_big_endian_encodes_correctly() {
881 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
882 let port = listener.local_addr().unwrap().port();
883 let handle = thread::spawn(move || {
884 let (mut stream, _) = listener.accept().expect("accept");
885 let mut buf = [0u8; 24];
886 stream.read_exact(&mut buf).expect("read");
887 buf
888 });
889
890 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
891 let client = make_client(stream, 1.0, "big-endian");
892 let tensor = Tensor::new(vec![1.5, 2.5, 3.5], vec![1, 3]).unwrap();
893 let result = write_builtin(
894 client.clone(),
895 Value::Tensor(tensor),
896 vec![Value::from("double")],
897 )
898 .expect("write");
899 match result {
900 Value::Num(count) => assert_eq!(count, 3.0),
901 other => panic!("expected numeric count, got {other:?}"),
902 }
903 remove_client_for_test(client_id(&client));
904
905 let received = handle.join().expect("join");
906 let mut expected = Vec::new();
907 extend_f64(&mut expected, 1.5, ByteOrder::Big);
908 extend_f64(&mut expected, 2.5, ByteOrder::Big);
909 extend_f64(&mut expected, 3.5, ByteOrder::Big);
910 assert_eq!(received.to_vec(), expected);
911 }
912
913 #[test]
914 fn write_char_payload_encodes_ascii() {
915 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
916 let port = listener.local_addr().unwrap().port();
917 let handle = thread::spawn(move || {
918 let (mut stream, _) = listener.accept().expect("accept");
919 let mut buf = Vec::new();
920 stream.read_to_end(&mut buf).unwrap_or_default();
921 buf
922 });
923
924 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
925 let client = make_client(stream, 1.0, "little-endian");
926 let chars = CharArray::new("RunMat".chars().collect(), 1, 6).unwrap();
927 let result = write_builtin(
928 client.clone(),
929 Value::CharArray(chars),
930 vec![Value::from("char")],
931 )
932 .expect("write");
933 match result {
934 Value::Num(count) => assert_eq!(count, 6.0),
935 other => panic!("expected numeric count, got {other:?}"),
936 }
937 remove_client_for_test(client_id(&client));
938 let received = handle.join().expect("join");
939 assert_eq!(received, b"RunMat");
940 }
941
942 #[test]
943 fn write_errors_when_client_disconnected() {
944 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
945 let port = listener.local_addr().unwrap().port();
946 let barrier = Arc::new(Barrier::new(2));
947 let thread_barrier = barrier.clone();
948 let handle = thread::spawn(move || {
949 let (stream, _) = listener.accept().expect("accept");
950 thread_barrier.wait();
951 drop(stream);
952 });
953
954 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
955 let client = make_client(stream, 1.0, "little-endian");
956 let id = client_id(&client);
957 if let Some(handle_ref) = client_handle(id) {
958 if let Ok(mut guard) = handle_ref.lock() {
959 guard.connected = false;
960 }
961 }
962
963 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
964 let err =
965 write_builtin(client.clone(), Value::Tensor(tensor), Vec::new()).expect_err("write");
966 assert!(
967 err.starts_with(MESSAGE_ID_NOT_CONNECTED),
968 "unexpected error {err}"
969 );
970
971 remove_client_for_test(id);
972 barrier.wait();
973 handle.join().expect("join");
974 }
975
976 #[test]
977 #[cfg(feature = "doc_export")]
978 fn doc_examples_compile() {
979 let blocks = test_support::doc_examples(DOC_MD);
980 assert!(!blocks.is_empty());
981 }
982}