1use runmat_builtins::{CharArray, LogicalArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::interaction;
11use crate::{
12 build_runtime_error, call_builtin_async, gather_if_needed_async, BuiltinResult, RuntimeError,
13};
14
15const DEFAULT_PROMPT: &str = "Input: ";
16
17fn input_error(identifier: &str, message: impl Into<String>) -> RuntimeError {
18 build_runtime_error(message)
19 .with_identifier(identifier.to_string())
20 .with_builtin("input")
21 .build()
22}
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::input")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "input",
27 op_kind: GpuOpKind::Custom("interaction"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::GatherImmediately,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "Prompts execute on the host. Input text is always delivered via the host handler; GPU tensors are only gathered when used as prompt strings.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::input")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "input",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "Side-effecting builtin; excluded from fusion plans.",
49};
50
51#[runtime_builtin(
52 name = "input",
53 type_resolver(crate::builtins::io::type_resolvers::input_type),
54 builtin_path = "crate::builtins::io::input"
55)]
56async fn input_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
57 if args.len() > 2 {
58 return Err(input_error(
59 "RunMat:input:TooManyInputs",
60 "input: too many inputs",
61 ));
62 }
63
64 let mut prompt_index = if args.is_empty() { None } else { Some(0usize) };
65 let mut parsed_flag: Option<bool> = None;
66
67 if let Some(idx) = if args.len() == 2 { Some(1usize) } else { None } {
68 match parse_string_flag(&args[idx]).await {
69 Ok(flag) => parsed_flag = Some(flag),
70 Err(original_err) => {
71 if let Some(prompt_idx) = prompt_index {
72 match parse_string_flag(&args[prompt_idx]).await {
73 Ok(swapped_flag) => {
74 parsed_flag = Some(swapped_flag);
75 prompt_index = Some(idx);
76 }
77 Err(_) => {
78 return Err(original_err);
79 }
80 }
81 } else {
82 return Err(original_err);
83 }
84 }
85 }
86 }
87
88 let prompt = if let Some(idx) = prompt_index {
89 parse_prompt(&args[idx]).await?
90 } else {
91 DEFAULT_PROMPT.to_string()
92 };
93 let return_string = parsed_flag.unwrap_or(false);
94 let line = interaction::request_line_async(&prompt, true)
95 .await
96 .map_err(|err| {
97 let message = err.message().to_string();
98 build_runtime_error(format!("input: {message}"))
99 .with_identifier("RunMat:input:InteractionFailed")
100 .with_source(err)
101 .with_builtin("input")
102 .build()
103 })?;
104 if return_string {
105 return Ok(Value::CharArray(CharArray::new_row(&line)));
106 }
107 parse_numeric_response(&line).await
108}
109
110async fn parse_prompt(value: &Value) -> Result<String, RuntimeError> {
111 let gathered = gather_if_needed_async(value).await?;
112 match gathered {
113 Value::CharArray(ca) => {
114 if ca.rows != 1 {
115 Err(input_error(
116 "RunMat:input:PromptMustBeRowVector",
117 "input: prompt must be a row vector",
118 ))
119 } else {
120 Ok(ca.data.iter().collect())
121 }
122 }
123 Value::String(text) => Ok(text),
124 Value::StringArray(sa) => {
125 if sa.data.len() == 1 {
126 Ok(sa.data[0].clone())
127 } else {
128 Err(input_error(
129 "RunMat:input:PromptMustBeScalarString",
130 "input: prompt must be a scalar string",
131 ))
132 }
133 }
134 other => Err(input_error(
135 "RunMat:input:InvalidPromptType",
136 format!("input: invalid prompt type ({other:?})"),
137 )),
138 }
139}
140
141async fn parse_string_flag(value: &Value) -> Result<bool, RuntimeError> {
142 let gathered = gather_if_needed_async(value).await?;
143 let text = match gathered {
144 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect::<String>(),
145 Value::String(s) => s,
146 Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
147 other => {
148 return Err(input_error(
149 "RunMat:input:InvalidStringFlag",
150 format!("input: invalid string flag ({other:?})"),
151 ))
152 }
153 };
154 let trimmed = text.trim();
155 if trimmed.eq_ignore_ascii_case("s") {
156 Ok(true)
157 } else {
158 Err(input_error(
159 "RunMat:input:InvalidStringFlag",
160 format!("input: invalid string flag ({trimmed})"),
161 ))
162 }
163}
164
165async fn parse_numeric_response(line: &str) -> Result<Value, RuntimeError> {
166 let trimmed = line.trim();
167 if trimmed.is_empty() || trimmed == "[]" {
168 return Ok(Value::Tensor(Tensor::zeros(vec![0, 0])));
169 }
170
171 if let Some(v) = parse_scalar_value(trimmed) {
174 return Ok(v);
175 }
176
177 if trimmed.starts_with('[') && trimmed.ends_with(']') {
180 if let Some(v) = parse_matrix_literal(trimmed) {
181 return Ok(v);
182 }
183 }
184
185 if let Some(hook) = interaction::current_eval_hook() {
190 return hook(trimmed.to_string()).await.map_err(|err| {
191 let message = err.message().to_string();
192 build_runtime_error(format!("input: invalid expression ({message})"))
193 .with_identifier("RunMat:input:EvalFailed")
194 .with_source(err)
195 .with_builtin("input")
196 .build()
197 });
198 }
199
200 call_builtin_async("str2double", &[Value::String(trimmed.to_string())])
202 .await
203 .map_err(|err| {
204 let message = err.message().to_string();
205 build_runtime_error(format!("input: invalid numeric expression ({message})"))
206 .with_identifier("RunMat:input:InvalidNumericExpression")
207 .with_source(err)
208 .with_builtin("input")
209 .build()
210 })
211}
212
213fn parse_scalar_value(s: &str) -> Option<Value> {
225 match s.to_ascii_lowercase().as_str() {
226 "true" => return Some(Value::Bool(true)),
227 "false" => return Some(Value::Bool(false)),
228 "pi" => return Some(Value::Num(std::f64::consts::PI)),
229 "inf" | "+inf" | "infinity" | "+infinity" => return Some(Value::Num(f64::INFINITY)),
230 "-inf" | "-infinity" => return Some(Value::Num(f64::NEG_INFINITY)),
231 "nan" => return Some(Value::Num(f64::NAN)),
232 _ => {}
233 }
234 let has_non_numeric = s.chars().any(|c| {
238 matches!(c, '[' | ']' | ',' | ';' | '(' | ')' | ' ' | '\t')
239 || (c.is_ascii_alphabetic() && c != 'e' && c != 'E' && c != 'i' && c != 'j')
240 });
241 if has_non_numeric {
242 return None;
243 }
244 s.parse::<f64>().ok().map(Value::Num)
245}
246
247fn parse_matrix_literal(s: &str) -> Option<Value> {
257 let inner = s.strip_prefix('[')?.strip_suffix(']')?;
258 let inner = inner.trim();
259 if inner.is_empty() {
260 return Some(Value::Tensor(Tensor::zeros(vec![0, 0])));
261 }
262
263 let row_strs: Vec<&str> = inner.split(';').collect();
264 let mut values: Vec<Value> = Vec::new();
265 let mut nrows = 0usize;
266 let mut ncols: Option<usize> = None;
267
268 for row_str in &row_strs {
269 let tokens: Vec<&str> = row_str
270 .split(|c: char| c == ',' || c.is_ascii_whitespace())
271 .filter(|t| !t.is_empty())
272 .collect();
273 if tokens.is_empty() {
274 continue;
275 }
276 match ncols {
277 None => ncols = Some(tokens.len()),
278 Some(expected) if tokens.len() != expected => return None,
279 _ => {}
280 }
281 for token in &tokens {
282 values.push(parse_scalar_value(token)?);
283 }
284 nrows += 1;
285 }
286
287 let ncols = ncols.unwrap_or(0);
288 if nrows == 0 || ncols == 0 {
289 return Some(Value::Tensor(Tensor::zeros(vec![0, 0])));
290 }
291 if nrows == 1 && ncols == 1 {
293 return Some(values.remove(0));
294 }
295
296 let all_logical = values.iter().all(|v| matches!(v, Value::Bool(_)));
301 if all_logical {
302 let mut data: Vec<u8> = vec![0u8; nrows * ncols];
303 for r in 0..nrows {
304 for c in 0..ncols {
305 let row_major_idx = r * ncols + c;
306 let col_major_idx = r + c * nrows;
307 data[col_major_idx] = match &values[row_major_idx] {
308 Value::Bool(b) => u8::from(*b),
309 _ => unreachable!(),
310 };
311 }
312 }
313 LogicalArray::new(data, vec![nrows, ncols])
314 .ok()
315 .map(Value::LogicalArray)
316 } else {
317 let mut data: Vec<f64> = vec![0f64; nrows * ncols];
318 for r in 0..nrows {
319 for c in 0..ncols {
320 let row_major_idx = r * ncols + c;
321 let col_major_idx = r + c * nrows;
322 data[col_major_idx] = match &values[row_major_idx] {
323 Value::Num(f) => *f,
324 Value::Bool(b) => f64::from(u8::from(*b)),
325 _ => unreachable!(),
326 };
327 }
328 }
329 Tensor::new_2d(data, nrows, ncols).ok().map(Value::Tensor)
330 }
331}
332
333#[cfg(test)]
334pub(crate) mod tests {
335 use super::*;
336 use crate::interaction::{push_queued_response, InteractionResponse};
337
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339 #[test]
340 fn numeric_input_parses_scalar() {
341 push_queued_response(Ok(InteractionResponse::Line("41".into())));
342 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
343 assert_eq!(value, Value::Num(41.0));
344 }
345
346 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
347 #[test]
348 fn string_mode_returns_char_row() {
349 push_queued_response(Ok(InteractionResponse::Line("RunMat".into())));
350 let prompt = Value::CharArray(CharArray::new_row("Name: "));
351 let mode = Value::String("s".to_string());
352 let value = futures::executor::block_on(input_builtin(vec![prompt, mode])).expect("input");
353 assert_eq!(value, Value::CharArray(CharArray::new_row("RunMat")));
354 }
355
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357 #[test]
358 fn empty_response_returns_empty_tensor() {
359 push_queued_response(Ok(InteractionResponse::Line(" ".into())));
360 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
361 match value {
362 Value::Tensor(t) => assert!(t.data.is_empty()),
363 other => panic!("expected empty tensor, got {other:?}"),
364 }
365 }
366
367 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
368 #[test]
369 fn matrix_literal_parses_without_eval_hook() {
370 push_queued_response(Ok(InteractionResponse::Line("[1 2 3]".into())));
373 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
374 match value {
375 Value::Tensor(t) => {
376 assert_eq!(t.rows, 1);
377 assert_eq!(t.cols, 3);
378 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
379 }
380 other => panic!("expected 1×3 tensor, got {other:?}"),
381 }
382 }
383
384 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
385 #[test]
386 fn named_constants_parse_without_eval_hook() {
387 push_queued_response(Ok(InteractionResponse::Line("pi".into())));
388 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
389 assert_eq!(value, Value::Num(std::f64::consts::PI));
390 }
391
392 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
396 #[test]
397 fn bare_e_is_not_eulers_number() {
398 assert_eq!(parse_scalar_value("e"), None);
399 assert_eq!(parse_scalar_value("E"), None);
400 }
401
402 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
404 #[test]
405 fn matrix_with_bare_e_does_not_parse() {
406 assert_eq!(parse_matrix_literal("[1 e 3]"), None);
407 }
408
409 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
410 #[test]
411 fn true_input_returns_logical_not_double() {
412 push_queued_response(Ok(InteractionResponse::Line("true".into())));
413 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
414 assert_eq!(value, Value::Bool(true));
415 }
416
417 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418 #[test]
419 fn false_input_returns_logical_not_double() {
420 push_queued_response(Ok(InteractionResponse::Line("false".into())));
421 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
422 assert_eq!(value, Value::Bool(false));
423 }
424
425 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
426 #[test]
427 fn bool_input_is_case_insensitive() {
428 push_queued_response(Ok(InteractionResponse::Line("TRUE".into())));
429 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
430 assert_eq!(value, Value::Bool(true));
431 }
432
433 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
434 #[test]
435 fn column_vector_parses_without_eval_hook() {
436 push_queued_response(Ok(InteractionResponse::Line("[1;2;3]".into())));
437 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
438 match value {
439 Value::Tensor(t) => {
440 assert_eq!(t.rows, 3);
441 assert_eq!(t.cols, 1);
442 assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
443 }
444 other => panic!("expected 3×1 tensor, got {other:?}"),
445 }
446 }
447
448 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449 #[test]
450 fn logical_row_vector_parses_as_logical_array() {
451 push_queued_response(Ok(InteractionResponse::Line("[true false]".into())));
452 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
453 match value {
454 Value::LogicalArray(la) => {
455 assert_eq!(la.shape, vec![1, 2]);
456 assert_eq!(la.data, vec![1, 0]);
457 }
458 other => panic!("expected LogicalArray, got {other:?}"),
459 }
460 }
461
462 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
463 #[test]
464 fn logical_column_vector_parses_as_logical_array() {
465 push_queued_response(Ok(InteractionResponse::Line("[true; false]".into())));
466 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
467 match value {
468 Value::LogicalArray(la) => {
469 assert_eq!(la.shape, vec![2, 1]);
470 assert_eq!(la.data, vec![1, 0]);
471 }
472 other => panic!("expected LogicalArray, got {other:?}"),
473 }
474 }
475
476 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477 #[test]
478 fn mixed_logical_and_numeric_coerces_to_double_tensor() {
479 push_queued_response(Ok(InteractionResponse::Line("[true 2.0]".into())));
480 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
481 match value {
482 Value::Tensor(t) => {
483 assert_eq!(t.rows, 1);
484 assert_eq!(t.cols, 2);
485 assert_eq!(t.data, vec![1.0, 2.0]);
486 }
487 other => panic!("expected Tensor, got {other:?}"),
488 }
489 }
490
491 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
492 #[test]
493 fn matrix_2x2_column_major_layout() {
494 push_queued_response(Ok(InteractionResponse::Line("[1 2; 3 4]".into())));
497 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
498 match value {
499 Value::Tensor(t) => {
500 assert_eq!(t.rows, 2);
501 assert_eq!(t.cols, 2);
502 assert_eq!(t.get2(0, 0).unwrap(), 1.0, "(0,0) should be 1");
503 assert_eq!(t.get2(0, 1).unwrap(), 2.0, "(0,1) should be 2");
504 assert_eq!(t.get2(1, 0).unwrap(), 3.0, "(1,0) should be 3");
505 assert_eq!(t.get2(1, 1).unwrap(), 4.0, "(1,1) should be 4");
506 }
507 other => panic!("expected 2×2 tensor, got {other:?}"),
508 }
509 }
510
511 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
512 #[test]
513 fn logical_matrix_2x2_column_major_layout() {
514 push_queued_response(Ok(InteractionResponse::Line(
516 "[true false; false true]".into(),
517 )));
518 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
519 match value {
520 Value::LogicalArray(la) => {
521 assert_eq!(la.shape, vec![2, 2]);
522 assert_eq!(la.data, vec![1, 0, 0, 1]);
524 }
525 other => panic!("expected 2×2 LogicalArray, got {other:?}"),
526 }
527 }
528
529 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
530 #[test]
531 fn invalid_string_flag_errors_before_prompt() {
532 push_queued_response(Ok(InteractionResponse::Line("ignored".into())));
533 let prompt = Value::String("Ready?".to_string());
534 let bad_flag = Value::String("not-string-mode".to_string());
535 let err = futures::executor::block_on(input_builtin(vec![prompt, bad_flag])).unwrap_err();
536 assert_eq!(err.identifier(), Some("RunMat:input:InvalidStringFlag"));
537 }
538}