use crate::{
TypeExpr,
inference::{FlowSourceLocation, InferenceStep, Scopes, infer},
nodety::{Edge, IntoNode, Node, Nodety, NodetyError, inference::InferenceConfig},
scope::ScopePointer,
r#type::Type,
type_expr::Unscoped,
validation::ValidationError,
};
use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::StableDiGraph;
use std::cell::RefCell;
#[cfg(feature = "json-schema")]
use schemars::JsonSchema;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "tsify")]
use tsify::Tsify;
pub struct NodetyCached<T: Type> {
nodety: Nodety<T>,
config: InferenceConfig<T>,
cache: RefCell<Option<Scopes<T>>>,
}
impl<T: Type> NodetyCached<T> {
pub fn new(steps: Vec<InferenceStep>) -> Self {
Self {
nodety: Nodety::new(),
config: InferenceConfig { steps, ..Default::default() },
cache: RefCell::new(None),
}
}
pub fn from_nodety(nodety: Nodety<T>, steps: Vec<InferenceStep>) -> Self {
Self { nodety, config: InferenceConfig { steps, ..Default::default() }, cache: RefCell::new(None) }
}
pub fn with_capacity(nodes: usize, edges: usize, steps: Vec<InferenceStep>) -> Self {
Self {
nodety: Nodety::with_capacity(nodes, edges),
config: InferenceConfig { steps, ..Default::default() },
cache: RefCell::new(None),
}
}
fn invalidate_cache(&self) {
*self.cache.borrow_mut() = None;
}
pub fn add_node(&mut self, node: impl IntoNode<T>) -> Result<NodeIndex, NodetyError> {
let result = self.nodety.add_node(node)?;
self.invalidate_cache();
Ok(result)
}
pub fn update_node(&mut self, node_id: NodeIndex, node: impl IntoNode<T>) -> Result<(), NodetyError> {
self.nodety.update_node(node_id, node)?;
self.invalidate_cache();
Ok(())
}
pub fn remove_node(&mut self, node_id: NodeIndex) -> Result<(), NodetyError> {
self.nodety.remove_node(node_id)?;
self.invalidate_cache();
Ok(())
}
pub fn add_edge(&mut self, source: NodeIndex, target: NodeIndex, edge: Edge) -> Result<EdgeIndex, NodetyError> {
let idx = self.nodety.add_edge(source, target, edge)?;
self.invalidate_cache();
Ok(idx)
}
pub fn remove_edge(&mut self, edge_idx: EdgeIndex) -> Option<Edge> {
self.nodety.remove_edge(edge_idx).inspect(|_| self.invalidate_cache())
}
pub fn get_node(&self, node_idx: NodeIndex) -> Option<&Node<T, Unscoped>> {
self.nodety.get_node(node_idx)
}
pub fn program(&self) -> &StableDiGraph<Node<T, Unscoped>, Edge> {
self.nodety.program()
}
pub fn to_dot(&self) -> String {
self.nodety.to_dot()
}
pub fn infer(&self) -> Scopes<T> {
if let Some(cached) = self.cache.borrow().as_ref() {
return cached.clone();
}
let result = self.nodety.infer(&self.config);
*self.cache.borrow_mut() = Some(result.clone());
result
}
pub fn validate(&self) -> Vec<ValidationError<T>> {
if let Some(cached) = self.cache.borrow().as_ref() {
return self.nodety.validate(cached);
}
let scopes = self.nodety.infer(&self.config);
let result = self.nodety.validate(&scopes);
*self.cache.borrow_mut() = Some(scopes);
result
}
pub fn infer_node_scope(
&self,
node_idx: NodeIndex,
exclude_input: Option<ExcludePorts>,
exclude_output: Option<ExcludePorts>,
) -> Result<ScopePointer<T>, NodetyError> {
if exclude_input.is_none() && exclude_output.is_none() {
let scopes = self.infer();
let Some(node_scope) = scopes.get(&node_idx) else { return Err(NodetyError::NodeNotFound) };
return Ok(ScopePointer::clone(node_scope));
}
let scopes = self.nodety.build_scopes();
let mut flows = self.nodety.collect_flows(&scopes);
let prev_len = flows.len();
let Some(node_signature) = self.nodety.get_node(node_idx) else { return Err(NodetyError::NodeNotFound) };
let min_input_ports_len =
if let TypeExpr::PortTypes(inputs) = &node_signature.signature.inputs { inputs.ports.len() } else { 0 };
let min_output_ports_len =
if let TypeExpr::PortTypes(outputs) = &node_signature.signature.outputs { outputs.ports.len() } else { 0 };
flows.retain(|flow| {
if flow.target_location.node_idx == node_idx {
match exclude_input {
None => (),
Some(ExcludePorts::Index(idx)) if idx == flow.target_location.input_idx => return false,
Some(ExcludePorts::Vargs) if flow.target_location.input_idx > min_input_ports_len => return false,
_ => (),
};
}
let FlowSourceLocation::Output(flow_source_node_idx, output_idx) = flow.source_location else {
return true;
};
if flow_source_node_idx == node_idx {
match exclude_output {
None => (),
Some(ExcludePorts::Index(idx)) if idx == output_idx => return false,
Some(ExcludePorts::Vargs) if output_idx > min_output_ports_len => return false,
_ => (),
};
}
true
});
let changed = prev_len != flows.len();
let scopes = if changed {
let raw_flows = flows.into_iter().map(|flow| flow.flow).collect();
infer(raw_flows, &self.config);
scopes
} else {
self.infer()
};
let Some(node_scope) = scopes.get(&node_idx) else { return Err(NodetyError::NodeNotFound) };
Ok(ScopePointer::clone(node_scope))
}
pub fn inner(&self) -> &Nodety<T> {
&self.nodety
}
}
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", serde(tag = "type", content = "index"))]
#[cfg_attr(feature = "json-schema", derive(JsonSchema))]
#[cfg_attr(feature = "tsify", derive(Tsify))]
#[cfg_attr(feature = "tsify", tsify(into_wasm_abi, from_wasm_abi))]
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
pub enum ExcludePorts {
Index(usize),
Vargs,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
demo_type::DemoType,
notation::parse::{expr, sig_u},
};
#[test]
pub fn test_normalize_type_from_parent() {
let mut nodety = NodetyCached::<DemoType>::new(InferenceStep::default_steps());
let map_node_idx = nodety.add_node(Node::new(sig_u("<I>(Array<I> = Array<Integer>) -> ()"))).unwrap();
let input_node_idx =
nodety.add_node(Node::new_child(sig_u("() -> (I)"), NodeIndex::from(map_node_idx))).unwrap();
let scope = nodety.infer_node_scope(input_node_idx, Some(ExcludePorts::Index(0)), None).unwrap();
let normalized = expr("I").normalize(&scope);
assert_eq!(normalized, expr("Integer"));
}
}