use crate::error::{CoreResult, InferenceError, SignalError, StorageError};
use crate::types::{CsiFrame, FrameId, PoseEstimate, ProcessedSignal, Timestamp};
pub trait CanonicalFrame {
fn to_canonical_bytes(&self) -> alloc_vec::Vec<u8>;
fn witness_hash(&self) -> [u8; 32] {
blake3::hash(&self.to_canonical_bytes()).into()
}
}
#[cfg(feature = "std")]
mod alloc_vec {
pub use std::vec::Vec;
}
#[cfg(not(feature = "std"))]
mod alloc_vec {
pub use alloc::vec::Vec;
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SignalProcessorConfig {
pub buffer_size: usize,
pub sample_rate_hz: f64,
pub apply_noise_filter: bool,
pub filter_cutoff_hz: f64,
pub normalize_amplitude: bool,
pub unwrap_phase: bool,
pub window_function: WindowFunction,
}
impl Default for SignalProcessorConfig {
fn default() -> Self {
Self {
buffer_size: 64,
sample_rate_hz: 1000.0,
apply_noise_filter: true,
filter_cutoff_hz: 50.0,
normalize_amplitude: true,
unwrap_phase: true,
window_function: WindowFunction::Hann,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum WindowFunction {
Rectangular,
#[default]
Hann,
Hamming,
Blackman,
Kaiser,
}
pub trait SignalProcessor: Send + Sync {
fn config(&self) -> &SignalProcessorConfig;
fn set_config(&mut self, config: SignalProcessorConfig) -> Result<(), SignalError>;
fn push_frame(&mut self, frame: CsiFrame) -> Result<(), SignalError>;
fn try_process(&mut self) -> Result<Option<ProcessedSignal>, SignalError>;
fn force_process(&mut self) -> Result<ProcessedSignal, SignalError>;
fn buffered_frame_count(&self) -> usize;
fn clear_buffer(&mut self);
fn reset(&mut self);
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct InferenceConfig {
pub model_path: String,
pub device: InferenceDevice,
pub max_batch_size: usize,
pub num_threads: usize,
pub confidence_threshold: f32,
pub nms_threshold: f32,
pub use_fp16: bool,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
model_path: String::new(),
device: InferenceDevice::Cpu,
max_batch_size: 8,
num_threads: 4,
confidence_threshold: 0.5,
nms_threshold: 0.45,
use_fp16: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum InferenceDevice {
#[default]
Cpu,
Cuda {
device_id: usize,
},
TensorRt {
device_id: usize,
},
CoreMl,
WebGpu,
}
pub trait NeuralInference: Send + Sync {
fn config(&self) -> &InferenceConfig;
fn is_ready(&self) -> bool;
fn model_version(&self) -> &str;
fn load_model(&mut self) -> Result<(), InferenceError>;
fn unload_model(&mut self);
fn infer(&self, signal: &ProcessedSignal) -> Result<PoseEstimate, InferenceError>;
fn infer_batch(&self, signals: &[ProcessedSignal])
-> Result<Vec<PoseEstimate>, InferenceError>;
fn warmup(&mut self) -> Result<(), InferenceError>;
fn stats(&self) -> InferenceStats;
}
#[derive(Debug, Clone, Default)]
pub struct InferenceStats {
pub total_inferences: u64,
pub avg_latency_ms: f64,
pub p95_latency_ms: f64,
pub max_latency_ms: f64,
pub throughput: f64,
pub gpu_memory_bytes: Option<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct QueryOptions {
pub limit: Option<usize>,
pub offset: Option<usize>,
pub start_time: Option<Timestamp>,
pub end_time: Option<Timestamp>,
pub device_id: Option<String>,
pub sort_order: SortOrder,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SortOrder {
#[default]
Ascending,
Descending,
}
pub trait DataStore: Send + Sync {
fn is_connected(&self) -> bool;
fn store_csi_frame(&self, frame: &CsiFrame) -> Result<(), StorageError>;
fn get_csi_frame(&self, id: &FrameId) -> Result<CsiFrame, StorageError>;
fn query_csi_frames(&self, options: &QueryOptions) -> Result<Vec<CsiFrame>, StorageError>;
fn store_pose_estimate(&self, estimate: &PoseEstimate) -> Result<(), StorageError>;
fn get_pose_estimate(&self, id: &FrameId) -> Result<PoseEstimate, StorageError>;
fn query_pose_estimates(
&self,
options: &QueryOptions,
) -> Result<Vec<PoseEstimate>, StorageError>;
fn get_recent_estimates(&self, count: usize) -> Result<Vec<PoseEstimate>, StorageError>;
fn delete_csi_frames_before(&self, timestamp: &Timestamp) -> Result<u64, StorageError>;
fn delete_pose_estimates_before(&self, timestamp: &Timestamp) -> Result<u64, StorageError>;
fn stats(&self) -> StorageStats;
}
#[derive(Debug, Clone, Default)]
pub struct StorageStats {
pub csi_frame_count: u64,
pub pose_estimate_count: u64,
pub total_size_bytes: u64,
pub oldest_record: Option<Timestamp>,
pub newest_record: Option<Timestamp>,
}
#[cfg(feature = "async")]
use async_trait::async_trait;
#[cfg(feature = "async")]
#[async_trait]
pub trait AsyncSignalProcessor: Send + Sync {
fn config(&self) -> &SignalProcessorConfig;
async fn set_config(&mut self, config: SignalProcessorConfig) -> Result<(), SignalError>;
async fn push_frame(&mut self, frame: CsiFrame) -> Result<(), SignalError>;
async fn try_process(&mut self) -> Result<Option<ProcessedSignal>, SignalError>;
async fn force_process(&mut self) -> Result<ProcessedSignal, SignalError>;
fn buffered_frame_count(&self) -> usize;
async fn clear_buffer(&mut self);
async fn reset(&mut self);
}
#[cfg(feature = "async")]
#[async_trait]
pub trait AsyncNeuralInference: Send + Sync {
fn config(&self) -> &InferenceConfig;
fn is_ready(&self) -> bool;
fn model_version(&self) -> &str;
async fn load_model(&mut self) -> Result<(), InferenceError>;
async fn unload_model(&mut self);
async fn infer(&self, signal: &ProcessedSignal) -> Result<PoseEstimate, InferenceError>;
async fn infer_batch(
&self,
signals: &[ProcessedSignal],
) -> Result<Vec<PoseEstimate>, InferenceError>;
async fn warmup(&mut self) -> Result<(), InferenceError>;
fn stats(&self) -> InferenceStats;
}
#[cfg(feature = "async")]
#[async_trait]
pub trait AsyncDataStore: Send + Sync {
fn is_connected(&self) -> bool;
async fn store_csi_frame(&self, frame: &CsiFrame) -> Result<(), StorageError>;
async fn get_csi_frame(&self, id: &FrameId) -> Result<CsiFrame, StorageError>;
async fn query_csi_frames(&self, options: &QueryOptions)
-> Result<Vec<CsiFrame>, StorageError>;
async fn store_pose_estimate(&self, estimate: &PoseEstimate) -> Result<(), StorageError>;
async fn get_pose_estimate(&self, id: &FrameId) -> Result<PoseEstimate, StorageError>;
async fn query_pose_estimates(
&self,
options: &QueryOptions,
) -> Result<Vec<PoseEstimate>, StorageError>;
async fn get_recent_estimates(&self, count: usize) -> Result<Vec<PoseEstimate>, StorageError>;
async fn delete_csi_frames_before(&self, timestamp: &Timestamp) -> Result<u64, StorageError>;
async fn delete_pose_estimates_before(
&self,
timestamp: &Timestamp,
) -> Result<u64, StorageError>;
fn stats(&self) -> StorageStats;
}
pub trait Pipeline: Send + Sync {
type Input;
type Output;
type Error;
fn process(&self, input: Self::Input) -> Result<Self::Output, Self::Error>;
}
pub trait Validate {
fn validate(&self) -> CoreResult<()>;
}
pub trait Resettable {
fn reset(&mut self);
}
pub trait HealthCheck {
type Status;
fn health_check(&self) -> Self::Status;
fn is_healthy(&self) -> bool;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_signal_processor_config_default() {
let config = SignalProcessorConfig::default();
assert_eq!(config.buffer_size, 64);
assert!(config.apply_noise_filter);
assert!(config.sample_rate_hz > 0.0);
}
#[test]
fn test_inference_config_default() {
let config = InferenceConfig::default();
assert_eq!(config.device, InferenceDevice::Cpu);
assert!(config.confidence_threshold > 0.0);
assert!(config.max_batch_size > 0);
}
#[test]
fn test_query_options_default() {
let options = QueryOptions::default();
assert!(options.limit.is_none());
assert!(options.offset.is_none());
assert_eq!(options.sort_order, SortOrder::Ascending);
}
#[test]
fn test_inference_device_variants() {
let cpu = InferenceDevice::Cpu;
let cuda = InferenceDevice::Cuda { device_id: 0 };
let tensorrt = InferenceDevice::TensorRt { device_id: 1 };
assert_eq!(cpu, InferenceDevice::Cpu);
assert!(matches!(cuda, InferenceDevice::Cuda { device_id: 0 }));
assert!(matches!(
tensorrt,
InferenceDevice::TensorRt { device_id: 1 }
));
}
}