use oxionnx_core::Tensor;
use std::collections::HashMap;
#[derive(Debug, Default)]
pub struct IoBinding {
inputs: HashMap<String, Tensor>,
outputs: HashMap<String, Tensor>,
}
impl IoBinding {
pub fn new() -> Self {
Self::default()
}
pub fn bind_input(&mut self, name: impl Into<String>, tensor: Tensor) {
self.inputs.insert(name.into(), tensor);
}
pub fn bind_output(&mut self, name: impl Into<String>, buffer: Tensor) {
self.outputs.insert(name.into(), buffer);
}
pub fn get_output(&self, name: &str) -> Option<&Tensor> {
self.outputs.get(name)
}
pub fn get_output_mut(&mut self, name: &str) -> Option<&mut Tensor> {
self.outputs.get_mut(name)
}
pub fn clear_inputs(&mut self) {
self.inputs.clear();
}
pub fn clear(&mut self) {
self.inputs.clear();
self.outputs.clear();
}
pub fn input_names(&self) -> impl Iterator<Item = &str> {
self.inputs.keys().map(|s| s.as_str())
}
pub fn output_names(&self) -> impl Iterator<Item = &str> {
self.outputs.keys().map(|s| s.as_str())
}
pub(crate) fn inputs(&self) -> &HashMap<String, Tensor> {
&self.inputs
}
pub(crate) fn outputs_mut(&mut self) -> &mut HashMap<String, Tensor> {
&mut self.outputs
}
}