1use runmat_builtins::{IntValue, StructValue, Value};
4use runmat_macros::runtime_builtin;
5use std::io::{self, Write};
6use std::net::TcpStream;
7
8use crate::builtins::common::spec::{
9 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
10 ReductionNaN, ResidencyPolicy, ShapeRequirements,
11};
12use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
13
14use super::accept::{client_handle, configure_stream, CLIENT_HANDLE_FIELD};
15
16const MESSAGE_ID_INVALID_CLIENT: &str = "RunMat:write:InvalidTcpClient";
17const MESSAGE_ID_INVALID_DATA: &str = "RunMat:write:InvalidData";
18const MESSAGE_ID_INVALID_DATATYPE: &str = "RunMat:write:InvalidDataType";
19const MESSAGE_ID_NOT_CONNECTED: &str = "RunMat:write:NotConnected";
20const MESSAGE_ID_TIMEOUT: &str = "RunMat:write:Timeout";
21const MESSAGE_ID_CONNECTION_CLOSED: &str = "RunMat:write:ConnectionClosed";
22const MESSAGE_ID_INTERNAL: &str = "RunMat:write:InternalError";
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::net::write")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "write",
27 op_kind: GpuOpKind::Custom("network"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::GatherImmediately,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "Socket writes always execute on the host CPU; GPU providers are never consulted.",
38};
39
40fn write_error(message_id: &'static str, message: impl Into<String>) -> RuntimeError {
41 build_runtime_error(message)
42 .with_identifier(message_id)
43 .with_builtin("write")
44 .build()
45}
46
47fn write_flow(message_id: &'static str, message: impl Into<String>) -> RuntimeError {
48 write_error(message_id, message)
49}
50
51fn map_write_flow(err: RuntimeError, message_id: &'static str, context: &str) -> RuntimeError {
52 build_runtime_error(format!("{context}: {}", err.message()))
53 .with_identifier(message_id)
54 .with_builtin("write")
55 .with_source(err)
56 .build()
57}
58
59#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::net::write")]
60pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
61 name: "write",
62 shape: ShapeRequirements::Any,
63 constant_strategy: ConstantStrategy::InlineLiteral,
64 elementwise: None,
65 reduction: None,
66 emits_nan: false,
67 notes: "Networking builtin executed eagerly on the CPU.",
68};
69
70#[runtime_builtin(
71 name = "write",
72 category = "io/net",
73 summary = "Write numeric or text data to a TCP/IP client.",
74 keywords = "write,tcpclient,networking",
75 type_resolver(crate::builtins::io::type_resolvers::write_type),
76 builtin_path = "crate::builtins::io::net::write"
77)]
78async fn write_builtin(
79 client: Value,
80 data: Value,
81 rest: Vec<Value>,
82) -> crate::BuiltinResult<Value> {
83 let client = gather_if_needed_async(&client)
84 .await
85 .map_err(|flow| map_write_flow(flow, MESSAGE_ID_INVALID_CLIENT, "write"))?;
86 let data = gather_if_needed_async(&data)
87 .await
88 .map_err(|flow| map_write_flow(flow, MESSAGE_ID_INVALID_DATA, "write"))?;
89
90 let mut gathered_rest = Vec::with_capacity(rest.len());
91 for value in rest {
92 gathered_rest.push(
93 gather_if_needed_async(&value)
94 .await
95 .map_err(|flow| map_write_flow(flow, MESSAGE_ID_INVALID_DATATYPE, "write"))?,
96 );
97 }
98 let datatype = parse_arguments(&gathered_rest)?;
99
100 let client_struct = match &client {
101 Value::Struct(st) => st,
102 _ => {
103 return Err(write_flow(
104 MESSAGE_ID_INVALID_CLIENT,
105 "write: expected tcpclient struct as first argument",
106 ))
107 }
108 };
109
110 let client_id = extract_client_id(client_struct)?;
111 let handle = client_handle(client_id).ok_or_else(|| {
112 write_flow(
113 MESSAGE_ID_INVALID_CLIENT,
114 "write: tcpclient handle is no longer valid",
115 )
116 })?;
117
118 let (mut stream, timeout, byte_order) = {
119 let guard = handle.lock().unwrap_or_else(|poison| poison.into_inner());
120 if !guard.connected {
121 return Err(write_flow(
122 MESSAGE_ID_NOT_CONNECTED,
123 "write: tcpclient is disconnected",
124 ));
125 }
126 let timeout = guard.timeout;
127 let byte_order = parse_byte_order(&guard.byte_order);
128 let stream = guard.stream.try_clone().map_err(|err| {
129 write_flow(MESSAGE_ID_INTERNAL, format!("write: clone failed ({err})"))
130 })?;
131 (stream, timeout, byte_order)
132 };
133
134 if let Err(err) = configure_stream(&stream, timeout) {
135 return Err(write_flow(
136 MESSAGE_ID_INTERNAL,
137 format!("write: unable to configure socket timeout ({err})"),
138 ));
139 }
140
141 let payload = prepare_payload(&data, datatype, byte_order)?;
142 if payload.bytes.is_empty() {
143 return Ok(Value::Num(0.0));
144 }
145
146 match write_bytes(&mut stream, &payload.bytes) {
147 Ok(_) => Ok(Value::Num(payload.elements as f64)),
148 Err(WriteError::Timeout) => Err(write_flow(
149 MESSAGE_ID_TIMEOUT,
150 "write: timed out while sending data",
151 )),
152 Err(WriteError::ConnectionClosed) => {
153 if let Ok(mut guard) = handle.lock() {
154 guard.connected = false;
155 }
156 Err(write_flow(
157 MESSAGE_ID_CONNECTION_CLOSED,
158 "write: connection closed before all data was sent",
159 ))
160 }
161 Err(WriteError::Io(err)) => Err(write_flow(
162 MESSAGE_ID_INTERNAL,
163 format!("write: socket error ({err})"),
164 )),
165 }
166}
167
168#[derive(Clone, Copy)]
169enum DataType {
170 UInt8,
171 Int8,
172 UInt16,
173 Int16,
174 UInt32,
175 Int32,
176 UInt64,
177 Int64,
178 Single,
179 Double,
180 Char,
181 String,
182}
183
184impl DataType {
185 fn default() -> Self {
186 DataType::UInt8
187 }
188
189 fn element_size(self) -> usize {
190 match self {
191 DataType::UInt8 | DataType::Int8 | DataType::Char | DataType::String => 1,
192 DataType::UInt16 | DataType::Int16 => 2,
193 DataType::UInt32 | DataType::Int32 | DataType::Single => 4,
194 DataType::UInt64 | DataType::Int64 | DataType::Double => 8,
195 }
196 }
197}
198
199#[derive(Clone, Copy)]
200enum ByteOrder {
201 Little,
202 Big,
203}
204
205struct Payload {
206 bytes: Vec<u8>,
207 elements: usize,
208}
209
210fn parse_arguments(args: &[Value]) -> BuiltinResult<DataType> {
211 match args.len() {
212 0 => Ok(DataType::default()),
213 1 => parse_datatype(&args[0]),
214 _ => Err(write_flow(
215 MESSAGE_ID_INVALID_DATATYPE,
216 "write: expected at most one datatype argument",
217 )),
218 }
219}
220
221fn parse_datatype(value: &Value) -> BuiltinResult<DataType> {
222 let text = scalar_string(value)?;
223 let lowered = text.trim().to_ascii_lowercase();
224 if lowered.is_empty() {
225 return Err(write_flow(
226 MESSAGE_ID_INVALID_DATATYPE,
227 "write: datatype must not be empty",
228 ));
229 }
230 let dtype = match lowered.as_str() {
231 "uint8" => DataType::UInt8,
232 "int8" => DataType::Int8,
233 "uint16" => DataType::UInt16,
234 "int16" => DataType::Int16,
235 "uint32" => DataType::UInt32,
236 "int32" => DataType::Int32,
237 "uint64" => DataType::UInt64,
238 "int64" => DataType::Int64,
239 "single" => DataType::Single,
240 "double" => DataType::Double,
241 "char" => DataType::Char,
242 "string" => DataType::String,
243 _ => {
244 return Err(write_flow(
245 MESSAGE_ID_INVALID_DATATYPE,
246 format!("write: unsupported datatype '{text}'"),
247 ))
248 }
249 };
250 Ok(dtype)
251}
252
253fn prepare_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
254 match datatype {
255 DataType::Char => char_payload(data),
256 DataType::String => string_payload(data),
257 _ => numeric_payload(data, datatype, order),
258 }
259}
260
261fn numeric_payload(data: &Value, datatype: DataType, order: ByteOrder) -> BuiltinResult<Payload> {
262 let values = flatten_numeric(data)?;
263 let mut bytes = Vec::with_capacity(values.len() * datatype.element_size());
264 for value in values.iter().copied() {
265 match datatype {
266 DataType::UInt8 => bytes.push(cast_to_u8(value)),
267 DataType::Int8 => bytes.push(cast_to_i8(value) as u8),
268 DataType::UInt16 => extend_u16(&mut bytes, cast_to_u16(value), order),
269 DataType::Int16 => extend_i16(&mut bytes, cast_to_i16(value), order),
270 DataType::UInt32 => extend_u32(&mut bytes, cast_to_u32(value), order),
271 DataType::Int32 => extend_i32(&mut bytes, cast_to_i32(value), order),
272 DataType::UInt64 => extend_u64(&mut bytes, cast_to_u64(value), order),
273 DataType::Int64 => extend_i64(&mut bytes, cast_to_i64(value), order),
274 DataType::Single => extend_f32(&mut bytes, cast_to_f32(value), order),
275 DataType::Double => extend_f64(&mut bytes, value, order),
276 DataType::Char | DataType::String => unreachable!(),
277 }
278 }
279 Ok(Payload {
280 bytes,
281 elements: values.len(),
282 })
283}
284
285fn char_payload(data: &Value) -> BuiltinResult<Payload> {
286 let bytes = match data {
287 Value::CharArray(ca) => ca.data.iter().map(|&ch| (ch as u32 & 0xFF) as u8).collect(),
288 Value::String(text) => text.bytes().collect(),
289 Value::StringArray(sa) => {
290 if sa.data.len() != 1 {
291 return Err(write_flow(
292 MESSAGE_ID_INVALID_DATA,
293 "write: string array input must be scalar when using 'char'",
294 ));
295 }
296 sa.data[0].as_bytes().to_vec()
297 }
298 Value::Tensor(t) => t.data.iter().map(|&v| cast_to_u8(v)).collect::<Vec<u8>>(),
299 Value::Num(n) => vec![cast_to_u8(*n)],
300 Value::Int(iv) => vec![cast_to_u8(iv.to_f64())],
301 Value::Bool(b) => vec![if *b { 1 } else { 0 }],
302 Value::LogicalArray(la) => la
303 .data
304 .iter()
305 .map(|&b| if b != 0 { 1 } else { 0 })
306 .collect(),
307 _ => {
308 return Err(write_flow(
309 MESSAGE_ID_INVALID_DATA,
310 "write: unsupported input for 'char' datatype",
311 ))
312 }
313 };
314 Ok(Payload {
315 elements: bytes.len(),
316 bytes,
317 })
318}
319
320fn string_payload(data: &Value) -> BuiltinResult<Payload> {
321 match data {
322 Value::String(text) => Ok(Payload {
323 elements: 1,
324 bytes: text.as_bytes().to_vec(),
325 }),
326 Value::CharArray(ca) => {
327 let string: String = ca.data.iter().collect();
328 Ok(Payload {
329 elements: 1,
330 bytes: string.into_bytes(),
331 })
332 }
333 Value::StringArray(sa) => {
334 if sa.data.is_empty() {
335 return Ok(Payload {
336 elements: 0,
337 bytes: Vec::new(),
338 });
339 }
340 if sa.data.len() != 1 {
341 return Err(write_flow(
342 MESSAGE_ID_INVALID_DATA,
343 "write: string array input must be scalar when using 'string'",
344 ));
345 }
346 Ok(Payload {
347 elements: 1,
348 bytes: sa.data[0].as_bytes().to_vec(),
349 })
350 }
351 _ => Err(write_flow(
352 MESSAGE_ID_INVALID_DATA,
353 "write: expected text input when using 'string' datatype",
354 )),
355 }
356}
357
358fn flatten_numeric(value: &Value) -> BuiltinResult<Vec<f64>> {
359 match value {
360 Value::Tensor(t) => Ok(t.data.clone()),
361 Value::Num(n) => Ok(vec![*n]),
362 Value::Int(iv) => Ok(vec![iv.to_f64()]),
363 Value::Bool(b) => Ok(vec![if *b { 1.0 } else { 0.0 }]),
364 Value::LogicalArray(la) => Ok(la
365 .data
366 .iter()
367 .map(|&b| if b != 0 { 1.0 } else { 0.0 })
368 .collect()),
369 Value::CharArray(ca) => Ok(ca
370 .data
371 .iter()
372 .map(|&ch| (ch as u32 & 0xFF) as f64)
373 .collect()),
374 Value::String(text) => Ok(text.chars().map(|ch| (ch as u32) as f64).collect()),
375 Value::StringArray(sa) => {
376 if sa.data.len() != 1 {
377 return Err(write_flow(
378 MESSAGE_ID_INVALID_DATA,
379 "write: string array input must be scalar",
380 ));
381 }
382 Ok(sa.data[0].chars().map(|ch| (ch as u32) as f64).collect())
383 }
384 Value::Complex(_, _) | Value::ComplexTensor(_) => Err(write_flow(
385 MESSAGE_ID_INVALID_DATA,
386 "write: complex data is not supported",
387 )),
388 Value::Cell(_)
389 | Value::Struct(_)
390 | Value::Object(_)
391 | Value::HandleObject(_)
392 | Value::Listener(_)
393 | Value::FunctionHandle(_)
394 | Value::Closure(_)
395 | Value::ClassRef(_)
396 | Value::MException(_)
397 | Value::OutputList(_) => Err(write_flow(
398 MESSAGE_ID_INVALID_DATA,
399 "write: unsupported input type",
400 )),
401 Value::GpuTensor(_) => Err(write_flow(
402 MESSAGE_ID_INVALID_DATA,
403 "write: GPU tensor should have been gathered before encoding",
404 )),
405 }
406}
407
408fn cast_to_u8(value: f64) -> u8 {
409 let rounded = rounded_scalar(value);
410 if !rounded.is_finite() {
411 return if rounded.is_sign_negative() {
412 0
413 } else {
414 u8::MAX
415 };
416 }
417 if rounded < 0.0 {
418 0
419 } else if rounded > u8::MAX as f64 {
420 u8::MAX
421 } else {
422 rounded as u8
423 }
424}
425
426fn cast_to_i8(value: f64) -> i8 {
427 let rounded = rounded_scalar(value);
428 if !rounded.is_finite() {
429 return if rounded.is_sign_negative() {
430 i8::MIN
431 } else {
432 i8::MAX
433 };
434 }
435 if rounded < i8::MIN as f64 {
436 i8::MIN
437 } else if rounded > i8::MAX as f64 {
438 i8::MAX
439 } else {
440 rounded as i8
441 }
442}
443
444fn cast_to_u16(value: f64) -> u16 {
445 let rounded = rounded_scalar(value);
446 if !rounded.is_finite() {
447 return if rounded.is_sign_negative() {
448 0
449 } else {
450 u16::MAX
451 };
452 }
453 if rounded < 0.0 {
454 0
455 } else if rounded > u16::MAX as f64 {
456 u16::MAX
457 } else {
458 rounded as u16
459 }
460}
461
462fn cast_to_i16(value: f64) -> i16 {
463 let rounded = rounded_scalar(value);
464 if !rounded.is_finite() {
465 return if rounded.is_sign_negative() {
466 i16::MIN
467 } else {
468 i16::MAX
469 };
470 }
471 if rounded < i16::MIN as f64 {
472 i16::MIN
473 } else if rounded > i16::MAX as f64 {
474 i16::MAX
475 } else {
476 rounded as i16
477 }
478}
479
480fn cast_to_u32(value: f64) -> u32 {
481 let rounded = rounded_scalar(value);
482 if !rounded.is_finite() {
483 return if rounded.is_sign_negative() {
484 0
485 } else {
486 u32::MAX
487 };
488 }
489 if rounded < 0.0 {
490 0
491 } else if rounded > u32::MAX as f64 {
492 u32::MAX
493 } else {
494 rounded as u32
495 }
496}
497
498fn cast_to_i32(value: f64) -> i32 {
499 let rounded = rounded_scalar(value);
500 if !rounded.is_finite() {
501 return if rounded.is_sign_negative() {
502 i32::MIN
503 } else {
504 i32::MAX
505 };
506 }
507 if rounded < i32::MIN as f64 {
508 i32::MIN
509 } else if rounded > i32::MAX as f64 {
510 i32::MAX
511 } else {
512 rounded as i32
513 }
514}
515
516fn cast_to_u64(value: f64) -> u64 {
517 let rounded = rounded_scalar(value);
518 if !rounded.is_finite() {
519 return if rounded.is_sign_negative() {
520 0
521 } else {
522 u64::MAX
523 };
524 }
525 if rounded < 0.0 {
526 0
527 } else if rounded > u64::MAX as f64 {
528 u64::MAX
529 } else {
530 rounded as u64
531 }
532}
533
534fn cast_to_i64(value: f64) -> i64 {
535 let rounded = rounded_scalar(value);
536 if !rounded.is_finite() {
537 return if rounded.is_sign_negative() {
538 i64::MIN
539 } else {
540 i64::MAX
541 };
542 }
543 if rounded < i64::MIN as f64 {
544 i64::MIN
545 } else if rounded > i64::MAX as f64 {
546 i64::MAX
547 } else {
548 rounded as i64
549 }
550}
551
552fn cast_to_f32(value: f64) -> f32 {
553 value as f32
554}
555
556fn rounded_scalar(value: f64) -> f64 {
557 if value.is_nan() {
558 0.0
559 } else {
560 value.round()
561 }
562}
563
564fn extend_u16(buffer: &mut Vec<u8>, value: u16, order: ByteOrder) {
565 match order {
566 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
567 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
568 }
569}
570
571fn extend_i16(buffer: &mut Vec<u8>, value: i16, order: ByteOrder) {
572 match order {
573 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
574 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
575 }
576}
577
578fn extend_u32(buffer: &mut Vec<u8>, value: u32, order: ByteOrder) {
579 match order {
580 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
581 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
582 }
583}
584
585fn extend_i32(buffer: &mut Vec<u8>, value: i32, order: ByteOrder) {
586 match order {
587 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
588 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
589 }
590}
591
592fn extend_u64(buffer: &mut Vec<u8>, value: u64, order: ByteOrder) {
593 match order {
594 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
595 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
596 }
597}
598
599fn extend_i64(buffer: &mut Vec<u8>, value: i64, order: ByteOrder) {
600 match order {
601 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
602 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
603 }
604}
605
606fn extend_f32(buffer: &mut Vec<u8>, value: f32, order: ByteOrder) {
607 match order {
608 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
609 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
610 }
611}
612
613fn extend_f64(buffer: &mut Vec<u8>, value: f64, order: ByteOrder) {
614 match order {
615 ByteOrder::Little => buffer.extend_from_slice(&value.to_le_bytes()),
616 ByteOrder::Big => buffer.extend_from_slice(&value.to_be_bytes()),
617 }
618}
619
620fn parse_byte_order(text: &str) -> ByteOrder {
621 if text.eq_ignore_ascii_case("big-endian") || text.eq_ignore_ascii_case("big endian") {
622 ByteOrder::Big
623 } else {
624 ByteOrder::Little
625 }
626}
627
628fn scalar_string(value: &Value) -> BuiltinResult<String> {
629 match value {
630 Value::String(s) => Ok(s.clone()),
631 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
632 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
633 _ => Err(write_flow(
634 MESSAGE_ID_INVALID_DATATYPE,
635 "write: datatype argument must be a string scalar or character row vector",
636 )),
637 }
638}
639
640fn extract_client_id(struct_value: &StructValue) -> BuiltinResult<u64> {
641 let id_value = struct_value
642 .fields
643 .get(CLIENT_HANDLE_FIELD)
644 .ok_or_else(|| {
645 write_flow(
646 MESSAGE_ID_INVALID_CLIENT,
647 "write: tcpclient struct is missing internal handle",
648 )
649 })?;
650 match id_value {
651 Value::Int(IntValue::U64(id)) => Ok(*id),
652 Value::Int(iv) => Ok(iv.to_i64() as u64),
653 _ => Err(write_flow(
654 MESSAGE_ID_INVALID_CLIENT,
655 "write: tcpclient struct has invalid handle field",
656 )),
657 }
658}
659
660enum WriteError {
661 Timeout,
662 ConnectionClosed,
663 Io(io::Error),
664}
665
666fn write_bytes(stream: &mut TcpStream, bytes: &[u8]) -> Result<(), WriteError> {
667 let mut offset = 0usize;
668 while offset < bytes.len() {
669 match stream.write(&bytes[offset..]) {
670 Ok(0) => return Err(WriteError::ConnectionClosed),
671 Ok(n) => offset += n,
672 Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
673 Err(err) if is_timeout(&err) => return Err(WriteError::Timeout),
674 Err(err) if is_connection_closed_error(&err) => {
675 return Err(WriteError::ConnectionClosed)
676 }
677 Err(err) => return Err(WriteError::Io(err)),
678 }
679 }
680 Ok(())
681}
682
683fn is_timeout(err: &io::Error) -> bool {
684 matches!(
685 err.kind(),
686 io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock
687 )
688}
689
690fn is_connection_closed_error(err: &io::Error) -> bool {
691 matches!(
692 err.kind(),
693 io::ErrorKind::BrokenPipe
694 | io::ErrorKind::ConnectionReset
695 | io::ErrorKind::ConnectionAborted
696 | io::ErrorKind::NotConnected
697 | io::ErrorKind::UnexpectedEof
698 )
699}
700
701#[cfg(test)]
702pub(crate) mod tests {
703 use super::*;
704 use crate::builtins::io::net::accept::{
705 configure_stream, insert_client, remove_client_for_test,
706 };
707 use runmat_builtins::{CharArray, IntValue, StructValue, Tensor};
708 use std::io::Read;
709 use std::net::{TcpListener, TcpStream};
710 use std::sync::{Arc, Barrier};
711 use std::thread;
712
713 fn make_client(stream: TcpStream, timeout: f64, byte_order: &str) -> Value {
714 let peer_addr = stream.peer_addr().expect("peer addr");
715 configure_stream(&stream, timeout).expect("configure stream");
716 let client_id = insert_client(stream, 0, peer_addr, timeout, byte_order.to_string());
717 let mut st = StructValue::new();
718 st.fields.insert(
719 CLIENT_HANDLE_FIELD.to_string(),
720 Value::Int(IntValue::U64(client_id)),
721 );
722 Value::Struct(st)
723 }
724
725 fn client_id(client: &Value) -> u64 {
726 match client {
727 Value::Struct(st) => match st.fields.get(CLIENT_HANDLE_FIELD) {
728 Some(Value::Int(IntValue::U64(id))) => *id,
729 Some(Value::Int(iv)) => iv.to_i64() as u64,
730 other => panic!("unexpected id field {other:?}"),
731 },
732 other => panic!("expected struct, got {other:?}"),
733 }
734 }
735
736 fn assert_error_identifier(err: RuntimeError, expected: &str) {
737 assert_eq!(err.identifier(), Some(expected));
738 }
739
740 fn run_write(client: Value, data: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
741 futures::executor::block_on(write_builtin(client, data, rest))
742 }
743
744 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
745 #[test]
746 fn write_default_uint8_sends_bytes() {
747 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
748 let port = listener.local_addr().unwrap().port();
749 let handle = thread::spawn(move || {
750 let (mut stream, _) = listener.accept().expect("accept");
751 let mut received = Vec::new();
752 stream.read_to_end(&mut received).unwrap_or_default();
753 received
754 });
755
756 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
757 let client = make_client(stream, 1.0, "little-endian");
758 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
759 let result = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect("write");
760 match result {
761 Value::Num(count) => assert_eq!(count, 4.0),
762 other => panic!("expected numeric result, got {other:?}"),
763 }
764 remove_client_for_test(client_id(&client));
765 let received = handle.join().expect("join");
766 assert_eq!(received, vec![1, 2, 3, 4]);
767 }
768
769 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
770 #[test]
771 fn write_double_big_endian_encodes_correctly() {
772 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
773 let port = listener.local_addr().unwrap().port();
774 let handle = thread::spawn(move || {
775 let (mut stream, _) = listener.accept().expect("accept");
776 let mut buf = [0u8; 24];
777 stream.read_exact(&mut buf).expect("read");
778 buf
779 });
780
781 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
782 let client = make_client(stream, 1.0, "big-endian");
783 let tensor = Tensor::new(vec![1.5, 2.5, 3.5], vec![1, 3]).unwrap();
784 let result = run_write(
785 client.clone(),
786 Value::Tensor(tensor),
787 vec![Value::from("double")],
788 )
789 .expect("write");
790 match result {
791 Value::Num(count) => assert_eq!(count, 3.0),
792 other => panic!("expected numeric count, got {other:?}"),
793 }
794 remove_client_for_test(client_id(&client));
795
796 let received = handle.join().expect("join");
797 let mut expected = Vec::new();
798 extend_f64(&mut expected, 1.5, ByteOrder::Big);
799 extend_f64(&mut expected, 2.5, ByteOrder::Big);
800 extend_f64(&mut expected, 3.5, ByteOrder::Big);
801 assert_eq!(received.to_vec(), expected);
802 }
803
804 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
805 #[test]
806 fn write_char_payload_encodes_ascii() {
807 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
808 let port = listener.local_addr().unwrap().port();
809 let handle = thread::spawn(move || {
810 let (mut stream, _) = listener.accept().expect("accept");
811 let mut buf = Vec::new();
812 stream.read_to_end(&mut buf).unwrap_or_default();
813 buf
814 });
815
816 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
817 let client = make_client(stream, 1.0, "little-endian");
818 let chars = CharArray::new("RunMat".chars().collect(), 1, 6).unwrap();
819 let result = run_write(
820 client.clone(),
821 Value::CharArray(chars),
822 vec![Value::from("char")],
823 )
824 .expect("write");
825 match result {
826 Value::Num(count) => assert_eq!(count, 6.0),
827 other => panic!("expected numeric count, got {other:?}"),
828 }
829 remove_client_for_test(client_id(&client));
830 let received = handle.join().expect("join");
831 assert_eq!(received, b"RunMat");
832 }
833
834 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
835 #[test]
836 fn write_errors_when_client_disconnected() {
837 let listener = TcpListener::bind("127.0.0.1:0").expect("listener");
838 let port = listener.local_addr().unwrap().port();
839 let barrier = Arc::new(Barrier::new(2));
840 let thread_barrier = barrier.clone();
841 let handle = thread::spawn(move || {
842 let (stream, _) = listener.accept().expect("accept");
843 thread_barrier.wait();
844 drop(stream);
845 });
846
847 let stream = TcpStream::connect(("127.0.0.1", port)).expect("connect");
848 let client = make_client(stream, 1.0, "little-endian");
849 let id = client_id(&client);
850 if let Some(handle_ref) = client_handle(id) {
851 if let Ok(mut guard) = handle_ref.lock() {
852 guard.connected = false;
853 }
854 }
855
856 let tensor = Tensor::new(vec![1.0, 2.0, 3.0], vec![1, 3]).unwrap();
857 let err = run_write(client.clone(), Value::Tensor(tensor), Vec::new()).expect_err("write");
858 assert_error_identifier(err, MESSAGE_ID_NOT_CONNECTED);
859
860 remove_client_for_test(id);
861 barrier.wait();
862 handle.join().expect("join");
863 }
864}