Skip to main content

runmat_runtime/builtins/io/
input.rs

1//! MATLAB-compatible `input` builtin for line-oriented console interaction.
2
3use runmat_builtins::{CharArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8    ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::interaction;
11use crate::{
12    build_runtime_error, call_builtin_async, gather_if_needed_async, BuiltinResult, RuntimeError,
13};
14
15const DEFAULT_PROMPT: &str = "Input: ";
16
17fn input_error(identifier: &str, message: impl Into<String>) -> RuntimeError {
18    build_runtime_error(message)
19        .with_identifier(identifier.to_string())
20        .with_builtin("input")
21        .build()
22}
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::input")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26    name: "input",
27    op_kind: GpuOpKind::Custom("interaction"),
28    supported_precisions: &[],
29    broadcast: BroadcastSemantics::None,
30    provider_hooks: &[],
31    constant_strategy: ConstantStrategy::InlineLiteral,
32    residency: ResidencyPolicy::GatherImmediately,
33    nan_mode: ReductionNaN::Include,
34    two_pass_threshold: None,
35    workgroup_size: None,
36    accepts_nan_mode: false,
37    notes: "Prompts execute on the host. Input text is always delivered via the host handler; GPU tensors are only gathered when used as prompt strings.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::input")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42    name: "input",
43    shape: ShapeRequirements::Any,
44    constant_strategy: ConstantStrategy::InlineLiteral,
45    elementwise: None,
46    reduction: None,
47    emits_nan: false,
48    notes: "Side-effecting builtin; excluded from fusion plans.",
49};
50
51#[runtime_builtin(
52    name = "input",
53    type_resolver(crate::builtins::io::type_resolvers::input_type),
54    builtin_path = "crate::builtins::io::input"
55)]
56async fn input_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
57    if args.len() > 2 {
58        return Err(input_error(
59            "RunMat:input:TooManyInputs",
60            "input: too many inputs",
61        ));
62    }
63
64    let mut prompt_index = if args.is_empty() { None } else { Some(0usize) };
65    let mut parsed_flag: Option<bool> = None;
66
67    if let Some(idx) = if args.len() == 2 { Some(1usize) } else { None } {
68        match parse_string_flag(&args[idx]).await {
69            Ok(flag) => parsed_flag = Some(flag),
70            Err(original_err) => {
71                if let Some(prompt_idx) = prompt_index {
72                    match parse_string_flag(&args[prompt_idx]).await {
73                        Ok(swapped_flag) => {
74                            parsed_flag = Some(swapped_flag);
75                            prompt_index = Some(idx);
76                        }
77                        Err(_) => {
78                            return Err(original_err);
79                        }
80                    }
81                } else {
82                    return Err(original_err);
83                }
84            }
85        }
86    }
87
88    let prompt = if let Some(idx) = prompt_index {
89        parse_prompt(&args[idx]).await?
90    } else {
91        DEFAULT_PROMPT.to_string()
92    };
93    let return_string = parsed_flag.unwrap_or(false);
94    let line = interaction::request_line_async(&prompt, true)
95        .await
96        .map_err(|err| {
97            let message = err.message().to_string();
98            build_runtime_error(format!("input: {message}"))
99                .with_identifier("RunMat:input:InteractionFailed")
100                .with_source(err)
101                .with_builtin("input")
102                .build()
103        })?;
104    if return_string {
105        return Ok(Value::CharArray(CharArray::new_row(&line)));
106    }
107    parse_numeric_response(&line).await
108}
109
110async fn parse_prompt(value: &Value) -> Result<String, RuntimeError> {
111    let gathered = gather_if_needed_async(value).await?;
112    match gathered {
113        Value::CharArray(ca) => {
114            if ca.rows != 1 {
115                Err(input_error(
116                    "RunMat:input:PromptMustBeRowVector",
117                    "input: prompt must be a row vector",
118                ))
119            } else {
120                Ok(ca.data.iter().collect())
121            }
122        }
123        Value::String(text) => Ok(text),
124        Value::StringArray(sa) => {
125            if sa.data.len() == 1 {
126                Ok(sa.data[0].clone())
127            } else {
128                Err(input_error(
129                    "RunMat:input:PromptMustBeScalarString",
130                    "input: prompt must be a scalar string",
131                ))
132            }
133        }
134        other => Err(input_error(
135            "RunMat:input:InvalidPromptType",
136            format!("input: invalid prompt type ({other:?})"),
137        )),
138    }
139}
140
141async fn parse_string_flag(value: &Value) -> Result<bool, RuntimeError> {
142    let gathered = gather_if_needed_async(value).await?;
143    let text = match gathered {
144        Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect::<String>(),
145        Value::String(s) => s,
146        Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
147        other => {
148            return Err(input_error(
149                "RunMat:input:InvalidStringFlag",
150                format!("input: invalid string flag ({other:?})"),
151            ))
152        }
153    };
154    let trimmed = text.trim();
155    if trimmed.eq_ignore_ascii_case("s") {
156        Ok(true)
157    } else {
158        Err(input_error(
159            "RunMat:input:InvalidStringFlag",
160            format!("input: invalid string flag ({trimmed})"),
161        ))
162    }
163}
164
165async fn parse_numeric_response(line: &str) -> Result<Value, RuntimeError> {
166    let trimmed = line.trim();
167    if trimmed.is_empty() || trimmed == "[]" {
168        return Ok(Value::Tensor(Tensor::zeros(vec![0, 0])));
169    }
170    call_builtin_async("str2double", &[Value::String(trimmed.to_string())])
171        .await
172        .map_err(|err| {
173            let message = err.message().to_string();
174            build_runtime_error(format!("input: invalid numeric expression ({message})"))
175                .with_identifier("RunMat:input:InvalidNumericExpression")
176                .with_source(err)
177                .with_builtin("input")
178                .build()
179        })
180}
181
182#[cfg(test)]
183pub(crate) mod tests {
184    use super::*;
185    use crate::interaction::{push_queued_response, InteractionResponse};
186
187    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
188    #[test]
189    fn numeric_input_parses_scalar() {
190        push_queued_response(Ok(InteractionResponse::Line("41".into())));
191        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
192        assert_eq!(value, Value::Num(41.0));
193    }
194
195    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
196    #[test]
197    fn string_mode_returns_char_row() {
198        push_queued_response(Ok(InteractionResponse::Line("RunMat".into())));
199        let prompt = Value::CharArray(CharArray::new_row("Name: "));
200        let mode = Value::String("s".to_string());
201        let value = futures::executor::block_on(input_builtin(vec![prompt, mode])).expect("input");
202        assert_eq!(value, Value::CharArray(CharArray::new_row("RunMat")));
203    }
204
205    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
206    #[test]
207    fn empty_response_returns_empty_tensor() {
208        push_queued_response(Ok(InteractionResponse::Line("   ".into())));
209        let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
210        match value {
211            Value::Tensor(t) => assert!(t.data.is_empty()),
212            other => panic!("expected empty tensor, got {other:?}"),
213        }
214    }
215
216    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
217    #[test]
218    fn invalid_string_flag_errors_before_prompt() {
219        push_queued_response(Ok(InteractionResponse::Line("ignored".into())));
220        let prompt = Value::String("Ready?".to_string());
221        let bad_flag = Value::String("not-string-mode".to_string());
222        let err = futures::executor::block_on(input_builtin(vec![prompt, bad_flag])).unwrap_err();
223        assert_eq!(err.identifier(), Some("RunMat:input:InvalidStringFlag"));
224    }
225}