Skip to main content

polars_stream/
graph.rs

1use std::time::Instant;
2
3use parking_lot::Mutex;
4use polars_error::PolarsResult;
5use slotmap::{Key, SecondaryMap, SlotMap};
6
7use crate::execute::StreamingExecutionState;
8use crate::metrics::GraphMetrics;
9use crate::nodes::ComputeNode;
10
11slotmap::new_key_type! {
12    pub struct GraphNodeKey;
13    pub struct LogicalPipeKey;
14}
15
16/// Represents the compute graph.
17///
18/// The `nodes` perform computation and the `pipes` form the connections between nodes
19/// that data is sent through.
20#[derive(Default)]
21pub struct Graph {
22    pub nodes: SlotMap<GraphNodeKey, GraphNode>,
23    pub pipes: SlotMap<LogicalPipeKey, LogicalPipe>,
24}
25
26impl Graph {
27    /// Allocate the needed `capacity` for the `Graph`.
28    pub fn with_capacity(capacity: usize) -> Self {
29        Self {
30            nodes: SlotMap::with_capacity_and_key(capacity),
31            pipes: SlotMap::with_capacity_and_key(capacity),
32        }
33    }
34
35    /// Add a new `GraphNode` to the `Graph` and connect the inputs and outputs
36    /// to their respective `LogicalPipe`s.
37    pub fn add_node<N: ComputeNode + 'static>(
38        &mut self,
39        node: N,
40        inputs: impl IntoIterator<Item = (GraphNodeKey, usize)>,
41    ) -> GraphNodeKey {
42        // Add the GraphNode.
43        let node_key = self.nodes.insert(GraphNode {
44            compute: Box::new(node),
45            inputs: Vec::new(),
46            outputs: Vec::new(),
47        });
48
49        // Create and add pipes that connect input to output.
50        for (recv_port, (sender, send_port)) in inputs.into_iter().enumerate() {
51            let pipe = LogicalPipe {
52                sender,
53                send_port,
54                send_state: PortState::Blocked,
55                receiver: node_key,
56                recv_port,
57                recv_state: PortState::Blocked,
58            };
59
60            // Add the pipe.
61            let pipe_key = self.pipes.insert(pipe);
62
63            // And connect input to output.
64            self.nodes[node_key].inputs.push(pipe_key);
65            if self.nodes[sender].outputs.len() <= send_port {
66                self.nodes[sender]
67                    .outputs
68                    .resize(send_port + 1, LogicalPipeKey::null());
69            }
70            assert!(self.nodes[sender].outputs[send_port].is_null());
71            self.nodes[sender].outputs[send_port] = pipe_key;
72        }
73
74        node_key
75    }
76
77    /// Updates all the nodes' states until a fixed point is reached.
78    pub fn update_all_states(
79        &mut self,
80        state: &StreamingExecutionState,
81        metrics: Option<&Mutex<GraphMetrics>>,
82    ) -> PolarsResult<()> {
83        let mut to_update: Vec<_> = self.nodes.keys().collect();
84        let mut scheduled_for_update: SecondaryMap<GraphNodeKey, ()> =
85            self.nodes.keys().map(|k| (k, ())).collect();
86
87        let verbose = std::env::var("POLARS_VERBOSE_STATE_UPDATE").as_deref() == Ok("1");
88
89        let mut recv_state = Vec::new();
90        let mut send_state = Vec::new();
91        while let Some(node_key) = to_update.pop() {
92            scheduled_for_update.remove(node_key);
93            let node = &mut self.nodes[node_key];
94
95            // Get the states of nodes this node is connected to.
96            recv_state.clear();
97            send_state.clear();
98            recv_state.extend(node.inputs.iter().map(|i| self.pipes[*i].send_state));
99            send_state.extend(node.outputs.iter().map(|o| self.pipes[*o].recv_state));
100
101            // Compute the new state of this node given its environment.
102            if verbose {
103                eprintln!(
104                    "updating {}, before: {recv_state:?} {send_state:?}",
105                    node.compute.name()
106                );
107            }
108            let start = (metrics.is_some() || verbose).then(Instant::now);
109            if let Some(lock) = metrics {
110                lock.lock().start_state_update(node_key);
111            }
112
113            node.compute
114                .update_state(&mut recv_state, &mut send_state, state)?;
115            let elapsed = start.map(|s| s.elapsed());
116            if let Some(lock) = metrics {
117                let is_done = recv_state.iter().all(|s| *s == PortState::Done)
118                    && send_state.iter().all(|s| *s == PortState::Done);
119                lock.lock()
120                    .stop_state_update(node_key, elapsed.unwrap(), is_done);
121            }
122            if verbose {
123                eprintln!(
124                    "updating {}, after: {recv_state:?} {send_state:?} (took {:?})",
125                    node.compute.name(),
126                    elapsed.unwrap()
127                );
128            }
129
130            // Propagate information.
131            for (input, state) in node.inputs.iter().zip(recv_state.iter()) {
132                let pipe = &mut self.pipes[*input];
133                if pipe.recv_state != *state {
134                    assert!(
135                        pipe.recv_state != PortState::Done,
136                        "implementation error: state transition from Done to Blocked/Ready attempted"
137                    );
138                    pipe.recv_state = *state;
139                    if scheduled_for_update.insert(pipe.sender, ()).is_none() {
140                        to_update.push(pipe.sender);
141                    }
142                }
143            }
144
145            for (output, state) in node.outputs.iter().zip(send_state.iter()) {
146                let pipe = &mut self.pipes[*output];
147                if pipe.send_state != *state {
148                    assert!(
149                        pipe.send_state != PortState::Done,
150                        "implementation error: state transition from Done to Blocked/Ready attempted"
151                    );
152                    pipe.send_state = *state;
153                    if scheduled_for_update.insert(pipe.receiver, ()).is_none() {
154                        to_update.push(pipe.receiver);
155                    }
156                }
157            }
158        }
159        Ok(())
160    }
161}
162
163/// A node in the graph represents a computation performed on the stream of morsels
164/// that flow through it.
165pub struct GraphNode {
166    pub compute: Box<dyn ComputeNode>,
167    pub inputs: Vec<LogicalPipeKey>,
168    pub outputs: Vec<LogicalPipeKey>,
169}
170
171/// A pipe sends data between nodes.
172#[allow(unused)] // TODO: remove.
173#[derive(Clone)]
174pub struct LogicalPipe {
175    // Node that we send data to.
176    pub sender: GraphNodeKey,
177    // Output location:
178    // graph[x].output[i].send_port == i
179    pub send_port: usize,
180    pub send_state: PortState,
181
182    // Node that we receive data from.
183    pub receiver: GraphNodeKey,
184    // Input location:
185    // graph[x].inputs[i].recv_port == i
186    pub recv_port: usize,
187    pub recv_state: PortState,
188}
189
190#[derive(Copy, Clone, PartialEq, Eq, Debug, PartialOrd, Ord)]
191pub enum PortState {
192    Blocked,
193    Ready,
194    Done,
195}