polars_expr/state/
execution_state.rs

1use std::borrow::Cow;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::{Mutex, RwLock};
4use std::time::Duration;
5
6use bitflags::bitflags;
7use polars_core::config::verbose;
8use polars_core::prelude::*;
9use polars_ops::prelude::ChunkJoinOptIds;
10use polars_utils::relaxed_cell::RelaxedCell;
11use polars_utils::unique_id::UniqueId;
12
13use super::NodeTimer;
14
15pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
16
17#[derive(Default)]
18pub struct WindowCache {
19    groups: RwLock<PlHashMap<String, GroupPositions>>,
20    join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
21    map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
22}
23
24impl WindowCache {
25    pub(crate) fn clear(&self) {
26        let Self {
27            groups,
28            join_tuples,
29            map_idx,
30        } = self;
31        groups.write().unwrap().clear();
32        join_tuples.write().unwrap().clear();
33        map_idx.write().unwrap().clear();
34    }
35
36    pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
37        let g = self.groups.read().unwrap();
38        g.get(key).cloned()
39    }
40
41    pub fn insert_groups(&self, key: String, groups: GroupPositions) {
42        let mut g = self.groups.write().unwrap();
43        g.insert(key, groups);
44    }
45
46    pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
47        let g = self.join_tuples.read().unwrap();
48        g.get(key).cloned()
49    }
50
51    pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
52        let mut g = self.join_tuples.write().unwrap();
53        g.insert(key, join_tuples);
54    }
55
56    pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
57        let g = self.map_idx.read().unwrap();
58        g.get(key).cloned()
59    }
60
61    pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
62        let mut g = self.map_idx.write().unwrap();
63        g.insert(key, idx);
64    }
65}
66
67bitflags! {
68    #[repr(transparent)]
69    #[derive(Copy, Clone)]
70    pub(super) struct StateFlags: u8 {
71        /// More verbose logging
72        const VERBOSE = 0x01;
73        /// Indicates that window expression's [`GroupTuples`] may be cached.
74        const CACHE_WINDOW_EXPR = 0x02;
75        /// Indicates the expression has a window function
76        const HAS_WINDOW = 0x04;
77    }
78}
79
80impl Default for StateFlags {
81    fn default() -> Self {
82        StateFlags::CACHE_WINDOW_EXPR
83    }
84}
85
86impl StateFlags {
87    fn init() -> Self {
88        let verbose = verbose();
89        let mut flags: StateFlags = Default::default();
90        if verbose {
91            flags |= StateFlags::VERBOSE;
92        }
93        flags
94    }
95    fn as_u8(self) -> u8 {
96        unsafe { std::mem::transmute(self) }
97    }
98}
99
100impl From<u8> for StateFlags {
101    fn from(value: u8) -> Self {
102        unsafe { std::mem::transmute(value) }
103    }
104}
105
106struct CachedValue {
107    /// The number of times the cache will still be read.
108    /// Zero means that there will be no more reads and the cache can be dropped.
109    remaining_hits: AtomicI64,
110    df: DataFrame,
111}
112
113/// State/ cache that is maintained during the Execution of the physical plan.
114#[derive(Clone)]
115pub struct ExecutionState {
116    // cached by a `.cache` call and kept in memory for the duration of the plan.
117    df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
118    pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
119    /// Used by Window Expressions to cache intermediate state
120    pub window_cache: Arc<WindowCache>,
121    // every join/union split gets an increment to distinguish between schema state
122    pub branch_idx: usize,
123    pub flags: RelaxedCell<u8>,
124    pub ext_contexts: Arc<Vec<DataFrame>>,
125    node_timer: Option<NodeTimer>,
126    stop: Arc<RelaxedCell<bool>>,
127}
128
129impl ExecutionState {
130    pub fn new() -> Self {
131        let mut flags: StateFlags = Default::default();
132        if verbose() {
133            flags |= StateFlags::VERBOSE;
134        }
135        Self {
136            df_cache: Default::default(),
137            schema_cache: Default::default(),
138            window_cache: Default::default(),
139            branch_idx: 0,
140            flags: RelaxedCell::from(StateFlags::init().as_u8()),
141            ext_contexts: Default::default(),
142            node_timer: None,
143            stop: Arc::new(RelaxedCell::from(false)),
144        }
145    }
146
147    /// Toggle this to measure execution times.
148    pub fn time_nodes(&mut self, start: std::time::Instant) {
149        self.node_timer = Some(NodeTimer::new(start))
150    }
151    pub fn has_node_timer(&self) -> bool {
152        self.node_timer.is_some()
153    }
154
155    pub fn finish_timer(self) -> PolarsResult<DataFrame> {
156        self.node_timer.unwrap().finish()
157    }
158
159    // Timings should be a list of (start, end, name) where the start
160    // and end are raw durations since the query start as nanoseconds.
161    pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
162        for &(start, end, ref name) in timings {
163            self.node_timer.as_ref().unwrap().store_duration(
164                Duration::from_nanos(start),
165                Duration::from_nanos(end),
166                name.to_string(),
167            );
168        }
169    }
170
171    // This is wrong when the U64 overflows which will never happen.
172    pub fn should_stop(&self) -> PolarsResult<()> {
173        try_raise_keyboard_interrupt();
174        polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
175        Ok(())
176    }
177
178    pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
179        self.stop.clone()
180    }
181
182    pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
183        match &self.node_timer {
184            None => func(),
185            Some(timer) => {
186                let start = std::time::Instant::now();
187                let out = func();
188                let end = std::time::Instant::now();
189
190                timer.store(start, end, name.as_ref().to_string());
191                out
192            },
193        }
194    }
195
196    /// Partially clones and partially clears state
197    /// This should be used when splitting a node, like a join or union
198    pub fn split(&self) -> Self {
199        Self {
200            df_cache: self.df_cache.clone(),
201            schema_cache: Default::default(),
202            window_cache: Default::default(),
203            branch_idx: self.branch_idx,
204            flags: self.flags.clone(),
205            ext_contexts: self.ext_contexts.clone(),
206            node_timer: self.node_timer.clone(),
207            stop: self.stop.clone(),
208        }
209    }
210
211    pub fn set_schema(&self, schema: SchemaRef) {
212        let mut lock = self.schema_cache.write().unwrap();
213        *lock = Some(schema);
214    }
215
216    /// Clear the schema. Typically at the end of a projection.
217    pub fn clear_schema_cache(&self) {
218        let mut lock = self.schema_cache.write().unwrap();
219        *lock = None;
220    }
221
222    /// Get the schema.
223    pub fn get_schema(&self) -> Option<SchemaRef> {
224        let lock = self.schema_cache.read().unwrap();
225        lock.clone()
226    }
227
228    pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
229        if self.verbose() {
230            eprintln!("CACHE SET: cache id: {id}");
231        }
232
233        let value = Arc::new(CachedValue {
234            remaining_hits: AtomicI64::new(cache_hits as i64),
235            df,
236        });
237
238        let prev = self.df_cache.write().unwrap().insert(*id, value);
239        assert!(prev.is_none(), "duplicate set cache: {id}");
240    }
241
242    pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
243        let guard = self.df_cache.read().unwrap();
244        let value = guard.get(id).expect("cache not prefilled");
245        let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
246        if remaining < 0 {
247            panic!("cache used more times than expected: {id}");
248        }
249        if self.verbose() {
250            eprintln!("CACHE HIT: cache id: {id}");
251        }
252        if remaining == 1 {
253            drop(guard);
254            let value = self.df_cache.write().unwrap().remove(id).unwrap();
255            if self.verbose() {
256                eprintln!("CACHE DROP: cache id: {id}");
257            }
258            Arc::into_inner(value).unwrap().df
259        } else {
260            value.df.clone()
261        }
262    }
263
264    /// Clear the cache used by the Window expressions
265    pub fn clear_window_expr_cache(&self) {
266        self.window_cache.clear();
267    }
268
269    fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
270        let flags: StateFlags = self.flags.load().into();
271        let flags = f(flags);
272        self.flags.store(flags.as_u8());
273    }
274
275    /// Indicates that window expression's [`GroupTuples`] may be cached.
276    pub fn cache_window(&self) -> bool {
277        let flags: StateFlags = self.flags.load().into();
278        flags.contains(StateFlags::CACHE_WINDOW_EXPR)
279    }
280
281    /// Indicates that window expression's [`GroupTuples`] may be cached.
282    pub fn has_window(&self) -> bool {
283        let flags: StateFlags = self.flags.load().into();
284        flags.contains(StateFlags::HAS_WINDOW)
285    }
286
287    /// More verbose logging
288    pub fn verbose(&self) -> bool {
289        let flags: StateFlags = self.flags.load().into();
290        flags.contains(StateFlags::VERBOSE)
291    }
292
293    pub fn remove_cache_window_flag(&mut self) {
294        self.set_flags(&|mut flags| {
295            flags.remove(StateFlags::CACHE_WINDOW_EXPR);
296            flags
297        });
298    }
299
300    pub fn insert_cache_window_flag(&mut self) {
301        self.set_flags(&|mut flags| {
302            flags.insert(StateFlags::CACHE_WINDOW_EXPR);
303            flags
304        });
305    }
306    // this will trigger some conservative
307    pub fn insert_has_window_function_flag(&mut self) {
308        self.set_flags(&|mut flags| {
309            flags.insert(StateFlags::HAS_WINDOW);
310            flags
311        });
312    }
313}
314
315impl Default for ExecutionState {
316    fn default() -> Self {
317        ExecutionState::new()
318    }
319}