Skip to main content

runmat_runtime/builtins/io/
input.rs

1//! MATLAB-compatible `input` builtin for line-oriented console interaction.
2
3use runmat_builtins::{CharArray, LogicalArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::interaction;
11use crate::{
12    build_runtime_error, call_builtin_async, gather_if_needed_async, BuiltinResult, RuntimeError,
13};
14
15const DEFAULT_PROMPT: &str = "Input: ";
16
17fn input_error(identifier: &str, message: impl Into<String>) -> RuntimeError {
18    build_runtime_error(message)
19        .with_identifier(identifier.to_string())
20        .with_builtin("input")
21        .build()
22}
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::input")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "input",
27    op_kind: GpuOpKind::Custom("interaction"),
28    supported_precisions: &[],
29    broadcast: BroadcastSemantics::None,
30    provider_hooks: &[],
31    constant_strategy: ConstantStrategy::InlineLiteral,
32    residency: ResidencyPolicy::GatherImmediately,
33    nan_mode: ReductionNaN::Include,
34    two_pass_threshold: None,
35    workgroup_size: None,
36    accepts_nan_mode: false,
37    notes: "Prompts execute on the host. Input text is always delivered via the host handler; GPU tensors are only gathered when used as prompt strings.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::input")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42    name: "input",
43    shape: ShapeRequirements::Any,
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    elementwise: None,
46    reduction: None,
47    emits_nan: false,
48    notes: "Side-effecting builtin; excluded from fusion plans.",
49};
50
51#[runtime_builtin(
52    name = "input",
53    type_resolver(crate::builtins::io::type_resolvers::input_type),
54    builtin_path = "crate::builtins::io::input"
55)]
56async fn input_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
57    if args.len() > 2 {
58        return Err(input_error(
59            "RunMat:input:TooManyInputs",
60            "input: too many inputs",
61        ));
62    }
63
64    let mut prompt_index = if args.is_empty() { None } else { Some(0usize) };
65    let mut parsed_flag: Option<bool> = None;
66
67    if let Some(idx) = if args.len() == 2 { Some(1usize) } else { None } {
68        match parse_string_flag(&args[idx]).await {
69            Ok(flag) => parsed_flag = Some(flag),
70            Err(original_err) => {
71                if let Some(prompt_idx) = prompt_index {
72                    match parse_string_flag(&args[prompt_idx]).await {
73                        Ok(swapped_flag) => {
74                            parsed_flag = Some(swapped_flag);
75                            prompt_index = Some(idx);
76                        }
77                        Err(_) => {
78                            return Err(original_err);
79                        }
80                    }
81                } else {
82                    return Err(original_err);
83                }
84            }
85        }
86    }
87
88    let prompt = if let Some(idx) = prompt_index {
89        parse_prompt(&args[idx]).await?
90    } else {
91        DEFAULT_PROMPT.to_string()
92    };
93    let return_string = parsed_flag.unwrap_or(false);
94    let line = interaction::request_line_async(&prompt, true)
95        .await
96        .map_err(|err| {
97            let message = err.message().to_string();
98            build_runtime_error(format!("input: {message}"))
99                .with_identifier("RunMat:input:InteractionFailed")
100                .with_source(err)
101                .with_builtin("input")
102                .build()
103        })?;
104    if return_string {
105        return Ok(Value::CharArray(CharArray::new_row(&line)));
106    }
107    parse_numeric_response(&line).await
108}
109
110async fn parse_prompt(value: &Value) -> Result<String, RuntimeError> {
111    let gathered = gather_if_needed_async(value).await?;
112    match gathered {
113        Value::CharArray(ca) => {
114            if ca.rows != 1 {
115                Err(input_error(
116                    "RunMat:input:PromptMustBeRowVector",
117                    "input: prompt must be a row vector",
118                ))
119            } else {
120                Ok(ca.data.iter().collect())
121            }
122        }
123        Value::String(text) => Ok(text),
124        Value::StringArray(sa) => {
125            if sa.data.len() == 1 {
126                Ok(sa.data[0].clone())
127            } else {
128                Err(input_error(
129                    "RunMat:input:PromptMustBeScalarString",
130                    "input: prompt must be a scalar string",
131                ))
132            }
133        }
134        other => Err(input_error(
135            "RunMat:input:InvalidPromptType",
136            format!("input: invalid prompt type ({other:?})"),
137        )),
138    }
139}
140
141async fn parse_string_flag(value: &Value) -> Result<bool, RuntimeError> {
142    let gathered = gather_if_needed_async(value).await?;
143    let text = match gathered {
144        Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect::<String>(),
145        Value::String(s) => s,
146        Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
147        other => {
148            return Err(input_error(
149                "RunMat:input:InvalidStringFlag",
150                format!("input: invalid string flag ({other:?})"),
151            ))
152        }
153    };
154    let trimmed = text.trim();
155    if trimmed.eq_ignore_ascii_case("s") {
156        Ok(true)
157    } else {
158        Err(input_error(
159            "RunMat:input:InvalidStringFlag",
160            format!("input: invalid string flag ({trimmed})"),
161        ))
162    }
163}
164
165async fn parse_numeric_response(line: &str) -> Result<Value, RuntimeError> {
166    let trimmed = line.trim();
167    if trimmed.is_empty() || trimmed == "[]" {
168        return Ok(Value::Tensor(Tensor::zeros(vec![0, 0])));
169    }
170
171    // Fast path 1: scalar literals, named constants, and logical keywords.
172    // Handles the vast majority of input() use cases without touching the VM.
173    if let Some(v) = parse_scalar_value(trimmed) {
174        return Ok(v);
175    }
176
177    // Fast path 2: matrix/vector literals like `[1 2 3]`, `[1;2;3]`, `[true false]`.
178    // Avoids recursive interpret() calls for this common case.
179    if trimmed.starts_with('[') && trimmed.ends_with(']') {
180        if let Some(v) = parse_matrix_literal(trimmed) {
181            return Ok(v);
182        }
183    }
184
185    // Full eval path for complex expressions (`sqrt(2)`, `pi/2`, `ones(3)`, etc.).
186    // The eval hook is only safe to call when the executor can handle re-entrant
187    // polls (e.g. the WASM async runtime). On native the fast paths above cover
188    // the common cases; truly complex expressions fall back to str2double here.
189    if let Some(hook) = interaction::current_eval_hook() {
190        return hook(trimmed.to_string()).await.map_err(|err| {
191            let message = err.message().to_string();
192            build_runtime_error(format!("input: invalid expression ({message})"))
193                .with_identifier("RunMat:input:EvalFailed")
194                .with_source(err)
195                .with_builtin("input")
196                .build()
197        });
198    }
199
200    // Fallback when no eval hook is installed (unit tests, native REPL).
201    call_builtin_async("str2double", &[Value::String(trimmed.to_string())])
202        .await
203        .map_err(|err| {
204            let message = err.message().to_string();
205            build_runtime_error(format!("input: invalid numeric expression ({message})"))
206                .with_identifier("RunMat:input:InvalidNumericExpression")
207                .with_source(err)
208                .with_builtin("input")
209                .build()
210        })
211}
212
213/// Parse a single MATLAB scalar token into a [`Value`].
214///
215/// Returns [`Value::Bool`] for `true`/`false` (case-insensitive), [`Value::Num`]
216/// for numeric literals and named constants (`pi`, `inf`, `nan`), and
217/// `None` for anything that looks like a matrix, function call, or unknown
218/// identifier.
219///
220/// Note: `e` is intentionally **not** handled here. It is not a MATLAB built-in
221/// constant; typing `e` at an `input()` prompt would perform a variable lookup in
222/// MATLAB and error if `e` is undefined. Unknown identifiers fall through to the
223/// eval hook or `str2double`, which produce the correct error.
224fn parse_scalar_value(s: &str) -> Option<Value> {
225    match s.to_ascii_lowercase().as_str() {
226        "true" => return Some(Value::Bool(true)),
227        "false" => return Some(Value::Bool(false)),
228        "pi" => return Some(Value::Num(std::f64::consts::PI)),
229        "inf" | "+inf" | "infinity" | "+infinity" => return Some(Value::Num(f64::INFINITY)),
230        "-inf" | "-infinity" => return Some(Value::Num(f64::NEG_INFINITY)),
231        "nan" => return Some(Value::Num(f64::NAN)),
232        _ => {}
233    }
234    // Plain numeric literals: integers, decimals, scientific notation, optional sign.
235    // We reject anything containing brackets, commas, spaces (which would indicate a
236    // matrix or an expression), or letters other than 'e'/'E' for exponent notation.
237    let has_non_numeric = s.chars().any(|c| {
238        matches!(c, '[' | ']' | ',' | ';' | '(' | ')' | ' ' | '\t')
239            || (c.is_ascii_alphabetic() && c != 'e' && c != 'E' && c != 'i' && c != 'j')
240    });
241    if has_non_numeric {
242        return None;
243    }
244    s.parse::<f64>().ok().map(Value::Num)
245}
246
247/// Parse a MATLAB matrix literal of the form `[elements]`.
248///
249/// Rows are separated by `;` and elements within a row by whitespace and/or `,`.
250/// Every element must be a token accepted by [`parse_scalar_value`].
251/// Returns `None` if the literal is malformed or contains non-scalar elements.
252///
253/// Output type mirrors MATLAB semantics:
254/// - All-logical elements → [`Value::LogicalArray`]
255/// - Any numeric element  → [`Value::Tensor`] (logical elements coerced to `f64`)
256fn parse_matrix_literal(s: &str) -> Option<Value> {
257    let inner = s.strip_prefix('[')?.strip_suffix(']')?;
258    let inner = inner.trim();
259    if inner.is_empty() {
260        return Some(Value::Tensor(Tensor::zeros(vec![0, 0])));
261    }
262
263    let row_strs: Vec<&str> = inner.split(';').collect();
264    let mut values: Vec<Value> = Vec::new();
265    let mut nrows = 0usize;
266    let mut ncols: Option<usize> = None;
267
268    for row_str in &row_strs {
269        let tokens: Vec<&str> = row_str
270            .split(|c: char| c == ',' || c.is_ascii_whitespace())
271            .filter(|t| !t.is_empty())
272            .collect();
273        if tokens.is_empty() {
274            continue;
275        }
276        match ncols {
277            None => ncols = Some(tokens.len()),
278            Some(expected) if tokens.len() != expected => return None,
279            _ => {}
280        }
281        for token in &tokens {
282            values.push(parse_scalar_value(token)?);
283        }
284        nrows += 1;
285    }
286
287    let ncols = ncols.unwrap_or(0);
288    if nrows == 0 || ncols == 0 {
289        return Some(Value::Tensor(Tensor::zeros(vec![0, 0])));
290    }
291    // Scalar: preserve the exact type (Bool or Num) rather than always wrapping in Tensor.
292    if nrows == 1 && ncols == 1 {
293        return Some(values.remove(0));
294    }
295
296    // All-logical → LogicalArray; any numeric element → Tensor (bools coerced to f64).
297    // `values` is in row-major order (row 0 left-to-right, then row 1, …), but both
298    // Tensor and LogicalArray store data in column-major order (data[r + c*rows]).
299    // Reorder so that column-major index maps to the correct element.
300    let all_logical = values.iter().all(|v| matches!(v, Value::Bool(_)));
301    if all_logical {
302        let mut data: Vec<u8> = vec![0u8; nrows * ncols];
303        for r in 0..nrows {
304            for c in 0..ncols {
305                let row_major_idx = r * ncols + c;
306                let col_major_idx = r + c * nrows;
307                data[col_major_idx] = match &values[row_major_idx] {
308                    Value::Bool(b) => u8::from(*b),
309                    _ => unreachable!(),
310                };
311            }
312        }
313        LogicalArray::new(data, vec![nrows, ncols])
314            .ok()
315            .map(Value::LogicalArray)
316    } else {
317        let mut data: Vec<f64> = vec![0f64; nrows * ncols];
318        for r in 0..nrows {
319            for c in 0..ncols {
320                let row_major_idx = r * ncols + c;
321                let col_major_idx = r + c * nrows;
322                data[col_major_idx] = match &values[row_major_idx] {
323                    Value::Num(f) => *f,
324                    Value::Bool(b) => f64::from(u8::from(*b)),
325                    _ => unreachable!(),
326                };
327            }
328        }
329        Tensor::new_2d(data, nrows, ncols).ok().map(Value::Tensor)
330    }
331}
332
333#[cfg(test)]
334pub(crate) mod tests {
335    use super::*;
336    use crate::interaction::{push_queued_response, InteractionResponse};
337
338    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
339    #[test]
340    fn numeric_input_parses_scalar() {
341        push_queued_response(Ok(InteractionResponse::Line("41".into())));
342        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
343        assert_eq!(value, Value::Num(41.0));
344    }
345
346    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
347    #[test]
348    fn string_mode_returns_char_row() {
349        push_queued_response(Ok(InteractionResponse::Line("RunMat".into())));
350        let prompt = Value::CharArray(CharArray::new_row("Name: "));
351        let mode = Value::String("s".to_string());
352        let value = futures::executor::block_on(input_builtin(vec![prompt, mode])).expect("input");
353        assert_eq!(value, Value::CharArray(CharArray::new_row("RunMat")));
354    }
355
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
357    #[test]
358    fn empty_response_returns_empty_tensor() {
359        push_queued_response(Ok(InteractionResponse::Line("   ".into())));
360        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
361        match value {
362            Value::Tensor(t) => assert!(t.data.is_empty()),
363            other => panic!("expected empty tensor, got {other:?}"),
364        }
365    }
366
367    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
368    #[test]
369    fn matrix_literal_parses_without_eval_hook() {
370        // The fast-path parser handles `[1 2 3]` directly, so no eval hook (and
371        // therefore no recursive interpret() call) is needed.
372        push_queued_response(Ok(InteractionResponse::Line("[1 2 3]".into())));
373        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
374        match value {
375            Value::Tensor(t) => {
376                assert_eq!(t.rows, 1);
377                assert_eq!(t.cols, 3);
378                assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
379            }
380            other => panic!("expected 1×3 tensor, got {other:?}"),
381        }
382    }
383
384    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
385    #[test]
386    fn named_constants_parse_without_eval_hook() {
387        push_queued_response(Ok(InteractionResponse::Line("pi".into())));
388        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
389        assert_eq!(value, Value::Num(std::f64::consts::PI));
390    }
391
392    /// `e` is not a MATLAB built-in constant. The fast-path parser must not map
393    /// it to Euler's number; it should fall through so the eval hook or
394    /// `str2double` can handle it (which will NaN or error on an unknown identifier).
395    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
396    #[test]
397    fn bare_e_is_not_eulers_number() {
398        assert_eq!(parse_scalar_value("e"), None);
399        assert_eq!(parse_scalar_value("E"), None);
400    }
401
402    /// `[1 e 3]` must not silently produce `[1.0, 2.718…, 3.0]`.
403    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
404    #[test]
405    fn matrix_with_bare_e_does_not_parse() {
406        assert_eq!(parse_matrix_literal("[1 e 3]"), None);
407    }
408
409    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
410    #[test]
411    fn true_input_returns_logical_not_double() {
412        push_queued_response(Ok(InteractionResponse::Line("true".into())));
413        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
414        assert_eq!(value, Value::Bool(true));
415    }
416
417    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
418    #[test]
419    fn false_input_returns_logical_not_double() {
420        push_queued_response(Ok(InteractionResponse::Line("false".into())));
421        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
422        assert_eq!(value, Value::Bool(false));
423    }
424
425    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
426    #[test]
427    fn bool_input_is_case_insensitive() {
428        push_queued_response(Ok(InteractionResponse::Line("TRUE".into())));
429        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
430        assert_eq!(value, Value::Bool(true));
431    }
432
433    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
434    #[test]
435    fn column_vector_parses_without_eval_hook() {
436        push_queued_response(Ok(InteractionResponse::Line("[1;2;3]".into())));
437        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
438        match value {
439            Value::Tensor(t) => {
440                assert_eq!(t.rows, 3);
441                assert_eq!(t.cols, 1);
442                assert_eq!(t.data, vec![1.0, 2.0, 3.0]);
443            }
444            other => panic!("expected 3×1 tensor, got {other:?}"),
445        }
446    }
447
448    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
449    #[test]
450    fn logical_row_vector_parses_as_logical_array() {
451        push_queued_response(Ok(InteractionResponse::Line("[true false]".into())));
452        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
453        match value {
454            Value::LogicalArray(la) => {
455                assert_eq!(la.shape, vec![1, 2]);
456                assert_eq!(la.data, vec![1, 0]);
457            }
458            other => panic!("expected LogicalArray, got {other:?}"),
459        }
460    }
461
462    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
463    #[test]
464    fn logical_column_vector_parses_as_logical_array() {
465        push_queued_response(Ok(InteractionResponse::Line("[true; false]".into())));
466        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
467        match value {
468            Value::LogicalArray(la) => {
469                assert_eq!(la.shape, vec![2, 1]);
470                assert_eq!(la.data, vec![1, 0]);
471            }
472            other => panic!("expected LogicalArray, got {other:?}"),
473        }
474    }
475
476    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
477    #[test]
478    fn mixed_logical_and_numeric_coerces_to_double_tensor() {
479        push_queued_response(Ok(InteractionResponse::Line("[true 2.0]".into())));
480        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
481        match value {
482            Value::Tensor(t) => {
483                assert_eq!(t.rows, 1);
484                assert_eq!(t.cols, 2);
485                assert_eq!(t.data, vec![1.0, 2.0]);
486            }
487            other => panic!("expected Tensor, got {other:?}"),
488        }
489    }
490
491    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
492    #[test]
493    fn matrix_2x2_column_major_layout() {
494        // [1 2; 3 4] → get2(r,c) must return element at row r, col c, not the transpose.
495        // Column-major storage: data = [1, 3, 2, 4] (not the row-major [1, 2, 3, 4]).
496        push_queued_response(Ok(InteractionResponse::Line("[1 2; 3 4]".into())));
497        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
498        match value {
499            Value::Tensor(t) => {
500                assert_eq!(t.rows, 2);
501                assert_eq!(t.cols, 2);
502                assert_eq!(t.get2(0, 0).unwrap(), 1.0, "(0,0) should be 1");
503                assert_eq!(t.get2(0, 1).unwrap(), 2.0, "(0,1) should be 2");
504                assert_eq!(t.get2(1, 0).unwrap(), 3.0, "(1,0) should be 3");
505                assert_eq!(t.get2(1, 1).unwrap(), 4.0, "(1,1) should be 4");
506            }
507            other => panic!("expected 2×2 tensor, got {other:?}"),
508        }
509    }
510
511    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
512    #[test]
513    fn logical_matrix_2x2_column_major_layout() {
514        // [true false; false true] → column-major data = [1, 0, 0, 1].
515        push_queued_response(Ok(InteractionResponse::Line(
516            "[true false; false true]".into(),
517        )));
518        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
519        match value {
520            Value::LogicalArray(la) => {
521                assert_eq!(la.shape, vec![2, 2]);
522                // column-major: col 0 first ([true, false]), then col 1 ([false, true])
523                assert_eq!(la.data, vec![1, 0, 0, 1]);
524            }
525            other => panic!("expected 2×2 LogicalArray, got {other:?}"),
526        }
527    }
528
529    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
530    #[test]
531    fn invalid_string_flag_errors_before_prompt() {
532        push_queued_response(Ok(InteractionResponse::Line("ignored".into())));
533        let prompt = Value::String("Ready?".to_string());
534        let bad_flag = Value::String("not-string-mode".to_string());
535        let err = futures::executor::block_on(input_builtin(vec![prompt, bad_flag])).unwrap_err();
536        assert_eq!(err.identifier(), Some("RunMat:input:InvalidStringFlag"));
537    }
538}