use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction {
Asc,
Desc,
}
impl Direction {
pub fn sql(self) -> &'static str {
match self {
Direction::Asc => "ASC",
Direction::Desc => "DESC",
}
}
pub fn apply(self, ord: std::cmp::Ordering) -> std::cmp::Ordering {
match self {
Direction::Asc => ord,
Direction::Desc => ord.reverse(),
}
}
}
pub fn sort_in_memory<T, K>(
rows: &mut [T],
keys: &[(K, Direction)],
cmp: impl Fn(&K, &T, &T) -> std::cmp::Ordering,
tiebreak: impl Fn(&T, &T) -> std::cmp::Ordering,
) {
rows.sort_by(|a, b| {
for (k, dir) in keys {
let ord = cmp(k, a, b);
if ord != std::cmp::Ordering::Equal {
return dir.apply(ord);
}
}
tiebreak(a, b)
});
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SortField {
pub key: String,
pub direction: Option<Direction>,
}
impl SortField {
pub fn parse(raw: &str) -> Result<Self, SortParseError> {
let raw = raw.trim();
if raw.is_empty() {
return Err(SortParseError::EmptyField);
}
if let Some(rest) = raw.strip_prefix('-') {
return Self::new(rest.trim(), Some(Direction::Desc));
}
if let Some(rest) = raw.strip_prefix('+') {
return Self::new(rest.trim(), Some(Direction::Asc));
}
if let Some((k, d)) = raw.split_once(':') {
let dir = match d.trim().to_ascii_lowercase().as_str() {
"asc" | "ascending" | "" => Direction::Asc,
"desc" | "descending" => Direction::Desc,
other => return Err(SortParseError::BadDirection(other.to_string())),
};
return Self::new(k.trim(), Some(dir));
}
Self::new(raw, None)
}
fn new(key: &str, direction: Option<Direction>) -> Result<Self, SortParseError> {
if key.is_empty() {
return Err(SortParseError::EmptyField);
}
Ok(Self {
key: key.to_ascii_lowercase(),
direction,
})
}
pub fn direction_or(&self, default: Direction) -> Direction {
self.direction.unwrap_or(default)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SortSpec(pub Vec<SortField>);
impl SortSpec {
pub fn parse(raw: &str) -> Result<Self, SortParseError> {
let raw = raw.trim();
if raw.is_empty() {
return Err(SortParseError::EmptySpec);
}
let mut fields = Vec::new();
for part in raw.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
fields.push(SortField::parse(part)?);
}
if fields.is_empty() {
return Err(SortParseError::EmptySpec);
}
Ok(Self(fields))
}
pub fn fields(&self) -> &[SortField] {
&self.0
}
pub fn single(key: &str) -> Self {
Self(vec![SortField {
key: key.to_string(),
direction: None,
}])
}
}
#[derive(Debug, Error)]
pub enum SortParseError {
#[error("--sort spec is empty")]
EmptySpec,
#[error("--sort field is empty (check for stray commas)")]
EmptyField,
#[error("unknown sort direction `{0}` (expected `asc` or `desc`)")]
BadDirection(String),
#[error("unknown --sort key `{key}` (expected: {expected})")]
UnknownKey { key: String, expected: String },
}
#[derive(Debug, Clone, Copy)]
pub struct SortKeySpec<T: 'static> {
pub variant: T,
pub canonical: &'static str,
pub aliases: &'static [&'static str],
pub default_dir: Direction,
}
pub trait SortKeyDef: Sized + Copy + 'static {
fn specs() -> &'static [SortKeySpec<Self>];
fn from_field(field: &SortField) -> Result<(Self, Direction), SortParseError> {
for spec in Self::specs() {
if spec.canonical == field.key || spec.aliases.iter().any(|a| *a == field.key) {
return Ok((spec.variant, field.direction_or(spec.default_dir)));
}
}
let expected: Vec<&'static str> = Self::specs().iter().map(|s| s.canonical).collect();
Err(SortParseError::UnknownKey {
key: field.key.clone(),
expected: expected.join(", "),
})
}
fn help_text() -> String {
let mut out = String::new();
for (i, spec) in Self::specs().iter().enumerate() {
if i > 0 {
out.push_str(", ");
}
out.push_str(spec.canonical);
if i == 0 {
out.push_str(" (default ");
out.push_str(spec.default_dir.sql());
out.push(')');
} else {
out.push_str(" (");
out.push_str(spec.default_dir.sql());
out.push(')');
}
}
out
}
}
pub fn build_order_by(parts: &[(&str, Direction)], tiebreaker_column: &str) -> String {
let mut out: Vec<String> = parts
.iter()
.map(|(col, dir)| format!("{col} {}", dir.sql()))
.collect();
if !tiebreaker_column.is_empty() {
let tail_dir = parts.last().map(|(_, d)| *d).unwrap_or(Direction::Asc);
out.push(format!("{tiebreaker_column} {}", tail_dir.sql()));
}
out.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sort_in_memory_orders_by_key_then_tiebreak() {
let mut rows = vec![(1, 5, 30), (1, 5, 10), (2, 9, 20)];
sort_in_memory(
&mut rows,
&[("val", Direction::Desc)],
|k, a, b| match *k {
"val" => a.1.cmp(&b.1),
_ => std::cmp::Ordering::Equal,
},
|a, b| a.2.cmp(&b.2),
);
assert_eq!(rows, vec![(2, 9, 20), (1, 5, 10), (1, 5, 30)]);
}
#[test]
fn sort_in_memory_multi_key_asc() {
let mut rows = vec![(1, 5), (2, 5), (1, 3)];
sort_in_memory(
&mut rows,
&[("g", Direction::Asc), ("v", Direction::Asc)],
|k, a, b| match *k {
"g" => a.0.cmp(&b.0),
"v" => a.1.cmp(&b.1),
_ => std::cmp::Ordering::Equal,
},
|_, _| std::cmp::Ordering::Equal,
);
assert_eq!(rows, vec![(1, 3), (1, 5), (2, 5)]);
}
fn field_at(spec: &SortSpec, i: usize) -> anyhow::Result<&SortField> {
spec.0
.get(i)
.ok_or_else(|| anyhow::anyhow!("sort spec has no field at index {i}"))
}
#[test]
fn single_bare_key_has_no_direction() -> anyhow::Result<()> {
let s = SortSpec::parse("total")?;
assert_eq!(s.0.len(), 1);
let f0 = field_at(&s, 0)?;
assert_eq!(f0.key, "total");
assert!(f0.direction.is_none());
Ok(())
}
#[test]
fn colon_direction_parses() -> anyhow::Result<()> {
let s = SortSpec::parse("name:desc")?;
assert_eq!(field_at(&s, 0)?.direction, Some(Direction::Desc));
let s = SortSpec::parse("Count:ASC")?;
let f0 = field_at(&s, 0)?;
assert_eq!(f0.key, "count"); assert_eq!(f0.direction, Some(Direction::Asc));
Ok(())
}
#[test]
fn dash_plus_prefix_shorthand() -> anyhow::Result<()> {
let s = SortSpec::parse("-duration")?;
assert_eq!(field_at(&s, 0)?.direction, Some(Direction::Desc));
let s = SortSpec::parse(" + start ")?;
let f0 = field_at(&s, 0)?;
assert_eq!(f0.key, "start");
assert_eq!(f0.direction, Some(Direction::Asc));
Ok(())
}
#[test]
fn multi_field_parses_each_independently() -> anyhow::Result<()> {
let s = SortSpec::parse("-total, +name, p99:desc")?;
assert_eq!(s.0.len(), 3);
let f0 = field_at(&s, 0)?;
let f1 = field_at(&s, 1)?;
let f2 = field_at(&s, 2)?;
assert_eq!(f0.key, "total");
assert_eq!(f0.direction, Some(Direction::Desc));
assert_eq!(f1.key, "name");
assert_eq!(f1.direction, Some(Direction::Asc));
assert_eq!(f2.key, "p99");
assert_eq!(f2.direction, Some(Direction::Desc));
Ok(())
}
#[test]
fn bad_direction_rejected() {
assert!(matches!(
SortSpec::parse("total:nope"),
Err(SortParseError::BadDirection(_))
));
}
#[test]
fn empty_inputs_rejected() -> anyhow::Result<()> {
assert!(matches!(
SortSpec::parse(""),
Err(SortParseError::EmptySpec)
));
assert!(matches!(
SortSpec::parse(" "),
Err(SortParseError::EmptySpec)
));
assert_eq!(SortSpec::parse("total,")?.0.len(), 1);
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ToyKey {
Total,
Name,
Duration,
}
impl SortKeyDef for ToyKey {
fn specs() -> &'static [SortKeySpec<Self>] {
&[
SortKeySpec {
variant: ToyKey::Total,
canonical: "total",
aliases: &["total_ns"],
default_dir: Direction::Desc,
},
SortKeySpec {
variant: ToyKey::Name,
canonical: "name",
aliases: &[],
default_dir: Direction::Asc,
},
SortKeySpec {
variant: ToyKey::Duration,
canonical: "duration",
aliases: &["dur"],
default_dir: Direction::Desc,
},
]
}
}
#[test]
fn sort_key_def_resolves_canonical() -> Result<(), SortParseError> {
let f = SortField::parse("total")?;
let (k, d) = ToyKey::from_field(&f)?;
assert_eq!(k, ToyKey::Total);
assert_eq!(d, Direction::Desc); Ok(())
}
#[test]
fn sort_key_def_resolves_alias() -> Result<(), SortParseError> {
let f = SortField::parse("total_ns:asc")?;
let (k, d) = ToyKey::from_field(&f)?;
assert_eq!(k, ToyKey::Total);
assert_eq!(d, Direction::Asc); Ok(())
}
#[test]
fn sort_key_def_rejects_unknown_with_expected_list() -> Result<(), SortParseError> {
let f = SortField::parse("nope")?;
let err = match ToyKey::from_field(&f) {
Err(e) => e.to_string(),
Ok(_) => {
return Err(SortParseError::BadDirection(
"nope unexpectedly resolved".to_string(),
));
}
};
assert!(err.contains("unknown --sort key `nope`"), "got: {err}");
assert!(err.contains("total"));
assert!(err.contains("name"));
assert!(err.contains("duration"));
Ok(())
}
#[test]
fn sort_key_def_help_text_marks_first_as_default() {
let s = ToyKey::help_text();
assert_eq!(s, "total (default DESC), name (ASC), duration (DESC)");
}
#[test]
fn direction_or_default() -> anyhow::Result<()> {
let s = SortSpec::parse("total")?;
assert_eq!(
field_at(&s, 0)?.direction_or(Direction::Desc),
Direction::Desc
);
let s = SortSpec::parse("total:asc")?;
assert_eq!(
field_at(&s, 0)?.direction_or(Direction::Desc),
Direction::Asc
);
Ok(())
}
}