1use std::error::Error;
4
5use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView};
6
7#[derive(Clone)]
11pub struct NodeInfo {
12 name: String,
13 shape: Vec<Dimension>,
14}
15
16impl NodeInfo {
17 pub fn name(&self) -> &str {
18 &self.name
19 }
20
21 pub fn shape(&self) -> &[Dimension] {
22 &self.shape
23 }
24
25 pub fn from_name_shape(name: &str, shape: &[Dimension]) -> NodeInfo {
26 NodeInfo {
27 name: name.to_string(),
28 shape: shape.to_vec(),
29 }
30 }
31}
32
33pub trait Model {
39 fn find_node(&self, name: &str) -> Option<NodeId>;
41
42 fn node_info(&self, id: NodeId) -> Option<NodeInfo>;
47
48 fn input_ids(&self) -> &[NodeId];
50
51 fn run(
53 &self,
54 inputs: Vec<(NodeId, ValueOrView)>,
55 outputs: &[NodeId],
56 opts: Option<RunOptions>,
57 ) -> Result<Vec<Value>, Box<dyn Error>>;
58
59 fn partial_run(
62 &self,
63 inputs: Vec<(NodeId, ValueOrView)>,
64 outputs: &[NodeId],
65 opts: Option<RunOptions>,
66 ) -> Result<Vec<(NodeId, Value)>, Box<dyn Error>>;
67}
68
69impl Model for rten::Model {
70 fn find_node(&self, name: &str) -> Option<NodeId> {
71 self.find_node(name)
72 }
73
74 fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
75 self.node_info(id).and_then(|info| {
76 let name = info.name()?;
77 let dims = info.shape()?;
78
79 Some(NodeInfo {
80 name: name.to_string(),
81 shape: dims,
82 })
83 })
84 }
85
86 fn input_ids(&self) -> &[NodeId] {
87 self.input_ids()
88 }
89
90 fn run(
91 &self,
92 inputs: Vec<(NodeId, ValueOrView)>,
93 outputs: &[NodeId],
94 opts: Option<RunOptions>,
95 ) -> Result<Vec<Value>, Box<dyn Error>> {
96 self.run(inputs, outputs, opts).map_err(|e| e.into())
97 }
98
99 fn partial_run(
100 &self,
101 inputs: Vec<(NodeId, ValueOrView)>,
102 outputs: &[NodeId],
103 opts: Option<RunOptions>,
104 ) -> Result<Vec<(NodeId, Value)>, Box<dyn Error>> {
105 self.partial_run(inputs, outputs, opts)
106 .map_err(|e| e.into())
107 }
108}