1use std::collections::VecDeque;
4use std::time::Duration;
5
6use base64::engine::general_purpose::STANDARD as BASE64_ENGINE;
7use base64::Engine;
8use runmat_builtins::{CellArray, CharArray, StructValue, Tensor, Value};
9use runmat_macros::runtime_builtin;
10use url::Url;
11
12use super::transport::{
13 self, decode_body_as_text, header_value, HttpMethod, HttpRequest, HEADER_CONTENT_TYPE,
14};
15use crate::builtins::common::spec::{
16 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
17 ReductionNaN, ResidencyPolicy, ShapeRequirements,
18};
19use crate::builtins::io::json::jsondecode::decode_json_text;
20use crate::call_builtin_async;
21use crate::{build_runtime_error, gather_if_needed_async, BuiltinResult, RuntimeError};
22
23const DEFAULT_TIMEOUT_SECONDS: f64 = 60.0;
24const DEFAULT_USER_AGENT: &str = "RunMat webwrite/0.0";
25
26#[allow(clippy::too_many_lines)]
27#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::http::webwrite")]
28pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
29 name: "webwrite",
30 op_kind: GpuOpKind::Custom("http-write"),
31 supported_precisions: &[],
32 broadcast: BroadcastSemantics::None,
33 provider_hooks: &[],
34 constant_strategy: ConstantStrategy::InlineLiteral,
35 residency: ResidencyPolicy::GatherImmediately,
36 nan_mode: ReductionNaN::Include,
37 two_pass_threshold: None,
38 workgroup_size: None,
39 accepts_nan_mode: false,
40 notes: "HTTP uploads run on the CPU and gather gpuArray inputs before serialisation.",
41};
42
43fn webwrite_error(message: impl Into<String>) -> RuntimeError {
44 build_runtime_error(message)
45 .with_builtin("webwrite")
46 .build()
47}
48
49fn remap_webwrite_flow<F>(err: RuntimeError, message: F) -> RuntimeError
50where
51 F: FnOnce(&RuntimeError) -> String,
52{
53 build_runtime_error(message(&err))
54 .with_builtin("webwrite")
55 .with_source(err)
56 .build()
57}
58
59fn webwrite_flow_with_context(err: RuntimeError) -> RuntimeError {
60 remap_webwrite_flow(err, |err| format!("webwrite: {}", err.message()))
61}
62
63#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::http::webwrite")]
64pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
65 name: "webwrite",
66 shape: ShapeRequirements::Any,
67 constant_strategy: ConstantStrategy::InlineLiteral,
68 elementwise: None,
69 reduction: None,
70 emits_nan: false,
71 notes: "webwrite performs network I/O and terminates fusion graphs.",
72};
73
74#[runtime_builtin(
75 name = "webwrite",
76 category = "io/http",
77 summary = "Send data to web services using HTTP POST/PUT requests and return the response.",
78 keywords = "webwrite,http post,rest client,json upload,form post",
79 accel = "sink",
80 type_resolver(crate::builtins::io::type_resolvers::webwrite_type),
81 builtin_path = "crate::builtins::io::http::webwrite"
82)]
83async fn webwrite_builtin(url: Value, rest: Vec<Value>) -> crate::BuiltinResult<Value> {
84 let gathered_url = gather_if_needed_async(&url)
85 .await
86 .map_err(webwrite_flow_with_context)?;
87 let url_text = expect_string_scalar(
88 &gathered_url,
89 "webwrite: URL must be a character vector or string scalar",
90 )?;
91 if url_text.trim().is_empty() {
92 return Err(webwrite_error("webwrite: URL must not be empty"));
93 }
94 if rest.is_empty() {
95 return Err(webwrite_error("webwrite: missing data argument"));
96 }
97
98 let mut gathered = Vec::with_capacity(rest.len());
99 for value in rest {
100 gathered.push(
101 gather_if_needed_async(&value)
102 .await
103 .map_err(webwrite_flow_with_context)?,
104 );
105 }
106 let mut queue: VecDeque<Value> = VecDeque::from(gathered);
107 let data_value = queue
108 .pop_front()
109 .ok_or_else(|| webwrite_error("webwrite: missing data argument"))?;
110
111 let (options, query_params) = parse_arguments(queue)?;
112 let body = prepare_request_body(data_value, &options).await?;
113 execute_request(&url_text, options, &query_params, body)
114}
115
116fn parse_arguments(
117 mut queue: VecDeque<Value>,
118) -> BuiltinResult<(WebWriteOptions, Vec<(String, String)>)> {
119 let mut options = WebWriteOptions::default();
120 let mut query_params = Vec::new();
121
122 if matches!(queue.front(), Some(Value::Struct(_))) {
123 if let Some(Value::Struct(struct_value)) = queue.pop_front() {
124 process_struct_fields(&struct_value, &mut options, &mut query_params)?;
125 }
126 } else if matches!(queue.front(), Some(Value::Cell(_))) {
127 if let Some(Value::Cell(cell)) = queue.pop_front() {
128 append_query_from_cell(&cell, &mut query_params)?;
129 }
130 }
131
132 while let Some(name_value) = queue.pop_front() {
133 let name = expect_string_scalar(
134 &name_value,
135 "webwrite: parameter names must be character vectors or strings",
136 )?;
137 let value = queue
138 .pop_front()
139 .ok_or_else(|| webwrite_error("webwrite: missing value for name-value argument"))?;
140 process_name_value_pair(&name, &value, &mut options, &mut query_params)?;
141 }
142
143 Ok((options, query_params))
144}
145
146fn process_struct_fields(
147 struct_value: &StructValue,
148 options: &mut WebWriteOptions,
149 query_params: &mut Vec<(String, String)>,
150) -> BuiltinResult<()> {
151 for (key, value) in &struct_value.fields {
152 process_name_value_pair(key, value, options, query_params)?;
153 }
154 Ok(())
155}
156
157fn process_name_value_pair(
158 name: &str,
159 value: &Value,
160 options: &mut WebWriteOptions,
161 query_params: &mut Vec<(String, String)>,
162) -> BuiltinResult<()> {
163 let lower = name.to_ascii_lowercase();
164 match lower.as_str() {
165 "contenttype" => {
166 let ct = parse_content_type(value)?;
167 options.content_type = ct;
168 Ok(())
169 }
170 "mediatype" => {
171 let media = expect_string_scalar(
172 value,
173 "webwrite: MediaType must be a character vector or string scalar",
174 )?;
175 let trimmed = media.trim();
176 if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("auto") {
177 options.media_type = None;
178 options.request_format = RequestFormat::Auto;
179 options.request_format_explicit = false;
180 } else {
181 options.media_type = Some(media.clone());
182 options.request_format = infer_request_format(&media);
183 options.request_format_explicit = true;
184 }
185 Ok(())
186 }
187 "timeout" => {
188 options.timeout = parse_timeout(value)?;
189 Ok(())
190 }
191 "headerfields" => {
192 let headers = parse_header_fields(value)?;
193 options.headers.extend(headers);
194 Ok(())
195 }
196 "useragent" => {
197 options.user_agent = Some(expect_string_scalar(
198 value,
199 "webwrite: UserAgent must be a character vector or string scalar",
200 )?);
201 Ok(())
202 }
203 "username" => {
204 options.username = Some(expect_string_scalar(
205 value,
206 "webwrite: Username must be a character vector or string scalar",
207 )?);
208 Ok(())
209 }
210 "password" => {
211 options.password = Some(expect_string_scalar(
212 value,
213 "webwrite: Password must be a character vector or string scalar",
214 )?);
215 Ok(())
216 }
217 "requestmethod" => {
218 options.method = parse_request_method(value)?;
219 Ok(())
220 }
221 "queryparameters" => append_query_from_value(value, query_params),
222 _ => {
223 let param_value = value_to_query_string(value, name)?;
224 query_params.push((name.to_string(), param_value));
225 Ok(())
226 }
227 }
228}
229
230fn execute_request(
231 url_text: &str,
232 options: WebWriteOptions,
233 query_params: &[(String, String)],
234 body: PreparedBody,
235) -> BuiltinResult<Value> {
236 let username_present = options
237 .username
238 .as_ref()
239 .map(|s| !s.is_empty())
240 .unwrap_or(false);
241 let password_present = options
242 .password
243 .as_ref()
244 .map(|s| !s.is_empty())
245 .unwrap_or(false);
246 if password_present && !username_present {
247 return Err(webwrite_error(
248 "webwrite: Password requires a Username option",
249 ));
250 }
251
252 let mut url = Url::parse(url_text).map_err(|err| {
253 build_runtime_error(format!("webwrite: invalid URL '{url_text}': {err}"))
254 .with_builtin("webwrite")
255 .with_source(err)
256 .build()
257 })?;
258 if !query_params.is_empty() {
259 {
260 let mut pairs = url.query_pairs_mut();
261 for (name, value) in query_params {
262 pairs.append_pair(name, value);
263 }
264 }
265 }
266 let user_agent = options
267 .user_agent
268 .as_deref()
269 .filter(|ua| !ua.trim().is_empty())
270 .unwrap_or(DEFAULT_USER_AGENT)
271 .to_string();
272
273 let mut headers = options.headers.clone();
274 let has_auth_header = headers
275 .iter()
276 .any(|(name, _)| name.eq_ignore_ascii_case("authorization"));
277 if !has_auth_header {
278 if let Some(username) = options.username.as_ref().filter(|s| !s.is_empty()) {
279 let password = options.password.clone().unwrap_or_default();
280 let token = BASE64_ENGINE.encode(format!("{username}:{password}"));
281 headers.push(("Authorization".to_string(), format!("Basic {token}")));
282 }
283 }
284
285 let has_ct_header = headers
286 .iter()
287 .any(|(name, _)| name.eq_ignore_ascii_case("content-type"));
288 if !has_ct_header {
289 if let Some(ct) = &body.content_type {
290 headers.push(("Content-Type".to_string(), ct.clone()));
291 }
292 }
293
294 let request = HttpRequest {
295 url,
296 method: options.method,
297 headers,
298 body: Some(body.bytes),
299 timeout: options.timeout,
300 user_agent,
301 };
302
303 let response = transport::send_request(&request).map_err(|err| {
304 build_runtime_error(err.message_with_prefix("webwrite"))
305 .with_builtin("webwrite")
306 .with_source(err)
307 .build()
308 })?;
309
310 let header_content_type =
311 header_value(&response.headers, HEADER_CONTENT_TYPE).map(|value| value.to_string());
312 let resolved = options.resolve_content_type(header_content_type.as_deref());
313
314 match resolved {
315 ResolvedContentType::Json => {
316 let body_text = decode_body_as_text(&response.body, header_content_type.as_deref());
317 let value = decode_json_text(&body_text).map_err(map_json_error)?;
318 Ok(value)
319 }
320 ResolvedContentType::Text => {
321 let body_text = decode_body_as_text(&response.body, header_content_type.as_deref());
322 Ok(Value::CharArray(CharArray::new_row(&body_text)))
323 }
324 ResolvedContentType::Binary => {
325 let data: Vec<f64> = response.body.iter().map(|b| f64::from(*b)).collect();
326 let cols = data.len();
327 let tensor = Tensor::new(data, vec![1, cols])
328 .map_err(|err| webwrite_error(format!("webwrite: {err}")))?;
329 Ok(Value::Tensor(tensor))
330 }
331 }
332}
333
334async fn prepare_request_body(
335 data: Value,
336 options: &WebWriteOptions,
337) -> BuiltinResult<PreparedBody> {
338 let format = match options.request_format {
339 RequestFormat::Auto => guess_request_format(&data),
340 set => set,
341 };
342 let content_type = options
343 .media_type
344 .clone()
345 .or_else(|| default_content_type_for(format));
346 let bytes = match format {
347 RequestFormat::Form => encode_form_payload(&data)?,
348 RequestFormat::Json => encode_json_payload(&data).await?,
349 RequestFormat::Text => encode_text_payload(&data)?,
350 RequestFormat::Binary => encode_binary_payload(&data)?,
351 RequestFormat::Auto => encode_json_payload(&data).await?,
352 };
353 Ok(PreparedBody {
354 bytes,
355 content_type,
356 })
357}
358
359fn encode_form_payload(value: &Value) -> BuiltinResult<Vec<u8>> {
360 let mut pairs = Vec::new();
361 match value {
362 Value::Struct(struct_value) => {
363 for (key, val) in &struct_value.fields {
364 let text = value_to_query_string(val, key)?;
365 pairs.push((key.clone(), text));
366 }
367 }
368 Value::Cell(cell) => {
369 append_query_from_cell(cell, &mut pairs)?;
370 }
371 Value::CharArray(_)
372 | Value::String(_)
373 | Value::Num(_)
374 | Value::Int(_)
375 | Value::Tensor(_) => {
376 let text = scalar_to_string(value)?;
378 pairs.push(("data".to_string(), text));
379 }
380 _ => {
381 return Err(webwrite_error(
382 "webwrite: form payloads must be structs, two-column cell arrays, or scalars",
383 ))
384 }
385 }
386
387 let encoded = encode_form_pairs(&pairs);
388 Ok(encoded.into_bytes())
389}
390
391fn encode_form_pairs(pairs: &[(String, String)]) -> String {
392 let mut result = String::new();
393 for (idx, (name, value)) in pairs.iter().enumerate() {
394 if idx > 0 {
395 result.push('&');
396 }
397 result.push_str(&url_encode_component(name));
398 result.push('=');
399 result.push_str(&url_encode_component(value));
400 }
401 result
402}
403
404fn url_encode_component(input: &str) -> String {
405 let mut out = String::new();
406 for byte in input.bytes() {
407 match byte {
408 b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'*' => {
409 out.push(byte as char);
410 }
411 b' ' => out.push('+'),
412 _ => {
413 out.push('%');
414 out.push(hex_digit(byte >> 4));
415 out.push(hex_digit(byte & 0xF));
416 }
417 }
418 }
419 out
420}
421
422fn hex_digit(nibble: u8) -> char {
423 match nibble {
424 0..=9 => (b'0' + nibble) as char,
425 10..=15 => (b'A' + (nibble - 10)) as char,
426 _ => unreachable!(),
427 }
428}
429
430async fn encode_json_payload(value: &Value) -> BuiltinResult<Vec<u8>> {
431 let encoded = call_builtin_async("jsonencode", std::slice::from_ref(value))
432 .await
433 .map_err(|flow| remap_webwrite_flow(flow, |err| format!("webwrite: {}", err.message())))?;
434 let text = expect_string_scalar(
435 &encoded,
436 "webwrite: jsonencode returned unexpected value; expected text scalar",
437 )?;
438 Ok(text.into_bytes())
439}
440
441fn encode_text_payload(value: &Value) -> BuiltinResult<Vec<u8>> {
442 let text = scalar_to_string(value)?;
443 Ok(text.into_bytes())
444}
445
446fn encode_binary_payload(value: &Value) -> BuiltinResult<Vec<u8>> {
447 match value {
448 Value::Tensor(tensor) => tensor_f64_to_bytes(tensor),
449 Value::Num(n) => Ok(vec![float_to_byte(*n)?]),
450 Value::Int(i) => Ok(vec![int_to_byte(i.to_i64())?]),
451 Value::Bool(b) => Ok(vec![if *b { 1 } else { 0 }]),
452 Value::LogicalArray(array) => Ok(array.data.clone()),
453 Value::CharArray(ca) => {
454 let mut bytes = Vec::with_capacity(ca.data.len());
455 for ch in &ca.data {
456 let code = *ch as u32;
457 if code > 0xFF {
458 return Err(webwrite_error(
459 "webwrite: character codes exceed 255 for binary payload",
460 ));
461 }
462 bytes.push(code as u8);
463 }
464 Ok(bytes)
465 }
466 Value::String(s) => Ok(s.as_bytes().to_vec()),
467 Value::StringArray(sa) => {
468 if sa.data.len() == 1 {
469 Ok(sa.data[0].as_bytes().to_vec())
470 } else {
471 Err(webwrite_error(
472 "webwrite: binary payload string arrays must be scalar",
473 ))
474 }
475 }
476 _ => Err(webwrite_error(
477 "webwrite: unsupported value for binary payload",
478 )),
479 }
480}
481
482fn tensor_f64_to_bytes(tensor: &Tensor) -> BuiltinResult<Vec<u8>> {
483 let mut bytes = Vec::with_capacity(tensor.data.len());
484 for value in &tensor.data {
485 bytes.push(float_to_byte(*value)?);
486 }
487 Ok(bytes)
488}
489
490fn float_to_byte(value: f64) -> BuiltinResult<u8> {
491 if !value.is_finite() {
492 return Err(webwrite_error(
493 "webwrite: binary payload values must be finite",
494 ));
495 }
496 let rounded = value.round();
497 if (value - rounded).abs() > 1e-9 {
498 return Err(webwrite_error(
499 "webwrite: binary payload values must be integers in 0..255",
500 ));
501 }
502 let int_val = rounded as i64;
503 int_to_byte(int_val)
504}
505
506fn int_to_byte(value: i64) -> BuiltinResult<u8> {
507 if !(0..=255).contains(&value) {
508 return Err(webwrite_error(
509 "webwrite: binary payload values must be in the range 0..255",
510 ));
511 }
512
513 Ok(value as u8)
514}
515
516fn append_query_from_value(
517 value: &Value,
518 query_params: &mut Vec<(String, String)>,
519) -> BuiltinResult<()> {
520 match value {
521 Value::Struct(struct_value) => {
522 for (key, val) in &struct_value.fields {
523 let text = value_to_query_string(val, key)?;
524 query_params.push((key.clone(), text));
525 }
526 Ok(())
527 }
528 Value::Cell(cell) => append_query_from_cell(cell, query_params),
529 _ => Err(webwrite_error(
530 "webwrite: QueryParameters must be a struct or cell array",
531 )),
532 }
533}
534
535fn append_query_from_cell(
536 cell: &CellArray,
537 query_params: &mut Vec<(String, String)>,
538) -> BuiltinResult<()> {
539 if cell.cols != 2 {
540 return Err(webwrite_error(
541 "webwrite: cell array of query parameters must have two columns",
542 ));
543 }
544 for row in 0..cell.rows {
545 let name_value = cell
546 .get(row, 0)
547 .map_err(|err| webwrite_error(format!("webwrite: {err}")))?;
548 let value_value = cell
549 .get(row, 1)
550 .map_err(|err| webwrite_error(format!("webwrite: {err}")))?;
551 let name = expect_string_scalar(
552 &name_value,
553 "webwrite: query parameter names must be text scalars",
554 )?;
555 let text = value_to_query_string(&value_value, &name)?;
556 query_params.push((name, text));
557 }
558 Ok(())
559}
560
561fn parse_content_type(value: &Value) -> BuiltinResult<ContentTypeHint> {
562 let text = expect_string_scalar(
563 value,
564 "webwrite: ContentType must be a character vector or string scalar",
565 )?;
566 let lower = text.trim().to_ascii_lowercase();
567 match lower.as_str() {
568 "auto" => Ok(ContentTypeHint::Auto),
569 "json" => Ok(ContentTypeHint::Json),
570 "text" => Ok(ContentTypeHint::Text),
571 "binary" => Ok(ContentTypeHint::Binary),
572 _ => Err(webwrite_error(
573 "webwrite: ContentType must be 'auto', 'json', 'text', or 'binary'",
574 )),
575 }
576}
577
578fn parse_timeout(value: &Value) -> BuiltinResult<Duration> {
579 let seconds = numeric_scalar(
580 value,
581 "webwrite: Timeout must be a finite, non-negative scalar numeric value",
582 )?;
583 if !seconds.is_finite() || seconds < 0.0 {
584 return Err(webwrite_error(
585 "webwrite: Timeout must be a finite, non-negative scalar numeric value",
586 ));
587 }
588 Ok(Duration::from_secs_f64(seconds))
589}
590
591fn parse_request_method(value: &Value) -> BuiltinResult<HttpMethod> {
592 let text = expect_string_scalar(
593 value,
594 "webwrite: RequestMethod must be a character vector or string scalar",
595 )?;
596 match text.trim().to_ascii_lowercase().as_str() {
597 "auto" => Ok(HttpMethod::Post),
598 "post" => Ok(HttpMethod::Post),
599 "put" => Ok(HttpMethod::Put),
600 "patch" => Ok(HttpMethod::Patch),
601 "delete" => Ok(HttpMethod::Delete),
602 other => Err(webwrite_error(format!(
603 "webwrite: unsupported RequestMethod '{}'; expected auto, post, put, patch, or delete",
604 other
605 ))),
606 }
607}
608
609fn parse_header_fields(value: &Value) -> BuiltinResult<Vec<(String, String)>> {
610 match value {
611 Value::Struct(struct_value) => {
612 let mut headers = Vec::with_capacity(struct_value.fields.len());
613 for (key, val) in &struct_value.fields {
614 let header_value = expect_string_scalar(
615 val,
616 "webwrite: header values must be character vectors or string scalars",
617 )?;
618 headers.push((key.clone(), header_value));
619 }
620 Ok(headers)
621 }
622 Value::Cell(cell) => {
623 if cell.cols != 2 {
624 return Err(webwrite_error(
625 "webwrite: HeaderFields cell array must have exactly two columns",
626 ));
627 }
628 let mut headers = Vec::with_capacity(cell.rows);
629 for row in 0..cell.rows {
630 let name = cell
631 .get(row, 0)
632 .map_err(|err| webwrite_error(format!("webwrite: {err}")))?;
633 let value = cell
634 .get(row, 1)
635 .map_err(|err| webwrite_error(format!("webwrite: {err}")))?;
636 let header_name = expect_string_scalar(
637 &name,
638 "webwrite: header names must be character vectors or string scalars",
639 )?;
640 if header_name.trim().is_empty() {
641 return Err(webwrite_error("webwrite: header names must not be empty"));
642 }
643 let header_value = expect_string_scalar(
644 &value,
645 "webwrite: header values must be character vectors or string scalars",
646 )?;
647 headers.push((header_name, header_value));
648 }
649 Ok(headers)
650 }
651 _ => Err(webwrite_error(
652 "webwrite: HeaderFields must be a struct or two-column cell array",
653 )),
654 }
655}
656
657fn map_json_error(err: RuntimeError) -> RuntimeError {
658 let message = if let Some(rest) = err.message().strip_prefix("jsondecode: ") {
659 format!("webwrite: failed to parse JSON response ({rest})")
660 } else {
661 format!(
662 "webwrite: failed to parse JSON response ({})",
663 err.message()
664 )
665 };
666 build_runtime_error(message)
667 .with_builtin("webwrite")
668 .with_source(err)
669 .build()
670}
671
672fn numeric_scalar(value: &Value, context: &str) -> BuiltinResult<f64> {
673 match value {
674 Value::Num(n) => Ok(*n),
675 Value::Int(i) => Ok(i.to_f64()),
676 Value::Tensor(tensor) => {
677 if tensor.data.len() == 1 {
678 Ok(tensor.data[0])
679 } else {
680 Err(webwrite_error(context))
681 }
682 }
683 _ => Err(webwrite_error(context)),
684 }
685}
686
687fn scalar_to_string(value: &Value) -> BuiltinResult<String> {
688 match value {
689 Value::String(s) => Ok(s.clone()),
690 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
691 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
692 Value::Num(n) => Ok(format!("{}", n)),
693 Value::Int(i) => Ok(i.to_i64().to_string()),
694 Value::Bool(b) => Ok(if *b { "true".into() } else { "false".into() }),
695 Value::Tensor(tensor) => {
696 if tensor.data.len() == 1 {
697 Ok(format!("{}", tensor.data[0]))
698 } else {
699 Err(webwrite_error(
700 "webwrite: expected scalar value for text payload",
701 ))
702 }
703 }
704 Value::LogicalArray(array) => {
705 if array.len() == 1 {
706 Ok(if array.data[0] != 0 {
707 "true".into()
708 } else {
709 "false".into()
710 })
711 } else {
712 Err(webwrite_error(
713 "webwrite: expected scalar value for text payload",
714 ))
715 }
716 }
717 _ => Err(webwrite_error(
718 "webwrite: unsupported value type for text payload",
719 )),
720 }
721}
722
723fn expect_string_scalar(value: &Value, context: &str) -> BuiltinResult<String> {
724 match value {
725 Value::String(s) => Ok(s.clone()),
726 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
727 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
728 _ => Err(webwrite_error(context)),
729 }
730}
731
732fn value_to_query_string(value: &Value, name: &str) -> BuiltinResult<String> {
733 match value {
734 Value::String(s) => Ok(s.clone()),
735 Value::CharArray(ca) if ca.rows == 1 => Ok(ca.data.iter().collect()),
736 Value::StringArray(sa) if sa.data.len() == 1 => Ok(sa.data[0].clone()),
737 Value::Num(n) => Ok(format!("{}", n)),
738 Value::Int(i) => Ok(i.to_i64().to_string()),
739 Value::Bool(b) => Ok(if *b { "true".into() } else { "false".into() }),
740 Value::Tensor(tensor) => {
741 if tensor.data.len() == 1 {
742 Ok(format!("{}", tensor.data[0]))
743 } else {
744 Err(webwrite_error(format!(
745 "webwrite: query parameter '{}' must be scalar",
746 name
747 )))
748 }
749 }
750 Value::LogicalArray(array) => {
751 if array.len() == 1 {
752 Ok(if array.data[0] != 0 {
753 "true".into()
754 } else {
755 "false".into()
756 })
757 } else {
758 Err(webwrite_error(format!(
759 "webwrite: query parameter '{}' must be scalar",
760 name
761 )))
762 }
763 }
764 _ => Err(webwrite_error(format!(
765 "webwrite: unsupported value type for query parameter '{}'",
766 name
767 ))),
768 }
769}
770
771fn guess_request_format(value: &Value) -> RequestFormat {
772 match value {
773 Value::Struct(_) => RequestFormat::Form,
774 Value::Cell(cell) if cell.cols == 2 => RequestFormat::Form,
775 Value::CharArray(ca) if ca.rows == 1 => RequestFormat::Text,
776 Value::String(_) => RequestFormat::Text,
777 Value::StringArray(sa) => {
778 if sa.data.len() == 1 {
779 RequestFormat::Text
780 } else {
781 RequestFormat::Json
782 }
783 }
784 Value::Tensor(_) | Value::LogicalArray(_) => RequestFormat::Json,
785 Value::Num(_) | Value::Int(_) | Value::Bool(_) => RequestFormat::Json,
786 _ => RequestFormat::Json,
787 }
788}
789
790fn infer_request_format(media_type: &str) -> RequestFormat {
791 let lower = media_type.trim().to_ascii_lowercase();
792 if lower.contains("json") {
793 RequestFormat::Json
794 } else if lower.starts_with("text/") || lower.contains("xml") {
795 RequestFormat::Text
796 } else if lower == "application/x-www-form-urlencoded" {
797 RequestFormat::Form
798 } else {
799 RequestFormat::Binary
800 }
801}
802
803fn default_content_type_for(format: RequestFormat) -> Option<String> {
804 match format {
805 RequestFormat::Form => Some("application/x-www-form-urlencoded".to_string()),
806 RequestFormat::Json => Some("application/json".to_string()),
807 RequestFormat::Text => Some("text/plain; charset=utf-8".to_string()),
808 RequestFormat::Binary => Some("application/octet-stream".to_string()),
809 RequestFormat::Auto => None,
810 }
811}
812
813#[derive(Clone, Debug)]
814struct PreparedBody {
815 bytes: Vec<u8>,
816 content_type: Option<String>,
817}
818
819#[derive(Clone, Copy, Debug)]
820enum ContentTypeHint {
821 Auto,
822 Text,
823 Json,
824 Binary,
825}
826
827#[derive(Clone, Copy, Debug)]
828enum ResolvedContentType {
829 Text,
830 Json,
831 Binary,
832}
833
834#[derive(Clone, Copy, Debug)]
835enum RequestFormat {
836 Auto,
837 Form,
838 Json,
839 Text,
840 Binary,
841}
842
843#[derive(Clone, Debug)]
844struct WebWriteOptions {
845 content_type: ContentTypeHint,
846 timeout: Duration,
847 headers: Vec<(String, String)>,
848 user_agent: Option<String>,
849 username: Option<String>,
850 password: Option<String>,
851 method: HttpMethod,
852 request_format: RequestFormat,
853 request_format_explicit: bool,
854 media_type: Option<String>,
855}
856
857impl Default for WebWriteOptions {
858 fn default() -> Self {
859 Self {
860 content_type: ContentTypeHint::Auto,
861 timeout: Duration::from_secs_f64(DEFAULT_TIMEOUT_SECONDS),
862 headers: Vec::new(),
863 user_agent: None,
864 username: None,
865 password: None,
866 method: HttpMethod::Post,
867 request_format: RequestFormat::Auto,
868 request_format_explicit: false,
869 media_type: None,
870 }
871 }
872}
873
874impl WebWriteOptions {
875 fn resolve_content_type(&self, header: Option<&str>) -> ResolvedContentType {
876 match self.content_type {
877 ContentTypeHint::Json => ResolvedContentType::Json,
878 ContentTypeHint::Text => ResolvedContentType::Text,
879 ContentTypeHint::Binary => ResolvedContentType::Binary,
880 ContentTypeHint::Auto => infer_response_content_type(header),
881 }
882 }
883}
884
885fn infer_response_content_type(header: Option<&str>) -> ResolvedContentType {
886 if let Some(raw) = header {
887 let mime = raw
888 .split(';')
889 .next()
890 .map(|part| part.trim().to_ascii_lowercase())
891 .unwrap_or_default();
892 if mime == "application/json" || mime == "text/json" || mime.ends_with("+json") {
893 ResolvedContentType::Json
894 } else if mime.starts_with("text/")
895 || mime == "application/xml"
896 || mime.ends_with("+xml")
897 || mime == "application/xhtml+xml"
898 || mime == "application/javascript"
899 || mime == "application/x-www-form-urlencoded"
900 {
901 ResolvedContentType::Text
902 } else {
903 ResolvedContentType::Binary
904 }
905 } else {
906 ResolvedContentType::Text
907 }
908}
909
910#[cfg(test)]
911pub(crate) mod tests {
912 use super::*;
913 use std::io::{Read, Write};
914 use std::net::{TcpListener, TcpStream};
915 use std::sync::mpsc;
916 use std::thread;
917
918 fn spawn_server<F>(handler: F) -> String
919 where
920 F: FnOnce(TcpStream) + Send + 'static,
921 {
922 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
923 let addr = listener.local_addr().unwrap();
924 thread::spawn(move || {
925 if let Ok((stream, _)) = listener.accept() {
926 handler(stream);
927 }
928 });
929 format!("http://{}", addr)
930 }
931
932 fn read_request(stream: &mut TcpStream) -> (String, Vec<u8>) {
933 let mut buffer = Vec::new();
934 let mut tmp = [0u8; 512];
935 let mut header_end = None;
936 loop {
937 match stream.read(&mut tmp) {
938 Ok(0) => break,
939 Ok(n) => {
940 buffer.extend_from_slice(&tmp[..n]);
941 if let Some(idx) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
942 header_end = Some(idx + 4);
943 break;
944 }
945 }
946 Err(_) => break,
947 }
948 }
949 let header_end = header_end.unwrap_or(buffer.len());
950 let headers = String::from_utf8_lossy(&buffer[..header_end]).to_string();
951 let content_length = headers
952 .lines()
953 .find_map(|line| {
954 let mut parts = line.splitn(2, ':');
955 let name = parts.next()?.trim();
956 let value = parts.next()?.trim();
957 if name.eq_ignore_ascii_case("content-length") {
958 value.parse::<usize>().ok()
959 } else {
960 None
961 }
962 })
963 .unwrap_or(0);
964 let mut body = buffer[header_end..].to_vec();
965 while body.len() < content_length {
966 match stream.read(&mut tmp) {
967 Ok(0) => break,
968 Ok(n) => body.extend_from_slice(&tmp[..n]),
969 Err(_) => break,
970 }
971 }
972 (headers, body)
973 }
974
975 fn respond_with(mut stream: TcpStream, content_type: &str, body: &[u8]) {
976 let response = format!(
977 "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: {}\r\nConnection: close\r\n\r\n",
978 body.len(),
979 content_type
980 );
981 let _ = stream.write_all(response.as_bytes());
982 let _ = stream.write_all(body);
983 }
984
985 fn run_webwrite(url: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
986 futures::executor::block_on(webwrite_builtin(url, rest))
987 }
988
989 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
990 #[test]
991 fn webwrite_posts_form_data_by_default() {
992 let payload = {
993 let mut st = StructValue::new();
994 st.fields.insert("name".to_string(), Value::from("Ada"));
995 st.fields.insert("score".to_string(), Value::Num(42.0));
996 st
997 };
998 let opts = {
999 let mut st = StructValue::new();
1000 st.fields
1001 .insert("ContentType".to_string(), Value::from("json"));
1002 st
1003 };
1004
1005 let (tx, rx) = mpsc::channel();
1006 let url = spawn_server(move |mut stream| {
1007 let (headers, body) = read_request(&mut stream);
1008 tx.send((headers, body)).unwrap();
1009 respond_with(
1010 stream,
1011 "application/json",
1012 br#"{"status":"ok","received":true}"#,
1013 );
1014 });
1015
1016 let result = run_webwrite(
1017 Value::from(url),
1018 vec![Value::Struct(payload), Value::Struct(opts)],
1019 )
1020 .expect("webwrite");
1021
1022 let (headers, body) = rx.recv().expect("request captured");
1023 assert!(headers.starts_with("POST "));
1024 let headers_lower = headers.to_ascii_lowercase();
1025 assert!(headers_lower.contains("content-type: application/x-www-form-urlencoded"));
1026 let body_text = String::from_utf8(body).expect("utf8 body");
1027 assert!(body_text.contains("name=Ada"));
1028 assert!(body_text.contains("score=42"));
1029
1030 match result {
1031 Value::Struct(reply) => {
1032 assert!(matches!(
1033 reply.fields.get("received"),
1034 Some(Value::Bool(true))
1035 ));
1036 }
1037 other => panic!("expected struct response, got {other:?}"),
1038 }
1039 }
1040
1041 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1042 #[test]
1043 fn webwrite_sends_json_when_media_type_json() {
1044 let payload = {
1045 let mut st = StructValue::new();
1046 st.fields.insert("title".to_string(), Value::from("RunMat"));
1047 st.fields.insert("stars".to_string(), Value::Num(5.0));
1048 st
1049 };
1050 let opts = {
1051 let mut st = StructValue::new();
1052 st.fields.insert(
1053 "MediaType".to_string(),
1054 Value::from("application/json; charset=utf-8"),
1055 );
1056 st.fields
1057 .insert("ContentType".to_string(), Value::from("json"));
1058 st
1059 };
1060
1061 let (tx, rx) = mpsc::channel();
1062 let url = spawn_server(move |mut stream| {
1063 let (headers, body) = read_request(&mut stream);
1064 tx.send((headers, body)).unwrap();
1065 respond_with(stream, "application/json", br#"{"ok":true}"#);
1066 });
1067
1068 let result = run_webwrite(
1069 Value::from(url),
1070 vec![Value::Struct(payload), Value::Struct(opts)],
1071 )
1072 .expect("webwrite");
1073
1074 let (headers, body) = rx.recv().expect("request");
1075 let headers_lower = headers.to_ascii_lowercase();
1076 assert!(headers_lower.contains("content-type: application/json"));
1077 let body_text = String::from_utf8(body).expect("utf8 body");
1078 assert!(body_text.contains("\"title\":\"RunMat\""));
1079 assert!(body_text.contains("\"stars\":5"));
1080
1081 match result {
1082 Value::Struct(reply) => {
1083 assert!(matches!(reply.fields.get("ok"), Some(Value::Bool(true))));
1084 }
1085 other => panic!("expected struct response, got {other:?}"),
1086 }
1087 }
1088
1089 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1090 #[test]
1091 fn webwrite_applies_basic_auth_and_custom_headers() {
1092 let payload = Value::from("");
1093 let mut header_struct = StructValue::new();
1094 header_struct
1095 .fields
1096 .insert("X-Test".to_string(), Value::from("yes"));
1097 header_struct
1098 .fields
1099 .insert("Accept".to_string(), Value::from("text/plain"));
1100 let mut opts_struct = StructValue::new();
1101 opts_struct
1102 .fields
1103 .insert("Username".to_string(), Value::from("ada"));
1104 opts_struct
1105 .fields
1106 .insert("Password".to_string(), Value::from("secret"));
1107 opts_struct
1108 .fields
1109 .insert("HeaderFields".to_string(), Value::Struct(header_struct));
1110 opts_struct
1111 .fields
1112 .insert("ContentType".to_string(), Value::from("text"));
1113 opts_struct
1114 .fields
1115 .insert("MediaType".to_string(), Value::from("text/plain"));
1116
1117 let (tx, rx) = mpsc::channel();
1118 let url = spawn_server(move |mut stream| {
1119 let (headers, _) = read_request(&mut stream);
1120 tx.send(headers).unwrap();
1121 respond_with(stream, "text/plain", b"OK");
1122 });
1123
1124 let result = run_webwrite(Value::from(url), vec![payload, Value::Struct(opts_struct)])
1125 .expect("webwrite");
1126
1127 let headers = rx.recv().expect("headers");
1128 let headers_lower = headers.to_ascii_lowercase();
1129 assert!(headers_lower.contains("authorization: basic"));
1130 assert!(headers_lower.contains("x-test: yes"));
1131 assert!(headers_lower.contains("accept: text/plain"));
1132
1133 match result {
1134 Value::CharArray(ca) => {
1135 let text: String = ca.data.iter().collect();
1136 assert_eq!(text, "OK");
1137 }
1138 other => panic!("expected char array, got {other:?}"),
1139 }
1140 }
1141
1142 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1143 #[test]
1144 fn webwrite_supports_query_parameters() {
1145 let payload = Value::Struct(StructValue::new());
1146 let mut qp_struct = StructValue::new();
1147 qp_struct.fields.insert("page".to_string(), Value::Num(2.0));
1148 qp_struct
1149 .fields
1150 .insert("verbose".to_string(), Value::Bool(true));
1151 let mut opts_struct = StructValue::new();
1152 opts_struct
1153 .fields
1154 .insert("QueryParameters".to_string(), Value::Struct(qp_struct));
1155
1156 let (tx, rx) = mpsc::channel();
1157 let url = spawn_server(move |mut stream| {
1158 let (headers, _) = read_request(&mut stream);
1159 tx.send(headers).unwrap();
1160 respond_with(stream, "application/json", br#"{"ok":true}"#);
1161 });
1162
1163 let _ = run_webwrite(
1164 Value::from(url.clone()),
1165 vec![payload, Value::Struct(opts_struct)],
1166 )
1167 .expect("webwrite");
1168
1169 let headers = rx.recv().expect("headers");
1170 let first_line = headers.lines().next().unwrap_or("");
1171 assert!(first_line.starts_with("POST "));
1172 assert!(first_line.contains("page=2"));
1173 assert!(first_line.contains("verbose=true"));
1174 }
1175
1176 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
1177 #[test]
1178 fn webwrite_binary_payload_respected() {
1179 let tensor = Tensor::new(vec![1.0, 2.0, 3.0, 255.0], vec![4, 1]).unwrap();
1180 let payload = Value::Tensor(tensor);
1181 let mut opts_struct = StructValue::new();
1182 opts_struct
1183 .fields
1184 .insert("ContentType".to_string(), Value::from("binary"));
1185 opts_struct.fields.insert(
1186 "MediaType".to_string(),
1187 Value::from("application/octet-stream"),
1188 );
1189
1190 let (tx, rx) = mpsc::channel();
1191 let url = spawn_server(move |mut stream| {
1192 let (headers, body) = read_request(&mut stream);
1193 tx.send((headers, body)).unwrap();
1194 respond_with(stream, "text/plain", b"OK");
1195 });
1196
1197 let _ = run_webwrite(Value::from(url), vec![payload, Value::Struct(opts_struct)])
1198 .expect("webwrite");
1199
1200 let (headers, body) = rx.recv().expect("request");
1201 let headers_lower = headers.to_ascii_lowercase();
1202 assert!(headers_lower.contains("content-type: application/octet-stream"));
1203 assert_eq!(body, vec![1, 2, 3, 255]);
1204 }
1205}