runmat_runtime/builtins/common/
arg_tokens.rs1use runmat_builtins::{LiteralValue, ResolveContext, Value};
2
3#[derive(Clone, Debug, PartialEq)]
4pub enum ArgToken {
5 Number(f64),
6 Bool(bool),
7 String(String),
8 Vector(Vec<ArgToken>),
9 Unknown,
10}
11
12pub fn tokens_from_values(args: &[Value]) -> Vec<ArgToken> {
13 args.iter().map(token_from_value).collect()
14}
15
16pub fn tokens_from_context(ctx: &ResolveContext) -> Vec<ArgToken> {
17 ctx.literal_args.iter().map(token_from_literal).collect()
18}
19
20fn token_from_literal(value: &LiteralValue) -> ArgToken {
21 match value {
22 LiteralValue::Number(num) => ArgToken::Number(*num),
23 LiteralValue::Bool(value) => ArgToken::Bool(*value),
24 LiteralValue::String(text) => ArgToken::String(text.to_ascii_lowercase()),
25 LiteralValue::Vector(values) => {
26 ArgToken::Vector(values.iter().map(token_from_literal).collect())
27 }
28 LiteralValue::Unknown => ArgToken::Unknown,
29 }
30}
31
32fn token_from_value(value: &Value) -> ArgToken {
33 match value {
34 Value::Num(num) => ArgToken::Number(*num),
35 Value::Int(value) => ArgToken::Number(value.to_f64()),
36 Value::Bool(value) => ArgToken::Bool(*value),
37 Value::String(text) => ArgToken::String(text.to_ascii_lowercase()),
38 Value::StringArray(arr) if arr.data.len() == 1 => {
39 ArgToken::String(arr.data[0].to_ascii_lowercase())
40 }
41 Value::CharArray(arr) if arr.rows == 1 => {
42 let text: String = arr.data.iter().collect();
43 ArgToken::String(text.to_ascii_lowercase())
44 }
45 Value::Tensor(tensor) => token_from_tensor(&tensor.data, &tensor.shape),
46 Value::LogicalArray(arr) => token_from_logical(&arr.data, &arr.shape),
47 _ => ArgToken::Unknown,
48 }
49}
50
51fn token_from_tensor(data: &[f64], shape: &[usize]) -> ArgToken {
52 if data.len() == 1 {
53 return ArgToken::Number(data[0]);
54 }
55 if is_vector_shape(shape) {
56 return ArgToken::Vector(data.iter().copied().map(ArgToken::Number).collect());
57 }
58 ArgToken::Unknown
59}
60
61fn token_from_logical(data: &[u8], shape: &[usize]) -> ArgToken {
62 if data.len() == 1 {
63 return ArgToken::Bool(data[0] != 0);
64 }
65 if is_vector_shape(shape) {
66 return ArgToken::Vector(data.iter().map(|b| ArgToken::Bool(*b != 0)).collect());
67 }
68 ArgToken::Unknown
69}
70
71fn is_vector_shape(shape: &[usize]) -> bool {
72 if shape.is_empty() {
73 return false;
74 }
75 if shape.len() == 1 {
76 return true;
77 }
78 if shape.len() == 2 {
79 return shape[0] == 1 || shape[1] == 1;
80 }
81 false
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use runmat_builtins::{IntValue, LiteralValue, ResolveContext};
88
89 #[test]
90 fn tokens_from_context_lowercases_strings() {
91 let ctx = ResolveContext::new(vec![LiteralValue::String("OmItNaN".to_string())]);
92 assert_eq!(
93 tokens_from_context(&ctx),
94 vec![ArgToken::String("omitnan".to_string())]
95 );
96 }
97
98 #[test]
99 fn tokens_from_context_handles_vectors() {
100 let ctx = ResolveContext::new(vec![LiteralValue::Vector(vec![
101 LiteralValue::Number(1.0),
102 LiteralValue::Bool(true),
103 ])]);
104 assert_eq!(
105 tokens_from_context(&ctx),
106 vec![ArgToken::Vector(vec![
107 ArgToken::Number(1.0),
108 ArgToken::Bool(true)
109 ])]
110 );
111 }
112
113 #[test]
114 fn tokens_from_values_handles_scalar_inputs() {
115 let args = vec![
116 Value::Num(2.0),
117 Value::Int(IntValue::I32(3)),
118 Value::Bool(true),
119 Value::String("All".to_string()),
120 ];
121 assert_eq!(
122 tokens_from_values(&args),
123 vec![
124 ArgToken::Number(2.0),
125 ArgToken::Number(3.0),
126 ArgToken::Bool(true),
127 ArgToken::String("all".to_string()),
128 ]
129 );
130 }
131
132 #[test]
133 fn tokens_from_values_handles_vector_tensor() {
134 let tensor = runmat_builtins::Tensor::new(vec![1.0, 2.0], vec![1, 2]).unwrap();
135 let args = vec![Value::Tensor(tensor)];
136 assert_eq!(
137 tokens_from_values(&args),
138 vec![ArgToken::Vector(vec![
139 ArgToken::Number(1.0),
140 ArgToken::Number(2.0)
141 ])]
142 );
143 }
144}