use crate::{LuaResult, LuaValue, lua_vm::LuaState};
pub fn table_sort(l: &mut LuaState) -> LuaResult<usize> {
let table_val = l
.get_arg(1)
.ok_or_else(|| crate::stdlib::debug::argerror(l, 1, "table expected"))?;
let comp = l.get_arg(2);
if !table_val.is_table() {
return Err(crate::stdlib::debug::arg_typeerror(
l, 1, "table", &table_val,
));
}
let len = l.obj_len(&table_val)?;
if len >= i32::MAX as i64 {
return Err(l.error("bad argument #1 to 'sort' (array too big)".to_string()));
}
if len <= 1 {
return Ok(0);
}
let has_comp = comp.is_some() && !comp.as_ref().map(|v| v.is_nil()).unwrap_or(true);
let comp_func = if has_comp {
comp.unwrap()
} else {
LuaValue::nil()
};
let n = len as usize;
let has_meta = table_val
.as_table_mut()
.map(|t| t.has_metatable())
.unwrap_or(false);
let mut buf: Vec<LuaValue> = Vec::with_capacity(n);
if has_meta {
for i in 1..=n {
let val = l.table_geti(&table_val, i as i64)?;
buf.push(val);
}
} else {
let table = table_val.as_table_mut().unwrap();
for i in 1..=n {
let val = table.raw_geti(i as i64).unwrap_or(LuaValue::nil());
buf.push(val);
}
}
l.nny += 1;
let result = sort_buffer(l, &mut buf, &comp_func, has_comp);
l.nny -= 1;
result?;
if has_meta {
for (i, val) in buf.into_iter().enumerate() {
l.table_seti(&table_val, (i + 1) as i64, val)?;
}
} else {
let table = table_val.as_table_mut().unwrap();
for (i, val) in buf.into_iter().enumerate() {
table.raw_seti((i + 1) as i64, val);
}
}
if let Some(gc_ptr) = table_val.as_gc_ptr() {
l.gc_barrier_back(gc_ptr);
}
Ok(0)
}
fn sort_buffer(
l: &mut LuaState,
buf: &mut [LuaValue],
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<()> {
let n = buf.len();
if n <= 1 {
return Ok(());
}
if !has_comp {
let first_tt = buf[0].tt();
if buf.iter().all(|v| v.tt() == first_tt) {
if buf[0].is_integer() {
buf.sort_unstable_by(|a, b| {
let ia = unsafe { a.value.i };
let ib = unsafe { b.value.i };
ia.cmp(&ib)
});
return Ok(());
}
if buf[0].is_float() {
buf.sort_unstable_by(|a, b| {
let fa = unsafe { a.value.n };
let fb = unsafe { b.value.n };
fa.partial_cmp(&fb).unwrap_or(std::cmp::Ordering::Equal)
});
return Ok(());
}
if buf[0].is_string() {
buf.sort_unstable_by(|a, b| {
let sa = a.as_bytes().unwrap_or(&[]);
let sb = b.as_bytes().unwrap_or(&[]);
sa.cmp(sb)
});
return Ok(());
}
}
if buf.iter().all(|v| v.is_integer() || v.is_float()) {
buf.sort_unstable_by(|a, b| {
let na = a.as_number().unwrap_or(0.0);
let nb = b.as_number().unwrap_or(0.0);
na.partial_cmp(&nb).unwrap_or(std::cmp::Ordering::Equal)
});
return Ok(());
}
}
let max_depth = (usize::BITS - n.leading_zeros()) as usize * 2; introsort(l, buf, 0, n - 1, max_depth, comp_func, has_comp)
}
#[inline]
fn sort_compare(
l: &mut LuaState,
a: LuaValue,
b: LuaValue,
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<bool> {
if has_comp {
l.call_compare(*comp_func, a, b)
} else {
l.obj_lt(&a, &b)
}
}
#[inline]
fn sort3(
l: &mut LuaState,
buf: &mut [LuaValue],
lo: usize,
mid: usize,
hi: usize,
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<()> {
if sort_compare(l, buf[mid], buf[lo], comp_func, has_comp)? {
buf.swap(lo, mid);
}
if sort_compare(l, buf[hi], buf[mid], comp_func, has_comp)? {
buf.swap(mid, hi);
if sort_compare(l, buf[mid], buf[lo], comp_func, has_comp)? {
buf.swap(lo, mid);
}
}
Ok(())
}
fn partition(
l: &mut LuaState,
buf: &mut [LuaValue],
lo: usize,
hi: usize,
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<usize> {
let mid = lo + (hi - lo) / 2;
sort3(l, buf, lo, mid, hi, comp_func, has_comp)?;
if hi - lo <= 2 {
return Ok(mid);
}
let pivot = buf[mid];
buf.swap(mid, hi - 1);
let mut i = lo;
let mut j = hi - 1;
loop {
loop {
i += 1;
if !sort_compare(l, buf[i], pivot, comp_func, has_comp)? {
break;
}
if i == hi - 1 {
return Err(l.error("invalid order function for sorting".to_string()));
}
}
loop {
j -= 1;
if !sort_compare(l, pivot, buf[j], comp_func, has_comp)? {
break;
}
if j < i {
return Err(l.error("invalid order function for sorting".to_string()));
}
}
if j < i {
break;
}
buf.swap(i, j);
}
buf.swap(i, hi - 1);
Ok(i)
}
fn heapsort(
l: &mut LuaState,
buf: &mut [LuaValue],
lo: usize,
hi: usize,
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<()> {
let n = hi - lo + 1;
if n <= 1 {
return Ok(());
}
for i in (0..n / 2).rev() {
sift_down(l, buf, lo, i, n, comp_func, has_comp)?;
}
for end in (1..n).rev() {
buf.swap(lo, lo + end);
sift_down(l, buf, lo, 0, end, comp_func, has_comp)?;
}
Ok(())
}
fn sift_down(
l: &mut LuaState,
buf: &mut [LuaValue],
lo: usize,
mut pos: usize,
n: usize,
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<()> {
loop {
let left = 2 * pos + 1;
if left >= n {
break;
}
let right = left + 1;
let mut largest = pos;
if sort_compare(l, buf[lo + largest], buf[lo + left], comp_func, has_comp)? {
largest = left;
}
if right < n && sort_compare(l, buf[lo + largest], buf[lo + right], comp_func, has_comp)? {
largest = right;
}
if largest == pos {
break;
}
buf.swap(lo + pos, lo + largest);
pos = largest;
}
Ok(())
}
fn introsort(
l: &mut LuaState,
buf: &mut [LuaValue],
mut lo: usize,
mut hi: usize,
mut depth_limit: usize,
comp_func: &LuaValue,
has_comp: bool,
) -> LuaResult<()> {
while lo < hi {
let n = hi - lo + 1;
if n == 2 {
if sort_compare(l, buf[hi], buf[lo], comp_func, has_comp)? {
buf.swap(lo, hi);
}
return Ok(());
}
if n == 3 {
let mid = lo + 1;
sort3(l, buf, lo, mid, hi, comp_func, has_comp)?;
return Ok(());
}
if depth_limit == 0 {
return heapsort(l, buf, lo, hi, comp_func, has_comp);
}
depth_limit -= 1;
let p = partition(l, buf, lo, hi, comp_func, has_comp)?;
if p.saturating_sub(lo) < hi.saturating_sub(p) {
if p > lo {
introsort(l, buf, lo, p - 1, depth_limit, comp_func, has_comp)?;
}
lo = p + 1;
} else {
if p < hi {
introsort(l, buf, p + 1, hi, depth_limit, comp_func, has_comp)?;
}
if p == 0 {
break;
}
hi = p - 1;
}
}
Ok(())
}