use core::cell::UnsafeCell;
use crate::device::DeviceSpec;
use crate::error::OsStatus;
use crate::io::IoBuffer;
use crate::realtime::{RealtimeContext, Refcount, State, StateCell};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub struct DriverInfo {
pub name: &'static str,
pub manufacturer: &'static str,
pub version: &'static str,
}
impl DriverInfo {
#[must_use]
pub const fn of<T: Driver>() -> Self {
Self {
name: T::NAME,
manufacturer: T::MANUFACTURER,
version: T::VERSION,
}
}
}
pub trait Driver: Sized + Send {
const NAME: &'static str;
const MANUFACTURER: &'static str;
const VERSION: &'static str;
fn new() -> Self;
fn device(&self) -> DeviceSpec;
fn initialize(&mut self) -> Result<(), OsStatus> {
Ok(())
}
fn start_io(&mut self) -> Result<(), OsStatus> {
Ok(())
}
fn stop_io(&mut self) {}
fn process_io(&mut self, rt: &RealtimeContext, buffer: &mut IoBuffer<'_>);
}
pub trait AnyDriver: Send + Sync {
fn add_ref(&self) -> u32;
fn release(&self) -> u32;
fn refcount(&self) -> u32;
fn state(&self) -> State;
fn initialize(&self) -> Result<(), OsStatus>;
fn start_io(&self) -> Result<(), OsStatus>;
fn stop_io(&self) -> Result<(), OsStatus>;
fn process_io(&self, rt: &RealtimeContext, buffer: &mut IoBuffer<'_>) -> Result<(), OsStatus>;
fn device(&self) -> DeviceSpec;
fn info(&self) -> DriverInfo;
}
pub struct DriverInstance<T: Driver> {
inner: UnsafeCell<T>,
state: StateCell,
refcount: Refcount,
}
unsafe impl<T: Driver> Sync for DriverInstance<T> {}
impl<T: Driver> DriverInstance<T> {
#[must_use]
pub fn new() -> Self {
Self {
inner: UnsafeCell::new(T::new()),
state: StateCell::new(),
refcount: Refcount::new(),
}
}
#[inline]
#[must_use]
pub fn state(&self) -> State {
self.state.load()
}
#[inline]
#[must_use]
pub fn refcount(&self) -> u32 {
self.refcount.count()
}
#[inline]
pub fn add_ref(&self) -> u32 {
self.refcount.add_ref()
}
#[inline]
pub fn release(&self) -> u32 {
self.refcount.release()
}
pub fn initialize(&self) -> Result<(), OsStatus> {
self.state
.initialize()
.map_err(|_| OsStatus::ILLEGAL_OPERATION)?;
let inner = unsafe { &mut *self.inner.get() };
match inner.initialize() {
Ok(()) => Ok(()),
Err(status) => {
let _ = self.state.reset();
Err(status)
}
}
}
pub fn start_io(&self) -> Result<(), OsStatus> {
self.state
.start()
.map_err(|_| OsStatus::ILLEGAL_OPERATION)?;
let inner = unsafe { &mut *self.inner.get() };
match inner.start_io() {
Ok(()) => Ok(()),
Err(status) => {
let _ = self.state.stop();
Err(status)
}
}
}
pub fn stop_io(&self) -> Result<(), OsStatus> {
if self.state.load() != State::Running {
return Err(OsStatus::ILLEGAL_OPERATION);
}
let inner = unsafe { &mut *self.inner.get() };
inner.stop_io();
self.state.stop().map_err(|_| OsStatus::ILLEGAL_OPERATION)?;
Ok(())
}
pub fn process_io(
&self,
rt: &RealtimeContext,
buffer: &mut IoBuffer<'_>,
) -> Result<(), OsStatus> {
if !self.state.is_running() {
return Err(OsStatus::NOT_RUNNING);
}
let inner = unsafe { &mut *self.inner.get() };
inner.process_io(rt, buffer);
Ok(())
}
#[must_use]
pub fn device(&self) -> DeviceSpec {
let inner = unsafe { &*self.inner.get() };
inner.device()
}
#[inline]
#[must_use]
pub fn info(&self) -> DriverInfo {
DriverInfo::of::<T>()
}
}
impl<T: Driver> Default for DriverInstance<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Driver> AnyDriver for DriverInstance<T> {
#[inline]
fn add_ref(&self) -> u32 {
Self::add_ref(self)
}
#[inline]
fn release(&self) -> u32 {
Self::release(self)
}
#[inline]
fn refcount(&self) -> u32 {
Self::refcount(self)
}
#[inline]
fn state(&self) -> State {
Self::state(self)
}
#[inline]
fn initialize(&self) -> Result<(), OsStatus> {
Self::initialize(self)
}
#[inline]
fn start_io(&self) -> Result<(), OsStatus> {
Self::start_io(self)
}
#[inline]
fn stop_io(&self) -> Result<(), OsStatus> {
Self::stop_io(self)
}
#[inline]
fn process_io(&self, rt: &RealtimeContext, buffer: &mut IoBuffer<'_>) -> Result<(), OsStatus> {
Self::process_io(self, rt, buffer)
}
#[inline]
fn device(&self) -> DeviceSpec {
Self::device(self)
}
#[inline]
fn info(&self) -> DriverInfo {
Self::info(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::DeviceSpec;
use crate::format::StreamFormat;
use crate::io::{IoOperation, Timestamp};
use crate::stream::StreamSpec;
use core::cell::Cell;
use static_assertions::assert_impl_all;
struct Loopback {
initialize_seen: Cell<u32>,
start_seen: Cell<u32>,
stop_seen: Cell<u32>,
process_seen: Cell<u32>,
start_should_fail: Cell<bool>,
}
impl Driver for Loopback {
const NAME: &'static str = "tympan-aspl loopback (test)";
const MANUFACTURER: &'static str = "tympan-aspl";
const VERSION: &'static str = "0.0.0";
fn new() -> Self {
Self {
initialize_seen: Cell::new(0),
start_seen: Cell::new(0),
stop_seen: Cell::new(0),
process_seen: Cell::new(0),
start_should_fail: Cell::new(false),
}
}
fn device(&self) -> DeviceSpec {
DeviceSpec::new(
"com.tympan.test.loopback",
"Test Loopback",
Self::MANUFACTURER,
)
.with_input(StreamSpec::input(StreamFormat::float32(48_000.0, 2)))
.with_output(StreamSpec::output(StreamFormat::float32(48_000.0, 2)))
}
fn initialize(&mut self) -> Result<(), OsStatus> {
self.initialize_seen.set(self.initialize_seen.get() + 1);
Ok(())
}
fn start_io(&mut self) -> Result<(), OsStatus> {
if self.start_should_fail.get() {
return Err(OsStatus::NOT_READY);
}
self.start_seen.set(self.start_seen.get() + 1);
Ok(())
}
fn stop_io(&mut self) {
self.stop_seen.set(self.stop_seen.get() + 1);
}
fn process_io(&mut self, _rt: &RealtimeContext, buffer: &mut IoBuffer<'_>) {
self.process_seen.set(self.process_seen.get() + 1);
let n = buffer.output.len().min(buffer.input.len());
buffer.output[..n].copy_from_slice(&buffer.input[..n]);
}
}
assert_impl_all!(DriverInstance<Loopback>: Sync, AnyDriver);
fn rt() -> RealtimeContext {
unsafe { RealtimeContext::new_unchecked() }
}
#[test]
fn new_starts_uninitialized_with_refcount_one() {
let d = DriverInstance::<Loopback>::new();
assert_eq!(d.state(), State::Uninitialized);
assert_eq!(d.refcount(), 1);
}
#[test]
fn default_matches_new() {
let d: DriverInstance<Loopback> = DriverInstance::default();
assert_eq!(d.state(), State::Uninitialized);
assert_eq!(d.refcount(), 1);
}
#[test]
fn add_ref_release_delegate_to_refcount() {
let d = DriverInstance::<Loopback>::new();
assert_eq!(d.add_ref(), 2);
assert_eq!(d.add_ref(), 3);
assert_eq!(d.release(), 2);
assert_eq!(d.release(), 1);
}
#[test]
fn info_reflects_associated_constants() {
let d = DriverInstance::<Loopback>::new();
let info = d.info();
assert_eq!(info.name, "tympan-aspl loopback (test)");
assert_eq!(info.manufacturer, "tympan-aspl");
assert_eq!(info.version, "0.0.0");
}
#[test]
fn device_is_forwarded_from_the_user_driver() {
let d = DriverInstance::<Loopback>::new();
let spec = d.device();
assert_eq!(spec.uid(), "com.tympan.test.loopback");
assert!(spec.is_loopback());
}
#[test]
fn initialize_transitions_and_forwards_to_user() {
let d = DriverInstance::<Loopback>::new();
assert!(d.initialize().is_ok());
assert_eq!(d.state(), State::Initialized);
let inner = unsafe { &*d.inner.get() };
assert_eq!(inner.initialize_seen.get(), 1);
}
#[test]
fn double_initialize_is_illegal() {
let d = DriverInstance::<Loopback>::new();
d.initialize().unwrap();
assert_eq!(d.initialize(), Err(OsStatus::ILLEGAL_OPERATION));
}
#[test]
fn start_requires_initialized() {
let d = DriverInstance::<Loopback>::new();
assert_eq!(d.start_io(), Err(OsStatus::ILLEGAL_OPERATION));
assert_eq!(d.state(), State::Uninitialized);
}
#[test]
fn start_io_transitions_and_forwards_to_user() {
let d = DriverInstance::<Loopback>::new();
d.initialize().unwrap();
d.start_io().unwrap();
assert_eq!(d.state(), State::Running);
let inner = unsafe { &*d.inner.get() };
assert_eq!(inner.start_seen.get(), 1);
}
#[test]
fn start_failure_rolls_state_back_to_initialized() {
let d = DriverInstance::<Loopback>::new();
d.initialize().unwrap();
unsafe { &*d.inner.get() }.start_should_fail.set(true);
assert_eq!(d.start_io(), Err(OsStatus::NOT_READY));
assert_eq!(d.state(), State::Initialized);
}
#[test]
fn stop_io_returns_to_initialized() {
let d = DriverInstance::<Loopback>::new();
d.initialize().unwrap();
d.start_io().unwrap();
d.stop_io().unwrap();
assert_eq!(d.state(), State::Initialized);
let inner = unsafe { &*d.inner.get() };
assert_eq!(inner.stop_seen.get(), 1);
}
#[test]
fn stop_without_start_is_illegal() {
let d = DriverInstance::<Loopback>::new();
assert_eq!(d.stop_io(), Err(OsStatus::ILLEGAL_OPERATION));
}
#[test]
fn process_io_requires_running_state() {
let d = DriverInstance::<Loopback>::new();
let input = [0.0_f32; 4];
let mut output = [0.0_f32; 4];
let rt = rt();
let mut buffer = IoBuffer::new(
Timestamp::ZERO,
IoOperation::PROCESS_OUTPUT,
&input,
&mut output,
);
assert_eq!(d.process_io(&rt, &mut buffer), Err(OsStatus::NOT_RUNNING));
d.initialize().unwrap();
let mut buffer = IoBuffer::new(
Timestamp::ZERO,
IoOperation::PROCESS_OUTPUT,
&input,
&mut output,
);
assert_eq!(d.process_io(&rt, &mut buffer), Err(OsStatus::NOT_RUNNING));
}
#[test]
fn process_io_after_start_copies_input_to_output() {
let d = DriverInstance::<Loopback>::new();
d.initialize().unwrap();
d.start_io().unwrap();
let input = [0.1_f32, -0.2, 0.3, -0.4];
let mut output = [0.0_f32; 4];
let rt = rt();
let mut buffer = IoBuffer::new(
Timestamp::ZERO,
IoOperation::PROCESS_OUTPUT,
&input,
&mut output,
);
d.process_io(&rt, &mut buffer).unwrap();
assert_eq!(output, input);
let inner = unsafe { &*d.inner.get() };
assert_eq!(inner.process_seen.get(), 1);
}
#[test]
fn full_lifecycle_round_trip() {
let d = DriverInstance::<Loopback>::new();
d.initialize().unwrap();
d.start_io().unwrap();
let input = [0.5_f32; 4];
let mut output = [0.0_f32; 4];
let rt = rt();
for _ in 0..3 {
let mut buffer = IoBuffer::new(
Timestamp::ZERO,
IoOperation::PROCESS_OUTPUT,
&input,
&mut output,
);
d.process_io(&rt, &mut buffer).unwrap();
}
d.stop_io().unwrap();
assert_eq!(d.state(), State::Initialized);
d.start_io().unwrap();
d.stop_io().unwrap();
assert_eq!(d.state(), State::Initialized);
let inner = unsafe { &*d.inner.get() };
assert_eq!(inner.process_seen.get(), 3);
assert_eq!(inner.start_seen.get(), 2);
assert_eq!(inner.stop_seen.get(), 2);
}
#[test]
fn type_erased_dispatch_drives_full_lifecycle() {
use std::sync::Arc;
let driver: Arc<dyn AnyDriver> = Arc::new(DriverInstance::<Loopback>::new());
assert_eq!(driver.state(), State::Uninitialized);
assert_eq!(driver.refcount(), 1);
assert_eq!(driver.add_ref(), 2);
driver.initialize().unwrap();
driver.start_io().unwrap();
assert_eq!(driver.state(), State::Running);
let input = [0.25_f32, -0.5, 0.75, -1.0];
let mut output = [0.0_f32; 4];
let rt = rt();
let mut buffer = IoBuffer::new(
Timestamp::ZERO,
IoOperation::PROCESS_OUTPUT,
&input,
&mut output,
);
driver.process_io(&rt, &mut buffer).unwrap();
assert_eq!(output, input);
driver.stop_io().unwrap();
assert_eq!(driver.state(), State::Initialized);
assert_eq!(driver.release(), 1);
assert_eq!(driver.info().name, "tympan-aspl loopback (test)");
assert!(driver.device().is_loopback());
}
}