use onnx_ir::{Argument, ir::ArgType};
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use std::collections::HashMap;
use super::node_traits::arg_to_ident;
#[derive(Clone, Debug, Default)]
pub struct Scope {
variables: HashMap<String, TensorVariable>,
}
#[derive(Clone, Debug, new)]
struct TensorVariable {
references: usize,
node_position: usize,
}
impl Scope {
pub fn tensor_register_variable(&mut self, arg: &Argument, node_position: usize) {
if let Some(variable) = self.variables.get_mut(&arg.name) {
if variable.node_position == node_position {
variable.references += 1;
}
} else {
self.variables
.insert(arg.name.clone(), TensorVariable::new(0, node_position));
}
}
pub fn tensor_register_future_use(&mut self, arg: &Argument, node_position: usize) {
if let Some(variable) = self.variables.get_mut(&arg.name) {
if node_position >= variable.node_position {
variable.references += 1;
}
} else {
self.variables
.insert(arg.name.clone(), TensorVariable::new(1, node_position));
}
}
pub fn tensor_use_owned(&mut self, arg: &Argument, node_position: usize) -> TokenStream {
let name = Ident::new(&arg.name, Span::call_site());
if let Some(variable) = self.variables.get_mut(&arg.name) {
let mut count = 0;
if node_position >= variable.node_position {
if variable.references > 0 {
variable.references -= 1;
count = variable.references;
}
}
if count > 0 {
quote! {
#name.clone()
}
} else {
quote! {
#name
}
}
} else {
quote! {
#name
}
}
}
pub fn at_position(&mut self, node_position: usize) -> ScopeAtPosition<'_> {
ScopeAtPosition {
scope: self,
node_position,
}
}
}
pub struct ScopeAtPosition<'a> {
scope: &'a mut Scope,
node_position: usize,
}
impl<'a> ScopeAtPosition<'a> {
pub fn arg(&mut self, arg: &Argument) -> TokenStream {
match &arg.ty {
ArgType::Tensor(_) | ArgType::ScalarTensor(_) => {
self.scope.tensor_use_owned(arg, self.node_position)
}
ArgType::ScalarNative(_) | ArgType::Shape(_) => {
let name = arg_to_ident(arg);
quote! { #name }
}
}
}
pub fn scope(&mut self) -> &mut Scope {
self.scope
}
pub fn node_position(&self) -> usize {
self.node_position
}
}