1use runmat_builtins::{
4 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
5 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
6 IntValue, StructValue, Value,
7};
8use runmat_macros::runtime_builtin;
9use std::io::{self, Write};
10use std::net::TcpStream;
11
12use crate::builtins::common::spec::{
13 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
14 ReductionNaN, ResidencyPolicy, ShapeRequirements,
15};
16use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
17
18use super::accept::{client_handle, configure_stream, CLIENT_HANDLE_FIELD};
19
20const BUILTIN_NAME: &str = "write";
21
22const WRITE_OUTPUT_COUNT: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23 name: "count",
24 ty: BuiltinParamType::NumericScalar,
25 arity: BuiltinParamArity::Required,
26 default: None,
27 description: "Number of elements written to the socket.",
28}];
29const WRITE_INPUTS_CLIENT_DATA: [BuiltinParamDescriptor; 2] = [
30 BuiltinParamDescriptor {
31 name: "client",
32 ty: BuiltinParamType::Any,
33 arity: BuiltinParamArity::Required,
34 default: None,
35 description: "tcpclient handle struct.",
36 },
37 BuiltinParamDescriptor {
38 name: "data",
39 ty: BuiltinParamType::Any,
40 arity: BuiltinParamArity::Required,
41 default: None,
42 description: "Payload to send.",
43 },
44];
45const WRITE_INPUTS_CLIENT_DATA_DATATYPE: [BuiltinParamDescriptor; 3] = [
46 BuiltinParamDescriptor {
47 name: "client",
48 ty: BuiltinParamType::Any,
49 arity: BuiltinParamArity::Required,
50 default: None,
51 description: "tcpclient handle struct.",
52 },
53 BuiltinParamDescriptor {
54 name: "data",
55 ty: BuiltinParamType::Any,
56 arity: BuiltinParamArity::Required,
57 default: None,
58 description: "Payload to send.",
59 },
60 BuiltinParamDescriptor {
61 name: "datatype",
62 ty: BuiltinParamType::StringScalar,
63 arity: BuiltinParamArity::Optional,
64 default: Some("\"uint8\""),
65 description: "Data type label (for example \"uint8\", \"double\", \"char\", \"string\").",
66 },
67];
68const WRITE_SIGNATURES: [BuiltinSignatureDescriptor; 2] = [
69 BuiltinSignatureDescriptor {
70 label: "count = write(client, data)",
71 inputs: &WRITE_INPUTS_CLIENT_DATA,
72 outputs: &WRITE_OUTPUT_COUNT,
73 },
74 BuiltinSignatureDescriptor {
75 label: "count = write(client, data, datatype)",
76 inputs: &WRITE_INPUTS_CLIENT_DATA_DATATYPE,
77 outputs: &WRITE_OUTPUT_COUNT,
78 },
79];
80
81const WRITE_ERROR_INVALID_CLIENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
82 code: "RM.WRITE.INVALID_CLIENT",
83 identifier: Some("RunMat:write:InvalidTcpClient"),
84 when: "Client handle is missing, malformed, invalid, or disconnected.",
85 message: "write: invalid tcpclient handle",
86};
87const WRITE_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
88 code: "RM.WRITE.INVALID_INPUT",
89 identifier: Some("RunMat:write:InvalidInput"),
90 when: "Argument list shape is unsupported for write.",
91 message: "write: invalid argument list",
92};
93const WRITE_ERROR_INVALID_DATA: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
94 code: "RM.WRITE.INVALID_DATA",
95 identifier: Some("RunMat:write:InvalidData"),
96 when: "Payload cannot be converted to the requested datatype.",
97 message: "write: invalid data payload",
98};
99const WRITE_ERROR_INVALID_DATATYPE: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
100 code: "RM.WRITE.INVALID_DATATYPE",
101 identifier: Some("RunMat:write:InvalidDataType"),
102 when: "Datatype argument is not a supported scalar text label.",
103 message: "write: invalid datatype argument",
104};
105const WRITE_ERROR_NOT_CONNECTED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
106 code: "RM.WRITE.NOT_CONNECTED",
107 identifier: Some("RunMat:write:NotConnected"),
108 when: "Client has no active socket connection.",
109 message: "write: tcpclient is disconnected",
110};
111const WRITE_ERROR_TIMEOUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
112 code: "RM.WRITE.TIMEOUT",
113 identifier: Some("RunMat:write:Timeout"),
114 when: "Socket write exceeds configured timeout.",
115 message: "write: timed out while sending data",
116};
117const WRITE_ERROR_CONNECTION_CLOSED: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
118 code: "RM.WRITE.CONNECTION_CLOSED",
119 identifier: Some("RunMat:write:ConnectionClosed"),
120 when: "Peer closes socket before payload is fully written.",
121 message: "write: connection closed before all data was sent",
122};
123const WRITE_ERROR_INTERNAL: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
124 code: "RM.WRITE.INTERNAL",
125 identifier: Some("RunMat:write:InternalError"),
126 when: "Internal socket/control-flow conversion fails.",
127 message: "write: internal socket error",
128};
129const WRITE_ERRORS: [BuiltinErrorDescriptor; 8] = [
130 WRITE_ERROR_INVALID_CLIENT,
131 WRITE_ERROR_INVALID_INPUT,
132 WRITE_ERROR_INVALID_DATA,
133 WRITE_ERROR_INVALID_DATATYPE,
134 WRITE_ERROR_NOT_CONNECTED,
135 WRITE_ERROR_TIMEOUT,
136 WRITE_ERROR_CONNECTION_CLOSED,
137 WRITE_ERROR_INTERNAL,
138];
139pub const WRITE_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
140 signatures: &WRITE_SIGNATURES,
141 output_mode: BuiltinOutputMode::Fixed,
142 completion_policy: BuiltinCompletionPolicy::Public,
143 errors: &WRITE_ERRORS,
144};
145
146#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::net::write")]
147pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
148 name: "write",
149 op_kind: GpuOpKind::Custom("network"),
150 supported_precisions: &[],
151 broadcast: BroadcastSemantics::None,
152 provider_hooks: &[],
153 constant_strategy: ConstantStrategy::InlineLiteral,
154 residency: ResidencyPolicy::GatherImmediately,
155 nan_mode: ReductionNaN::Include,
156 two_pass_threshold: None,
157 workgroup_size: None,
158 accepts_nan_mode: false,
159 notes: "Socket writes always execute on the host CPU; GPU providers are never consulted.",
160};
161
162fn write_error_with_message(
163 message: impl Into<String>,
164 error: &'static BuiltinErrorDescriptor,
165) -> RuntimeError {
166 let mut builder = build_runtime_error(message).with_builtin(BUILTIN_NAME);
167 if let Some(identifier) = error.identifier {
168 builder = builder.with_identifier(identifier);
169 }
170 builder.build()
171}
172
173fn write_error(error: &'static BuiltinErrorDescriptor) -> RuntimeError {
174 write_error_with_message(error.message, error)
175}
176
177fn write_error_with_detail(
178 error: &'static BuiltinErrorDescriptor,
179 detail: impl AsRef<str>,
180) -> RuntimeError {
181 let detail = detail.as_ref();
182 let detail = detail.strip_prefix("write: ").unwrap_or(detail);
183 write_error_with_message(format!("{}: {}", error.message, detail), error)
184}
185
186fn write_flow(error: &'static BuiltinErrorDescriptor, message: impl AsRef<str>) -> RuntimeError {
187 write_error_with_detail(error, message)
188}
189
190fn map_write_flow(err: RuntimeError, error: &'static BuiltinErrorDescriptor) -> RuntimeError {
191 let mut builder = build_runtime_error(format!("{BUILTIN_NAME}: {}", err.message()))
192 .with_builtin(BUILTIN_NAME)
193 .with_source(err);
194 if let Some(identifier) = error.identifier {
195 builder = builder.with_identifier(identifier);
196 }
197 builder.build()
198}
199
200#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::net::write")]
201pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
202 name: "write",
203 shape: ShapeRequirements::Any,
204 constant_strategy: ConstantStrategy::InlineLiteral,
205 elementwise: None,
206 reduction: None,
207 emits_nan: false,
208 notes: "Networking builtin executed eagerly on the CPU.",
209};
210
211#[runtime_builtin(
212 name = "write",
213 category = "io/net",
214 summary = "Write numeric or text payloads to TCP client connections.",
215 keywords = "write,tcpclient,networking",
216 type_resolver(crate::builtins::io::type_resolvers::write_type),
217 descriptor(crate::builtins::io::net::write::WRITE_DESCRIPTOR),
218 builtin_path = "crate::builtins::io::net::write"
219)]
220async fn write_builtin(
221 client: Value,
222 data: Value,
223 rest: Vec<Value>,
224) -> crate::BuiltinResult<Value> {
225 let client = gather_if_needed_async(&client)
226 .await
227 .map_err(|flow| map_write_flow(flow, &WRITE_ERROR_INVALID_CLIENT))?;
228 let data = gather_if_needed_async(&data)
229 .await
230 .map_err(|flow| map_write_flow(flow, &WRITE_ERROR_INVALID_DATA))?;
231
232 let mut gathered_rest = Vec::with_capacity(rest.len());
233 for value in rest {
234 gathered_rest.push(
235 gather_if_needed_async(&value)
236 .await
237 .map_err(|flow| map_write_flow(flow, &WRITE_ERROR_INVALID_DATATYPE))?,
238 );
239 }
240 let datatype = parse_arguments(&gathered_rest)?;
241
242 let client_struct = match &client {
243 Value::Struct(st) => st,
244 _ => {
245 return Err(write_flow(
246 &WRITE_ERROR_INVALID_CLIENT,
247 "write: expected tcpclient struct as first argument",
248 ))
249 }
250 };
251
252 let client_id = extract_client_id(client_struct)?;
253 let handle = client_handle(client_id).ok_or_else(|| {
254 write_flow(
255 &WRITE_ERROR_INVALID_CLIENT,
256 "write: tcpclient handle is no longer valid",
257 )
258 })?;
259
260 let (mut stream, timeout, byte_order) = {
261 let guard = handle.lock().unwrap_or_else(|poison| poison.into_inner());
262 if !guard.connected {
263 return Err(write_error(&WRITE_ERROR_NOT_CONNECTED));
264 }
265 let timeout = guard.timeout;
266 let byte_order = parse_byte_order(&guard.byte_order);
267 let stream = guard.stream.try_clone().map_err(|err| {
268 write_flow(
269 &WRITE_ERROR_INTERNAL,
270 format!("write: clone failed ({err})"),
271 )
272 })?;
273 (stream, timeout, byte_order)
274 };
275
276 if let Err(err) = configure_stream(&stream, timeout) {
277 return Err(write_flow(
278 &WRITE_ERROR_INTERNAL,
279 format!("write: unable to configure socket timeout ({err})"),
280 ));
281 }
282
283 let payload = prepare_payload(&data, datatype, byte_order)?;
284 if payload.bytes.is_empty() {
285 return Ok(Value::Num(0.0));
286 }
287
288 match write_bytes(&mut stream, &payload.bytes) {
289 Ok(_) => Ok(Value::Num(payload.elements as f64)),
290 Err(WriteError::Timeout) => Err(write_error(&WRITE_ERROR_TIMEOUT)),
291 Err(WriteError::ConnectionClosed) => {
292 if let Ok(mut guard) = handle.lock() {
293 guard.connected = false;
294 }
295 Err(write_error(&WRITE_ERROR_CONNECTION_CLOSED))
296 }
297 Err(WriteError::Io(err)) => Err(write_flow(
298 &WRITE_ERROR_INTERNAL,
299 format!("write: socket error ({err})"),
300 )),
301 }
302}
303
304#[derive(Clone, Copy)]
305enum DataType {
306 UInt8,
307 Int8,
308 UInt16,
309 Int16,
310 UInt32,
311 Int32,
312 UInt64,
313 Int64,
314 Single,
315 Double,
316 Char,
317 String,
318}
319
320impl DataType {
321 fn default() -> Self {
322 DataType::UInt8
323 }
324
325 fn element_size(self) -> usize {
326 match self {
327 DataType::UInt8 | DataType::Int8 | DataType::Char | DataType::String => 1,
328 DataType::UInt16 | DataType::Int16 => 2,
329 DataType::UInt32 | DataType::Int32 | DataType::Single => 4,
330 DataType::UInt64 | DataType::Int64 | DataType::Double => 8,
331 }
332 }
333}
334
335#[derive(Clone, Copy)]
336enum ByteOrder {
337 Little,
338 Big,
339}
340
341struct Payload {
342 bytes: Vec<u8>,
343 elements: usize,
344}
345
346fn parse_arguments(args: &[Value]) -> BuiltinResult<DataType> {
347 match args.len() {
348 0 => Ok(DataType::default()),
349 1 => parse_datatype(&args[0]),
350 _ => Err(write_flow(
351 &WRITE_ERROR_INVALID_INPUT,
352 "write: expected at most one datatype argument",
353 )),
354 }
355}
356
357fn parse_datatype(value: &Value) -> BuiltinResult<DataType> {
358 let text = scalar_string(value)?;
359 let lowered = text.trim().to_ascii_lowercase();
360 if lowered.is_empty() {
361 return Err(write_flow(
362 &WRITE_ERROR_INVALID_DATATYPE,
363 "write: datatype must not be empty",
364 ));
365 }
366 let dtype = match lowered.as_str() {
367 "uint8" => DataType::UInt8,
368 "int8" => DataType::Int8,
369 "uint16" => DataType::UInt16,
370 "int16" => DataType::Int16,
371 "uint32" => DataType::UInt32,
372 "int32" => DataType::Int32,
373 "uint64" => DataType::UInt64,
374 "int64" => DataType::Int64,
375 "single" => DataType::Single,
376 "double" => DataType::Double,
377 "char" => DataType::Char,
378 "string" => DataType::String,
379 _ => {
380 return Err(write_flow(
381 &WRITE_ERROR_INVALID_DATATYPE,
382 format!("write: unsupported datatype '{text}'"),
383 ))
384 }
385 };
386 Ok(dtype)
387}
388
389fn prepare_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
390 match datatype {
391 DataType::Char => char_payload(data),
392 DataType::String => string_payload(data),
393 _ => numeric_payload(data, datatype, order),
394 }
395}
396
397fn numeric_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
398 let values = flatten_numeric(data)?;
399 let mut bytes = Vec::with_capacity(values.len() * datatype.element_size());
400 for value in values.iter().copied() {
401 match datatype {
402 DataType::UInt8 => bytes.push(cast_to_u8(value)),
403 DataType::Int8 => bytes.push(cast_to_i8(value) as u8),
404 DataType::UInt16 => extend_u16(&mut bytes, cast_to_u16(value), order),
405 DataType::Int16 => extend_i16(&mut bytes, cast_to_i16(value), order),
406 DataType::UInt32 => extend_u32(&mut bytes, cast_to_u32(value), order),
407 DataType::Int32 => extend_i32(&mut bytes, cast_to_i32(value), order),
408 DataType::UInt64 => extend_u64(&mut bytes, cast_to_u64(value), order),
409 DataType::Int64 => extend_i64(&mut bytes, cast_to_i64(value), order),
410 DataType::Single => extend_f32(&mut bytes, cast_to_f32(value), order),
411 DataType::Double => extend_f64(&mut bytes, value, order),
412 DataType::Char | DataType::String => unreachable!(),
413 }
414 }
415 Ok(Payload {
416 bytes,
417 elements: values.len(),
418 })
419}
420
421fn char_payload(data: &Value) -> BuiltinResult<Payload> {
422 let bytes = match data {
423 Value::CharArray(ca) => ca.data.iter().map(|&ch| (ch as u32 & 0xFF) as u8).collect(),
424 Value::String(text) => text.bytes().collect(),
425 Value::StringArray(sa) => {
426 if sa.data.len() != 1 {
427 return Err(write_flow(
428 &WRITE_ERROR_INVALID_DATA,
429 "write: string array input must be scalar when using 'char'",
430 ));
431 }
432 sa.data[0].as_bytes().to_vec()
433 }
434 Value::Tensor(t) => t.data.iter().map(|&v| cast_to_u8(v)).collect::<Vec<u8>>(),
435 Value::Num(n) => vec![cast_to_u8(*n)],
436 Value::Int(iv) => vec![cast_to_u8(iv.to_f64())],
437 Value::Bool(b) => vec![if *b { 1 } else { 0 }],
438 Value::LogicalArray(la) => la
439 .data
440 .iter()
441 .map(|&b| if b != 0 { 1 } else { 0 })
442 .collect(),
443 _ => {
444 return Err(write_flow(
445 &WRITE_ERROR_INVALID_DATA,
446 "write: unsupported input for 'char' datatype",
447 ))
448 }
449 };
450 Ok(Payload {
451 elements: bytes.len(),
452 bytes,
453 })
454}
455
456fn string_payload(data: &Value) -> BuiltinResult<Payload> {
457 match data {
458 Value::String(text) => Ok(Payload {
459 elements: 1,
460 bytes: text.as_bytes().to_vec(),
461 }),
462 Value::CharArray(ca) => {
463 let string: String = ca.data.iter().collect();
464 Ok(Payload {
465 elements: 1,
466 bytes: string.into_bytes(),
467 })
468 }
469 Value::StringArray(sa) => {
470 if sa.data.is_empty() {
471 return Ok(Payload {
472 elements: 0,
473 bytes: Vec::new(),
474 });
475 }
476 if sa.data.len() != 1 {
477 return Err(write_flow(
478 &WRITE_ERROR_INVALID_DATA,
479 "write: string array input must be scalar when using 'string'",
480 ));
481 }
482 Ok(Payload {
483 elements: 1,
484 bytes: sa.data[0].as_bytes().to_vec(),
485 })
486 }
487 _ => Err(write_flow(
488 &WRITE_ERROR_INVALID_DATA,
489 "write: expected text input when using 'string' datatype",
490 )),
491 }
492}
493
494fn flatten_numeric(value: &Value) -> BuiltinResult<Vec<f64>> {
495 match value {
496 Value::Tensor(t) => Ok(t.data.clone()),
497 Value::SparseTensor(s) => {
498 let total_elements = s.rows.checked_mul(s.cols).ok_or_else(|| {
499 write_error_with_message(
500 "write: sparse matrix dimensions overflow",
501 &WRITE_ERROR_INTERNAL,
502 )
503 })?;
504 if total_elements > 10_000_000 {
505 return Err(write_error_with_message(
506 format!("write: cannot densify sparse tensor {}x{} ({} elements exceeds safe threshold)", s.rows, s.cols, total_elements),
507 &WRITE_ERROR_INTERNAL,
508 ));
509 }
510 s.to_dense().map(|dense| dense.data).map_err(|err| {
511 write_error_with_message(format!("write: {err}"), &WRITE_ERROR_INTERNAL)
512 })
513 }
514 Value::Num(n) => Ok(vec![*n]),
515 Value::Int(iv) => Ok(vec![iv.to_f64()]),
516 Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
517 Value::LogicalArray(la) => Ok(la
518 .data
519 .iter()
520 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
521 .collect()),
522 Value::CharArray(ca) => Ok(ca
523 .data
524 .iter()
525 .map(|&ch| (ch as u32 & 0xFF) as f64)
526 .collect()),
527 Value::String(text) => Ok(text.chars().map(|ch| (ch as u32) as f64).collect()),
528 Value::StringArray(sa) => {
529 if sa.data.len() != 1 {
530 return Err(write_flow(
531 &WRITE_ERROR_INVALID_DATA,
532 "write: string array input must be scalar",
533 ));
534 }
535 Ok(sa.data[0].chars().map(|ch| (ch as u32) as f64).collect())
536 }
537 Value::Complex(_, _) | Value::ComplexTensor(_) => Err(write_flow(
538 &WRITE_ERROR_INVALID_DATA,
539 "write: complex data is not supported",
540 )),
541 Value::Cell(_)
542 | Value::Struct(_)
543 | Value::Object(_)
544 | Value::HandleObject(_)
545 | Value::Listener(_)
546 | Value::FunctionHandle(_)
547 | Value::ExternalFunctionHandle(_)
548 | Value::MethodFunctionHandle(_)
549 | Value::BoundFunctionHandle { .. }
550 | Value::Closure(_)
551 | Value::ClassRef(_)
552 | Value::MException(_)
553 | Value::OutputList(_) => Err(write_flow(
554 &WRITE_ERROR_INVALID_DATA,
555 "write: unsupported input type",
556 )),
557 Value::GpuTensor(_) => Err(write_flow(
558 &WRITE_ERROR_INVALID_DATA,
559 "write: GPU tensor should have been gathered before encoding",
560 )),
561 }
562}
563
564fn cast_to_u8(value: f64) -> u8 {
565 let rounded = rounded_scalar(value);
566 if !rounded.is_finite() {
567 return if rounded.is_sign_negative() {
568 0
569 } else {
570 u8::MAX
571 };
572 }
573 if rounded < 0.0 {
574 0
575 } else if rounded > u8::MAX as f64 {
576 u8::MAX
577 } else {
578 rounded as u8
579 }
580}
581
582fn cast_to_i8(value: f64) -> i8 {
583 let rounded = rounded_scalar(value);
584 if !rounded.is_finite() {
585 return if rounded.is_sign_negative() {
586 i8::MIN
587 } else {
588 i8::MAX
589 };
590 }
591 if rounded < i8::MIN as f64 {
592 i8::MIN
593 } else if rounded > i8::MAX as f64 {
594 i8::MAX
595 } else {
596 rounded as i8
597 }
598}
599
600fn cast_to_u16(value: f64) -> u16 {
601 let rounded = rounded_scalar(value);
602 if !rounded.is_finite() {
603 return if rounded.is_sign_negative() {
604 0
605 } else {
606 u16::MAX
607 };
608 }
609 if rounded < 0.0 {
610 0
611 } else if rounded > u16::MAX as f64 {
612 u16::MAX
613 } else {
614 rounded as u16
615 }
616}
617
618fn cast_to_i16(value: f64) -> i16 {
619 let rounded = rounded_scalar(value);
620 if !rounded.is_finite() {
621 return if rounded.is_sign_negative() {
622 i16::MIN
623 } else {
624 i16::MAX
625 };
626 }
627 if rounded < i16::MIN as f64 {
628 i16::MIN
629 } else if rounded > i16::MAX as f64 {
630 i16::MAX
631 } else {
632 rounded as i16
633 }
634}
635
636fn cast_to_u32(value: f64) -> u32 {
637 let rounded = rounded_scalar(value);
638 if !rounded.is_finite() {
639 return if rounded.is_sign_negative() {
640 0
641 } else {
642 u32::MAX
643 };
644 }
645 if rounded < 0.0 {
646 0
647 } else if rounded > u32::MAX as f64 {
648 u32::MAX
649 } else {
650 rounded as u32
651 }
652}
653
654fn cast_to_i32(value: f64) -> i32 {
655 let rounded = rounded_scalar(value);
656 if !rounded.is_finite() {
657 return if rounded.is_sign_negative() {
658 i32::MIN
659 } else {
660 i32::MAX
661 };
662 }
663 if rounded < i32::MIN as f64 {
664 i32::MIN
665 } else if rounded > i32::MAX as f64 {
666 i32::MAX
667 } else {
668 rounded as i32
669 }
670}
671
672fn cast_to_u64(value: f64) -> u64 {
673 let rounded = rounded_scalar(value);
674 if !rounded.is_finite() {
675 return if rounded.is_sign_negative() {
676 0
677 } else {
678 u64::MAX
679 };
680 }
681 if rounded < 0.0 {
682 0
683 } else if rounded > u64::MAX as f64 {
684 u64::MAX
685 } else {
686 rounded as u64
687 }
688}
689
690fn cast_to_i64(value: f64) -> i64 {
691 let rounded = rounded_scalar(value);
692 if !rounded.is_finite() {
693 return if rounded.is_sign_negative() {
694 i64::MIN
695 } else {
696 i64::MAX
697 };
698 }
699 if rounded < i64::MIN as f64 {
700 i64::MIN
701 } else if rounded > i64::MAX as f64 {
702 i64::MAX
703 } else {
704 rounded as i64
705 }
706}
707
708fn cast_to_f32(value: f64) -> f32 {
709 value as f32
710}
711
712fn rounded_scalar(value: f64) -> f64 {
713 if value.is_nan() {
714 0.0
715 } else {
716 value.round()
717 }
718}
719
720fn extend_u16(buffer: &mut Vec<u8>, value: u16, order: ByteOrder) {
721 match order {
722 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
723 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
724 }
725}
726
727fn extend_i16(buffer: &mut Vec<u8>, value: i16, order: ByteOrder) {
728 match order {
729 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
730 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
731 }
732}
733
734fn extend_u32(buffer: &mut Vec<u8>, value: u32, order: ByteOrder) {
735 match order {
736 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
737 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
738 }
739}
740
741fn extend_i32(buffer: &mut Vec<u8>, value: i32, order: ByteOrder) {
742 match order {
743 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
744 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
745 }
746}
747
748fn extend_u64(buffer: &mut Vec<u8>, value: u64, order: ByteOrder) {
749 match order {
750 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
751 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
752 }
753}
754
755fn extend_i64(buffer: &mut Vec<u8>, value: i64, order: ByteOrder) {
756 match order {
757 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
758 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
759 }
760}
761
762fn extend_f32(buffer: &mut Vec<u8>, value: f32, order: ByteOrder) {
763 match order {
764 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
765 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
766 }
767}
768
769fn extend_f64(buffer: &mut Vec<u8>, value: f64, order: ByteOrder) {
770 match order {
771 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
772 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
773 }
774}
775
776fn parse_byte_order(text: &str) -> ByteOrder {
777 if text.eq_ignore_ascii_case("big-endian") || text.eq_ignore_ascii_case("big endian") {
778 ByteOrder::Big
779 } else {
780 ByteOrder::Little
781 }
782}
783
784fn scalar_string(value: &Value) -> BuiltinResult<String> {
785 match value {
786 Value::String(s) => Ok(s.clone()),
787 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
788 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
789 _ => Err(write_flow(
790 &WRITE_ERROR_INVALID_DATATYPE,
791 "write: datatype argument must be a string scalar or character row vector",
792 )),
793 }
794}
795
796fn extract_client_id(struct_value: &StructValue) -> BuiltinResult<u64> {
797 let id_value = struct_value
798 .fields
799 .get(CLIENT_HANDLE_FIELD)
800 .ok_or_else(|| {
801 write_flow(
802 &WRITE_ERROR_INVALID_CLIENT,
803 "write: tcpclient struct is missing internal handle",
804 )
805 })?;
806 match id_value {
807 Value::Int(IntValue::U64(id)) => Ok(*id),
808 Value::Int(iv) => Ok(iv.to_i64() as u64),
809 _ => Err(write_flow(
810 &WRITE_ERROR_INVALID_CLIENT,
811 "write: tcpclient struct has invalid handle field",
812 )),
813 }
814}
815
816enum WriteError {
817 Timeout,
818 ConnectionClosed,
819 Io(io::Error),
820}
821
822fn write_bytes(stream: &mut TcpStream, bytes: &[u8]) -> Result<(), WriteError> {
823 let mut offset = 0usize;
824 while offset < bytes.len() {
825 match stream.write(&bytes[offset..]) {
826 Ok(0) => return Err(WriteError::ConnectionClosed),
827 Ok(n) => offset += n,
828 Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
829 Err(err) if is_timeout(&err) => return Err(WriteError::Timeout),
830 Err(err) if is_connection_closed_error(&err) => {
831 return Err(WriteError::ConnectionClosed)
832 }
833 Err(err) => return Err(WriteError::Io(err)),
834 }
835 }
836 Ok(())
837}
838
839fn is_timeout(err: &io::Error) -> bool {
840 matches!(
841 err.kind(),
842 io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
843 )
844}
845
846fn is_connection_closed_error(err: &io::Error) -> bool {
847 matches!(
848 err.kind(),
849 io::ErrorKind::BrokenPipe
850 | io::ErrorKind::ConnectionReset
851 | io::ErrorKind::ConnectionAborted
852 | io::ErrorKind::NotConnected
853 | io::ErrorKind::UnexpectedEof
854 )
855}
856
857#[cfg(test)]
858pub(crate) mod tests {
859 use super::*;
860 use crate::builtins::io::net::accept::{
861 configure_stream, insert_client, remove_client_for_test,
862 };
863 use runmat_builtins::{CharArray, IntValue, StructValue, Tensor};
864 use std::io::Read;
865 use std::net::{TcpListener, TcpStream};
866 use std::sync::{Arc, Barrier};
867 use std::thread;
868
869 fn make_client(stream: TcpStream, timeout: f64, byte_order: &str) -> Value {
870 let peer_addr = stream.peer_addr().expect("peer addr");
871 configure_stream(&stream, timeout).expect("configure stream");
872 let client_id = insert_client(stream, 0, peer_addr, timeout, byte_order.to_string());
873 let mut st = StructValue::new();
874 st.fields.insert(
875 CLIENT_HANDLE_FIELD.to_string(),
876 Value::Int(IntValue::U64(client_id)),
877 );
878 Value::Struct(st)
879 }
880
881 fn client_id(client: &Value) -> u64 {
882 match client {
883 Value::Struct(st) => match st.fields.get(CLIENT_HANDLE_FIELD) {
884 Some(Value::Int(IntValue::U64(id))) => *id,
885 Some(Value::Int(iv)) => iv.to_i64() as u64,
886 other => panic!("unexpected id field {other:?}"),
887 },
888 other => panic!("expected struct, got {other:?}"),
889 }
890 }
891
892 fn assert_error_identifier(err: RuntimeError, expected: &str) {
893 assert_eq!(err.identifier(), Some(expected));
894 }
895
896 fn run_write(client: Value, data: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
897 futures::executor::block_on(write_builtin(client, data, rest))
898 }
899
900 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
901 #[test]
902 fn write_descriptor_signatures_cover_core_forms() {
903 let labels: Vec<&str> = WRITE_DESCRIPTOR
904 .signatures
905 .iter()
906 .map(|sig| sig.label)
907 .collect();
908 assert!(labels.contains(&"count = write(client, data)"));
909 assert!(labels.contains(&"count = write(client, data, datatype)"));
910 }
911
912 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
913 #[test]
914 fn write_default_uint8_sends_bytes() {
915 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
916 let port = listener.local_addr().unwrap().port();
917 let handle = thread::spawn(move || {
918 let (mut stream, _) = listener.accept().expect("accept");
919 let mut received = Vec::new();
920 stream.read_to_end(&mut received).unwrap_or_default();
921 received
922 });
923
924 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
925 let client = make_client(stream, 1.0, "little-endian");
926 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
927 let result = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect("write");
928 match result {
929 Value::Num(count) => assert_eq!(count, 4.0),
930 other => panic!("expected numeric result, got {other:?}"),
931 }
932 remove_client_for_test(client_id(&client));
933 let received = handle.join().expect("join");
934 assert_eq!(received, vec![1, 2, 3, 4]);
935 }
936
937 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
938 #[test]
939 fn write_double_big_endian_encodes_correctly() {
940 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
941 let port = listener.local_addr().unwrap().port();
942 let handle = thread::spawn(move || {
943 let (mut stream, _) = listener.accept().expect("accept");
944 let mut buf = [0u8; 24];
945 stream.read_exact(&mut buf).expect("read");
946 buf
947 });
948
949 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
950 let client = make_client(stream, 1.0, "big-endian");
951 let tensor = Tensor::new(vec![1.5, 2.5, 3.5], vec![1, 3]).unwrap();
952 let result = run_write(
953 client.clone(),
954 Value::Tensor(tensor),
955 vec![Value::from("double")],
956 )
957 .expect("write");
958 match result {
959 Value::Num(count) => assert_eq!(count, 3.0),
960 other => panic!("expected numeric count, got {other:?}"),
961 }
962 remove_client_for_test(client_id(&client));
963
964 let received = handle.join().expect("join");
965 let mut expected = Vec::new();
966 extend_f64(&mut expected, 1.5, ByteOrder::Big);
967 extend_f64(&mut expected, 2.5, ByteOrder::Big);
968 extend_f64(&mut expected, 3.5, ByteOrder::Big);
969 assert_eq!(received.to_vec(), expected);
970 }
971
972 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
973 #[test]
974 fn write_char_payload_encodes_ascii() {
975 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
976 let port = listener.local_addr().unwrap().port();
977 let handle = thread::spawn(move || {
978 let (mut stream, _) = listener.accept().expect("accept");
979 let mut buf = Vec::new();
980 stream.read_to_end(&mut buf).unwrap_or_default();
981 buf
982 });
983
984 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
985 let client = make_client(stream, 1.0, "little-endian");
986 let chars = CharArray::new("RunMat".chars().collect(), 1, 6).unwrap();
987 let result = run_write(
988 client.clone(),
989 Value::CharArray(chars),
990 vec![Value::from("char")],
991 )
992 .expect("write");
993 match result {
994 Value::Num(count) => assert_eq!(count, 6.0),
995 other => panic!("expected numeric count, got {other:?}"),
996 }
997 remove_client_for_test(client_id(&client));
998 let received = handle.join().expect("join");
999 assert_eq!(received, b"RunMat");
1000 }
1001
1002 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1003 #[test]
1004 fn write_errors_when_client_disconnected() {
1005 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
1006 let port = listener.local_addr().unwrap().port();
1007 let barrier = Arc::new(Barrier::new(2));
1008 let thread_barrier = barrier.clone();
1009 let handle = thread::spawn(move || {
1010 let (stream, _) = listener.accept().expect("accept");
1011 thread_barrier.wait();
1012 drop(stream);
1013 });
1014
1015 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
1016 let client = make_client(stream, 1.0, "little-endian");
1017 let id = client_id(&client);
1018 if let Some(handle_ref) = client_handle(id) {
1019 if let Ok(mut guard) = handle_ref.lock() {
1020 guard.connected = false;
1021 }
1022 }
1023
1024 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
1025 let err = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect_err("write");
1026 assert_error_identifier(err, WRITE_ERROR_NOT_CONNECTED.identifier.unwrap());
1027
1028 remove_client_for_test(id);
1029 barrier.wait();
1030 handle.join().expect("join");
1031 }
1032}