hydroflow_plus 0.10.0

Functional programming API for hydroflow
Documentation
use std::cell::RefCell;
use std::collections::HashSet;

use crate::ir::*;

fn persist_pullup_node(
    node: &mut HfPlusNode,
    persist_pulled_tees: &mut HashSet<*const RefCell<HfPlusNode>>,
) {
    *node = match std::mem::replace(node, HfPlusNode::Placeholder) {
        HfPlusNode::Unpersist(box HfPlusNode::Persist(box behind_persist)) => behind_persist,

        HfPlusNode::Delta(box HfPlusNode::Persist(box behind_persist)) => behind_persist,

        HfPlusNode::Tee { inner } => {
            if persist_pulled_tees.contains(&(inner.0.as_ref() as *const RefCell<HfPlusNode>)) {
                HfPlusNode::Persist(Box::new(HfPlusNode::Tee {
                    inner: TeeNode(inner.0.clone()),
                }))
            } else if matches!(*inner.0.borrow(), HfPlusNode::Persist(_)) {
                persist_pulled_tees.insert(inner.0.as_ref() as *const RefCell<HfPlusNode>);
                if let HfPlusNode::Persist(box behind_persist) =
                    inner.0.replace(HfPlusNode::Placeholder)
                {
                    *inner.0.borrow_mut() = behind_persist;
                } else {
                    unreachable!()
                }

                HfPlusNode::Persist(Box::new(HfPlusNode::Tee {
                    inner: TeeNode(inner.0.clone()),
                }))
            } else {
                HfPlusNode::Tee { inner }
            }
        }

        HfPlusNode::Map {
            f,
            input: box HfPlusNode::Persist(behind_persist),
        } => HfPlusNode::Persist(Box::new(HfPlusNode::Map {
            f,
            input: behind_persist,
        })),

        HfPlusNode::FilterMap {
            f,
            input: box HfPlusNode::Persist(behind_persist),
        } => HfPlusNode::Persist(Box::new(HfPlusNode::FilterMap {
            f,
            input: behind_persist,
        })),

        HfPlusNode::FlatMap {
            f,
            input: box HfPlusNode::Persist(behind_persist),
        } => HfPlusNode::Persist(Box::new(HfPlusNode::FlatMap {
            f,
            input: behind_persist,
        })),

        HfPlusNode::Filter {
            f,
            input: box HfPlusNode::Persist(behind_persist),
        } => HfPlusNode::Persist(Box::new(HfPlusNode::Filter {
            f,
            input: behind_persist,
        })),

        HfPlusNode::Network {
            from_location,
            from_key,
            to_location,
            to_key,
            serialize_pipeline,
            instantiate_fn,
            deserialize_pipeline,
            input: box HfPlusNode::Persist(behind_persist),
            ..
        } => HfPlusNode::Persist(Box::new(HfPlusNode::Network {
            from_location,
            from_key,
            to_location,
            to_key,
            serialize_pipeline,
            instantiate_fn,
            deserialize_pipeline,
            input: behind_persist,
        })),

        HfPlusNode::Union(box HfPlusNode::Persist(left), box HfPlusNode::Persist(right)) => {
            HfPlusNode::Persist(Box::new(HfPlusNode::Union(left, right)))
        }

        HfPlusNode::CrossProduct(box HfPlusNode::Persist(left), box HfPlusNode::Persist(right)) => {
            HfPlusNode::Persist(Box::new(HfPlusNode::Delta(Box::new(
                HfPlusNode::CrossProduct(
                    Box::new(HfPlusNode::Persist(left)),
                    Box::new(HfPlusNode::Persist(right)),
                ),
            ))))
        }

        HfPlusNode::Join(box HfPlusNode::Persist(left), box HfPlusNode::Persist(right)) => {
            HfPlusNode::Persist(Box::new(HfPlusNode::Delta(Box::new(HfPlusNode::Join(
                Box::new(HfPlusNode::Persist(left)),
                Box::new(HfPlusNode::Persist(right)),
            )))))
        }

        HfPlusNode::Unique(box HfPlusNode::Persist(inner)) => {
            HfPlusNode::Persist(Box::new(HfPlusNode::Delta(Box::new(HfPlusNode::Unique(
                Box::new(HfPlusNode::Persist(inner)),
            )))))
        }

        node => node,
    };
}

pub fn persist_pullup(ir: Vec<HfPlusLeaf>) -> Vec<HfPlusLeaf> {
    let mut seen_tees = Default::default();
    let mut persist_pulled_tees = Default::default();
    ir.into_iter()
        .map(|l| {
            l.transform_children(
                |n, s| n.transform_bottom_up(persist_pullup_node, s, &mut persist_pulled_tees),
                &mut seen_tees,
            )
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use stageleft::*;

    use crate::deploy::MultiGraph;
    use crate::location::Location;

    #[test]
    fn persist_pullup_through_map() {
        let flow = crate::builder::FlowBuilder::new();
        let process = flow.process::<()>();

        process
            .source_iter(q!(0..10))
            .map(q!(|v| v + 1))
            .for_each(q!(|n| println!("{}", n)));

        let built = flow.finalize();

        insta::assert_debug_snapshot!(built.ir());

        let optimized = built.optimize_with(super::persist_pullup);

        insta::assert_debug_snapshot!(optimized.ir());
        for (id, graph) in optimized.compile_no_network::<MultiGraph>().hydroflow_ir() {
            insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
                insta::assert_snapshot!(graph.surface_syntax_string());
            });
        }
    }

    #[test]
    fn persist_pullup_behind_tee() {
        let flow = crate::builder::FlowBuilder::new();
        let process = flow.process::<()>();

        let tick = process.tick();
        let before_tee = process.source_iter(q!(0..10)).tick_batch(&tick).persist();

        before_tee
            .clone()
            .map(q!(|v| v + 1))
            .all_ticks()
            .for_each(q!(|n| println!("{}", n)));

        before_tee
            .clone()
            .map(q!(|v| v + 1))
            .all_ticks()
            .for_each(q!(|n| println!("{}", n)));

        let built = flow.finalize();

        insta::assert_debug_snapshot!(built.ir());

        let optimized = built.optimize_with(super::persist_pullup);

        insta::assert_debug_snapshot!(optimized.ir());

        for (id, graph) in optimized.compile_no_network::<MultiGraph>().hydroflow_ir() {
            insta::with_settings!({snapshot_suffix => format!("surface_graph_{id}")}, {
                insta::assert_snapshot!(graph.surface_syntax_string());
            });
        }
    }
}