Skip to main content

runmat_runtime/builtins/io/net/
write.rs

1//! MATLAB-compatible `write` builtin for TCP/IP clients in RunMat.
2
3use runmat_builtins::{IntValue, StructValue, Value};
4use runmat_macros::runtime_builtin;
5use std::io::{self, Write};
6use std::net::TcpStream;
7
8use crate::builtins::common::spec::{
9    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10    ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13
14use super::accept::{client_handle, configure_stream, CLIENT_HANDLE_FIELD};
15
16const MESSAGE_ID_INVALID_CLIENT: &str = "RunMat:write:InvalidTcpClient";
17const MESSAGE_ID_INVALID_DATA: &str = "RunMat:write:InvalidData";
18const MESSAGE_ID_INVALID_DATATYPE: &str = "RunMat:write:InvalidDataType";
19const MESSAGE_ID_NOT_CONNECTED: &str = "RunMat:write:NotConnected";
20const MESSAGE_ID_TIMEOUT: &str = "RunMat:write:Timeout";
21const MESSAGE_ID_CONNECTION_CLOSED: &str = "RunMat:write:ConnectionClosed";
22const MESSAGE_ID_INTERNAL: &str = "RunMat:write:InternalError";
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::net::write")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "write",
27    op_kind: GpuOpKind::Custom("network"),
28    supported_precisions: &[],
29    broadcast: BroadcastSemantics::None,
30    provider_hooks: &[],
31    constant_strategy: ConstantStrategy::InlineLiteral,
32    residency: ResidencyPolicy::GatherImmediately,
33    nan_mode: ReductionNaN::Include,
34    two_pass_threshold: None,
35    workgroup_size: None,
36    accepts_nan_mode: false,
37    notes: "Socket writes always execute on the host CPU; GPU providers are never consulted.",
38};
39
40fn write_error(message_id: &'static str, message: impl Into<String>) -> RuntimeError {
41    build_runtime_error(message)
42        .with_identifier(message_id)
43        .with_builtin("write")
44        .build()
45}
46
47fn write_flow(message_id: &'static str, message: impl Into<String>) -> RuntimeError {
48    write_error(message_id, message)
49}
50
51fn map_write_flow(err: RuntimeError, message_id: &'static str, context: &str) -> RuntimeError {
52    build_runtime_error(format!("{context}: {}", err.message()))
53        .with_identifier(message_id)
54        .with_builtin("write")
55        .with_source(err)
56        .build()
57}
58
59#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::net::write")]
60pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
61    name: "write",
62    shape: ShapeRequirements::Any,
63    constant_strategy: ConstantStrategy::InlineLiteral,
64    elementwise: None,
65    reduction: None,
66    emits_nan: false,
67    notes: "Networking builtin executed eagerly on the CPU.",
68};
69
70#[runtime_builtin(
71    name = "write",
72    category = "io/net",
73    summary = "Write numeric or text data to a TCP/IP client.",
74    keywords = "write,tcpclient,networking",
75    type_resolver(crate::builtins::io::type_resolvers::write_type),
76    builtin_path = "crate::builtins::io::net::write"
77)]
78async fn write_builtin(
79    client: Value,
80    data: Value,
81    rest: Vec<Value>,
82) -> crate::BuiltinResult<Value> {
83    let client = gather_if_needed_async(&client)
84        .await
85        .map_err(|flow| map_write_flow(flow, MESSAGE_ID_INVALID_CLIENT, "write"))?;
86    let data = gather_if_needed_async(&data)
87        .await
88        .map_err(|flow| map_write_flow(flow, MESSAGE_ID_INVALID_DATA, "write"))?;
89
90    let mut gathered_rest = Vec::with_capacity(rest.len());
91    for value in rest {
92        gathered_rest.push(
93            gather_if_needed_async(&value)
94                .await
95                .map_err(|flow| map_write_flow(flow, MESSAGE_ID_INVALID_DATATYPE, "write"))?,
96        );
97    }
98    let datatype = parse_arguments(&gathered_rest)?;
99
100    let client_struct = match &client {
101        Value::Struct(st) => st,
102        _ => {
103            return Err(write_flow(
104                MESSAGE_ID_INVALID_CLIENT,
105                "write: expected tcpclient struct as first argument",
106            ))
107        }
108    };
109
110    let client_id = extract_client_id(client_struct)?;
111    let handle = client_handle(client_id).ok_or_else(|| {
112        write_flow(
113            MESSAGE_ID_INVALID_CLIENT,
114            "write: tcpclient handle is no longer valid",
115        )
116    })?;
117
118    let (mut stream, timeout, byte_order) = {
119        let guard = handle.lock().unwrap_or_else(|poison| poison.into_inner());
120        if !guard.connected {
121            return Err(write_flow(
122                MESSAGE_ID_NOT_CONNECTED,
123                "write: tcpclient is disconnected",
124            ));
125        }
126        let timeout = guard.timeout;
127        let byte_order = parse_byte_order(&guard.byte_order);
128        let stream = guard.stream.try_clone().map_err(|err| {
129            write_flow(MESSAGE_ID_INTERNAL, format!("write: clone failed ({err})"))
130        })?;
131        (stream, timeout, byte_order)
132    };
133
134    if let Err(err) = configure_stream(&stream, timeout) {
135        return Err(write_flow(
136            MESSAGE_ID_INTERNAL,
137            format!("write: unable to configure socket timeout ({err})"),
138        ));
139    }
140
141    let payload = prepare_payload(&data, datatype, byte_order)?;
142    if payload.bytes.is_empty() {
143        return Ok(Value::Num(0.0));
144    }
145
146    match write_bytes(&mut stream, &payload.bytes) {
147        Ok(_) => Ok(Value::Num(payload.elements as f64)),
148        Err(WriteError::Timeout) => Err(write_flow(
149            MESSAGE_ID_TIMEOUT,
150            "write: timed out while sending data",
151        )),
152        Err(WriteError::ConnectionClosed) => {
153            if let Ok(mut guard) = handle.lock() {
154                guard.connected = false;
155            }
156            Err(write_flow(
157                MESSAGE_ID_CONNECTION_CLOSED,
158                "write: connection closed before all data was sent",
159            ))
160        }
161        Err(WriteError::Io(err)) => Err(write_flow(
162            MESSAGE_ID_INTERNAL,
163            format!("write: socket error ({err})"),
164        )),
165    }
166}
167
168#[derive(Clone, Copy)]
169enum DataType {
170    UInt8,
171    Int8,
172    UInt16,
173    Int16,
174    UInt32,
175    Int32,
176    UInt64,
177    Int64,
178    Single,
179    Double,
180    Char,
181    String,
182}
183
184impl DataType {
185    fn default() -> Self {
186        DataType::UInt8
187    }
188
189    fn element_size(self) -> usize {
190        match self {
191            DataType::UInt8 | DataType::Int8 | DataType::Char | DataType::String => 1,
192            DataType::UInt16 | DataType::Int16 => 2,
193            DataType::UInt32 | DataType::Int32 | DataType::Single => 4,
194            DataType::UInt64 | DataType::Int64 | DataType::Double => 8,
195        }
196    }
197}
198
199#[derive(Clone, Copy)]
200enum ByteOrder {
201    Little,
202    Big,
203}
204
205struct Payload {
206    bytes: Vec<u8>,
207    elements: usize,
208}
209
210fn parse_arguments(args: &[Value]) -> BuiltinResult<DataType> {
211    match args.len() {
212        0 => Ok(DataType::default()),
213        1 => parse_datatype(&args[0]),
214        _ => Err(write_flow(
215            MESSAGE_ID_INVALID_DATATYPE,
216            "write: expected at most one datatype argument",
217        )),
218    }
219}
220
221fn parse_datatype(value: &Value) -> BuiltinResult<DataType> {
222    let text = scalar_string(value)?;
223    let lowered = text.trim().to_ascii_lowercase();
224    if lowered.is_empty() {
225        return Err(write_flow(
226            MESSAGE_ID_INVALID_DATATYPE,
227            "write: datatype must not be empty",
228        ));
229    }
230    let dtype = match lowered.as_str() {
231        "uint8" => DataType::UInt8,
232        "int8" => DataType::Int8,
233        "uint16" => DataType::UInt16,
234        "int16" => DataType::Int16,
235        "uint32" => DataType::UInt32,
236        "int32" => DataType::Int32,
237        "uint64" => DataType::UInt64,
238        "int64" => DataType::Int64,
239        "single" => DataType::Single,
240        "double" => DataType::Double,
241        "char" => DataType::Char,
242        "string" => DataType::String,
243        _ => {
244            return Err(write_flow(
245                MESSAGE_ID_INVALID_DATATYPE,
246                format!("write: unsupported datatype '{text}'"),
247            ))
248        }
249    };
250    Ok(dtype)
251}
252
253fn prepare_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
254    match datatype {
255        DataType::Char => char_payload(data),
256        DataType::String => string_payload(data),
257        _ => numeric_payload(data, datatype, order),
258    }
259}
260
261fn numeric_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
262    let values = flatten_numeric(data)?;
263    let mut bytes = Vec::with_capacity(values.len() * datatype.element_size());
264    for value in values.iter().copied() {
265        match datatype {
266            DataType::UInt8 => bytes.push(cast_to_u8(value)),
267            DataType::Int8 => bytes.push(cast_to_i8(value) as u8),
268            DataType::UInt16 => extend_u16(&mut bytes, cast_to_u16(value), order),
269            DataType::Int16 => extend_i16(&mut bytes, cast_to_i16(value), order),
270            DataType::UInt32 => extend_u32(&mut bytes, cast_to_u32(value), order),
271            DataType::Int32 => extend_i32(&mut bytes, cast_to_i32(value), order),
272            DataType::UInt64 => extend_u64(&mut bytes, cast_to_u64(value), order),
273            DataType::Int64 => extend_i64(&mut bytes, cast_to_i64(value), order),
274            DataType::Single => extend_f32(&mut bytes, cast_to_f32(value), order),
275            DataType::Double => extend_f64(&mut bytes, value, order),
276            DataType::Char | DataType::String => unreachable!(),
277        }
278    }
279    Ok(Payload {
280        bytes,
281        elements: values.len(),
282    })
283}
284
285fn char_payload(data: &Value) -> BuiltinResult<Payload> {
286    let bytes = match data {
287        Value::CharArray(ca) => ca.data.iter().map(|&ch| (ch as u32 & 0xFF) as u8).collect(),
288        Value::String(text) => text.bytes().collect(),
289        Value::StringArray(sa) => {
290            if sa.data.len() != 1 {
291                return Err(write_flow(
292                    MESSAGE_ID_INVALID_DATA,
293                    "write: string array input must be scalar when using 'char'",
294                ));
295            }
296            sa.data[0].as_bytes().to_vec()
297        }
298        Value::Tensor(t) => t.data.iter().map(|&v| cast_to_u8(v)).collect::<Vec<u8>>(),
299        Value::Num(n) => vec![cast_to_u8(*n)],
300        Value::Int(iv) => vec![cast_to_u8(iv.to_f64())],
301        Value::Bool(b) => vec![if *b { 1 } else { 0 }],
302        Value::LogicalArray(la) => la
303            .data
304            .iter()
305            .map(|&b| if b != 0 { 1 } else { 0 })
306            .collect(),
307        _ => {
308            return Err(write_flow(
309                MESSAGE_ID_INVALID_DATA,
310                "write: unsupported input for 'char' datatype",
311            ))
312        }
313    };
314    Ok(Payload {
315        elements: bytes.len(),
316        bytes,
317    })
318}
319
320fn string_payload(data: &Value) -> BuiltinResult<Payload> {
321    match data {
322        Value::String(text) => Ok(Payload {
323            elements: 1,
324            bytes: text.as_bytes().to_vec(),
325        }),
326        Value::CharArray(ca) => {
327            let string: String = ca.data.iter().collect();
328            Ok(Payload {
329                elements: 1,
330                bytes: string.into_bytes(),
331            })
332        }
333        Value::StringArray(sa) => {
334            if sa.data.is_empty() {
335                return Ok(Payload {
336                    elements: 0,
337                    bytes: Vec::new(),
338                });
339            }
340            if sa.data.len() != 1 {
341                return Err(write_flow(
342                    MESSAGE_ID_INVALID_DATA,
343                    "write: string array input must be scalar when using 'string'",
344                ));
345            }
346            Ok(Payload {
347                elements: 1,
348                bytes: sa.data[0].as_bytes().to_vec(),
349            })
350        }
351        _ => Err(write_flow(
352            MESSAGE_ID_INVALID_DATA,
353            "write: expected text input when using 'string' datatype",
354        )),
355    }
356}
357
358fn flatten_numeric(value: &Value) -> BuiltinResult<Vec<f64>> {
359    match value {
360        Value::Tensor(t) => Ok(t.data.clone()),
361        Value::Num(n) => Ok(vec![*n]),
362        Value::Int(iv) => Ok(vec![iv.to_f64()]),
363        Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
364        Value::LogicalArray(la) => Ok(la
365            .data
366            .iter()
367            .map(|&b| if b != 0 { 1.0 } else { 0.0 })
368            .collect()),
369        Value::CharArray(ca) => Ok(ca
370            .data
371            .iter()
372            .map(|&ch| (ch as u32 & 0xFF) as f64)
373            .collect()),
374        Value::String(text) => Ok(text.chars().map(|ch| (ch as u32) as f64).collect()),
375        Value::StringArray(sa) => {
376            if sa.data.len() != 1 {
377                return Err(write_flow(
378                    MESSAGE_ID_INVALID_DATA,
379                    "write: string array input must be scalar",
380                ));
381            }
382            Ok(sa.data[0].chars().map(|ch| (ch as u32) as f64).collect())
383        }
384        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(write_flow(
385            MESSAGE_ID_INVALID_DATA,
386            "write: complex data is not supported",
387        )),
388        Value::Cell(_)
389        | Value::Struct(_)
390        | Value::Object(_)
391        | Value::HandleObject(_)
392        | Value::Listener(_)
393        | Value::FunctionHandle(_)
394        | Value::Closure(_)
395        | Value::ClassRef(_)
396        | Value::MException(_)
397        | Value::OutputList(_) => Err(write_flow(
398            MESSAGE_ID_INVALID_DATA,
399            "write: unsupported input type",
400        )),
401        Value::GpuTensor(_) => Err(write_flow(
402            MESSAGE_ID_INVALID_DATA,
403            "write: GPU tensor should have been gathered before encoding",
404        )),
405    }
406}
407
408fn cast_to_u8(value: f64) -> u8 {
409    let rounded = rounded_scalar(value);
410    if !rounded.is_finite() {
411        return if rounded.is_sign_negative() {
412            0
413        } else {
414            u8::MAX
415        };
416    }
417    if rounded < 0.0 {
418        0
419    } else if rounded > u8::MAX as f64 {
420        u8::MAX
421    } else {
422        rounded as u8
423    }
424}
425
426fn cast_to_i8(value: f64) -> i8 {
427    let rounded = rounded_scalar(value);
428    if !rounded.is_finite() {
429        return if rounded.is_sign_negative() {
430            i8::MIN
431        } else {
432            i8::MAX
433        };
434    }
435    if rounded < i8::MIN as f64 {
436        i8::MIN
437    } else if rounded > i8::MAX as f64 {
438        i8::MAX
439    } else {
440        rounded as i8
441    }
442}
443
444fn cast_to_u16(value: f64) -> u16 {
445    let rounded = rounded_scalar(value);
446    if !rounded.is_finite() {
447        return if rounded.is_sign_negative() {
448            0
449        } else {
450            u16::MAX
451        };
452    }
453    if rounded < 0.0 {
454        0
455    } else if rounded > u16::MAX as f64 {
456        u16::MAX
457    } else {
458        rounded as u16
459    }
460}
461
462fn cast_to_i16(value: f64) -> i16 {
463    let rounded = rounded_scalar(value);
464    if !rounded.is_finite() {
465        return if rounded.is_sign_negative() {
466            i16::MIN
467        } else {
468            i16::MAX
469        };
470    }
471    if rounded < i16::MIN as f64 {
472        i16::MIN
473    } else if rounded > i16::MAX as f64 {
474        i16::MAX
475    } else {
476        rounded as i16
477    }
478}
479
480fn cast_to_u32(value: f64) -> u32 {
481    let rounded = rounded_scalar(value);
482    if !rounded.is_finite() {
483        return if rounded.is_sign_negative() {
484            0
485        } else {
486            u32::MAX
487        };
488    }
489    if rounded < 0.0 {
490        0
491    } else if rounded > u32::MAX as f64 {
492        u32::MAX
493    } else {
494        rounded as u32
495    }
496}
497
498fn cast_to_i32(value: f64) -> i32 {
499    let rounded = rounded_scalar(value);
500    if !rounded.is_finite() {
501        return if rounded.is_sign_negative() {
502            i32::MIN
503        } else {
504            i32::MAX
505        };
506    }
507    if rounded < i32::MIN as f64 {
508        i32::MIN
509    } else if rounded > i32::MAX as f64 {
510        i32::MAX
511    } else {
512        rounded as i32
513    }
514}
515
516fn cast_to_u64(value: f64) -> u64 {
517    let rounded = rounded_scalar(value);
518    if !rounded.is_finite() {
519        return if rounded.is_sign_negative() {
520            0
521        } else {
522            u64::MAX
523        };
524    }
525    if rounded < 0.0 {
526        0
527    } else if rounded > u64::MAX as f64 {
528        u64::MAX
529    } else {
530        rounded as u64
531    }
532}
533
534fn cast_to_i64(value: f64) -> i64 {
535    let rounded = rounded_scalar(value);
536    if !rounded.is_finite() {
537        return if rounded.is_sign_negative() {
538            i64::MIN
539        } else {
540            i64::MAX
541        };
542    }
543    if rounded < i64::MIN as f64 {
544        i64::MIN
545    } else if rounded > i64::MAX as f64 {
546        i64::MAX
547    } else {
548        rounded as i64
549    }
550}
551
552fn cast_to_f32(value: f64) -> f32 {
553    value as f32
554}
555
556fn rounded_scalar(value: f64) -> f64 {
557    if value.is_nan() {
558        0.0
559    } else {
560        value.round()
561    }
562}
563
564fn extend_u16(buffer: &mut Vec<u8>, value: u16, order: ByteOrder) {
565    match order {
566        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
567        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
568    }
569}
570
571fn extend_i16(buffer: &mut Vec<u8>, value: i16, order: ByteOrder) {
572    match order {
573        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
574        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
575    }
576}
577
578fn extend_u32(buffer: &mut Vec<u8>, value: u32, order: ByteOrder) {
579    match order {
580        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
581        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
582    }
583}
584
585fn extend_i32(buffer: &mut Vec<u8>, value: i32, order: ByteOrder) {
586    match order {
587        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
588        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
589    }
590}
591
592fn extend_u64(buffer: &mut Vec<u8>, value: u64, order: ByteOrder) {
593    match order {
594        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
595        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
596    }
597}
598
599fn extend_i64(buffer: &mut Vec<u8>, value: i64, order: ByteOrder) {
600    match order {
601        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
602        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
603    }
604}
605
606fn extend_f32(buffer: &mut Vec<u8>, value: f32, order: ByteOrder) {
607    match order {
608        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
609        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
610    }
611}
612
613fn extend_f64(buffer: &mut Vec<u8>, value: f64, order: ByteOrder) {
614    match order {
615        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
616        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
617    }
618}
619
620fn parse_byte_order(text: &str) -> ByteOrder {
621    if text.eq_ignore_ascii_case("big-endian") || text.eq_ignore_ascii_case("big endian") {
622        ByteOrder::Big
623    } else {
624        ByteOrder::Little
625    }
626}
627
628fn scalar_string(value: &Value) -> BuiltinResult<String> {
629    match value {
630        Value::String(s) => Ok(s.clone()),
631        Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
632        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
633        _ => Err(write_flow(
634            MESSAGE_ID_INVALID_DATATYPE,
635            "write: datatype argument must be a string scalar or character row vector",
636        )),
637    }
638}
639
640fn extract_client_id(struct_value: &StructValue) -> BuiltinResult<u64> {
641    let id_value = struct_value
642        .fields
643        .get(CLIENT_HANDLE_FIELD)
644        .ok_or_else(|| {
645            write_flow(
646                MESSAGE_ID_INVALID_CLIENT,
647                "write: tcpclient struct is missing internal handle",
648            )
649        })?;
650    match id_value {
651        Value::Int(IntValue::U64(id)) => Ok(*id),
652        Value::Int(iv) => Ok(iv.to_i64() as u64),
653        _ => Err(write_flow(
654            MESSAGE_ID_INVALID_CLIENT,
655            "write: tcpclient struct has invalid handle field",
656        )),
657    }
658}
659
660enum WriteError {
661    Timeout,
662    ConnectionClosed,
663    Io(io::Error),
664}
665
666fn write_bytes(stream: &mut TcpStream, bytes: &[u8]) -> Result<(), WriteError> {
667    let mut offset = 0usize;
668    while offset < bytes.len() {
669        match stream.write(&bytes[offset..]) {
670            Ok(0) => return Err(WriteError::ConnectionClosed),
671            Ok(n) => offset += n,
672            Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
673            Err(err) if is_timeout(&err) => return Err(WriteError::Timeout),
674            Err(err) if is_connection_closed_error(&err) => {
675                return Err(WriteError::ConnectionClosed)
676            }
677            Err(err) => return Err(WriteError::Io(err)),
678        }
679    }
680    Ok(())
681}
682
683fn is_timeout(err: &io::Error) -> bool {
684    matches!(
685        err.kind(),
686        io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
687    )
688}
689
690fn is_connection_closed_error(err: &io::Error) -> bool {
691    matches!(
692        err.kind(),
693        io::ErrorKind::BrokenPipe
694            | io::ErrorKind::ConnectionReset
695            | io::ErrorKind::ConnectionAborted
696            | io::ErrorKind::NotConnected
697            | io::ErrorKind::UnexpectedEof
698    )
699}
700
701#[cfg(test)]
702pub(crate) mod tests {
703    use super::*;
704    use crate::builtins::io::net::accept::{
705        configure_stream, insert_client, remove_client_for_test,
706    };
707    use runmat_builtins::{CharArray, IntValue, StructValue, Tensor};
708    use std::io::Read;
709    use std::net::{TcpListener, TcpStream};
710    use std::sync::{Arc, Barrier};
711    use std::thread;
712
713    fn make_client(stream: TcpStream, timeout: f64, byte_order: &str) -> Value {
714        let peer_addr = stream.peer_addr().expect("peer addr");
715        configure_stream(&stream, timeout).expect("configure stream");
716        let client_id = insert_client(stream, 0, peer_addr, timeout, byte_order.to_string());
717        let mut st = StructValue::new();
718        st.fields.insert(
719            CLIENT_HANDLE_FIELD.to_string(),
720            Value::Int(IntValue::U64(client_id)),
721        );
722        Value::Struct(st)
723    }
724
725    fn client_id(client: &Value) -> u64 {
726        match client {
727            Value::Struct(st) => match st.fields.get(CLIENT_HANDLE_FIELD) {
728                Some(Value::Int(IntValue::U64(id))) => *id,
729                Some(Value::Int(iv)) => iv.to_i64() as u64,
730                other => panic!("unexpected id field {other:?}"),
731            },
732            other => panic!("expected struct, got {other:?}"),
733        }
734    }
735
736    fn assert_error_identifier(err: RuntimeError, expected: &str) {
737        assert_eq!(err.identifier(), Some(expected));
738    }
739
740    fn run_write(client: Value, data: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
741        futures::executor::block_on(write_builtin(client, data, rest))
742    }
743
744    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
745    #[test]
746    fn write_default_uint8_sends_bytes() {
747        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
748        let port = listener.local_addr().unwrap().port();
749        let handle = thread::spawn(move || {
750            let (mut stream, _) = listener.accept().expect("accept");
751            let mut received = Vec::new();
752            stream.read_to_end(&mut received).unwrap_or_default();
753            received
754        });
755
756        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
757        let client = make_client(stream, 1.0, "little-endian");
758        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
759        let result = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect("write");
760        match result {
761            Value::Num(count) => assert_eq!(count, 4.0),
762            other => panic!("expected numeric result, got {other:?}"),
763        }
764        remove_client_for_test(client_id(&client));
765        let received = handle.join().expect("join");
766        assert_eq!(received, vec![1, 2, 3, 4]);
767    }
768
769    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
770    #[test]
771    fn write_double_big_endian_encodes_correctly() {
772        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
773        let port = listener.local_addr().unwrap().port();
774        let handle = thread::spawn(move || {
775            let (mut stream, _) = listener.accept().expect("accept");
776            let mut buf = [0u8; 24];
777            stream.read_exact(&mut buf).expect("read");
778            buf
779        });
780
781        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
782        let client = make_client(stream, 1.0, "big-endian");
783        let tensor = Tensor::new(vec![1.5, 2.5, 3.5], vec![1, 3]).unwrap();
784        let result = run_write(
785            client.clone(),
786            Value::Tensor(tensor),
787            vec![Value::from("double")],
788        )
789        .expect("write");
790        match result {
791            Value::Num(count) => assert_eq!(count, 3.0),
792            other => panic!("expected numeric count, got {other:?}"),
793        }
794        remove_client_for_test(client_id(&client));
795
796        let received = handle.join().expect("join");
797        let mut expected = Vec::new();
798        extend_f64(&mut expected, 1.5, ByteOrder::Big);
799        extend_f64(&mut expected, 2.5, ByteOrder::Big);
800        extend_f64(&mut expected, 3.5, ByteOrder::Big);
801        assert_eq!(received.to_vec(), expected);
802    }
803
804    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
805    #[test]
806    fn write_char_payload_encodes_ascii() {
807        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
808        let port = listener.local_addr().unwrap().port();
809        let handle = thread::spawn(move || {
810            let (mut stream, _) = listener.accept().expect("accept");
811            let mut buf = Vec::new();
812            stream.read_to_end(&mut buf).unwrap_or_default();
813            buf
814        });
815
816        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
817        let client = make_client(stream, 1.0, "little-endian");
818        let chars = CharArray::new("RunMat".chars().collect(), 1, 6).unwrap();
819        let result = run_write(
820            client.clone(),
821            Value::CharArray(chars),
822            vec![Value::from("char")],
823        )
824        .expect("write");
825        match result {
826            Value::Num(count) => assert_eq!(count, 6.0),
827            other => panic!("expected numeric count, got {other:?}"),
828        }
829        remove_client_for_test(client_id(&client));
830        let received = handle.join().expect("join");
831        assert_eq!(received, b"RunMat");
832    }
833
834    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
835    #[test]
836    fn write_errors_when_client_disconnected() {
837        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
838        let port = listener.local_addr().unwrap().port();
839        let barrier = Arc::new(Barrier::new(2));
840        let thread_barrier = barrier.clone();
841        let handle = thread::spawn(move || {
842            let (stream, _) = listener.accept().expect("accept");
843            thread_barrier.wait();
844            drop(stream);
845        });
846
847        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
848        let client = make_client(stream, 1.0, "little-endian");
849        let id = client_id(&client);
850        if let Some(handle_ref) = client_handle(id) {
851            if let Ok(mut guard) = handle_ref.lock() {
852                guard.connected = false;
853            }
854        }
855
856        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
857        let err = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect_err("write");
858        assert_error_identifier(err, MESSAGE_ID_NOT_CONNECTED);
859
860        remove_client_for_test(id);
861        barrier.wait();
862        handle.join().expect("join");
863    }
864}