1use crate::lex::error::LexError;
50
51const MAX_RECURSION_DEPTH: usize = 100;
53
54const MAX_TENSOR_ELEMENTS: usize = 10_000_000;
56
57#[derive(Debug, Clone, PartialEq)]
83pub enum Tensor {
84 Scalar(f64),
86 Array(Vec<Tensor>),
88}
89
90impl std::fmt::Display for Tensor {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Tensor::Scalar(n) => {
97 if n.fract() == 0.0 && n.is_finite() {
98 if *n >= i64::MIN as f64 && *n <= i64::MAX as f64 {
99 write!(f, "{}", *n as i64)
100 } else {
101 write!(f, "{}", n)
102 }
103 } else {
104 write!(f, "{}", n)
105 }
106 }
107 Tensor::Array(items) => {
108 write!(f, "[")?;
109 for (i, item) in items.iter().enumerate() {
110 if i > 0 {
111 write!(f, ", ")?;
112 }
113 write!(f, "{}", item)?;
114 }
115 write!(f, "]")
116 }
117 }
118 }
119}
120
121impl Tensor {
122 pub fn is_integer(&self) -> bool {
138 match self {
139 Tensor::Scalar(n) => n.fract() == 0.0,
140 Tensor::Array(items) => items.iter().all(|t| t.is_integer()),
141 }
142 }
143
144 pub fn shape(&self) -> Vec<usize> {
161 match self {
162 Tensor::Scalar(_) => vec![],
163 Tensor::Array(items) => {
164 if items.is_empty() {
165 vec![0]
166 } else {
167 let mut shape = vec![items.len()];
168 shape.extend(items[0].shape());
169 shape
170 }
171 }
172 }
173 }
174
175 pub fn flatten(&self) -> Vec<f64> {
186 let capacity = self.count_elements();
187 let mut result = Vec::with_capacity(capacity);
188 self.flatten_into(&mut result);
189 result
190 }
191
192 fn count_elements(&self) -> usize {
194 match self {
195 Tensor::Scalar(_) => 1,
196 Tensor::Array(items) => items.iter().map(|t| t.count_elements()).sum(),
197 }
198 }
199
200 fn flatten_into(&self, result: &mut Vec<f64>) {
202 match self {
203 Tensor::Scalar(n) => result.push(*n),
204 Tensor::Array(items) => {
205 for item in items {
206 item.flatten_into(result);
207 }
208 }
209 }
210 }
211
212 #[inline]
214 pub fn is_scalar(&self) -> bool {
215 matches!(self, Tensor::Scalar(_))
216 }
217
218 #[inline]
220 pub fn is_array(&self) -> bool {
221 matches!(self, Tensor::Array(_))
222 }
223
224 #[inline]
226 pub fn ndim(&self) -> usize {
227 self.shape().len()
228 }
229
230 #[inline]
232 pub fn len(&self) -> usize {
233 self.count_elements()
234 }
235
236 #[inline]
238 pub fn is_empty(&self) -> bool {
239 match self {
240 Tensor::Scalar(_) => false,
241 Tensor::Array(items) => items.is_empty(),
242 }
243 }
244}
245
246#[inline]
261pub fn is_tensor_literal(s: &str) -> bool {
262 let s = s.trim();
263 let bytes = s.as_bytes();
264 if bytes.first() != Some(&b'[') || bytes.last() != Some(&b']') {
265 return false;
266 }
267
268 let mut depth: i32 = 0;
269 for &b in bytes {
270 match b {
271 b'[' => depth += 1,
272 b']' => {
273 depth -= 1;
274 if depth < 0 {
275 return false;
276 }
277 }
278 b'0'..=b'9' | b'.' | b'-' | b',' | b' ' | b'\t' => {}
279 _ => return false,
280 }
281 }
282
283 depth == 0
284}
285
286pub fn parse_tensor(s: &str) -> Result<Tensor, LexError> {
315 let s = s.trim();
316 if !s.starts_with('[') {
317 return Err(LexError::UnexpectedChar(s.chars().next().unwrap_or(' ')));
318 }
319
320 let (tensor, remaining) = parse_tensor_inner(s, 0)?;
321
322 if !remaining.trim().is_empty() {
323 return Err(LexError::UnexpectedChar(
324 remaining.trim().chars().next().unwrap_or('?'),
325 ));
326 }
327
328 let element_count = tensor.count_elements();
330 if element_count > MAX_TENSOR_ELEMENTS {
331 return Err(LexError::InvalidStructure(format!(
332 "Tensor element count exceeds maximum: {} (max: {})",
333 element_count, MAX_TENSOR_ELEMENTS
334 )));
335 }
336
337 Ok(tensor)
338}
339
340fn estimate_array_size(s: &str) -> usize {
342 let mut depth = 0;
343 let mut comma_count = 0;
344
345 for ch in s.chars() {
346 match ch {
347 '[' => depth += 1,
348 ']' => {
349 if depth == 0 {
350 break;
351 }
352 depth -= 1;
353 }
354 ',' if depth == 0 => comma_count += 1,
355 _ => {}
356 }
357 }
358
359 if comma_count > 0 {
360 comma_count + 1
361 } else {
362 2
363 }
364}
365
366fn parse_tensor_inner(s: &str, depth: usize) -> Result<(Tensor, &str), LexError> {
367 if depth > MAX_RECURSION_DEPTH {
368 return Err(LexError::InvalidStructure(format!(
369 "Recursion depth exceeded (max: {})",
370 MAX_RECURSION_DEPTH
371 )));
372 }
373 let s = s.trim();
374
375 if let Some(remaining_str) = s.strip_prefix('[') {
376 let mut remaining = remaining_str;
377 let estimated_capacity = estimate_array_size(remaining_str);
378 let mut items = Vec::with_capacity(estimated_capacity);
379
380 loop {
381 remaining = remaining.trim_start();
382
383 if remaining.is_empty() {
384 return Err(LexError::UnbalancedBrackets);
385 }
386
387 if remaining.starts_with(']') {
388 remaining = &remaining[1..];
389 break;
390 }
391
392 if !items.is_empty() {
393 if !remaining.starts_with(',') {
394 return Err(LexError::UnexpectedChar(
395 remaining.chars().next().unwrap_or('?'),
396 ));
397 }
398 remaining = remaining[1..].trim_start();
399 }
400
401 if remaining.starts_with(']') {
402 remaining = &remaining[1..];
403 break;
404 }
405
406 if remaining.starts_with('[') {
407 let (tensor, rest) = parse_tensor_inner(remaining, depth + 1)?;
408 items.push(tensor);
409 remaining = rest;
410 } else {
411 let (num, rest) = parse_number(remaining)?;
412 items.push(Tensor::Scalar(num));
413 remaining = rest;
414 }
415 }
416
417 if items.is_empty() {
418 return Err(LexError::EmptyTensor);
419 }
420
421 if items.len() > 1 {
423 let first_shape = items[0].shape();
424 for item in &items[1..] {
425 if item.shape() != first_shape {
426 return Err(LexError::InconsistentDimensions);
427 }
428 }
429 }
430
431 Ok((Tensor::Array(items), remaining))
432 } else {
433 let (num, rest) = parse_number(s)?;
434 Ok((Tensor::Scalar(num), rest))
435 }
436}
437
438fn parse_number(s: &str) -> Result<(f64, &str), LexError> {
439 let s = s.trim_start();
440 let bytes = s.as_bytes();
441
442 let mut end = 0;
443 let mut has_dot = false;
444
445 if bytes.first() == Some(&b'-') {
446 end = 1;
447 }
448
449 while end < bytes.len() {
450 match bytes[end] {
451 b'0'..=b'9' => end += 1,
452 b'.' if !has_dot => {
453 has_dot = true;
454 end += 1;
455 }
456 _ => break,
457 }
458 }
459
460 if end == 0 || (end == 1 && bytes[0] == b'-') {
461 return Err(LexError::InvalidNumber(
462 s.chars().take(10).collect::<String>(),
463 ));
464 }
465
466 let num_str = &s[..end];
467 let num: f64 = num_str.parse().map_err(|_| {
468 let context = if num_str.len() > 80 {
469 format!("{}...", &num_str[..80])
470 } else {
471 num_str.to_string()
472 };
473 LexError::InvalidNumber(context)
474 })?;
475
476 if !num.is_finite() {
477 return Err(LexError::InvalidNumber(format!(
478 "{} (non-finite values not allowed)",
479 num_str
480 )));
481 }
482
483 Ok((num, &s[end..]))
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 #[test]
493 fn test_is_tensor_literal_valid() {
494 assert!(is_tensor_literal("[1, 2, 3]"));
495 assert!(is_tensor_literal("[[1, 2], [3, 4]]"));
496 assert!(is_tensor_literal("[1.5, 2.5]"));
497 assert!(is_tensor_literal("[-1, -2]"));
498 assert!(is_tensor_literal(" [1, 2, 3] "));
499 }
500
501 #[test]
502 fn test_is_tensor_literal_invalid() {
503 assert!(!is_tensor_literal("hello"));
504 assert!(!is_tensor_literal("@reference"));
505 assert!(!is_tensor_literal("123"));
506 assert!(!is_tensor_literal(""));
507 assert!(!is_tensor_literal("[1, 2"));
508 assert!(!is_tensor_literal("[a, b]"));
509 }
510
511 #[test]
514 fn test_parse_1d() {
515 let t = parse_tensor("[1, 2, 3]").unwrap();
516 assert_eq!(t.shape(), vec![3]);
517 assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
518 }
519
520 #[test]
521 fn test_parse_2d() {
522 let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
523 assert_eq!(t.shape(), vec![2, 2]);
524 assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0, 4.0]);
525 }
526
527 #[test]
528 fn test_parse_floats() {
529 let t = parse_tensor("[1.5, 2.5, 3.5]").unwrap();
530 assert_eq!(t.flatten(), vec![1.5, 2.5, 3.5]);
531 assert!(!t.is_integer());
532 }
533
534 #[test]
535 fn test_parse_negatives() {
536 let t = parse_tensor("[-1, -2, -3]").unwrap();
537 assert_eq!(t.flatten(), vec![-1.0, -2.0, -3.0]);
538 }
539
540 #[test]
541 fn test_parse_trailing_comma() {
542 let t = parse_tensor("[1, 2, 3,]").unwrap();
543 assert_eq!(t.flatten(), vec![1.0, 2.0, 3.0]);
544 }
545
546 #[test]
549 fn test_empty_tensor_error() {
550 assert!(matches!(parse_tensor("[]"), Err(LexError::EmptyTensor)));
551 }
552
553 #[test]
554 fn test_unbalanced_brackets_error() {
555 assert!(matches!(
556 parse_tensor("[1, 2"),
557 Err(LexError::UnbalancedBrackets)
558 ));
559 }
560
561 #[test]
562 fn test_inconsistent_dimensions_error() {
563 assert!(matches!(
564 parse_tensor("[[1, 2], [3]]"),
565 Err(LexError::InconsistentDimensions)
566 ));
567 }
568
569 #[test]
570 fn test_invalid_number_error() {
571 assert!(matches!(
572 parse_tensor("[abc]"),
573 Err(LexError::InvalidNumber(_))
574 ));
575 }
576
577 #[test]
580 fn test_tensor_display() {
581 let t = parse_tensor("[1, 2, 3]").unwrap();
582 assert_eq!(format!("{}", t), "[1, 2, 3]");
583
584 let t = parse_tensor("[[1, 2], [3, 4]]").unwrap();
585 assert_eq!(format!("{}", t), "[[1, 2], [3, 4]]");
586 }
587
588 #[test]
589 fn test_tensor_methods() {
590 let scalar = Tensor::Scalar(42.0);
591 assert!(scalar.is_scalar());
592 assert!(!scalar.is_array());
593 assert_eq!(scalar.ndim(), 0);
594 assert_eq!(scalar.len(), 1);
595 assert!(!scalar.is_empty());
596
597 let array = parse_tensor("[1, 2, 3]").unwrap();
598 assert!(!array.is_scalar());
599 assert!(array.is_array());
600 assert_eq!(array.ndim(), 1);
601 assert_eq!(array.len(), 3);
602 assert!(!array.is_empty());
603 }
604
605 #[test]
606 fn test_tensor_equality() {
607 let t1 = parse_tensor("[1, 2, 3]").unwrap();
608 let t2 = parse_tensor("[1, 2, 3]").unwrap();
609 assert_eq!(t1, t2);
610
611 let t3 = parse_tensor("[1, 2, 4]").unwrap();
612 assert_ne!(t1, t3);
613 }
614
615 #[test]
616 fn test_tensor_clone() {
617 let t1 = parse_tensor("[[1, 2], [3, 4]]").unwrap();
618 let t2 = t1.clone();
619 assert_eq!(t1, t2);
620 }
621
622 #[test]
625 fn test_display_roundtrip() {
626 let original = parse_tensor("[1, 2, 3]").unwrap();
627 let serialized = original.to_string();
628 let parsed = parse_tensor(&serialized).unwrap();
629 assert_eq!(original, parsed);
630
631 let original = parse_tensor("[[1.5, 2.5], [3.5, 4.5]]").unwrap();
632 let serialized = original.to_string();
633 let parsed = parse_tensor(&serialized).unwrap();
634 assert_eq!(original, parsed);
635 }
636}