Skip to main content

runmat_runtime/builtins/io/net/
write.rs

1//! MATLAB-compatible `write` builtin for TCP/IP clients in RunMat.
2
3use runmat_builtins::{
4    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6    IntValue, StructValue, Value,
7};
8use runmat_macros::runtime_builtin;
9use std::io::{self, Write};
10use std::net::TcpStream;
11
12use crate::builtins::common::spec::{
13    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14    ReductionNaN, ResidencyPolicy, ShapeRequirements,
15};
16use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
17
18use super::accept::{client_handle, configure_stream, CLIENT_HANDLE_FIELD};
19
20const BUILTIN_NAME: &str = "write";
21
22const WRITE_OUTPUT_COUNT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23    name: "count",
24    ty: BuiltinParamType::NumericScalar,
25    arity: BuiltinParamArity::Required,
26    default: None,
27    description: "Number of elements written to the socket.",
28}];
29const WRITE_INPUTS_CLIENT_DATA: [BuiltinParamDescriptor; 2] = [
30    BuiltinParamDescriptor {
31        name: "client",
32        ty: BuiltinParamType::Any,
33        arity: BuiltinParamArity::Required,
34        default: None,
35        description: "tcpclient handle struct.",
36    },
37    BuiltinParamDescriptor {
38        name: "data",
39        ty: BuiltinParamType::Any,
40        arity: BuiltinParamArity::Required,
41        default: None,
42        description: "Payload to send.",
43    },
44];
45const WRITE_INPUTS_CLIENT_DATA_DATATYPE: [BuiltinParamDescriptor; 3] = [
46    BuiltinParamDescriptor {
47        name: "client",
48        ty: BuiltinParamType::Any,
49        arity: BuiltinParamArity::Required,
50        default: None,
51        description: "tcpclient handle struct.",
52    },
53    BuiltinParamDescriptor {
54        name: "data",
55        ty: BuiltinParamType::Any,
56        arity: BuiltinParamArity::Required,
57        default: None,
58        description: "Payload to send.",
59    },
60    BuiltinParamDescriptor {
61        name: "datatype",
62        ty: BuiltinParamType::StringScalar,
63        arity: BuiltinParamArity::Optional,
64        default: Some("\"uint8\""),
65        description: "Data type label (for example \"uint8\", \"double\", \"char\", \"string\").",
66    },
67];
68const WRITE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
69    BuiltinSignatureDescriptor {
70        label: "count = write(client, data)",
71        inputs: &WRITE_INPUTS_CLIENT_DATA,
72        outputs: &WRITE_OUTPUT_COUNT,
73    },
74    BuiltinSignatureDescriptor {
75        label: "count = write(client, data, datatype)",
76        inputs: &WRITE_INPUTS_CLIENT_DATA_DATATYPE,
77        outputs: &WRITE_OUTPUT_COUNT,
78    },
79];
80
81const WRITE_ERROR_INVALID_CLIENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
82    code: "RM.WRITE.INVALID_CLIENT",
83    identifier: Some("RunMat:write:InvalidTcpClient"),
84    when: "Client handle is missing, malformed, invalid, or disconnected.",
85    message: "write: invalid tcpclient handle",
86};
87const WRITE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
88    code: "RM.WRITE.INVALID_INPUT",
89    identifier: Some("RunMat:write:InvalidInput"),
90    when: "Argument list shape is unsupported for write.",
91    message: "write: invalid argument list",
92};
93const WRITE_ERROR_INVALID_DATA: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
94    code: "RM.WRITE.INVALID_DATA",
95    identifier: Some("RunMat:write:InvalidData"),
96    when: "Payload cannot be converted to the requested datatype.",
97    message: "write: invalid data payload",
98};
99const WRITE_ERROR_INVALID_DATATYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
100    code: "RM.WRITE.INVALID_DATATYPE",
101    identifier: Some("RunMat:write:InvalidDataType"),
102    when: "Datatype argument is not a supported scalar text label.",
103    message: "write: invalid datatype argument",
104};
105const WRITE_ERROR_NOT_CONNECTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
106    code: "RM.WRITE.NOT_CONNECTED",
107    identifier: Some("RunMat:write:NotConnected"),
108    when: "Client has no active socket connection.",
109    message: "write: tcpclient is disconnected",
110};
111const WRITE_ERROR_TIMEOUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
112    code: "RM.WRITE.TIMEOUT",
113    identifier: Some("RunMat:write:Timeout"),
114    when: "Socket write exceeds configured timeout.",
115    message: "write: timed out while sending data",
116};
117const WRITE_ERROR_CONNECTION_CLOSED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118    code: "RM.WRITE.CONNECTION_CLOSED",
119    identifier: Some("RunMat:write:ConnectionClosed"),
120    when: "Peer closes socket before payload is fully written.",
121    message: "write: connection closed before all data was sent",
122};
123const WRITE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124    code: "RM.WRITE.INTERNAL",
125    identifier: Some("RunMat:write:InternalError"),
126    when: "Internal socket/control-flow conversion fails.",
127    message: "write: internal socket error",
128};
129const WRITE_ERRORS: [BuiltinErrorDescriptor; 8] = [
130    WRITE_ERROR_INVALID_CLIENT,
131    WRITE_ERROR_INVALID_INPUT,
132    WRITE_ERROR_INVALID_DATA,
133    WRITE_ERROR_INVALID_DATATYPE,
134    WRITE_ERROR_NOT_CONNECTED,
135    WRITE_ERROR_TIMEOUT,
136    WRITE_ERROR_CONNECTION_CLOSED,
137    WRITE_ERROR_INTERNAL,
138];
139pub const WRITE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
140    signatures: &WRITE_SIGNATURES,
141    output_mode: BuiltinOutputMode::Fixed,
142    completion_policy: BuiltinCompletionPolicy::Public,
143    errors: &WRITE_ERRORS,
144};
145
146#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::net::write")]
147pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
148    name: "write",
149    op_kind: GpuOpKind::Custom("network"),
150    supported_precisions: &[],
151    broadcast: BroadcastSemantics::None,
152    provider_hooks: &[],
153    constant_strategy: ConstantStrategy::InlineLiteral,
154    residency: ResidencyPolicy::GatherImmediately,
155    nan_mode: ReductionNaN::Include,
156    two_pass_threshold: None,
157    workgroup_size: None,
158    accepts_nan_mode: false,
159    notes: "Socket writes always execute on the host CPU; GPU providers are never consulted.",
160};
161
162fn write_error_with_message(
163    message: impl Into<String>,
164    error: &'static BuiltinErrorDescriptor,
165) -> RuntimeError {
166    let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
167    if let Some(identifier) = error.identifier {
168        builder = builder.with_identifier(identifier);
169    }
170    builder.build()
171}
172
173fn write_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
174    write_error_with_message(error.message, error)
175}
176
177fn write_error_with_detail(
178    error: &'static BuiltinErrorDescriptor,
179    detail: impl AsRef<str>,
180) -> RuntimeError {
181    let detail = detail.as_ref();
182    let detail = detail.strip_prefix("write: ").unwrap_or(detail);
183    write_error_with_message(format!("{}: {}", error.message, detail), error)
184}
185
186fn write_flow(error: &'static BuiltinErrorDescriptor, message: impl AsRef<str>) -> RuntimeError {
187    write_error_with_detail(error, message)
188}
189
190fn map_write_flow(err: RuntimeError, error: &'static BuiltinErrorDescriptor) -> RuntimeError {
191    let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {}", err.message()))
192        .with_builtin(BUILTIN_NAME)
193        .with_source(err);
194    if let Some(identifier) = error.identifier {
195        builder = builder.with_identifier(identifier);
196    }
197    builder.build()
198}
199
200#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::net::write")]
201pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
202    name: "write",
203    shape: ShapeRequirements::Any,
204    constant_strategy: ConstantStrategy::InlineLiteral,
205    elementwise: None,
206    reduction: None,
207    emits_nan: false,
208    notes: "Networking builtin executed eagerly on the CPU.",
209};
210
211#[runtime_builtin(
212    name = "write",
213    category = "io/net",
214    summary = "Write numeric or text payloads to TCP client connections.",
215    keywords = "write,tcpclient,networking",
216    type_resolver(crate::builtins::io::type_resolvers::write_type),
217    descriptor(crate::builtins::io::net::write::WRITE_DESCRIPTOR),
218    builtin_path = "crate::builtins::io::net::write"
219)]
220async fn write_builtin(
221    client: Value,
222    data: Value,
223    rest: Vec<Value>,
224) -> crate::BuiltinResult<Value> {
225    let client = gather_if_needed_async(&client)
226        .await
227        .map_err(|flow| map_write_flow(flow, &WRITE_ERROR_INVALID_CLIENT))?;
228    let data = gather_if_needed_async(&data)
229        .await
230        .map_err(|flow| map_write_flow(flow, &WRITE_ERROR_INVALID_DATA))?;
231
232    let mut gathered_rest = Vec::with_capacity(rest.len());
233    for value in rest {
234        gathered_rest.push(
235            gather_if_needed_async(&value)
236                .await
237                .map_err(|flow| map_write_flow(flow, &WRITE_ERROR_INVALID_DATATYPE))?,
238        );
239    }
240    let datatype = parse_arguments(&gathered_rest)?;
241
242    let client_struct = match &client {
243        Value::Struct(st) => st,
244        _ => {
245            return Err(write_flow(
246                &WRITE_ERROR_INVALID_CLIENT,
247                "write: expected tcpclient struct as first argument",
248            ))
249        }
250    };
251
252    let client_id = extract_client_id(client_struct)?;
253    let handle = client_handle(client_id).ok_or_else(|| {
254        write_flow(
255            &WRITE_ERROR_INVALID_CLIENT,
256            "write: tcpclient handle is no longer valid",
257        )
258    })?;
259
260    let (mut stream, timeout, byte_order) = {
261        let guard = handle.lock().unwrap_or_else(|poison| poison.into_inner());
262        if !guard.connected {
263            return Err(write_error(&WRITE_ERROR_NOT_CONNECTED));
264        }
265        let timeout = guard.timeout;
266        let byte_order = parse_byte_order(&guard.byte_order);
267        let stream = guard.stream.try_clone().map_err(|err| {
268            write_flow(
269                &WRITE_ERROR_INTERNAL,
270                format!("write: clone failed ({err})"),
271            )
272        })?;
273        (stream, timeout, byte_order)
274    };
275
276    if let Err(err) = configure_stream(&stream, timeout) {
277        return Err(write_flow(
278            &WRITE_ERROR_INTERNAL,
279            format!("write: unable to configure socket timeout ({err})"),
280        ));
281    }
282
283    let payload = prepare_payload(&data, datatype, byte_order)?;
284    if payload.bytes.is_empty() {
285        return Ok(Value::Num(0.0));
286    }
287
288    match write_bytes(&mut stream, &payload.bytes) {
289        Ok(_) => Ok(Value::Num(payload.elements as f64)),
290        Err(WriteError::Timeout) => Err(write_error(&WRITE_ERROR_TIMEOUT)),
291        Err(WriteError::ConnectionClosed) => {
292            if let Ok(mut guard) = handle.lock() {
293                guard.connected = false;
294            }
295            Err(write_error(&WRITE_ERROR_CONNECTION_CLOSED))
296        }
297        Err(WriteError::Io(err)) => Err(write_flow(
298            &WRITE_ERROR_INTERNAL,
299            format!("write: socket error ({err})"),
300        )),
301    }
302}
303
304#[derive(Clone, Copy)]
305enum DataType {
306    UInt8,
307    Int8,
308    UInt16,
309    Int16,
310    UInt32,
311    Int32,
312    UInt64,
313    Int64,
314    Single,
315    Double,
316    Char,
317    String,
318}
319
320impl DataType {
321    fn default() -> Self {
322        DataType::UInt8
323    }
324
325    fn element_size(self) -> usize {
326        match self {
327            DataType::UInt8 | DataType::Int8 | DataType::Char | DataType::String => 1,
328            DataType::UInt16 | DataType::Int16 => 2,
329            DataType::UInt32 | DataType::Int32 | DataType::Single => 4,
330            DataType::UInt64 | DataType::Int64 | DataType::Double => 8,
331        }
332    }
333}
334
335#[derive(Clone, Copy)]
336enum ByteOrder {
337    Little,
338    Big,
339}
340
341struct Payload {
342    bytes: Vec<u8>,
343    elements: usize,
344}
345
346fn parse_arguments(args: &[Value]) -> BuiltinResult<DataType> {
347    match args.len() {
348        0 => Ok(DataType::default()),
349        1 => parse_datatype(&args[0]),
350        _ => Err(write_flow(
351            &WRITE_ERROR_INVALID_INPUT,
352            "write: expected at most one datatype argument",
353        )),
354    }
355}
356
357fn parse_datatype(value: &Value) -> BuiltinResult<DataType> {
358    let text = scalar_string(value)?;
359    let lowered = text.trim().to_ascii_lowercase();
360    if lowered.is_empty() {
361        return Err(write_flow(
362            &WRITE_ERROR_INVALID_DATATYPE,
363            "write: datatype must not be empty",
364        ));
365    }
366    let dtype = match lowered.as_str() {
367        "uint8" => DataType::UInt8,
368        "int8" => DataType::Int8,
369        "uint16" => DataType::UInt16,
370        "int16" => DataType::Int16,
371        "uint32" => DataType::UInt32,
372        "int32" => DataType::Int32,
373        "uint64" => DataType::UInt64,
374        "int64" => DataType::Int64,
375        "single" => DataType::Single,
376        "double" => DataType::Double,
377        "char" => DataType::Char,
378        "string" => DataType::String,
379        _ => {
380            return Err(write_flow(
381                &WRITE_ERROR_INVALID_DATATYPE,
382                format!("write: unsupported datatype '{text}'"),
383            ))
384        }
385    };
386    Ok(dtype)
387}
388
389fn prepare_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
390    match datatype {
391        DataType::Char => char_payload(data),
392        DataType::String => string_payload(data),
393        _ => numeric_payload(data, datatype, order),
394    }
395}
396
397fn numeric_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
398    let values = flatten_numeric(data)?;
399    let mut bytes = Vec::with_capacity(values.len() * datatype.element_size());
400    for value in values.iter().copied() {
401        match datatype {
402            DataType::UInt8 => bytes.push(cast_to_u8(value)),
403            DataType::Int8 => bytes.push(cast_to_i8(value) as u8),
404            DataType::UInt16 => extend_u16(&mut bytes, cast_to_u16(value), order),
405            DataType::Int16 => extend_i16(&mut bytes, cast_to_i16(value), order),
406            DataType::UInt32 => extend_u32(&mut bytes, cast_to_u32(value), order),
407            DataType::Int32 => extend_i32(&mut bytes, cast_to_i32(value), order),
408            DataType::UInt64 => extend_u64(&mut bytes, cast_to_u64(value), order),
409            DataType::Int64 => extend_i64(&mut bytes, cast_to_i64(value), order),
410            DataType::Single => extend_f32(&mut bytes, cast_to_f32(value), order),
411            DataType::Double => extend_f64(&mut bytes, value, order),
412            DataType::Char | DataType::String => unreachable!(),
413        }
414    }
415    Ok(Payload {
416        bytes,
417        elements: values.len(),
418    })
419}
420
421fn char_payload(data: &Value) -> BuiltinResult<Payload> {
422    let bytes = match data {
423        Value::CharArray(ca) => ca.data.iter().map(|&ch| (ch as u32 & 0xFF) as u8).collect(),
424        Value::String(text) => text.bytes().collect(),
425        Value::StringArray(sa) => {
426            if sa.data.len() != 1 {
427                return Err(write_flow(
428                    &WRITE_ERROR_INVALID_DATA,
429                    "write: string array input must be scalar when using 'char'",
430                ));
431            }
432            sa.data[0].as_bytes().to_vec()
433        }
434        Value::Tensor(t) => t.data.iter().map(|&v| cast_to_u8(v)).collect::<Vec<u8>>(),
435        Value::Num(n) => vec![cast_to_u8(*n)],
436        Value::Int(iv) => vec![cast_to_u8(iv.to_f64())],
437        Value::Bool(b) => vec![if *b { 1 } else { 0 }],
438        Value::LogicalArray(la) => la
439            .data
440            .iter()
441            .map(|&b| if b != 0 { 1 } else { 0 })
442            .collect(),
443        _ => {
444            return Err(write_flow(
445                &WRITE_ERROR_INVALID_DATA,
446                "write: unsupported input for 'char' datatype",
447            ))
448        }
449    };
450    Ok(Payload {
451        elements: bytes.len(),
452        bytes,
453    })
454}
455
456fn string_payload(data: &Value) -> BuiltinResult<Payload> {
457    match data {
458        Value::String(text) => Ok(Payload {
459            elements: 1,
460            bytes: text.as_bytes().to_vec(),
461        }),
462        Value::CharArray(ca) => {
463            let string: String = ca.data.iter().collect();
464            Ok(Payload {
465                elements: 1,
466                bytes: string.into_bytes(),
467            })
468        }
469        Value::StringArray(sa) => {
470            if sa.data.is_empty() {
471                return Ok(Payload {
472                    elements: 0,
473                    bytes: Vec::new(),
474                });
475            }
476            if sa.data.len() != 1 {
477                return Err(write_flow(
478                    &WRITE_ERROR_INVALID_DATA,
479                    "write: string array input must be scalar when using 'string'",
480                ));
481            }
482            Ok(Payload {
483                elements: 1,
484                bytes: sa.data[0].as_bytes().to_vec(),
485            })
486        }
487        _ => Err(write_flow(
488            &WRITE_ERROR_INVALID_DATA,
489            "write: expected text input when using 'string' datatype",
490        )),
491    }
492}
493
494fn flatten_numeric(value: &Value) -> BuiltinResult<Vec<f64>> {
495    match value {
496        Value::Tensor(t) => Ok(t.data.clone()),
497        Value::SparseTensor(s) => {
498            let total_elements = s.rows.checked_mul(s.cols).ok_or_else(|| {
499                write_error_with_message(
500                    "write: sparse matrix dimensions overflow",
501                    &WRITE_ERROR_INTERNAL,
502                )
503            })?;
504            if total_elements > 10_000_000 {
505                return Err(write_error_with_message(
506                    format!("write: cannot densify sparse tensor {}x{} ({} elements exceeds safe threshold)", s.rows, s.cols, total_elements),
507                    &WRITE_ERROR_INTERNAL,
508                ));
509            }
510            s.to_dense().map(|dense| dense.data).map_err(|err| {
511                write_error_with_message(format!("write: {err}"), &WRITE_ERROR_INTERNAL)
512            })
513        }
514        Value::Num(n) => Ok(vec![*n]),
515        Value::Int(iv) => Ok(vec![iv.to_f64()]),
516        Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
517        Value::LogicalArray(la) => Ok(la
518            .data
519            .iter()
520            .map(|&b| if b != 0 { 1.0 } else { 0.0 })
521            .collect()),
522        Value::CharArray(ca) => Ok(ca
523            .data
524            .iter()
525            .map(|&ch| (ch as u32 & 0xFF) as f64)
526            .collect()),
527        Value::String(text) => Ok(text.chars().map(|ch| (ch as u32) as f64).collect()),
528        Value::StringArray(sa) => {
529            if sa.data.len() != 1 {
530                return Err(write_flow(
531                    &WRITE_ERROR_INVALID_DATA,
532                    "write: string array input must be scalar",
533                ));
534            }
535            Ok(sa.data[0].chars().map(|ch| (ch as u32) as f64).collect())
536        }
537        Value::Complex(_, _) | Value::ComplexTensor(_) => Err(write_flow(
538            &WRITE_ERROR_INVALID_DATA,
539            "write: complex data is not supported",
540        )),
541        Value::Cell(_)
542        | Value::Struct(_)
543        | Value::Object(_)
544        | Value::HandleObject(_)
545        | Value::Listener(_)
546        | Value::FunctionHandle(_)
547        | Value::ExternalFunctionHandle(_)
548        | Value::MethodFunctionHandle(_)
549        | Value::BoundFunctionHandle { .. }
550        | Value::Closure(_)
551        | Value::ClassRef(_)
552        | Value::MException(_)
553        | Value::OutputList(_) => Err(write_flow(
554            &WRITE_ERROR_INVALID_DATA,
555            "write: unsupported input type",
556        )),
557        Value::GpuTensor(_) => Err(write_flow(
558            &WRITE_ERROR_INVALID_DATA,
559            "write: GPU tensor should have been gathered before encoding",
560        )),
561    }
562}
563
564fn cast_to_u8(value: f64) -> u8 {
565    let rounded = rounded_scalar(value);
566    if !rounded.is_finite() {
567        return if rounded.is_sign_negative() {
568            0
569        } else {
570            u8::MAX
571        };
572    }
573    if rounded < 0.0 {
574        0
575    } else if rounded > u8::MAX as f64 {
576        u8::MAX
577    } else {
578        rounded as u8
579    }
580}
581
582fn cast_to_i8(value: f64) -> i8 {
583    let rounded = rounded_scalar(value);
584    if !rounded.is_finite() {
585        return if rounded.is_sign_negative() {
586            i8::MIN
587        } else {
588            i8::MAX
589        };
590    }
591    if rounded < i8::MIN as f64 {
592        i8::MIN
593    } else if rounded > i8::MAX as f64 {
594        i8::MAX
595    } else {
596        rounded as i8
597    }
598}
599
600fn cast_to_u16(value: f64) -> u16 {
601    let rounded = rounded_scalar(value);
602    if !rounded.is_finite() {
603        return if rounded.is_sign_negative() {
604            0
605        } else {
606            u16::MAX
607        };
608    }
609    if rounded < 0.0 {
610        0
611    } else if rounded > u16::MAX as f64 {
612        u16::MAX
613    } else {
614        rounded as u16
615    }
616}
617
618fn cast_to_i16(value: f64) -> i16 {
619    let rounded = rounded_scalar(value);
620    if !rounded.is_finite() {
621        return if rounded.is_sign_negative() {
622            i16::MIN
623        } else {
624            i16::MAX
625        };
626    }
627    if rounded < i16::MIN as f64 {
628        i16::MIN
629    } else if rounded > i16::MAX as f64 {
630        i16::MAX
631    } else {
632        rounded as i16
633    }
634}
635
636fn cast_to_u32(value: f64) -> u32 {
637    let rounded = rounded_scalar(value);
638    if !rounded.is_finite() {
639        return if rounded.is_sign_negative() {
640            0
641        } else {
642            u32::MAX
643        };
644    }
645    if rounded < 0.0 {
646        0
647    } else if rounded > u32::MAX as f64 {
648        u32::MAX
649    } else {
650        rounded as u32
651    }
652}
653
654fn cast_to_i32(value: f64) -> i32 {
655    let rounded = rounded_scalar(value);
656    if !rounded.is_finite() {
657        return if rounded.is_sign_negative() {
658            i32::MIN
659        } else {
660            i32::MAX
661        };
662    }
663    if rounded < i32::MIN as f64 {
664        i32::MIN
665    } else if rounded > i32::MAX as f64 {
666        i32::MAX
667    } else {
668        rounded as i32
669    }
670}
671
672fn cast_to_u64(value: f64) -> u64 {
673    let rounded = rounded_scalar(value);
674    if !rounded.is_finite() {
675        return if rounded.is_sign_negative() {
676            0
677        } else {
678            u64::MAX
679        };
680    }
681    if rounded < 0.0 {
682        0
683    } else if rounded > u64::MAX as f64 {
684        u64::MAX
685    } else {
686        rounded as u64
687    }
688}
689
690fn cast_to_i64(value: f64) -> i64 {
691    let rounded = rounded_scalar(value);
692    if !rounded.is_finite() {
693        return if rounded.is_sign_negative() {
694            i64::MIN
695        } else {
696            i64::MAX
697        };
698    }
699    if rounded < i64::MIN as f64 {
700        i64::MIN
701    } else if rounded > i64::MAX as f64 {
702        i64::MAX
703    } else {
704        rounded as i64
705    }
706}
707
708fn cast_to_f32(value: f64) -> f32 {
709    value as f32
710}
711
712fn rounded_scalar(value: f64) -> f64 {
713    if value.is_nan() {
714        0.0
715    } else {
716        value.round()
717    }
718}
719
720fn extend_u16(buffer: &mut Vec<u8>, value: u16, order: ByteOrder) {
721    match order {
722        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
723        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
724    }
725}
726
727fn extend_i16(buffer: &mut Vec<u8>, value: i16, order: ByteOrder) {
728    match order {
729        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
730        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
731    }
732}
733
734fn extend_u32(buffer: &mut Vec<u8>, value: u32, order: ByteOrder) {
735    match order {
736        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
737        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
738    }
739}
740
741fn extend_i32(buffer: &mut Vec<u8>, value: i32, order: ByteOrder) {
742    match order {
743        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
744        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
745    }
746}
747
748fn extend_u64(buffer: &mut Vec<u8>, value: u64, order: ByteOrder) {
749    match order {
750        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
751        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
752    }
753}
754
755fn extend_i64(buffer: &mut Vec<u8>, value: i64, order: ByteOrder) {
756    match order {
757        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
758        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
759    }
760}
761
762fn extend_f32(buffer: &mut Vec<u8>, value: f32, order: ByteOrder) {
763    match order {
764        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
765        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
766    }
767}
768
769fn extend_f64(buffer: &mut Vec<u8>, value: f64, order: ByteOrder) {
770    match order {
771        ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
772        ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
773    }
774}
775
776fn parse_byte_order(text: &str) -> ByteOrder {
777    if text.eq_ignore_ascii_case("big-endian") || text.eq_ignore_ascii_case("big endian") {
778        ByteOrder::Big
779    } else {
780        ByteOrder::Little
781    }
782}
783
784fn scalar_string(value: &Value) -> BuiltinResult<String> {
785    match value {
786        Value::String(s) => Ok(s.clone()),
787        Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
788        Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
789        _ => Err(write_flow(
790            &WRITE_ERROR_INVALID_DATATYPE,
791            "write: datatype argument must be a string scalar or character row vector",
792        )),
793    }
794}
795
796fn extract_client_id(struct_value: &StructValue) -> BuiltinResult<u64> {
797    let id_value = struct_value
798        .fields
799        .get(CLIENT_HANDLE_FIELD)
800        .ok_or_else(|| {
801            write_flow(
802                &WRITE_ERROR_INVALID_CLIENT,
803                "write: tcpclient struct is missing internal handle",
804            )
805        })?;
806    match id_value {
807        Value::Int(IntValue::U64(id)) => Ok(*id),
808        Value::Int(iv) => Ok(iv.to_i64() as u64),
809        _ => Err(write_flow(
810            &WRITE_ERROR_INVALID_CLIENT,
811            "write: tcpclient struct has invalid handle field",
812        )),
813    }
814}
815
816enum WriteError {
817    Timeout,
818    ConnectionClosed,
819    Io(io::Error),
820}
821
822fn write_bytes(stream: &mut TcpStream, bytes: &[u8]) -> Result<(), WriteError> {
823    let mut offset = 0usize;
824    while offset < bytes.len() {
825        match stream.write(&bytes[offset..]) {
826            Ok(0) => return Err(WriteError::ConnectionClosed),
827            Ok(n) => offset += n,
828            Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
829            Err(err) if is_timeout(&err) => return Err(WriteError::Timeout),
830            Err(err) if is_connection_closed_error(&err) => {
831                return Err(WriteError::ConnectionClosed)
832            }
833            Err(err) => return Err(WriteError::Io(err)),
834        }
835    }
836    Ok(())
837}
838
839fn is_timeout(err: &io::Error) -> bool {
840    matches!(
841        err.kind(),
842        io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
843    )
844}
845
846fn is_connection_closed_error(err: &io::Error) -> bool {
847    matches!(
848        err.kind(),
849        io::ErrorKind::BrokenPipe
850            | io::ErrorKind::ConnectionReset
851            | io::ErrorKind::ConnectionAborted
852            | io::ErrorKind::NotConnected
853            | io::ErrorKind::UnexpectedEof
854    )
855}
856
857#[cfg(test)]
858pub(crate) mod tests {
859    use super::*;
860    use crate::builtins::io::net::accept::{
861        configure_stream, insert_client, remove_client_for_test,
862    };
863    use runmat_builtins::{CharArray, IntValue, StructValue, Tensor};
864    use std::io::Read;
865    use std::net::{TcpListener, TcpStream};
866    use std::sync::{Arc, Barrier};
867    use std::thread;
868
869    fn make_client(stream: TcpStream, timeout: f64, byte_order: &str) -> Value {
870        let peer_addr = stream.peer_addr().expect("peer addr");
871        configure_stream(&stream, timeout).expect("configure stream");
872        let client_id = insert_client(stream, 0, peer_addr, timeout, byte_order.to_string());
873        let mut st = StructValue::new();
874        st.fields.insert(
875            CLIENT_HANDLE_FIELD.to_string(),
876            Value::Int(IntValue::U64(client_id)),
877        );
878        Value::Struct(st)
879    }
880
881    fn client_id(client: &Value) -> u64 {
882        match client {
883            Value::Struct(st) => match st.fields.get(CLIENT_HANDLE_FIELD) {
884                Some(Value::Int(IntValue::U64(id))) => *id,
885                Some(Value::Int(iv)) => iv.to_i64() as u64,
886                other => panic!("unexpected id field {other:?}"),
887            },
888            other => panic!("expected struct, got {other:?}"),
889        }
890    }
891
892    fn assert_error_identifier(err: RuntimeError, expected: &str) {
893        assert_eq!(err.identifier(), Some(expected));
894    }
895
896    fn run_write(client: Value, data: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
897        futures::executor::block_on(write_builtin(client, data, rest))
898    }
899
900    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
901    #[test]
902    fn write_descriptor_signatures_cover_core_forms() {
903        let labels: Vec<&str> = WRITE_DESCRIPTOR
904            .signatures
905            .iter()
906            .map(|sig| sig.label)
907            .collect();
908        assert!(labels.contains(&"count = write(client, data)"));
909        assert!(labels.contains(&"count = write(client, data, datatype)"));
910    }
911
912    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913    #[test]
914    fn write_default_uint8_sends_bytes() {
915        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
916        let port = listener.local_addr().unwrap().port();
917        let handle = thread::spawn(move || {
918            let (mut stream, _) = listener.accept().expect("accept");
919            let mut received = Vec::new();
920            stream.read_to_end(&mut received).unwrap_or_default();
921            received
922        });
923
924        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
925        let client = make_client(stream, 1.0, "little-endian");
926        let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
927        let result = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect("write");
928        match result {
929            Value::Num(count) => assert_eq!(count, 4.0),
930            other => panic!("expected numeric result, got {other:?}"),
931        }
932        remove_client_for_test(client_id(&client));
933        let received = handle.join().expect("join");
934        assert_eq!(received, vec![1, 2, 3, 4]);
935    }
936
937    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
938    #[test]
939    fn write_double_big_endian_encodes_correctly() {
940        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
941        let port = listener.local_addr().unwrap().port();
942        let handle = thread::spawn(move || {
943            let (mut stream, _) = listener.accept().expect("accept");
944            let mut buf = [0u8; 24];
945            stream.read_exact(&mut buf).expect("read");
946            buf
947        });
948
949        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
950        let client = make_client(stream, 1.0, "big-endian");
951        let tensor = Tensor::new(vec![1.5, 2.5, 3.5], vec![1, 3]).unwrap();
952        let result = run_write(
953            client.clone(),
954            Value::Tensor(tensor),
955            vec![Value::from("double")],
956        )
957        .expect("write");
958        match result {
959            Value::Num(count) => assert_eq!(count, 3.0),
960            other => panic!("expected numeric count, got {other:?}"),
961        }
962        remove_client_for_test(client_id(&client));
963
964        let received = handle.join().expect("join");
965        let mut expected = Vec::new();
966        extend_f64(&mut expected, 1.5, ByteOrder::Big);
967        extend_f64(&mut expected, 2.5, ByteOrder::Big);
968        extend_f64(&mut expected, 3.5, ByteOrder::Big);
969        assert_eq!(received.to_vec(), expected);
970    }
971
972    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
973    #[test]
974    fn write_char_payload_encodes_ascii() {
975        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
976        let port = listener.local_addr().unwrap().port();
977        let handle = thread::spawn(move || {
978            let (mut stream, _) = listener.accept().expect("accept");
979            let mut buf = Vec::new();
980            stream.read_to_end(&mut buf).unwrap_or_default();
981            buf
982        });
983
984        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
985        let client = make_client(stream, 1.0, "little-endian");
986        let chars = CharArray::new("RunMat".chars().collect(), 1, 6).unwrap();
987        let result = run_write(
988            client.clone(),
989            Value::CharArray(chars),
990            vec![Value::from("char")],
991        )
992        .expect("write");
993        match result {
994            Value::Num(count) => assert_eq!(count, 6.0),
995            other => panic!("expected numeric count, got {other:?}"),
996        }
997        remove_client_for_test(client_id(&client));
998        let received = handle.join().expect("join");
999        assert_eq!(received, b"RunMat");
1000    }
1001
1002    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1003    #[test]
1004    fn write_errors_when_client_disconnected() {
1005        let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
1006        let port = listener.local_addr().unwrap().port();
1007        let barrier = Arc::new(Barrier::new(2));
1008        let thread_barrier = barrier.clone();
1009        let handle = thread::spawn(move || {
1010            let (stream, _) = listener.accept().expect("accept");
1011            thread_barrier.wait();
1012            drop(stream);
1013        });
1014
1015        let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
1016        let client = make_client(stream, 1.0, "little-endian");
1017        let id = client_id(&client);
1018        if let Some(handle_ref) = client_handle(id) {
1019            if let Ok(mut guard) = handle_ref.lock() {
1020                guard.connected = false;
1021            }
1022        }
1023
1024        let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
1025        let err = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect_err("write");
1026        assert_error_identifier(err, WRITE_ERROR_NOT_CONNECTED.identifier.unwrap());
1027
1028        remove_client_for_test(id);
1029        barrier.wait();
1030        handle.join().expect("join");
1031    }
1032}