runmat_runtime/builtins/io/
input.rs1use runmat_builtins::{CharArray, Tensor, Value};
4use runmat_macros::runtime_builtin;
5
6use crate::builtins::common::spec::{
7 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
8 ReductionNaN, ResidencyPolicy, ShapeRequirements,
9};
10use crate::interaction;
11use crate::{
12 build_runtime_error, call_builtin_async, gather_if_needed_async, BuiltinResult, RuntimeError,
13};
14
15const DEFAULT_PROMPT: &str = "Input: ";
16
17fn input_error(identifier: &str, message: impl Into<String>) -> RuntimeError {
18 build_runtime_error(message)
19 .with_identifier(identifier.to_string())
20 .with_builtin("input")
21 .build()
22}
23
24#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::io::input")]
25pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
26 name: "input",
27 op_kind: GpuOpKind::Custom("interaction"),
28 supported_precisions: &[],
29 broadcast: BroadcastSemantics::None,
30 provider_hooks: &[],
31 constant_strategy: ConstantStrategy::InlineLiteral,
32 residency: ResidencyPolicy::GatherImmediately,
33 nan_mode: ReductionNaN::Include,
34 two_pass_threshold: None,
35 workgroup_size: None,
36 accepts_nan_mode: false,
37 notes: "Prompts execute on the host. Input text is always delivered via the host handler; GPU tensors are only gathered when used as prompt strings.",
38};
39
40#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::io::input")]
41pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
42 name: "input",
43 shape: ShapeRequirements::Any,
44 constant_strategy: ConstantStrategy::InlineLiteral,
45 elementwise: None,
46 reduction: None,
47 emits_nan: false,
48 notes: "Side-effecting builtin; excluded from fusion plans.",
49};
50
51#[runtime_builtin(
52 name = "input",
53 type_resolver(crate::builtins::io::type_resolvers::input_type),
54 builtin_path = "crate::builtins::io::input"
55)]
56async fn input_builtin(args: Vec<Value>) -> BuiltinResult<Value> {
57 if args.len() > 2 {
58 return Err(input_error(
59 "RunMat:input:TooManyInputs",
60 "input: too many inputs",
61 ));
62 }
63
64 let mut prompt_index = if args.is_empty() { None } else { Some(0usize) };
65 let mut parsed_flag: Option<bool> = None;
66
67 if let Some(idx) = if args.len() == 2 { Some(1usize) } else { None } {
68 match parse_string_flag(&args[idx]).await {
69 Ok(flag) => parsed_flag = Some(flag),
70 Err(original_err) => {
71 if let Some(prompt_idx) = prompt_index {
72 match parse_string_flag(&args[prompt_idx]).await {
73 Ok(swapped_flag) => {
74 parsed_flag = Some(swapped_flag);
75 prompt_index = Some(idx);
76 }
77 Err(_) => {
78 return Err(original_err);
79 }
80 }
81 } else {
82 return Err(original_err);
83 }
84 }
85 }
86 }
87
88 let prompt = if let Some(idx) = prompt_index {
89 parse_prompt(&args[idx]).await?
90 } else {
91 DEFAULT_PROMPT.to_string()
92 };
93 let return_string = parsed_flag.unwrap_or(false);
94 let line = interaction::request_line_async(&prompt, true)
95 .await
96 .map_err(|err| {
97 let message = err.message().to_string();
98 build_runtime_error(format!("input: {message}"))
99 .with_identifier("RunMat:input:InteractionFailed")
100 .with_source(err)
101 .with_builtin("input")
102 .build()
103 })?;
104 if return_string {
105 return Ok(Value::CharArray(CharArray::new_row(&line)));
106 }
107 parse_numeric_response(&line).await
108}
109
110async fn parse_prompt(value: &Value) -> Result<String, RuntimeError> {
111 let gathered = gather_if_needed_async(value).await?;
112 match gathered {
113 Value::CharArray(ca) => {
114 if ca.rows != 1 {
115 Err(input_error(
116 "RunMat:input:PromptMustBeRowVector",
117 "input: prompt must be a row vector",
118 ))
119 } else {
120 Ok(ca.data.iter().collect())
121 }
122 }
123 Value::String(text) => Ok(text),
124 Value::StringArray(sa) => {
125 if sa.data.len() == 1 {
126 Ok(sa.data[0].clone())
127 } else {
128 Err(input_error(
129 "RunMat:input:PromptMustBeScalarString",
130 "input: prompt must be a scalar string",
131 ))
132 }
133 }
134 other => Err(input_error(
135 "RunMat:input:InvalidPromptType",
136 format!("input: invalid prompt type ({other:?})"),
137 )),
138 }
139}
140
141async fn parse_string_flag(value: &Value) -> Result<bool, RuntimeError> {
142 let gathered = gather_if_needed_async(value).await?;
143 let text = match gathered {
144 Value::CharArray(ca) if ca.rows == 1 => ca.data.iter().collect::<String>(),
145 Value::String(s) => s,
146 Value::StringArray(sa) if sa.data.len() == 1 => sa.data[0].clone(),
147 other => {
148 return Err(input_error(
149 "RunMat:input:InvalidStringFlag",
150 format!("input: invalid string flag ({other:?})"),
151 ))
152 }
153 };
154 let trimmed = text.trim();
155 if trimmed.eq_ignore_ascii_case("s") {
156 Ok(true)
157 } else {
158 Err(input_error(
159 "RunMat:input:InvalidStringFlag",
160 format!("input: invalid string flag ({trimmed})"),
161 ))
162 }
163}
164
165async fn parse_numeric_response(line: &str) -> Result<Value, RuntimeError> {
166 let trimmed = line.trim();
167 if trimmed.is_empty() || trimmed == "[]" {
168 return Ok(Value::Tensor(Tensor::zeros(vec![0, 0])));
169 }
170 call_builtin_async("str2double", &[Value::String(trimmed.to_string())])
171 .await
172 .map_err(|err| {
173 let message = err.message().to_string();
174 build_runtime_error(format!("input: invalid numeric expression ({message})"))
175 .with_identifier("RunMat:input:InvalidNumericExpression")
176 .with_source(err)
177 .with_builtin("input")
178 .build()
179 })
180}
181
182#[cfg(test)]
183pub(crate) mod tests {
184 use super::*;
185 use crate::interaction::{push_queued_response, InteractionResponse};
186
187 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
188 #[test]
189 fn numeric_input_parses_scalar() {
190 push_queued_response(Ok(InteractionResponse::Line("41".into())));
191 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
192 assert_eq!(value, Value::Num(41.0));
193 }
194
195 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
196 #[test]
197 fn string_mode_returns_char_row() {
198 push_queued_response(Ok(InteractionResponse::Line("RunMat".into())));
199 let prompt = Value::CharArray(CharArray::new_row("Name: "));
200 let mode = Value::String("s".to_string());
201 let value = futures::executor::block_on(input_builtin(vec![prompt, mode])).expect("input");
202 assert_eq!(value, Value::CharArray(CharArray::new_row("RunMat")));
203 }
204
205 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
206 #[test]
207 fn empty_response_returns_empty_tensor() {
208 push_queued_response(Ok(InteractionResponse::Line(" ".into())));
209 let value = futures::executor::block_on(input_builtin(vec![])).expect("input");
210 match value {
211 Value::Tensor(t) => assert!(t.data.is_empty()),
212 other => panic!("expected empty tensor, got {other:?}"),
213 }
214 }
215
216 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
217 #[test]
218 fn invalid_string_flag_errors_before_prompt() {
219 push_queued_response(Ok(InteractionResponse::Line("ignored".into())));
220 let prompt = Value::String("Ready?".to_string());
221 let bad_flag = Value::String("not-string-mode".to_string());
222 let err = futures::executor::block_on(input_builtin(vec![prompt, bad_flag])).unwrap_err();
223 assert_eq!(err.identifier(), Some("RunMat:input:InvalidStringFlag"));
224 }
225}