#![no_std]
#![warn(missing_docs)]
use embedded_hal::delay::DelayNs;
use embedded_hal::digital::{OutputPin, PinState};
use embedded_hal::spi::SpiDevice;
#[derive(Debug)]
pub enum Error<SPI> {
IndexOutOfBounds,
IoError,
Spi(SPI),
}
pub struct ShiftRegister<const N: usize, SPI, OE, LATCH, D> {
spi: SPI,
not_oe: OE,
latch: LATCH,
delay: D,
data: [u8; N],
}
impl<const N: usize, SPI, OE, LATCH, D> ShiftRegister<N, SPI, OE, LATCH, D>
where
SPI: SpiDevice,
OE: OutputPin,
LATCH: OutputPin,
D: DelayNs,
{
pub fn new(spi: SPI, not_oe: OE, latch: LATCH, delay: D) -> Self {
ShiftRegister {
spi,
not_oe,
latch,
delay,
data: [0u8; N],
}
}
pub fn enable_output(&mut self, enable: bool) -> Result<(), Error<SPI::Error>> {
self.not_oe.set_state(PinState::from(!enable)).map_err(|_| Error::IoError)
}
fn latch(&mut self) -> Result<(), LATCH::Error> {
self.latch.set_high()?;
self.delay.delay_us(5);
self.latch.set_low()
}
pub fn get_output(&self, idx: usize) -> Result<bool, Error<SPI::Error>> {
let max_bits: usize = N.checked_mul(8).ok_or(Error::IndexOutOfBounds)?;
if idx >= max_bits {
return Err(Error::IndexOutOfBounds);
}
let byte_idx = N - 1 - (idx / 8);
let bit_idx = idx % 8;
Ok(self.data[byte_idx] & (1 << bit_idx) != 0)
}
pub fn set_output(&mut self, idx: usize, output_state: bool) -> Result<(), Error<SPI::Error>> {
let max_bits: usize = N.checked_mul(8).ok_or(Error::IndexOutOfBounds)?;
if idx >= max_bits {
return Err(Error::IndexOutOfBounds);
}
let byte_idx = N - 1 - (idx / 8);
let bit_idx = idx % 8;
if output_state {
self.data[byte_idx] |= 1 << bit_idx;
} else {
self.data[byte_idx] &= !(1u8 << bit_idx);
}
Ok(())
}
pub fn clear(&mut self) {
self.data = [0u8; N];
}
pub fn write_all(&mut self) -> Result<(), Error<SPI::Error>> {
self.spi.write(&self.data).map_err(Error::Spi)?;
self.latch().map_err(|_| Error::IoError)?;
Ok(())
}
pub fn write_output(
&mut self,
idx: usize,
output_state: bool,
) -> Result<(), Error<SPI::Error>> {
self.set_output(idx, output_state)?;
self.write_all()
}
}
#[cfg(test)]
mod test {
extern crate std;
use super::*;
use embedded_hal_mock::eh1::{
delay::{NoopDelay as DelayMock, CheckedDelay, Transaction as DelayTransaction},
digital::{Mock as PinMock, State, Transaction as PinTransaction},
spi::{Mock as SpiMock, Transaction as SpiTransaction},
};
#[test]
fn test_clear() {
let mut spi_mock: SpiMock<u8> = SpiMock::new(&[]);
let mut oe_mock = PinMock::new(&[]);
let mut latch_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let mut dev = ShiftRegister::<2, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.data = [0xff, 0x12];
dev.clear();
assert_eq!(dev.data, [0u8, 0u8]);
spi_mock.done();
oe_mock.done();
latch_mock.done();
}
#[test]
fn test_get_output() {
let mut spi_mock: SpiMock<u8> = SpiMock::new(&[]);
let mut oe_mock = PinMock::new(&[]);
let mut latch_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let mut dev = ShiftRegister::<2, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.data = [0x81, 0x13];
assert!(dev.get_output(0).unwrap()); assert!(dev.get_output(1).unwrap());
assert!(!dev.get_output(2).unwrap());
assert!(dev.get_output(8).unwrap());
assert!(!dev.get_output(9).unwrap());
assert!(!dev.get_output(12).unwrap());
assert!(dev.get_output(15).unwrap());
let err = dev.get_output(16).unwrap_err();
assert!(matches!(err, Error::IndexOutOfBounds));
spi_mock.done();
oe_mock.done();
latch_mock.done();
}
#[test]
fn test_set_output() {
let mut spi_mock: SpiMock<u8> = SpiMock::new(&[]);
let mut oe_mock = PinMock::new(&[]);
let mut latch_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let mut dev = ShiftRegister::<2, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.set_output(0, true).unwrap();
dev.set_output(1, true).unwrap();
dev.set_output(4, true).unwrap();
dev.set_output(8, true).unwrap();
dev.set_output(15, true).unwrap();
assert_eq!(dev.data, [0x81, 0x13]);
dev.set_output(1, false).unwrap();
dev.set_output(8, false).unwrap();
assert_eq!(dev.data, [0x80, 0x11]);
let err = dev.set_output(16, true).unwrap_err();
assert!(matches!(err, Error::IndexOutOfBounds));
spi_mock.done();
oe_mock.done();
latch_mock.done();
}
#[test]
fn test_enable_outputs() {
{
let mut spi: SpiMock<u8> = SpiMock::new(&[]);
let mut latch_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let mut oe_mock = PinMock::new(&[PinTransaction::set(State::High)]);
let mut dev = ShiftRegister::<1, _, _, _, _>::new(
spi.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.enable_output(false).unwrap();
spi.done();
oe_mock.done();
latch_mock.done();
}
{
let mut spi: SpiMock<u8> = SpiMock::new(&[]);
let mut latch_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let mut oe_mock = PinMock::new(&[PinTransaction::set(State::Low)]);
let mut dev = ShiftRegister::<1, _, _, _, _>::new(
spi.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.enable_output(true).unwrap();
spi.done();
oe_mock.done();
latch_mock.done();
}
}
#[test]
fn test_latch() {
let mut spi_mock: SpiMock<u8> = SpiMock::new(&[]);
let mut oe_mock = PinMock::new(&[]);
let mut delay_mock = CheckedDelay::new(&[
DelayTransaction::delay_us(5),
]);
let mut latch_mock = PinMock::new(&[
PinTransaction::set(State::High),
PinTransaction::set(State::Low),
]);
let mut dev = ShiftRegister::<1, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock.clone(),
);
dev.latch().unwrap();
spi_mock.done();
oe_mock.done();
latch_mock.done();
delay_mock.done();
}
#[test]
fn test_write_all() {
let mut oe_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let expectations = [
SpiTransaction::transaction_start(),
SpiTransaction::write_vec(std::vec![0x33, 0x22, 0x11]),
SpiTransaction::transaction_end(),
];
let mut spi_mock = SpiMock::new(&expectations);
let mut latch_mock = PinMock::new(&[
PinTransaction::set(State::High),
PinTransaction::set(State::Low),
]);
let mut dev = ShiftRegister::<3, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.data = [0x33, 0x22, 0x11];
dev.write_all().unwrap();
spi_mock.done();
oe_mock.done();
latch_mock.done();
}
#[test]
fn test_write_output() {
{
let mut oe_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let expectations = [
SpiTransaction::transaction_start(),
SpiTransaction::write_vec(std::vec![0x00, 0x01, 0x00]),
SpiTransaction::transaction_end(),
];
let mut spi_mock = SpiMock::new(&expectations);
let mut latch_mock = PinMock::new(&[
PinTransaction::set(State::High),
PinTransaction::set(State::Low),
]);
let mut dev = ShiftRegister::<3, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.write_output(8, true).unwrap();
spi_mock.done();
oe_mock.done();
latch_mock.done();
}
{
let mut oe_mock = PinMock::new(&[]);
let delay_mock = DelayMock::new();
let expectations = [
SpiTransaction::transaction_start(),
SpiTransaction::write_vec(std::vec![0xFF, 0xFE, 0xFF]),
SpiTransaction::transaction_end(),
];
let mut spi_mock = SpiMock::new(&expectations);
let mut latch_mock = PinMock::new(&[
PinTransaction::set(State::High),
PinTransaction::set(State::Low),
]);
let mut dev = ShiftRegister::<3, _, _, _, _>::new(
spi_mock.clone(),
oe_mock.clone(),
latch_mock.clone(),
delay_mock,
);
dev.data = [0xFF, 0xFF, 0xFF];
dev.write_output(8, false).unwrap();
spi_mock.done();
oe_mock.done();
latch_mock.done();
}
}
}