use oxionnx_core::Tensor;
use std::collections::HashMap;
#[derive(Debug, Default)]
pub struct IoBinding {
inputs: HashMap<String, Tensor>,
outputs: HashMap<String, Tensor>,
}
impl IoBinding {
pub fn new() -> Self {
Self::default()
}
pub fn bind_input(&mut self, name: impl Into<String>, tensor: Tensor) {
self.inputs.insert(name.into(), tensor);
}
pub fn bind_output(&mut self, name: impl Into<String>, buffer: Tensor) {
self.outputs.insert(name.into(), buffer);
}
pub fn get_output(&self, name: &str) -> Option<&Tensor> {
self.outputs.get(name)
}
pub fn get_output_mut(&mut self, name: &str) -> Option<&mut Tensor> {
self.outputs.get_mut(name)
}
pub fn clear_inputs(&mut self) {
self.inputs.clear();
}
pub fn clear(&mut self) {
self.inputs.clear();
self.outputs.clear();
}
pub fn input_names(&self) -> impl Iterator<Item = &str> {
self.inputs.keys().map(|s| s.as_str())
}
pub fn output_names(&self) -> impl Iterator<Item = &str> {
self.outputs.keys().map(|s| s.as_str())
}
pub(crate) fn inputs(&self) -> &HashMap<String, Tensor> {
&self.inputs
}
pub(crate) fn take_output_buffer(&mut self, name: &str) -> Option<Tensor> {
self.outputs.remove(name)
}
pub(crate) fn put_output_buffer(&mut self, name: String, tensor: Tensor) {
self.outputs.insert(name, tensor);
}
}