runmat_runtime/builtins/io/net/
write.rs

1//! MATLAB-compatible `write` builtin for TCP/IP clients in RunMat.
2
3use runmat_builtins::{IntValue, StructValue, Value};
4use runmat_macros::runtime_builtin;
5use std::io::{self, Write};
6use std::net::TcpStream;
7
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::{gather_if_needed, register_builtin_fusion_spec, register_builtin_gpu_spec};
13
14use super::accept::{client_handle, configure_stream, CLIENT_HANDLE_FIELD};
15
16#[cfg(feature = "doc_export")]
17use crate::register_builtin_doc_text;
18
19const MESSAGE_ID_INVALID_CLIENT: &str = "MATLAB:write:InvalidTcpClient";
20const MESSAGE_ID_INVALID_DATA: &str = "MATLAB:write:InvalidData";
21const MESSAGE_ID_INVALID_DATATYPE: &str = "MATLAB:write:InvalidDataType";
22const MESSAGE_ID_NOT_CONNECTED: &str = "MATLAB:write:NotConnected";
23const MESSAGE_ID_TIMEOUT: &str = "MATLAB:write:Timeout";
24const MESSAGE_ID_CONNECTION_CLOSED: &str = "MATLAB:write:ConnectionClosed";
25const MESSAGE_ID_INTERNAL: &str = "MATLAB:write:InternalError";
26
27#[cfg(feature = "doc_export")]
28pub const DOC_MD: &str = r#"---
29title: "write"
30category: "io/net"
31keywords: ["write", "tcpclient", "networking", "socket", "binary data", "text"]
32summary: "Write numeric or text data to a remote host through a MATLAB-compatible tcpclient struct."
33references:
34  - https://www.mathworks.com/help/matlab/ref/tcpclient.write.html
35gpu_support:
36  elementwise: false
37  reduction: false
38  precisions: []
39  broadcasting: "none"
40  notes: "All TCP writes execute on the host CPU. GPU-resident arguments are gathered automatically before socket I/O."
41fusion:
42  elementwise: false
43  reduction: false
44  max_inputs: 3
45  constants: "inline"
46requires_feature: null
47tested:
48  unit: "builtins::io::net::write::tests"
49---
50
51# What does the `write` function do in MATLAB / RunMat?
52`write(t, data)` transmits binary or textual data over the TCP/IP client returned by `tcpclient` (or `accept`).
53The builtin mirrors MATLAB’s `write` behaviour so existing socket code continues working without modification.
54It honours the client’s configured `Timeout`, applies the `ByteOrder` property when encoding multi-byte values,
55and accepts the optional `datatype` argument used throughout MATLAB’s I/O APIs.
56
57## How does the `write` function behave in MATLAB / RunMat?
58- `write(t, data)` converts `data` to unsigned 8-bit integers (the MATLAB default) and sends the bytes to the peer.
59  The return value is the number of elements written when the caller requests an output argument.
60- `write(t, data, datatype)` encodes the payload using the supplied MATLAB datatype token.
61  Supported values mirror MATLAB: `"uint8"` (default), `"int8"`, `"uint16"`, `"int16"`, `"uint32"`, `"int32"`,
62  `"uint64"`, `"int64"`, `"single"`, `"double"`, `"char"`, and `"string"`. Numeric conversions saturate to the
63  destination range just like MATLAB cast operations. `"char"` treats values as single-byte character codes and
64  `"string"` encodes UTF-8 text.
65- The client’s `ByteOrder` property controls how multi-byte numeric values are serialised. `"little-endian"` is
66  the default, while `"big-endian"` matches the traditional network byte order.
67- When the socket cannot send the entire payload before the timeout expires, `write` raises `MATLAB:write:Timeout`.
68  If the peer closes the connection before or during the transfer the builtin raises `MATLAB:write:ConnectionClosed`
69  and marks the client as disconnected.
70- Inputs that originate on the GPU are gathered back to the host automatically before any bytes are written.
71
72## `write` Function GPU Execution Behaviour
73Networking occurs on the host CPU. If `data` or the tcpclient struct resides on the GPU, RunMat gathers the values
74to host memory before converting them to bytes. Acceleration providers are not involved and the resulting payload
75remains on the CPU. Providers that support residency tracking automatically mark any gathered tensors as released.
76
77## Examples of using the `write` function in MATLAB / RunMat
78
79### Sending an array of bytes to an echo service
80```matlab
81client = tcpclient("127.0.0.1", 50000);
82count = write(client, uint8(1:4));
83```
84Expected output when an output argument is requested:
85```matlab
86count =
87     4
88```
89
90### Writing doubles with explicit byte order
91```matlab
92client = tcpclient("localhost", 50001, "ByteOrder", "big-endian");
93values = [1.5 2.5 3.5];
94write(client, values, "double");
95```
96The remote peer receives 24 bytes representing the doubles in big-endian order.
97
98### Transmitting ASCII text
99```matlab
100client = tcpclient("127.0.0.1", 50002);
101write(client, "RunMat TCP", "char");
102```
103Expected payload (one byte per character):
104```
10552 117 110 77 97 116 32 84 67 80
106```
107
108### Sending UTF-8 encoded strings
109```matlab
110client = tcpclient("127.0.0.1", 50003);
111write(client, "Διακριτό", "string");
112```
113The builtin encodes the Unicode text as UTF-8 before sending it across the socket.
114
115### Handling connection closures
116```matlab
117client = tcpclient("example.com", 12345, "Timeout", 0.25);
118try
119    write(client, uint8([1 2 3 4]));
120catch err
121    disp(err.identifier)
122end
123```
124Expected output when the peer closes the connection abruptly:
125```matlab
126MATLAB:write:ConnectionClosed
127```
128
129## FAQ
130
131### How many output values does `write` return?
132When the caller requests an output, the builtin returns the number of elements written (after datatype conversion).
133This mirrors the behaviour of MATLAB’s numeric I/O routines. If no output is requested, the value is discarded.
134
135### Does `write` support complex numbers?
136No. The input must be real. Pass separate real and imaginary parts or convert to a byte representation manually.
137
138### How are values rounded when converting to integer datatypes?
139Floating-point inputs are rounded to the nearest integer and then saturated to the target range, matching MATLAB
140casts (for example `uint8(255.7)` becomes `256 → 255`, `int8(-128.2)` becomes `-128`).
141
142### What happens to GPU-resident tensors?
143They are gathered automatically before the write. Networking is a CPU-only subsystem, so the resulting data is sent
144from host memory and any temporary handles are released after the transfer.
145
146### Can I stream large payloads?
147Yes. `write` loops until the entire payload has been sent or an error occurs. Large payloads honour the client’s
148timeout and byte-order settings.
149
150## See also
151[tcpclient](./tcpclient), [accept](./accept), [read](./read), [readline](./readline)
152
153## Source & feedback
154- Implementation: [`crates/runmat-runtime/src/builtins/io/net/write.rs`](https://github.com/runmat-org/runmat/blob/main/crates/runmat-runtime/src/builtins/io/net/write.rs)
155- Please [open an issue](https://github.com/runmat-org/runmat/issues/new/choose) if you encounter behavioural
156  differences from MATLAB.
157"#;
158
159pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
160    name: "write",
161    op_kind: GpuOpKind::Custom("network"),
162    supported_precisions: &[],
163    broadcast: BroadcastSemantics::None,
164    provider_hooks: &[],
165    constant_strategy: ConstantStrategy::InlineLiteral,
166    residency: ResidencyPolicy::GatherImmediately,
167    nan_mode: ReductionNaN::Include,
168    two_pass_threshold: None,
169    workgroup_size: None,
170    accepts_nan_mode: false,
171    notes: "Socket writes always execute on the host CPU; GPU providers are never consulted.",
172};
173
174register_builtin_gpu_spec!(GPU_SPEC);
175
176pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
177    name: "write",
178    shape: ShapeRequirements::Any,
179    constant_strategy: ConstantStrategy::InlineLiteral,
180    elementwise: None,
181    reduction: None,
182    emits_nan: false,
183    notes: "Networking builtin executed eagerly on the CPU.",
184};
185
186register_builtin_fusion_spec!(FUSION_SPEC);
187
188#[cfg(feature = "doc_export")]
189register_builtin_doc_text!("write", DOC_MD);
190
191#[runtime_builtin(
192    name = "write",
193    category = "io/net",
194    summary = "Write numeric or text data to a TCP/IP client.",
195    keywords = "write,tcpclient,networking"
196)]
197fn write_builtin(client: Value, data: Value, rest: Vec<Value>) -> Result<Value, String> {
198    let client = gather_if_needed(&client)
199        .map_err(|err| runtime_error(MESSAGE_ID_INVALID_CLIENT, format!("write: {err}")))?;
200    let data = gather_if_needed(&data)
201        .map_err(|err| runtime_error(MESSAGE_ID_INVALID_DATA, format!("write: {err}")))?;
202
203    let mut gathered_rest = Vec::with_capacity(rest.len());
204    for value in rest {
205        gathered_rest.push(
206            gather_if_needed(&value).map_err(|err| {
207                runtime_error(MESSAGE_ID_INVALID_DATATYPE, format!("write: {err}"))
208            })?,
209        );
210    }
211    let datatype = parse_arguments(&gathered_rest)?;
212
213    let client_struct = match &client {
214        Value::Struct(st) => st,
215        _ => {
216            return Err(runtime_error(
217                MESSAGE_ID_INVALID_CLIENT,
218                "write: expected tcpclient struct as first argument",
219            ))
220        }
221    };
222
223    let client_id = extract_client_id(client_struct)?;
224    let handle = client_handle(client_id).ok_or_else(|| {
225        runtime_error(
226            MESSAGE_ID_INVALID_CLIENT,
227            "write: tcpclient handle is no longer valid",
228        )
229    })?;
230
231    let (mut stream, timeout, byte_order) = {
232        let guard = handle.lock().unwrap_or_else(|poison| poison.into_inner());
233        if !guard.connected {
234            return Err(runtime_error(
235                MESSAGE_ID_NOT_CONNECTED,
236                "write: tcpclient is disconnected",
237            ));
238        }
239        let timeout = guard.timeout;
240        let byte_order = parse_byte_order(&guard.byte_order);
241        let stream = guard.stream.try_clone().map_err(|err| {
242            runtime_error(MESSAGE_ID_INTERNAL, format!("write: clone failed ({err})"))
243        })?;
244        (stream, timeout, byte_order)
245    };
246
247    if let Err(err) = configure_stream(&stream, timeout) {
248        return Err(runtime_error(
249            MESSAGE_ID_INTERNAL,
250            format!("write: unable to configure socket timeout ({err})"),
251        ));
252    }
253
254    let payload = prepare_payload(&data, datatype, byte_order)?;
255    if payload.bytes.is_empty() {
256        return Ok(Value::Num(0.0));
257    }
258
259    match write_bytes(&mut stream, &payload.bytes) {
260        Ok(_) => Ok(Value::Num(payload.elements as f64)),
261        Err(WriteError::Timeout) => Err(runtime_error(
262            MESSAGE_ID_TIMEOUT,
263            "write: timed out while sending data",
264        )),
265        Err(WriteError::ConnectionClosed) => {
266            if let Ok(mut guard) = handle.lock() {
267                guard.connected = false;
268            }
269            Err(runtime_error(
270                MESSAGE_ID_CONNECTION_CLOSED,
271                "write: connection closed before all data was sent",
272            ))
273        }
274        Err(WriteError::Io(err)) => Err(runtime_error(
275            MESSAGE_ID_INTERNAL,
276            format!("write: socket error ({err})"),
277        )),
278    }
279}
280
281#[derive(Clone, Copy)]
282enum DataType {
283    UInt8,
284    Int8,
285    UInt16,
286    Int16,
287    UInt32,
288    Int32,
289    UInt64,
290    Int64,
291    Single,
292    Double,
293    Char,
294    String,
295}
296
297impl DataType {
298    fn default() -> Self {
299        DataType::UInt8
300    }
301
302    fn element_size(self) -> usize {
303        match self {
304            DataType::UInt8 | DataType::Int8 | DataType::Char | DataType::String => 1,
305            DataType::UInt16 | DataType::Int16 => 2,
306            DataType::UInt32 | DataType::Int32 | DataType::Single => 4,
307            DataType::UInt64 | DataType::Int64 | DataType::Double => 8,
308        }
309    }
310}
311
312#[derive(Clone, Copy)]
313enum ByteOrder {
314    Little,
315    Big,
316}
317
318struct Payload {
319    bytes: Vec<u8>,
320    elements: usize,
321}
322
323fn parse_arguments(args: &[Value]) -> Result<DataType, String> {
324    match args.len() {
325        0 => Ok(DataType::default()),
326        1 => parse_datatype(&args[0]),
327        _ => Err(runtime_error(
328            MESSAGE_ID_INVALID_DATATYPE,
329            "write: expected at most one datatype argument",
330        )),
331    }
332}
333
334fn parse_datatype(value: &Value) -> Result<DataType, String> {
335    let text =
336        scalar_string(value).map_err(|err| runtime_error(MESSAGE_ID_INVALID_DATATYPE, err))?;
337    let lowered = text.trim().to_ascii_lowercase();
338    if lowered.is_empty() {
339        return Err(runtime_error(
340            MESSAGE_ID_INVALID_DATATYPE,
341            "write: datatype must not be empty",
342        ));
343    }
344    let dtype = match lowered.as_str() {
345        "uint8" => DataType::UInt8,
346        "int8" => DataType::Int8,
347        "uint16" => DataType::UInt16,
348        "int16" => DataType::Int16,
349        "uint32" => DataType::UInt32,
350        "int32" => DataType::Int32,
351        "uint64" => DataType::UInt64,
352        "int64" => DataType::Int64,
353        "single" => DataType::Single,
354        "double" => DataType::Double,
355        "char" => DataType::Char,
356        "string" => DataType::String,
357        _ => {
358            return Err(runtime_error(
359                MESSAGE_ID_INVALID_DATATYPE,
360                format!("write: unsupported datatype '{text}'"),
361            ))
362        }
363    };
364    Ok(dtype)
365}
366
367fn prepare_payload(data: &Value, datatype: DataType, order: ByteOrder) -> Result<Payload, String> {
368    match datatype {
369        DataType::Char => char_payload(data),
370        DataType::String => string_payload(data),
371        _ => numeric_payload(data, datatype, order),
372    }
373}
374
375fn numeric_payload(data: &Value, datatype: DataType, order: ByteOrder) -> Result<Payload, String> {
376    let values = flatten_numeric(data)?;
377    let mut bytes = Vec::with_capacity(values.len() * datatype.element_size());
378    for value in values.iter().copied() {
379        match datatype {
380            DataType::UInt8 => bytes.push(cast_to_u8(value)),
381            DataType::Int8 => bytes.push(cast_to_i8(value) as u8),
382            DataType::UInt16 => extend_u16(&mut bytes, cast_to_u16(value), order),
383            DataType::Int16 => extend_i16(&mut bytes, cast_to_i16(value), order),
384            DataType::UInt32 => extend_u32(&mut bytes, cast_to_u32(value), order),
385            DataType::Int32 => extend_i32(&mut bytes, cast_to_i32(value), order),
386            DataType::UInt64 => extend_u64(&mut bytes, cast_to_u64(value), order),
387            DataType::Int64 => extend_i64(&mut bytes, cast_to_i64(value), order),
388            DataType::Single => extend_f32(&mut bytes, cast_to_f32(value), order),
389            DataType::Double => extend_f64(&mut bytes, value, order),
390            DataType::Char | DataType::String => unreachable!(),
391        }
392    }
393    Ok(Payload {
394        bytes,
395        elements: values.len(),
396    })
397}
398
399fn char_payload(data: &Value) -> Result<Payload, String> {
400    let bytes = match data {
401        Value::CharArray(ca) => ca.data.iter().map(|&ch| (ch as u32 & 0xFF) as u8).collect(),
402        Value::String(text) => text.bytes().collect(),
403        Value::StringArray(sa) => {
404            if sa.data.len() != 1 {
405                return Err(runtime_error(
406                    MESSAGE_ID_INVALID_DATA,
407                    "write: string array input must be scalar when using 'char'",
408                ));
409            }
410            sa.data[0].as_bytes().to_vec()
411        }
412        Value::Tensor(t) => t.data.iter().map(|&v| cast_to_u8(v)).collect::<Vec<u8>>(),
413        Value::Num(n) => vec![cast_to_u8(*n)],
414        Value::Int(iv) => vec![cast_to_u8(iv.to_f64())],
415        Value::Bool(b) => vec![if *b { 1 } else { 0 }],
416        Value::LogicalArray(la) => la
417            .data
418            .iter()
419            .map(|&b| if b != 0 { 1 } else { 0 })
420            .collect(),
421        _ => {
422            return Err(runtime_error(
423                MESSAGE_ID_INVALID_DATA,
424                "write: unsupported input for 'char' datatype",
425            ))
426        }
427    };
428    Ok(Payload {
429        elements: bytes.len(),
430        bytes,
431    })
432}
433
434fn string_payload(data: &Value) -> Result<Payload, String> {
435    match data {
436        Value::String(text) => Ok(Payload {
437            elements: 1,
438            bytes: text.as_bytes().to_vec(),
439        }),
440        Value::CharArray(ca) => {
441            let string: String = ca.data.iter().collect();
442            Ok(Payload {
443                elements: 1,
444                bytes: string.into_bytes(),
445            })
446        }
447        Value::StringArray(sa) => {
448            if sa.data.is_empty() {
449                return Ok(Payload {
450                    elements: 0,
451                    bytes: Vec::new(),
452                });
453            }
454            if sa.data.len() != 1 {
455                return Err(runtime_error(
456                    MESSAGE_ID_INVALID_DATA,
457                    "write: string array input must be scalar when using 'string'",
458                ));
459            }
460            Ok(Payload {
461                elements: 1,
462                bytes: sa.data[0].as_bytes().to_vec(),
463            })
464        }
465        _ => Err(runtime_error(
466            MESSAGE_ID_INVALID_DATA,
467            "write: expected text input when using 'string' datatype",
468        )),
469    }
470}
471
472fn flatten_numeric(value: &Value) -> Result<Vec<f64>, String> {
473    match value {
474        Value::Tensor(t) => Ok(t.data.clone()),
475        Value::Num(n) => Ok(vec![*n]),
476        Value::Int(iv) => Ok(vec![iv.to_f64()]),
477        Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
478        Value::LogicalArray(la) => Ok(la
479            .data
480            .iter()
481            .map(|&b| if b != 0 { 1.0 } else { 0.0 })
482            .collect()),
483        Value::CharArray(ca) => Ok(ca
484            .data
485            .iter()
486            .map(|&ch| (ch as u32 & 0xFF) as f64)
487            .collect()),
488        Value::String(text) => Ok(text.chars().map(|ch| (ch as u32) as f64).collect()),
489        Value::StringArray(sa) => {
490            if sa.data.len() != 1 {
491                return Err(runtime_error(
492                    MESSAGE_ID_INVALID_DATA,
493                    "write: string array input must be scalar",
494                ));
495            }
496            Ok(sa.data[0].chars().map(|ch| (ch as u32) as f64).collect())
497        }
498        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(runtime_error(
499            MESSAGE_ID_INVALID_DATA,
500            "write: complex data is not supported",
501        )),
502        Value::Cell(_)
503        | Value::Struct(_)
504        | Value::Object(_)
505        | Value::HandleObject(_)
506        | Value::Listener(_)
507        | Value::FunctionHandle(_)
508        | Value::Closure(_)
509        | Value::ClassRef(_)
510        | Value::MException(_) => Err(runtime_error(
511            MESSAGE_ID_INVALID_DATA,
512            "write: unsupported input type",
513        )),
514        Value::GpuTensor(_) => Err(runtime_error(
515            MESSAGE_ID_INVALID_DATA,
516            "write: GPU tensor should have been gathered before encoding",
517        )),
518    }
519}
520
521fn cast_to_u8(value: f64) -> u8 {
522    let rounded = rounded_scalar(value);
523    if !rounded.is_finite() {
524        return if rounded.is_sign_negative() {
525            0
526        } else {
527            u8::MAX
528        };
529    }
530    if rounded < 0.0 {
531        0
532    } else if rounded > u8::MAX as f64 {
533        u8::MAX
534    } else {
535        rounded as u8
536    }
537}
538
539fn cast_to_i8(value: f64) -> i8 {
540    let rounded = rounded_scalar(value);
541    if !rounded.is_finite() {
542        return if rounded.is_sign_negative() {
543            i8::MIN
544        } else {
545            i8::MAX
546        };
547    }
548    if rounded < i8::MIN as f64 {
549        i8::MIN
550    } else if rounded > i8::MAX as f64 {
551        i8::MAX
552    } else {
553        rounded as i8
554    }
555}
556
557fn cast_to_u16(value: f64) -> u16 {
558    let rounded = rounded_scalar(value);
559    if !rounded.is_finite() {
560        return if rounded.is_sign_negative() {
561            0
562        } else {
563            u16::MAX
564        };
565    }
566    if rounded < 0.0 {
567        0
568    } else if rounded > u16::MAX as f64 {
569        u16::MAX
570    } else {
571        rounded as u16
572    }
573}
574
575fn cast_to_i16(value: f64) -> i16 {
576    let rounded = rounded_scalar(value);
577    if !rounded.is_finite() {
578        return if rounded.is_sign_negative() {
579            i16::MIN
580        } else {
581            i16::MAX
582        };
583    }
584    if rounded < i16::MIN as f64 {
585        i16::MIN
586    } else if rounded > i16::MAX as f64 {
587        i16::MAX
588    } else {
589        rounded as i16
590    }
591}
592
593fn cast_to_u32(value: f64) -> u32 {
594    let rounded = rounded_scalar(value);
595    if !rounded.is_finite() {
596        return if rounded.is_sign_negative() {
597            0
598        } else {
599            u32::MAX
600        };
601    }
602    if rounded < 0.0 {
603        0
604    } else if rounded > u32::MAX as f64 {
605        u32::MAX
606    } else {
607        rounded as u32
608    }
609}
610
611fn cast_to_i32(value: f64) -> i32 {
612    let rounded = rounded_scalar(value);
613    if !rounded.is_finite() {
614        return if rounded.is_sign_negative() {
615            i32::MIN
616        } else {
617            i32::MAX
618        };
619    }
620    if rounded < i32::MIN as f64 {
621        i32::MIN
622    } else if rounded > i32::MAX as f64 {
623        i32::MAX
624    } else {
625        rounded as i32
626    }
627}
628
629fn cast_to_u64(value: f64) -> u64 {
630    let rounded = rounded_scalar(value);
631    if !rounded.is_finite() {
632        return if rounded.is_sign_negative() {
633            0
634        } else {
635            u64::MAX
636        };
637    }
638    if rounded < 0.0 {
639        0
640    } else if rounded > u64::MAX as f64 {
641        u64::MAX
642    } else {
643        rounded as u64
644    }
645}
646
647fn cast_to_i64(value: f64) -> i64 {
648    let rounded = rounded_scalar(value);
649    if !rounded.is_finite() {
650        return if rounded.is_sign_negative() {
651            i64::MIN
652        } else {
653            i64::MAX
654        };
655    }
656    if rounded < i64::MIN as f64 {
657        i64::MIN
658    } else if rounded > i64::MAX as f64 {
659        i64::MAX
660    } else {
661        rounded as i64
662    }
663}
664
665fn cast_to_f32(value: f64) -> f32 {
666    value as f32
667}
668
669fn rounded_scalar(value: f64) -> f64 {
670    if value.is_nan() {
671        0.0
672    } else {
673        value.round()
674    }
675}
676
677fn extend_u16(buffer: &mut Vec<u8>, value: u16, order: ByteOrder) {
678    match order {
679        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
680        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
681    }
682}
683
684fn extend_i16(buffer: &mut Vec<u8>, value: i16, order: ByteOrder) {
685    match order {
686        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
687        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
688    }
689}
690
691fn extend_u32(buffer: &mut Vec<u8>, value: u32, order: ByteOrder) {
692    match order {
693        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
694        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
695    }
696}
697
698fn extend_i32(buffer: &mut Vec<u8>, value: i32, order: ByteOrder) {
699    match order {
700        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
701        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
702    }
703}
704
705fn extend_u64(buffer: &mut Vec<u8>, value: u64, order: ByteOrder) {
706    match order {
707        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
708        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
709    }
710}
711
712fn extend_i64(buffer: &mut Vec<u8>, value: i64, order: ByteOrder) {
713    match order {
714        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
715        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
716    }
717}
718
719fn extend_f32(buffer: &mut Vec<u8>, value: f32, order: ByteOrder) {
720    match order {
721        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
722        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
723    }
724}
725
726fn extend_f64(buffer: &mut Vec<u8>, value: f64, order: ByteOrder) {
727    match order {
728        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
729        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
730    }
731}
732
733fn parse_byte_order(text: &str) -> ByteOrder {
734    if text.eq_ignore_ascii_case("big-endian") || text.eq_ignore_ascii_case("big endian") {
735        ByteOrder::Big
736    } else {
737        ByteOrder::Little
738    }
739}
740
741fn scalar_string(value: &Value) -> Result<String, String> {
742    match value {
743        Value::String(s) => Ok(s.clone()),
744        Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
745        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
746        _ => Err(
747            "write: datatype argument must be a string scalar or character row vector".to_string(),
748        ),
749    }
750}
751
752fn extract_client_id(struct_value: &StructValue) -> Result<u64, String> {
753    let id_value = struct_value
754        .fields
755        .get(CLIENT_HANDLE_FIELD)
756        .ok_or_else(|| {
757            runtime_error(
758                MESSAGE_ID_INVALID_CLIENT,
759                "write: tcpclient struct is missing internal handle",
760            )
761        })?;
762    match id_value {
763        Value::Int(IntValue::U64(id)) => Ok(*id),
764        Value::Int(iv) => Ok(iv.to_i64() as u64),
765        _ => Err(runtime_error(
766            MESSAGE_ID_INVALID_CLIENT,
767            "write: tcpclient struct has invalid handle field",
768        )),
769    }
770}
771
772fn runtime_error(message_id: &'static str, message: impl Into<String>) -> String {
773    format!("{message_id}: {}", message.into())
774}
775
776enum WriteError {
777    Timeout,
778    ConnectionClosed,
779    Io(io::Error),
780}
781
782fn write_bytes(stream: &mut TcpStream, bytes: &[u8]) -> Result<(), WriteError> {
783    let mut offset = 0usize;
784    while offset < bytes.len() {
785        match stream.write(&bytes[offset..]) {
786            Ok(0) => return Err(WriteError::ConnectionClosed),
787            Ok(n) => offset += n,
788            Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
789            Err(err) if is_timeout(&err) => return Err(WriteError::Timeout),
790            Err(err) if is_connection_closed_error(&err) => {
791                return Err(WriteError::ConnectionClosed)
792            }
793            Err(err) => return Err(WriteError::Io(err)),
794        }
795    }
796    Ok(())
797}
798
799fn is_timeout(err: &io::Error) -> bool {
800    matches!(
801        err.kind(),
802        io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
803    )
804}
805
806fn is_connection_closed_error(err: &io::Error) -> bool {
807    matches!(
808        err.kind(),
809        io::ErrorKind::BrokenPipe
810            | io::ErrorKind::ConnectionReset
811            | io::ErrorKind::ConnectionAborted
812            | io::ErrorKind::NotConnected
813            | io::ErrorKind::UnexpectedEof
814    )
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    #[cfg(feature = "doc_export")]
821    use crate::builtins::common::test_support;
822    use crate::builtins::io::net::accept::{
823        configure_stream, insert_client, remove_client_for_test,
824    };
825    use runmat_builtins::{CharArray, IntValue, StructValue, Tensor};
826    use std::io::Read;
827    use std::net::{TcpListener, TcpStream};
828    use std::sync::{Arc, Barrier};
829    use std::thread;
830
831    fn make_client(stream: TcpStream, timeout: f64, byte_order: &str) -> Value {
832        let peer_addr = stream.peer_addr().expect("peer addr");
833        configure_stream(&stream, timeout).expect("configure stream");
834        let client_id = insert_client(stream, 0, peer_addr, timeout, byte_order.to_string());
835        let mut st = StructValue::new();
836        st.fields.insert(
837            CLIENT_HANDLE_FIELD.to_string(),
838            Value::Int(IntValue::U64(client_id)),
839        );
840        Value::Struct(st)
841    }
842
843    fn client_id(client: &Value) -> u64 {
844        match client {
845            Value::Struct(st) => match st.fields.get(CLIENT_HANDLE_FIELD) {
846                Some(Value::Int(IntValue::U64(id))) => *id,
847                Some(Value::Int(iv)) => iv.to_i64() as u64,
848                other => panic!("unexpected id field {other:?}"),
849            },
850            other => panic!("expected struct, got {other:?}"),
851        }
852    }
853
854    #[test]
855    fn write_default_uint8_sends_bytes() {
856        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
857        let port = listener.local_addr().unwrap().port();
858        let handle = thread::spawn(move || {
859            let (mut stream, _) = listener.accept().expect("accept");
860            let mut received = Vec::new();
861            stream.read_to_end(&mut received).unwrap_or_default();
862            received
863        });
864
865        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
866        let client = make_client(stream, 1.0, "little-endian");
867        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
868        let result =
869            write_builtin(client.clone(), Value::Tensor(tensor), Vec::new()).expect("write");
870        match result {
871            Value::Num(count) => assert_eq!(count, 4.0),
872            other => panic!("expected numeric result, got {other:?}"),
873        }
874        remove_client_for_test(client_id(&client));
875        let received = handle.join().expect("join");
876        assert_eq!(received, vec![1, 2, 3, 4]);
877    }
878
879    #[test]
880    fn write_double_big_endian_encodes_correctly() {
881        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
882        let port = listener.local_addr().unwrap().port();
883        let handle = thread::spawn(move || {
884            let (mut stream, _) = listener.accept().expect("accept");
885            let mut buf = [0u8; 24];
886            stream.read_exact(&mut buf).expect("read");
887            buf
888        });
889
890        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
891        let client = make_client(stream, 1.0, "big-endian");
892        let tensor = Tensor::new(vec![1.5, 2.5, 3.5], vec![1, 3]).unwrap();
893        let result = write_builtin(
894            client.clone(),
895            Value::Tensor(tensor),
896            vec![Value::from("double")],
897        )
898        .expect("write");
899        match result {
900            Value::Num(count) => assert_eq!(count, 3.0),
901            other => panic!("expected numeric count, got {other:?}"),
902        }
903        remove_client_for_test(client_id(&client));
904
905        let received = handle.join().expect("join");
906        let mut expected = Vec::new();
907        extend_f64(&mut expected, 1.5, ByteOrder::Big);
908        extend_f64(&mut expected, 2.5, ByteOrder::Big);
909        extend_f64(&mut expected, 3.5, ByteOrder::Big);
910        assert_eq!(received.to_vec(), expected);
911    }
912
913    #[test]
914    fn write_char_payload_encodes_ascii() {
915        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
916        let port = listener.local_addr().unwrap().port();
917        let handle = thread::spawn(move || {
918            let (mut stream, _) = listener.accept().expect("accept");
919            let mut buf = Vec::new();
920            stream.read_to_end(&mut buf).unwrap_or_default();
921            buf
922        });
923
924        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
925        let client = make_client(stream, 1.0, "little-endian");
926        let chars = CharArray::new("RunMat".chars().collect(), 1, 6).unwrap();
927        let result = write_builtin(
928            client.clone(),
929            Value::CharArray(chars),
930            vec![Value::from("char")],
931        )
932        .expect("write");
933        match result {
934            Value::Num(count) => assert_eq!(count, 6.0),
935            other => panic!("expected numeric count, got {other:?}"),
936        }
937        remove_client_for_test(client_id(&client));
938        let received = handle.join().expect("join");
939        assert_eq!(received, b"RunMat");
940    }
941
942    #[test]
943    fn write_errors_when_client_disconnected() {
944        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
945        let port = listener.local_addr().unwrap().port();
946        let barrier = Arc::new(Barrier::new(2));
947        let thread_barrier = barrier.clone();
948        let handle = thread::spawn(move || {
949            let (stream, _) = listener.accept().expect("accept");
950            thread_barrier.wait();
951            drop(stream);
952        });
953
954        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
955        let client = make_client(stream, 1.0, "little-endian");
956        let id = client_id(&client);
957        if let Some(handle_ref) = client_handle(id) {
958            if let Ok(mut guard) = handle_ref.lock() {
959                guard.connected = false;
960            }
961        }
962
963        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
964        let err =
965            write_builtin(client.clone(), Value::Tensor(tensor), Vec::new()).expect_err("write");
966        assert!(
967            err.starts_with(MESSAGE_ID_NOT_CONNECTED),
968            "unexpected error {err}"
969        );
970
971        remove_client_for_test(id);
972        barrier.wait();
973        handle.join().expect("join");
974    }
975
976    #[test]
977    #[cfg(feature = "doc_export")]
978    fn doc_examples_compile() {
979        let blocks = test_support::doc_examples(DOC_MD);
980        assert!(!blocks.is_empty());
981    }
982}