use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum ScType {
Bitstream { length: usize },
FixedPoint { width: u32, frac: u32 },
Rate,
UInt { width: u32 },
SInt { width: u32 },
Bool,
Vec { element: Box<ScType>, count: usize },
}
impl ScType {
pub fn bit_width(&self) -> usize {
match self {
Self::Bool => 1,
Self::Rate => 16, Self::UInt { width } | Self::SInt { width } => *width as usize,
Self::FixedPoint { width, .. } => *width as usize,
Self::Bitstream { .. } => 1, Self::Vec { element, count } => element.bit_width() * count,
}
}
}
impl fmt::Display for ScType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Bitstream { length } => write!(f, "bitstream<{length}>"),
Self::FixedPoint { width, frac } => write!(f, "fixed<{width},{frac}>"),
Self::Rate => write!(f, "rate"),
Self::UInt { width } => write!(f, "u{width}"),
Self::SInt { width } => write!(f, "i{width}"),
Self::Bool => write!(f, "bool"),
Self::Vec { element, count } => write!(f, "vec<{element},{count}>"),
}
}
}
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub struct ValueId(pub u32);
impl fmt::Display for ValueId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "%{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ScConst {
F64(f64),
I64(i64),
U64(u64),
F64Vec(Vec<f64>),
I64Vec(Vec<i64>),
}
#[derive(Debug, Clone, PartialEq)]
pub struct LifParams {
pub data_width: u32,
pub fraction: u32,
pub v_rest: i64,
pub v_reset: i64,
pub v_threshold: i64,
pub refractory_period: u32,
}
impl Default for LifParams {
fn default() -> Self {
Self {
data_width: 16,
fraction: 8,
v_rest: 0,
v_reset: 0,
v_threshold: 256, refractory_period: 2,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DenseParams {
pub n_inputs: usize,
pub n_neurons: usize,
pub data_width: u32,
pub stream_length: usize,
pub input_seed_base: u16,
pub weight_seed_base: u16,
pub y_min: i64,
pub y_max: i64,
}
impl Default for DenseParams {
fn default() -> Self {
Self {
n_inputs: 3,
n_neurons: 7,
data_width: 16,
stream_length: 1024,
input_seed_base: 0xACE1,
weight_seed_base: 0xBEEF,
y_min: 0,
y_max: 256, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceMode {
Sum,
Max,
}
impl fmt::Display for ReduceMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Sum => write!(f, "sum"),
Self::Max => write!(f, "max"),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ScOp {
Input {
id: ValueId,
name: String,
ty: ScType,
},
Output {
id: ValueId,
name: String,
source: ValueId,
},
Constant {
id: ValueId,
value: ScConst,
ty: ScType,
},
Encode {
id: ValueId,
prob: ValueId,
length: usize,
seed: u16,
},
BitwiseAnd {
id: ValueId,
lhs: ValueId,
rhs: ValueId,
},
Popcount { id: ValueId, input: ValueId },
LifStep {
id: ValueId,
current: ValueId,
leak: ValueId,
gain: ValueId,
noise: ValueId,
params: LifParams,
},
DenseForward {
id: ValueId,
inputs: ValueId,
weights: ValueId,
leak: ValueId,
gain: ValueId,
params: DenseParams,
},
BitwiseXor {
id: ValueId,
lhs: ValueId,
rhs: ValueId,
},
Reduce {
id: ValueId,
input: ValueId,
mode: ReduceMode,
},
GraphForward {
id: ValueId,
features: ValueId,
adjacency: ValueId,
n_nodes: usize,
n_features: usize,
},
SoftmaxAttention {
id: ValueId,
q: ValueId,
k: ValueId,
v: ValueId,
dim_k: usize,
},
KuramotoStep {
id: ValueId,
phases: ValueId,
omega: ValueId,
coupling: ValueId,
dt: f64,
},
Scale {
id: ValueId,
input: ValueId,
factor: f64,
},
Offset {
id: ValueId,
input: ValueId,
offset: f64,
},
DivConst {
id: ValueId,
input: ValueId,
divisor: u64,
},
}
impl ScOp {
pub fn result_id(&self) -> ValueId {
match self {
Self::Input { id, .. }
| Self::Output { id, .. }
| Self::Constant { id, .. }
| Self::Encode { id, .. }
| Self::BitwiseAnd { id, .. }
| Self::BitwiseXor { id, .. }
| Self::Popcount { id, .. }
| Self::Reduce { id, .. }
| Self::LifStep { id, .. }
| Self::DenseForward { id, .. }
| Self::GraphForward { id, .. }
| Self::SoftmaxAttention { id, .. }
| Self::KuramotoStep { id, .. }
| Self::Scale { id, .. }
| Self::Offset { id, .. }
| Self::DivConst { id, .. } => *id,
}
}
pub fn operands(&self) -> Vec<ValueId> {
match self {
Self::Input { .. } | Self::Constant { .. } => vec![],
Self::Output { source, .. } => vec![*source],
Self::Encode { prob, .. } => vec![*prob],
Self::BitwiseAnd { lhs, rhs, .. } | Self::BitwiseXor { lhs, rhs, .. } => {
vec![*lhs, *rhs]
}
Self::Popcount { input, .. } | Self::Reduce { input, .. } => vec![*input],
Self::LifStep {
current,
leak,
gain,
noise,
..
} => vec![*current, *leak, *gain, *noise],
Self::DenseForward {
inputs,
weights,
leak,
gain,
..
} => vec![*inputs, *weights, *leak, *gain],
Self::GraphForward {
features,
adjacency,
..
} => vec![*features, *adjacency],
Self::SoftmaxAttention { q, k, v, .. } => vec![*q, *k, *v],
Self::KuramotoStep {
phases,
omega,
coupling,
..
} => vec![*phases, *omega, *coupling],
Self::Scale { input, .. }
| Self::Offset { input, .. }
| Self::DivConst { input, .. } => {
vec![*input]
}
}
}
pub fn op_name(&self) -> &'static str {
match self {
Self::Input { .. } => "sc.input",
Self::Output { .. } => "sc.output",
Self::Constant { .. } => "sc.constant",
Self::Encode { .. } => "sc.encode",
Self::BitwiseAnd { .. } => "sc.and",
Self::BitwiseXor { .. } => "sc.xor",
Self::Popcount { .. } => "sc.popcount",
Self::Reduce { .. } => "sc.reduce",
Self::LifStep { .. } => "sc.lif_step",
Self::DenseForward { .. } => "sc.dense_forward",
Self::GraphForward { .. } => "sc.graph_forward",
Self::SoftmaxAttention { .. } => "sc.softmax_attention",
Self::KuramotoStep { .. } => "sc.kuramoto_step",
Self::Scale { .. } => "sc.scale",
Self::Offset { .. } => "sc.offset",
Self::DivConst { .. } => "sc.div_const",
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ScGraph {
pub name: String,
pub ops: Vec<ScOp>,
pub(crate) next_id: u32,
}
impl ScGraph {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
ops: Vec::new(),
next_id: 0,
}
}
pub fn fresh_id(&mut self) -> ValueId {
let id = ValueId(self.next_id);
self.next_id += 1;
id
}
pub fn push(&mut self, op: ScOp) -> ValueId {
let id = op.result_id();
self.ops.push(op);
id
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn inputs(&self) -> Vec<&ScOp> {
self.ops
.iter()
.filter(|op| matches!(op, ScOp::Input { .. }))
.collect()
}
pub fn outputs(&self) -> Vec<&ScOp> {
self.ops
.iter()
.filter(|op| matches!(op, ScOp::Output { .. }))
.collect()
}
}