1use std::collections::HashSet;
7
8use crate::proto::onnx::{FunctionProto, ModelProto};
9
10#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum VerifyError {
13 EmptyOpType {
15 function_name: String,
17 node_index: usize,
19 },
20
21 UnpairedWireSend {
24 wire_id: u64,
26 function_name: String,
28 },
29
30 UnpairedWireRecv {
32 wire_id: u64,
34 function_name: String,
36 },
37
38 UnresolvedFunctionCall {
40 target_name: String,
42 function_name: String,
44 node_index: usize,
46 },
47
48 EmptyFunctionTable,
50}
51
52impl std::fmt::Display for VerifyError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Self::EmptyOpType {
56 function_name,
57 node_index,
58 } => write!(
59 f,
60 "empty op_type at function `{}` node #{}",
61 function_name, node_index
62 ),
63 Self::UnpairedWireSend {
64 wire_id,
65 function_name,
66 } => write!(
67 f,
68 "wire.Send with wire_id={} in function `{}` has no matching Recv",
69 wire_id, function_name
70 ),
71 Self::UnpairedWireRecv {
72 wire_id,
73 function_name,
74 } => write!(
75 f,
76 "wire.Recv with wire_id={} in function `{}` has no matching Send",
77 wire_id, function_name
78 ),
79 Self::UnresolvedFunctionCall {
80 target_name,
81 function_name,
82 node_index,
83 } => write!(
84 f,
85 "function `{}` node #{} calls undefined function `{}`",
86 function_name, node_index, target_name
87 ),
88 Self::EmptyFunctionTable => f.write_str("ModelProto.functions is empty"),
89 }
90 }
91}
92
93impl std::error::Error for VerifyError {}
94
95pub fn types(model: &ModelProto) -> Result<(), VerifyError> {
98 if model.functions.is_empty() {
99 return Err(VerifyError::EmptyFunctionTable);
100 }
101 for function in &model.functions {
102 for (i, node) in function.node.iter().enumerate() {
103 if node.op_type.is_empty() {
104 return Err(VerifyError::EmptyOpType {
105 function_name: function.name.clone(),
106 node_index: i,
107 });
108 }
109 }
110 }
111 Ok(())
112}
113
114pub fn wire_pairs(model: &ModelProto) -> Result<(), VerifyError> {
116 for function in &model.functions {
117 let mut sends: HashSet<u64> = HashSet::new();
118 let mut recvs: HashSet<u64> = HashSet::new();
119 for node in &function.node {
120 let Some(wire_id) = read_wire_id(node) else {
121 continue;
122 };
123 if node.op_type == "Send" {
124 sends.insert(wire_id);
125 } else if node.op_type == "Recv" {
126 recvs.insert(wire_id);
127 }
128 }
129 if let Some(wire_id) = sends.difference(&recvs).next() {
130 return Err(VerifyError::UnpairedWireSend {
131 wire_id: *wire_id,
132 function_name: function.name.clone(),
133 });
134 }
135 if let Some(wire_id) = recvs.difference(&sends).next() {
136 return Err(VerifyError::UnpairedWireRecv {
137 wire_id: *wire_id,
138 function_name: function.name.clone(),
139 });
140 }
141 }
142 Ok(())
143}
144
145pub fn function_calls(model: &ModelProto) -> Result<(), VerifyError> {
147 let defined: HashSet<&str> = model.functions.iter().map(|f| f.name.as_str()).collect();
148 for function in &model.functions {
149 for (i, node) in function.node.iter().enumerate() {
150 if !node.op_type.starts_with("Call") {
151 continue;
152 }
153 let target = node.name.as_str();
155 if target.is_empty() {
156 continue;
157 }
158 if !defined.contains(target) {
159 return Err(VerifyError::UnresolvedFunctionCall {
160 target_name: target.to_string(),
161 function_name: function.name.clone(),
162 node_index: i,
163 });
164 }
165 }
166 }
167 Ok(())
168}
169
170fn read_wire_id(node: &crate::proto::onnx::NodeProto) -> Option<u64> {
173 node.metadata_props
174 .iter()
175 .find(|p| p.key == crate::keys::WIRE_ID_KEY)
176 .and_then(|p| p.value.parse::<u64>().ok())
177}
178
179pub fn wire_pairs_in_function(function: &FunctionProto) -> Result<(), VerifyError> {
181 let mut sends: HashSet<u64> = HashSet::new();
182 let mut recvs: HashSet<u64> = HashSet::new();
183 for node in &function.node {
184 let Some(wire_id) = read_wire_id(node) else {
185 continue;
186 };
187 if node.op_type == "Send" {
188 sends.insert(wire_id);
189 } else if node.op_type == "Recv" {
190 recvs.insert(wire_id);
191 }
192 }
193 if let Some(wire_id) = sends.difference(&recvs).next() {
194 return Err(VerifyError::UnpairedWireSend {
195 wire_id: *wire_id,
196 function_name: function.name.clone(),
197 });
198 }
199 if let Some(wire_id) = recvs.difference(&sends).next() {
200 return Err(VerifyError::UnpairedWireRecv {
201 wire_id: *wire_id,
202 function_name: function.name.clone(),
203 });
204 }
205 Ok(())
206}
207