use core::cell::RefCell;
use alloc::{rc::Rc, vec::Vec};
use vexide::devices::{
PortError,
adi::AdiEncoder,
position::Position,
smart::{
motor::{Motor, MotorError},
rotation::RotationSensor,
},
};
pub trait RotarySensor {
type Error;
fn position(&self) -> Result<Position, Self::Error>;
}
macro_rules! impl_rotary_sensor {
( $struct:ident, $method:ident, $err:ty) => {
impl RotarySensor for $struct {
type Error = $err;
fn position(&self) -> Result<Position, Self::Error> {
$struct::$method(&self)
}
}
};
}
impl_rotary_sensor!(Motor, position, MotorError);
impl_rotary_sensor!(RotationSensor, position, PortError);
impl_rotary_sensor!(AdiEncoder, position, PortError);
impl<T: RotarySensor> RotarySensor for Vec<T> {
type Error = T::Error;
fn position(&self) -> Result<Position, Self::Error> {
let mut total_motors = 0;
let mut degree_sum = 0.0;
let mut last_error = None;
for motor in self {
degree_sum += match motor.position() {
Ok(position) => {
total_motors += 1;
position.as_degrees()
}
Err(error) => {
last_error = Some(error);
continue;
}
};
}
if total_motors == 0 {
return if let Some(error) = last_error {
Err(error)
} else {
Ok(Position::default())
};
}
#[allow(clippy::cast_precision_loss)]
Ok(Position::from_degrees(degree_sum / f64::from(total_motors)))
}
}
impl<const N: usize, T: RotarySensor> RotarySensor for [T; N] {
type Error = T::Error;
fn position(&self) -> Result<Position, Self::Error> {
let mut total_motors = 0;
let mut degree_sum = 0.0;
let mut last_error = None;
for motor in self {
degree_sum += match motor.position() {
Ok(position) => {
total_motors += 1;
position.as_degrees()
}
Err(error) => {
last_error = Some(error);
continue;
}
};
}
if total_motors == 0 {
return if let Some(error) = last_error {
Err(error)
} else {
Ok(Position::default())
};
}
#[allow(clippy::cast_precision_loss)]
Ok(Position::from_degrees(degree_sum / f64::from(total_motors)))
}
}
impl<T: RotarySensor> RotarySensor for Rc<RefCell<T>> {
type Error = <T as RotarySensor>::Error;
fn position(&self) -> Result<Position, Self::Error> {
self.borrow().position()
}
}