use std::{
collections::{HashMap, HashSet},
fmt::Display,
path::PathBuf,
};
use itertools::Itertools;
use path_slash::PathBufExt as _;
use serde::{
de::{self, Visitor},
Deserialize, Deserializer,
};
pub(crate) struct LuaValueSeed;
impl<'de> Visitor<'de> for LuaValueSeed {
type Value = serde_value::Value;
fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.write_str("any Lua value")
}
fn visit_bool<E: de::Error>(self, v: bool) -> Result<Self::Value, E> {
Ok(serde_value::Value::Bool(v))
}
fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
Ok(serde_value::Value::I64(v))
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
Ok(serde_value::Value::U64(v))
}
fn visit_f64<E: de::Error>(self, v: f64) -> Result<Self::Value, E> {
Ok(serde_value::Value::F64(v))
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
Ok(serde_value::Value::String(v.to_string()))
}
fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
Ok(serde_value::Value::String(v))
}
fn visit_bytes<E: de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
let s = std::str::from_utf8(v).map_err(de::Error::custom)?;
Ok(serde_value::Value::String(s.to_string()))
}
fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
self.visit_bytes(&v)
}
fn visit_none<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(serde_value::Value::Unit)
}
fn visit_some<D2: Deserializer<'de>>(self, d: D2) -> Result<Self::Value, D2::Error> {
d.deserialize_any(LuaValueSeed)
}
fn visit_unit<E: de::Error>(self) -> Result<Self::Value, E> {
Ok(serde_value::Value::Unit)
}
fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut arr = Vec::new();
while let Some(v) = seq.next_element_seed(LuaValueSeed)? {
arr.push(v);
}
Ok(serde_value::Value::Seq(arr))
}
fn visit_map<A: de::MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
let mut obj = std::collections::BTreeMap::new();
while let Some(key) = map.next_key_seed(LuaValueSeed)? {
let val = map.next_value_seed(LuaValueSeed)?;
obj.insert(key, val);
}
Ok(serde_value::Value::Map(obj))
}
}
impl<'de> de::DeserializeSeed<'de> for LuaValueSeed {
type Value = serde_value::Value;
fn deserialize<D: Deserializer<'de>>(self, d: D) -> Result<Self::Value, D::Error> {
d.deserialize_any(self)
}
}
pub(crate) fn normalize_lua_value(value: serde_value::Value) -> serde_value::Value {
match value {
serde_value::Value::Bytes(bytes) => match String::from_utf8(bytes.clone()) {
Ok(s) => serde_value::Value::String(s),
Err(_) => serde_value::Value::Bytes(bytes),
},
serde_value::Value::Map(map)
if map
.keys()
.all(|k| matches!(k, serde_value::Value::I64(_) | serde_value::Value::U64(_))) =>
{
let seq = map
.iter()
.sorted_by_key(|(k, _)| match k {
serde_value::Value::I64(i) => *i,
serde_value::Value::U64(u) => *u as i64,
_ => unreachable!(),
})
.map(|(_, v)| normalize_lua_value(v.clone()))
.collect();
serde_value::Value::Seq(seq)
}
serde_value::Value::Map(map) => serde_value::Value::Map(
map.into_iter()
.map(|(k, v)| (normalize_lua_value(k), normalize_lua_value(v)))
.collect(),
),
serde_value::Value::Seq(seq) => {
serde_value::Value::Seq(seq.into_iter().map(normalize_lua_value).collect())
}
other => other,
}
}
#[derive(Hash, Debug, Eq, PartialEq, Clone, Deserialize)]
#[serde(untagged)]
pub(crate) enum LuaTableKey {
IntKey(u64),
StringKey(String),
}
pub(crate) fn deserialize_vec_from_lua_array_or_string<'de, D, T>(
deserializer: D,
) -> std::result::Result<Vec<T>, D::Error>
where
D: Deserializer<'de>,
T: From<String>,
T: Deserialize<'de>,
{
let value = normalize_lua_value(serde_value::Value::deserialize(deserializer)?);
if let serde_value::Value::String(str) = value {
Ok(vec![T::from(str)])
} else {
let value = normalize_lua_value(value);
value.clone().deserialize_into().map_err(|err| {
de::Error::custom(format!(
"expected a string or a list of strings, but got: {value:?} ({err})"
))
})
}
}
#[derive(Debug)]
struct StringAnalysis {
long_string_equal_signs: usize,
has_newline: bool,
has_nonprintable: bool,
}
impl From<&str> for StringAnalysis {
fn from(value: &str) -> Self {
let mut equal_signs = HashSet::new();
let mut has_newline = false;
let mut has_nonprintable = false;
let bytes = value.as_bytes();
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b']' => {
i += 1;
let bytes_after_brace = &bytes[i..];
let non_equal_sign = bytes_after_brace
.iter()
.copied()
.enumerate()
.find(|&(_, c)| c != b'=');
match non_equal_sign {
None => {
equal_signs.insert(bytes_after_brace.len());
break;
}
Some((index, b']')) => {
equal_signs.insert(index);
i += index;
}
Some((index, _)) => {
i += index;
}
}
}
b'\n' | b'\r' => {
has_newline = true;
let bytes_from_newline = &bytes[i..];
let (lf, cr) = bytes_from_newline
.iter()
.copied()
.take_while(|&c| c == b'\n' || c == b'\r')
.fold((0usize, 0usize), |(lf, cr), c| {
if c == b'\n' {
(lf + 1, cr)
} else {
(lf, cr + 1)
}
});
i += lf + cr;
if lf == 0 {
has_nonprintable = true;
}
}
b' '..=b'~' | b'\t' => {
i += 1;
}
_ => {
has_nonprintable = true;
i += 1;
}
}
}
#[allow(clippy::unwrap_used)]
let long_string_equal_signs = (0..).find(move |i| !equal_signs.contains(i)).unwrap();
Self {
long_string_equal_signs,
has_newline,
has_nonprintable,
}
}
}
pub(crate) enum DisplayLuaValue {
Boolean(bool),
String(String),
List(Vec<Self>),
Table(Vec<DisplayLuaKV>),
}
pub(crate) struct DisplayLuaKV {
pub(crate) key: String,
pub(crate) value: DisplayLuaValue,
}
impl Display for DisplayLuaValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use std::fmt::Write;
let mut buf = String::new();
match self {
DisplayLuaValue::Boolean(b) => write!(buf, "{b}")?,
DisplayLuaValue::String(s) => {
let analysis = StringAnalysis::from(s.as_str());
if analysis.has_newline && !analysis.has_nonprintable {
buf.push('[');
buf.extend(std::iter::repeat_n('=', analysis.long_string_equal_signs));
buf.push_str("[\n");
buf.push_str(s);
buf.push(']');
buf.extend(std::iter::repeat_n('=', analysis.long_string_equal_signs));
buf.push(']');
} else {
buf.push('"');
for c in s.bytes() {
match c {
b'"' => buf.push_str("\\\""),
b'\x07' => buf.push_str("\\a"),
b'\x08' => buf.push_str("\\b"),
b'\x0B' => buf.push_str("\\v"),
b'\x0C' => buf.push_str("\\f"),
b'\n' => buf.push_str("\\n"),
b'\r' => buf.push_str("\\r"),
b'\t' => buf.push_str("\\t"),
b'\\' => buf.push_str("\\\\"),
b' '..=b'~' => {
buf.push(c as char);
}
_ => {
write!(buf, "\\{c:03}")?;
}
}
}
buf.push('"');
}
}
DisplayLuaValue::List(l) => {
writeln!(buf, "{{")?;
for item in l {
writeln!(buf, "{item},")?;
}
write!(buf, "}}")?;
}
DisplayLuaValue::Table(t) => {
writeln!(buf, "{{")?;
for item in t {
writeln!(buf, "{item},")?;
}
write!(buf, "}}")?;
}
};
let output = match stylua_lib::format_code(
&buf,
stylua_lib::Config::default(),
None,
stylua_lib::OutputVerification::Full,
) {
Ok(formatted_code) => formatted_code,
Err(_) => buf,
};
write!(f, "{output}")
}
}
impl Display for DisplayLuaKV {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if !self
.key
.chars()
.all(|c| c == '_' || c.is_ascii_alphanumeric())
{
write!(f, "['{}'] = {}", self.key, self.value)
} else {
write!(f, "{} = {}", self.key, self.value)
}
}
}
pub(crate) trait DisplayAsLuaKV {
fn display_lua(&self) -> DisplayLuaKV;
}
pub(crate) trait DisplayAsLuaValue {
fn display_lua_value(&self) -> DisplayLuaValue;
}
impl DisplayAsLuaValue for String {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::String(self.clone())
}
}
impl DisplayAsLuaValue for bool {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::Boolean(*self)
}
}
impl DisplayAsLuaValue for PathBuf {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::String(self.to_slash_lossy().into_owned())
}
}
impl DisplayAsLuaValue for Vec<String> {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::List(self.iter().cloned().map(DisplayLuaValue::String).collect())
}
}
impl DisplayAsLuaValue for Vec<PathBuf> {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::List(
self.iter()
.map(|p| DisplayLuaValue::String(p.to_slash_lossy().into_owned()))
.collect(),
)
}
}
impl DisplayAsLuaValue for HashMap<String, String> {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::Table(
self.iter()
.map(|(k, v)| DisplayLuaKV {
key: k.clone(),
value: DisplayLuaValue::String(v.clone()),
})
.collect_vec(),
)
}
}
impl DisplayAsLuaValue for HashMap<String, PathBuf> {
fn display_lua_value(&self) -> DisplayLuaValue {
DisplayLuaValue::Table(
self.iter()
.map(|(k, v)| DisplayLuaKV {
key: k.clone(),
value: DisplayLuaValue::String(v.to_slash_lossy().into_owned()),
})
.collect_vec(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display_lua_value() {
let value = DisplayLuaValue::String("hello".to_string());
assert_eq!(format!("{value}"), r#""hello""#);
let value = DisplayLuaValue::String("he\"llo".to_string());
assert_eq!(format!("{value}"), r#""he\"llo""#);
let value = DisplayLuaValue::String("q\"a'".to_string());
assert_eq!(format!("{value}"), r#""q\"a'""#);
let value = DisplayLuaValue::String(
"1\"2\x073\x084\x0B5\x0C6\n7\r8\t9'a\\b\u{FFFFF}c\0d\u{1}e".to_string(),
);
assert_eq!(
format!("{value}"),
r#""1\"2\a3\b4\v5\f6\n7\r8\t9'a\\b\243\191\191\191c\000d\001e""#
);
let value = DisplayLuaValue::String("\n".to_string());
assert_eq!(format!("{value}"), "[[\n\n]]");
let value = DisplayLuaValue::String("\n]".to_string());
assert_eq!(format!("{value}"), "[=[\n\n]]=]");
let value = DisplayLuaValue::String("first line\nsecond line".to_string());
assert_eq!(format!("{value}"), "[[\nfirst line\nsecond line]]");
let value = DisplayLuaValue::String("first line\nsecond line]".to_string());
assert_eq!(format!("{value}"), "[=[\nfirst line\nsecond line]]=]");
let value = DisplayLuaValue::String("first line\nsecond line]=]".to_string());
assert_eq!(format!("{value}"), "[==[\nfirst line\nsecond line]=]]==]");
let value = DisplayLuaValue::String("first line\nsecond line]\nthird line".to_string());
assert_eq!(
format!("{value}"),
"[[\nfirst line\nsecond line]\nthird line]]"
);
let value = DisplayLuaValue::String("first line\nsecond line]]\nthird line".to_string());
assert_eq!(
format!("{value}"),
"[=[\nfirst line\nsecond line]]\nthird line]=]"
);
let value = DisplayLuaValue::String("first line\nsecond line]=]\nthird line".to_string());
assert_eq!(
format!("{value}"),
"[[\nfirst line\nsecond line]=]\nthird line]]"
);
let value = DisplayLuaValue::String("first line\nsecond line]=]]\nthird line".to_string());
assert_eq!(
format!("{value}"),
"[==[\nfirst line\nsecond line]=]]\nthird line]==]"
);
let value = DisplayLuaValue::String("first line\nsecond line]]=]\nthird line".to_string());
assert_eq!(
format!("{value}"),
"[==[\nfirst line\nsecond line]]=]\nthird line]==]"
);
let value = DisplayLuaValue::String("\tfirst line\n\tsecond line".to_string());
assert_eq!(format!("{value}"), "[[\n\tfirst line\n\tsecond line]]");
let value = DisplayLuaValue::String("\tfirst line\r\n\tsecond line".to_string());
assert_eq!(format!("{value}"), "[[\n\tfirst line\r\n\tsecond line]]");
let value = DisplayLuaValue::String("\tfirst line\r\tsecond line".to_string());
assert_eq!(format!("{value}"), r#""\tfirst line\r\tsecond line""#);
let value = DisplayLuaValue::Boolean(true);
assert_eq!(format!("{value}"), "true");
let value = DisplayLuaValue::List(vec![
DisplayLuaValue::String("hello".to_string()),
DisplayLuaValue::Boolean(true),
]);
assert_eq!(format!("{value}"), "{\n\"hello\",\ntrue,\n}");
let value = DisplayLuaValue::Table(vec![
DisplayLuaKV {
key: "key".to_string(),
value: DisplayLuaValue::String("value".to_string()),
},
DisplayLuaKV {
key: "key2".to_string(),
value: DisplayLuaValue::Boolean(true),
},
DisplayLuaKV {
key: "key3.key4".to_string(),
value: DisplayLuaValue::Boolean(true),
},
]);
assert_eq!(
format!("{value}"),
"{\nkey = \"value\",\nkey2 = true,\n['key3.key4'] = true,\n}"
);
}
}