use catgrad::prelude::ops::*;
use catgrad::prelude::*;
use std::collections::HashMap;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let model = SimpleMNISTModel;
let typed_term = model.term().expect("Failed to create typed term");
save_svg(&typed_term.term, &format!("{}.svg", model.path()))?;
let parameters = load_param_types();
let mut env = stdlib();
env.declarations
.extend(to_load_ops(model.path(), parameters.keys()));
let check_result =
typecheck::check(&env, ¶meters, typed_term.clone()).expect("typecheck failed");
let labeled_term = typed_term.term.clone().with_nodes(|_| check_result);
let filename = &format!("{}_typed.svg", model.path());
save_svg(&labeled_term.unwrap(), filename)?;
let backend = select_backend()?;
let results = run_interpreter(&backend, &typed_term, env)?;
for value in results {
println!("{value:?}");
}
Ok(())
}
fn run_interpreter<B: interpreter::Backend>(
backend: &B,
typed_term: &TypedTerm,
env: Environment,
) -> Result<Vec<interpreter::Value<B>>, Box<dyn std::error::Error>> {
let input_data: Vec<f32> = (0..2 * 28 * 28)
.map(|i| (i as f32 * 0.001) % 1.0) .collect();
let interpreter_params = load_param_data(backend);
let interpreter = interpreter::Interpreter::new(backend.clone(), env, interpreter_params);
let input_tensor = interpreter::tensor(
&interpreter.backend,
interpreter::Shape(vec![2, 28, 28]),
&input_data,
)
.expect("Failed to create input tensor");
let results = interpreter
.run(typed_term.term.clone(), vec![input_tensor])
.expect("Failed to run inference");
Ok(results)
}
fn select_backend() -> Result<impl interpreter::Backend, Box<dyn std::error::Error>> {
#[cfg(feature = "candle-backend")]
{
println!("selected candle backend...");
use catgrad::interpreter::backend::candle::CandleBackend;
#[allow(clippy::needless_return)]
return Ok(CandleBackend::new());
}
#[cfg(all(feature = "ndarray-backend", not(feature = "candle-backend")))]
{
println!("selected ndarray backend...");
use catgrad::interpreter::backend::ndarray::NdArrayBackend;
#[allow(clippy::needless_return)]
return Ok(NdArrayBackend);
}
#[cfg(not(any(feature = "candle-backend", feature = "ndarray-backend")))]
{
println!("selected ShapeOnly backend (no tensors computed)");
return Ok(interpreter::backend::shape_only::ShapeOnlyBackend);
}
}
pub struct SimpleMNISTModel;
impl Module<1, 1> for SimpleMNISTModel {
fn path(&self) -> Path {
Path::new(["model", "hidden"]).unwrap()
}
fn def(&self, builder: &Builder, [x]: [Var; 1]) -> [Var; 1] {
let [batch_size, h, w] = unpack::<3>(builder, shape(builder, x.clone()));
let flat_size = h * w;
let flat_shape = pack::<2>(builder, [batch_size, flat_size]);
let x = reshape(builder, flat_shape, x);
let root = self.path();
let p = param(builder, &root.extend(["0", "weights"]).unwrap());
let x = matmul(builder, x, p);
let x = nn::Sigmoid.call(builder, [x]);
let p = param(builder, &root.extend(["1", "weights"]).unwrap());
let x = matmul(builder, x, p);
let x = nn::Sigmoid.call(builder, [x]);
[x]
}
fn ty(&self) -> ([Type; 1], [Type; 1]) {
use catgrad::typecheck::*;
let batch_size = NatExpr::Var(0);
let t_x = Value::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![
batch_size.clone(),
NatExpr::Constant(28),
NatExpr::Constant(28),
]),
}));
let t_y = Value::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![batch_size, NatExpr::Constant(10)]),
}));
([t_x], [t_y])
}
}
fn load_param_types() -> typecheck::Parameters {
use catgrad::category::core::Dtype;
use catgrad::typecheck::value_types::{DtypeExpr, NatExpr, NdArrayType, ShapeExpr, TypeExpr};
let mut map = HashMap::new();
let layer1_type = Value::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![
NatExpr::Mul(vec![NatExpr::Constant(28), NatExpr::Constant(28)]),
NatExpr::Constant(100),
]),
}));
map.insert(
path(vec!["0", "weights"]).expect("invalid param path"),
layer1_type,
);
let layer2_type = Value::Tensor(TypeExpr::NdArrayType(NdArrayType {
dtype: DtypeExpr::Constant(Dtype::F32),
shape: ShapeExpr::Shape(vec![NatExpr::Constant(100), NatExpr::Constant(10)]),
}));
map.insert(
path(vec!["1", "weights"]).expect("invalid param path"),
layer2_type,
);
typecheck::Parameters::from(map)
}
fn load_param_data<B: interpreter::Backend>(backend: &B) -> interpreter::Parameters<B> {
use catgrad::category::core::Shape;
use std::collections::HashMap;
let mut map = HashMap::new();
let layer1_data: Vec<f32> = (0..784 * 100)
.map(|i| (i as f32 * 0.01 % 2.0) - 1.0) .collect();
let layer1_tensor =
interpreter::TaggedTensor::from_slice(backend, &layer1_data, Shape(vec![784, 100]))
.expect("Failed to create layer1 tensor");
map.insert(
path(vec!["0", "weights"]).expect("invalid param path"),
layer1_tensor,
);
let layer2_data: Vec<f32> = (0..100 * 10)
.map(|i| (i as f32 * 0.01 % 2.0) - 1.0)
.collect();
let layer2_tensor =
interpreter::TaggedTensor::from_slice(backend, &layer2_data, Shape(vec![100, 10]))
.expect("Failed to create layer2 tensor");
map.insert(
path(vec!["1", "weights"]).expect("invalid param path"),
layer2_tensor,
);
interpreter::Parameters::from(map)
}
#[cfg(feature = "svg")]
pub fn save_svg<
O: PartialEq + Clone + std::fmt::Display + std::fmt::Debug,
A: PartialEq + Clone + std::fmt::Display + std::fmt::Debug,
>(
term: &open_hypergraphs::lax::OpenHypergraph<O, A>,
filename: &str,
) -> Result<(), std::io::Error> {
use catgrad::svg::to_svg;
let bytes = match to_svg(term) {
Ok(bytes) => bytes,
Err(e) => {
eprintln!("Failed to generate SVG: {e}");
return Ok(());
}
};
let output_dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("examples")
.join("images");
if let Err(e) = std::fs::create_dir_all(&output_dir) {
eprintln!("Failed to create directory {output_dir:?}: {e}");
return Ok(());
}
let output_path = output_dir.join(filename);
println!("saving svg to {output_path:?}");
if let Err(e) = std::fs::write(&output_path, bytes) {
eprintln!("Failed to write SVG file {output_path:?}: {e}");
}
Ok(())
}
#[cfg(not(feature = "svg"))]
pub fn save_svg<O, A>(
_term: &open_hypergraphs::lax::OpenHypergraph<O, A>,
_filename: &str,
) -> Result<(), std::io::Error> {
println!("SVG feature not enabled, skipping diagram generation");
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn main() {
super::main().unwrap();
}
}