use super::super::types::OptLevel;
use super::super::SessionBuilder;
use crate::graph::{Attributes, Graph, Node, OpKind};
use crate::tensor::Tensor;
use std::collections::HashMap;
#[test]
fn test_op_placement_cpu_only() {
use crate::execution_providers::{decide_placement, OpPlacement, ProviderKind};
let placement = OpPlacement::CpuOnly;
let ops = [
OpKind::MatMul,
OpKind::Conv,
OpKind::Add,
OpKind::Reshape,
OpKind::Softmax,
OpKind::Relu,
];
for op in &ops {
let result = decide_placement(op, 1_000_000, &placement);
assert_eq!(
result,
ProviderKind::Cpu,
"CpuOnly must always return Cpu for {:?}",
op
);
}
}
#[test]
fn test_op_placement_auto_small_input() {
use crate::execution_providers::{decide_placement, OpPlacement, ProviderKind};
let placement = OpPlacement::Auto {
gpu_threshold_bytes: 65536,
};
let result = decide_placement(&OpKind::MatMul, 100, &placement);
assert_eq!(result, ProviderKind::Cpu);
}
#[test]
fn test_op_placement_auto_threshold() {
use crate::execution_providers::{decide_placement, OpPlacement, ProviderKind};
let placement = OpPlacement::Auto {
gpu_threshold_bytes: 1024,
};
let below = decide_placement(&OpKind::MatMul, 512, &placement);
assert_eq!(below, ProviderKind::Cpu);
let at = decide_placement(&OpKind::MatMul, 1024, &placement);
#[cfg(feature = "gpu")]
assert_eq!(at, ProviderKind::Gpu);
#[cfg(not(feature = "gpu"))]
assert_eq!(at, ProviderKind::Cpu);
let reshape = decide_placement(&OpKind::Reshape, 2048, &placement);
assert_eq!(reshape, ProviderKind::Cpu);
}
#[test]
fn test_op_placement_manual() {
use crate::execution_providers::{decide_placement, OpPlacement, ProviderKind};
let mut map = HashMap::new();
#[cfg(feature = "gpu")]
{
map.insert(OpKind::MatMul, ProviderKind::Gpu);
}
#[cfg(not(feature = "gpu"))]
{
map.insert(OpKind::MatMul, ProviderKind::Cpu);
}
let placement = OpPlacement::Manual(map);
let matmul_result = decide_placement(&OpKind::MatMul, 0, &placement);
#[cfg(feature = "gpu")]
assert_eq!(matmul_result, ProviderKind::Gpu);
#[cfg(not(feature = "gpu"))]
assert_eq!(matmul_result, ProviderKind::Cpu);
let reshape_result = decide_placement(&OpKind::Reshape, 0, &placement);
assert_eq!(reshape_result, ProviderKind::Cpu);
}
#[test]
fn test_decide_placement_default() {
use crate::execution_providers::{decide_placement, OpPlacement, ProviderKind};
let placement = OpPlacement::default();
let result = decide_placement(&OpKind::Add, 999999, &placement);
assert_eq!(result, ProviderKind::Cpu);
}
#[test]
fn test_is_gpu_capable_matmul() {
use crate::execution_providers::is_gpu_capable;
assert!(is_gpu_capable(&OpKind::MatMul));
assert!(is_gpu_capable(&OpKind::Gemm));
assert!(is_gpu_capable(&OpKind::Conv));
assert!(is_gpu_capable(&OpKind::Softmax));
assert!(is_gpu_capable(&OpKind::Relu));
assert!(is_gpu_capable(&OpKind::ReduceMean));
}
#[test]
fn test_is_gpu_capable_reshape() {
use crate::execution_providers::is_gpu_capable;
assert!(!is_gpu_capable(&OpKind::Reshape));
assert!(!is_gpu_capable(&OpKind::Squeeze));
assert!(!is_gpu_capable(&OpKind::Flatten));
assert!(!is_gpu_capable(&OpKind::Gather));
assert!(!is_gpu_capable(&OpKind::Shape));
}
#[test]
fn test_builder_op_placement_api() {
use crate::execution_providers::OpPlacement;
let builder = SessionBuilder::new().with_op_placement(OpPlacement::Auto {
gpu_threshold_bytes: 4096,
});
match &builder.op_placement {
OpPlacement::Auto {
gpu_threshold_bytes,
} => {
assert_eq!(*gpu_threshold_bytes, 4096);
}
other => panic!("Expected Auto, got {:?}", other),
}
let graph = Graph {
nodes: vec![Node {
name: "relu0".to_string(),
op: OpKind::Relu,
inputs: vec!["input".to_string()],
outputs: vec!["output".to_string()],
attrs: Attributes::default(),
}],
input_names: vec!["input".to_string()],
output_names: vec!["output".to_string()],
..Default::default()
};
let session = SessionBuilder::new()
.with_optimization_level(OptLevel::None)
.with_op_placement(OpPlacement::Auto {
gpu_threshold_bytes: 1024,
})
.build_from_graph(graph, HashMap::new())
.expect("build with op placement");
let input = Tensor::new(vec![-1.0, 2.0, -3.0], vec![1, 3]);
let out = session.run_one("input", input).expect("run");
let y = out.get("output").expect("output");
assert_eq!(y.data, vec![0.0, 2.0, 0.0]);
}