use tagged_dispatch::tagged_dispatch;
#[tagged_dispatch]
trait Draw {
fn draw(&self) -> &str;
}
#[derive(Clone)]
struct Circle {
radius: f32,
}
impl Draw for Circle {
fn draw(&self) -> &str {
"circle"
}
}
#[derive(Clone)]
struct Rectangle {
width: f32,
height: f32,
}
impl Draw for Rectangle {
fn draw(&self) -> &str {
"rectangle"
}
}
#[tagged_dispatch(Draw)]
enum Shape {
Circle,
Rectangle,
}
#[test]
fn test_debug_implementation() {
let circle = Shape::circle(Circle { radius: 1.0 });
let rect = Shape::rectangle(Rectangle { width: 2.0, height: 3.0 });
let circle_debug = format!("{:?}", circle);
let rect_debug = format!("{:?}", rect);
assert!(circle_debug.contains("Shape::Circle"));
assert!(rect_debug.contains("Shape::Rectangle"));
}
#[test]
fn test_equality() {
let circle1 = Shape::circle(Circle { radius: 1.0 });
let circle2 = Shape::circle(Circle { radius: 1.0 });
let circle3 = circle1.clone();
assert_ne!(circle1, circle2);
assert_eq!(circle1, circle1);
assert_ne!(circle1, circle3);
}
#[test]
fn test_ordering() {
let mut shapes = vec![
Shape::rectangle(Rectangle { width: 1.0, height: 2.0 }),
Shape::circle(Circle { radius: 1.0 }),
Shape::rectangle(Rectangle { width: 3.0, height: 4.0 }),
];
shapes.sort();
let tags: Vec<_> = shapes.iter().map(|s| s.tag_type()).collect();
for i in 1..tags.len() {
assert!(tags[i-1] <= tags[i]);
}
}
#[test]
fn test_derives_work() {
#[derive(Debug, PartialEq, Eq)]
struct Container {
shape: Shape,
name: String,
}
let container1 = Container {
shape: Shape::circle(Circle { radius: 1.0 }),
name: "test".to_string(),
};
let container2 = Container {
shape: Shape::circle(Circle { radius: 1.0 }),
name: "test".to_string(),
};
assert_ne!(container1, container2);
let debug_str = format!("{:?}", container1);
assert!(debug_str.contains("Container"));
assert!(debug_str.contains("Shape::Circle"));
}
#[cfg(feature = "allocator-bumpalo")]
#[test]
fn test_arena_version_traits() {
#[tagged_dispatch(Draw)]
enum ShapeArena<'a> {
Circle,
Rectangle,
}
let builder = ShapeArena::arena_builder();
let circle1 = builder.circle(Circle { radius: 1.0 });
let circle2 = builder.circle(Circle { radius: 1.0 });
assert_ne!(circle1, circle2);
let circle3 = circle1;
assert_eq!(circle1, circle3);
let debug = format!("{:?}", circle1);
assert!(debug.contains("ShapeArena::Circle"));
assert!(circle1 < circle2 || circle1 > circle2);
}