rten_generate/
model.rs

1//! Abstraction over [`rten::Model`] for querying and executing ML models.
2
3use std::error::Error;
4
5use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView};
6
7/// Describes the name and shape of a model input or output.
8///
9/// This is similar to [`rten::NodeInfo`] but the name and shape are required.
10#[derive(Clone)]
11pub struct NodeInfo {
12    name: String,
13    shape: Vec<Dimension>,
14}
15
16impl NodeInfo {
17    pub fn name(&self) -> &str {
18        &self.name
19    }
20
21    pub fn shape(&self) -> &[Dimension] {
22        &self.shape
23    }
24
25    pub fn from_name_shape(name: &str, shape: &[Dimension]) -> NodeInfo {
26        NodeInfo {
27            name: name.to_string(),
28            shape: shape.to_vec(),
29        }
30    }
31}
32
33/// Abstraction over [`rten::Model`] used by [`Generator`](crate::Generator) to
34/// query and execute a machine learning model.
35///
36/// This is implemented by [`rten::Model`] and the trait's methods correspond
37/// to methods of the same name in that type.
38pub trait Model {
39    /// Get the ID of an input or output node.
40    fn find_node(&self, name: &str) -> Option<NodeId>;
41
42    /// Get the name and shape of an input or output node.
43    ///
44    /// Returns `None` if the node does not exist, or name or shape information
45    /// is not available.
46    fn node_info(&self, id: NodeId) -> Option<NodeInfo>;
47
48    /// Return the node IDs of the model's inputs.
49    fn input_ids(&self) -> &[NodeId];
50
51    /// Run the model with the provided inputs and return the results.
52    fn run(
53        &self,
54        inputs: Vec<(NodeId, ValueOrView)>,
55        outputs: &[NodeId],
56        opts: Option<RunOptions>,
57    ) -> Result<Vec<Value>, Box<dyn Error>>;
58
59    /// Run as much of the model as possible given the provided inputs and
60    /// return the leaves of the evaluation where execution stopped.
61    fn partial_run(
62        &self,
63        inputs: Vec<(NodeId, ValueOrView)>,
64        outputs: &[NodeId],
65        opts: Option<RunOptions>,
66    ) -> Result<Vec<(NodeId, Value)>, Box<dyn Error>>;
67}
68
69impl Model for rten::Model {
70    fn find_node(&self, name: &str) -> Option<NodeId> {
71        self.find_node(name)
72    }
73
74    fn node_info(&self, id: NodeId) -> Option<NodeInfo> {
75        self.node_info(id).and_then(|info| {
76            let name = info.name()?;
77            let dims = info.shape()?;
78
79            Some(NodeInfo {
80                name: name.to_string(),
81                shape: dims,
82            })
83        })
84    }
85
86    fn input_ids(&self) -> &[NodeId] {
87        self.input_ids()
88    }
89
90    fn run(
91        &self,
92        inputs: Vec<(NodeId, ValueOrView)>,
93        outputs: &[NodeId],
94        opts: Option<RunOptions>,
95    ) -> Result<Vec<Value>, Box<dyn Error>> {
96        self.run(inputs, outputs, opts).map_err(|e| e.into())
97    }
98
99    fn partial_run(
100        &self,
101        inputs: Vec<(NodeId, ValueOrView)>,
102        outputs: &[NodeId],
103        opts: Option<RunOptions>,
104    ) -> Result<Vec<(NodeId, Value)>, Box<dyn Error>> {
105        self.partial_run(inputs, outputs, opts)
106            .map_err(|e| e.into())
107    }
108}