use snafu::Snafu;
use svod_dtype::DType;
mod recurrent;
pub use recurrent::{JitRecurrent, LstmState, RecurrentJit, StepTiming};
#[derive(Clone, Debug)]
pub struct InputSpec {
pub shape: Vec<usize>,
pub dtype: DType,
}
impl InputSpec {
pub fn new(shape: &[usize], dtype: DType) -> Self {
Self { shape: shape.to_vec(), dtype }
}
pub fn f32(shape: &[usize]) -> Self {
Self::new(shape, DType::Float32)
}
pub fn i32(shape: &[usize]) -> Self {
Self::new(shape, DType::Int32)
}
pub fn i64(shape: &[usize]) -> Self {
Self::new(shape, DType::Int64)
}
}
#[derive(Debug, Snafu)]
#[snafu(visibility(pub))]
pub enum JitError {
#[snafu(display("JIT not prepared: call prepare() first"))]
NotPrepared,
#[snafu(display("input buffer not found: {name}"))]
InputBufferNotFound { name: &'static str },
#[snafu(display("duplicate JIT input buffer: {name} aliases {duplicate_of} with {buffer_id:?}"))]
DuplicateInputBuffer { name: &'static str, duplicate_of: &'static str, buffer_id: svod_device::BufferId },
#[snafu(display("{source}"))]
Build { source: Box<dyn std::error::Error + Send + Sync> },
#[snafu(display("{source}"))]
Tensor {
#[snafu(source(from(svod_tensor::error::Error, Box::new)))]
source: Box<svod_tensor::error::Error>,
},
#[snafu(display("{source}"))]
Device {
#[snafu(source(from(svod_device::error::Error, Box::new)))]
source: Box<svod_device::error::Error>,
},
#[snafu(display(
"JIT output layout mismatch: declared {declared_head} head + {declared_state} state elements \
({}), actual {actual} elements. Check that the `build` closure returns `cat([head, h, c], -1)` \
with the declared shapes.",
declared_head + declared_state
))]
OutputLayoutMismatch { declared_head: usize, declared_state: usize, actual: usize },
#[snafu(display("{source}"))]
Runtime { source: svod_runtime::Error },
}
pub type Result<T> = std::result::Result<T, JitError>;