use super::item::FunctionFlags;
use crate::{PResult, Parser};
use solar_ast::{token::*, *};
use solar_interface::kw;
use std::{fmt, ops::RangeInclusive};
impl<'sess, 'ast> Parser<'sess, 'ast> {
#[instrument(level = "trace", skip_all)]
pub fn parse_type(&mut self) -> PResult<'sess, Type<'ast>> {
let mut ty = self
.parse_spanned(Self::parse_basic_ty_kind)
.map(|(span, kind)| Type { span, kind })?;
while self.eat(TokenKind::OpenDelim(Delimiter::Bracket)) {
let size = if self.check_noexpect(TokenKind::CloseDelim(Delimiter::Bracket)) {
None
} else {
Some(self.parse_expr()?)
};
self.expect(TokenKind::CloseDelim(Delimiter::Bracket))?;
ty = Type {
span: ty.span.to(self.prev_token.span),
kind: TypeKind::Array(self.alloc(TypeArray { element: ty, size })),
};
}
Ok(ty)
}
fn parse_basic_ty_kind(&mut self) -> PResult<'sess, TypeKind<'ast>> {
if self.check_elementary_type() {
self.parse_elementary_type().map(TypeKind::Elementary)
} else if self.eat_keyword(kw::Function) {
self.parse_function_header(FunctionFlags::FUNCTION_TY).map(|f| {
let FunctionHeader {
span: _,
name: _,
parameters,
visibility,
state_mutability,
modifiers: _,
virtual_: _,
override_: _,
returns,
} = f;
TypeKind::Function(self.alloc(TypeFunction {
parameters,
visibility,
state_mutability,
returns,
}))
})
} else if self.eat_keyword(kw::Mapping) {
self.parse_mapping_type().map(|x| TypeKind::Mapping(self.alloc(x)))
} else if self.check_path() {
self.parse_path().map(TypeKind::Custom)
} else {
self.unexpected()
}
}
pub(super) fn parse_elementary_type(&mut self) -> PResult<'sess, ElementaryType> {
let id = self.parse_ident_any()?;
debug_assert!(id.is_elementary_type());
let mut ty = match id.name {
kw::Address => ElementaryType::Address(false),
kw::Bool => ElementaryType::Bool,
kw::String => ElementaryType::String,
kw::Bytes => ElementaryType::Bytes,
kw::Fixed => ElementaryType::Fixed(TypeSize::ZERO, TypeFixedSize::ZERO),
kw::UFixed => ElementaryType::UFixed(TypeSize::ZERO, TypeFixedSize::ZERO),
kw::Int => ElementaryType::Int(TypeSize::ZERO),
kw::UInt => ElementaryType::UInt(TypeSize::ZERO),
s if s >= kw::UInt8 && s <= kw::UInt256 => {
let bytes = s.as_u32() - kw::UInt8.as_u32() + 1;
ElementaryType::UInt(TypeSize::new(bytes as u8).unwrap())
}
s if s >= kw::Int8 && s <= kw::Int256 => {
let bytes = s.as_u32() - kw::Int8.as_u32() + 1;
ElementaryType::Int(TypeSize::new(bytes as u8).unwrap())
}
s if s >= kw::Bytes1 && s <= kw::Bytes32 => {
let bytes = s.as_u32() - kw::Bytes1.as_u32() + 1;
ElementaryType::FixedBytes(TypeSize::new(bytes as u8).unwrap())
}
s => unreachable!("unexpected elementary type: {s}"),
};
let sm = self.parse_state_mutability();
match (&mut ty, sm) {
(ElementaryType::Address(p), Some(StateMutability::Payable)) => *p = true,
(_, None) => {}
(_, Some(_)) => {
let msg = if matches!(ty, ElementaryType::Address(_)) {
"address types can only be payable or non-payable"
} else {
"only address types can have state mutability"
};
self.dcx().err(msg).span(id.span.to(self.prev_token.span)).emit();
}
}
Ok(ty)
}
fn parse_mapping_type(&mut self) -> PResult<'sess, TypeMapping<'ast>> {
self.expect(TokenKind::OpenDelim(Delimiter::Parenthesis))?;
let key = self.parse_type()?;
if !key.is_elementary() && !key.is_custom() {
let msg =
"only elementary types or used-defined types can be used as key types in mappings";
self.dcx().err(msg).span(key.span).emit();
}
let key_name = self.parse_ident_opt()?;
self.expect(TokenKind::FatArrow)?;
let value = self.parse_type()?;
let value_name = self.parse_ident_opt()?;
self.expect(TokenKind::CloseDelim(Delimiter::Parenthesis))?;
Ok(TypeMapping { key, key_name, value, value_name })
}
}
#[derive(Debug, PartialEq)]
enum ParseTySizeError {
Parse(std::num::ParseIntError),
TryFrom(std::num::TryFromIntError),
NotMultipleOf8,
OutOfRange(RangeInclusive<u16>),
FixedX,
}
impl fmt::Display for ParseTySizeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Parse(e) => e.fmt(f),
Self::TryFrom(e) => e.fmt(f),
Self::NotMultipleOf8 => f.write_str("number must be a multiple of 8"),
Self::OutOfRange(range) => {
write!(f, "size is out of range of {}:{} (inclusive)", range.start(), range.end())
}
Self::FixedX => f.write_str("`fixed` sizes must be separated by exactly one 'x'"),
}
}
}
#[allow(dead_code)]
fn parse_fixed_type(original: &str) -> Result<Option<ElementaryType>, ParseTySizeError> {
let s = original;
let tmp = s.strip_prefix('u');
let unsigned = tmp.is_some();
let s = tmp.unwrap_or(s);
if let Some(s) = s.strip_prefix("fixed") {
debug_assert!(!s.is_empty());
let (m, n) = parse_fixed_size(s)?;
return Ok(Some(if unsigned {
ElementaryType::UFixed(m, n)
} else {
ElementaryType::Fixed(m, n)
}));
}
Ok(None)
}
#[allow(dead_code)]
fn parse_fb_size(s: &str) -> Result<TypeSize, ParseTySizeError> {
parse_ty_size_u8(s, 1..=32, false).map(|x| TypeSize::new(x).unwrap())
}
#[allow(dead_code)]
fn parse_int_size(s: &str) -> Result<TypeSize, ParseTySizeError> {
parse_ty_size_u8(s, 1..=32, true).map(|x| TypeSize::new(x).unwrap())
}
#[allow(dead_code)]
fn parse_fixed_size(s: &str) -> Result<(TypeSize, TypeFixedSize), ParseTySizeError> {
let (m, n) = s.split_once('x').ok_or(ParseTySizeError::FixedX)?;
let m = parse_int_size(m)?;
let n = parse_ty_size_u8(n, 0..=80, false)?;
let n = TypeFixedSize::new(n).unwrap();
Ok((m, n))
}
fn parse_ty_size_u8(
s: &str,
real_range: RangeInclusive<u8>,
to_bytes: bool,
) -> Result<u8, ParseTySizeError> {
let mut n = s.parse::<u16>().map_err(ParseTySizeError::Parse)?;
if to_bytes {
if !n.is_multiple_of(8) {
return Err(ParseTySizeError::NotMultipleOf8);
}
n /= 8;
}
let n = u8::try_from(n).map_err(ParseTySizeError::TryFrom)?;
if !real_range.contains(&n) {
let display_range = if to_bytes {
*real_range.start() as u16 * 8..=*real_range.end() as u16 * 8
} else {
*real_range.start() as u16..=*real_range.end() as u16
};
return Err(ParseTySizeError::OutOfRange(display_range));
}
Ok(n)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_size() {
use ParseTySizeError::*;
assert_eq!(parse_ty_size_u8("0", 0..=1, false), Ok(0));
assert_eq!(parse_ty_size_u8("1", 0..=1, false), Ok(1));
assert_eq!(parse_ty_size_u8("0", 0..=1, true), Ok(0));
assert_eq!(parse_ty_size_u8("1", 0..=1, true), Err(NotMultipleOf8));
assert_eq!(parse_ty_size_u8("8", 0..=1, true), Ok(1));
assert_eq!(parse_ty_size_u8("0", 1..=32, false), Err(OutOfRange(1..=32)));
assert_eq!(parse_ty_size_u8("0", 1..=32, true), Err(OutOfRange(8..=256)));
for n in 1..=32 {
assert_eq!(parse_ty_size_u8(&n.to_string(), 1..=32, false), Ok(n as u8));
for m in 1..=7u16 {
assert_eq!(
parse_ty_size_u8(&((n - 1) * 8 + m).to_string(), 1..=32, true),
Err(NotMultipleOf8)
);
}
assert_eq!(parse_ty_size_u8(&(n * 8).to_string(), 1..=32, true), Ok(n as u8));
}
}
}