polars_expr/state/
execution_state.rs

1use std::borrow::Cow;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::{Mutex, RwLock};
4use std::time::Duration;
5
6use bitflags::bitflags;
7use polars_core::config::verbose;
8use polars_core::prelude::*;
9use polars_ops::prelude::ChunkJoinOptIds;
10use polars_utils::relaxed_cell::RelaxedCell;
11use polars_utils::unique_id::UniqueId;
12
13use super::NodeTimer;
14
15pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
16
17#[derive(Default)]
18pub struct WindowCache {
19    groups: RwLock<PlHashMap<String, GroupPositions>>,
20    join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
21    map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
22}
23
24impl WindowCache {
25    pub(crate) fn clear(&self) {
26        let mut g = self.groups.write().unwrap();
27        g.clear();
28        let mut g = self.join_tuples.write().unwrap();
29        g.clear();
30    }
31
32    pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
33        let g = self.groups.read().unwrap();
34        g.get(key).cloned()
35    }
36
37    pub fn insert_groups(&self, key: String, groups: GroupPositions) {
38        let mut g = self.groups.write().unwrap();
39        g.insert(key, groups);
40    }
41
42    pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
43        let g = self.join_tuples.read().unwrap();
44        g.get(key).cloned()
45    }
46
47    pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
48        let mut g = self.join_tuples.write().unwrap();
49        g.insert(key, join_tuples);
50    }
51
52    pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
53        let g = self.map_idx.read().unwrap();
54        g.get(key).cloned()
55    }
56
57    pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
58        let mut g = self.map_idx.write().unwrap();
59        g.insert(key, idx);
60    }
61}
62
63bitflags! {
64    #[repr(transparent)]
65    #[derive(Copy, Clone)]
66    pub(super) struct StateFlags: u8 {
67        /// More verbose logging
68        const VERBOSE = 0x01;
69        /// Indicates that window expression's [`GroupTuples`] may be cached.
70        const CACHE_WINDOW_EXPR = 0x02;
71        /// Indicates the expression has a window function
72        const HAS_WINDOW = 0x04;
73    }
74}
75
76impl Default for StateFlags {
77    fn default() -> Self {
78        StateFlags::CACHE_WINDOW_EXPR
79    }
80}
81
82impl StateFlags {
83    fn init() -> Self {
84        let verbose = verbose();
85        let mut flags: StateFlags = Default::default();
86        if verbose {
87            flags |= StateFlags::VERBOSE;
88        }
89        flags
90    }
91    fn as_u8(self) -> u8 {
92        unsafe { std::mem::transmute(self) }
93    }
94}
95
96impl From<u8> for StateFlags {
97    fn from(value: u8) -> Self {
98        unsafe { std::mem::transmute(value) }
99    }
100}
101
102struct CachedValue {
103    /// The number of times the cache will still be read.
104    /// Zero means that there will be no more reads and the cache can be dropped.
105    remaining_hits: AtomicI64,
106    df: DataFrame,
107}
108
109/// State/ cache that is maintained during the Execution of the physical plan.
110#[derive(Clone)]
111pub struct ExecutionState {
112    // cached by a `.cache` call and kept in memory for the duration of the plan.
113    df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
114    pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
115    /// Used by Window Expressions to cache intermediate state
116    pub window_cache: Arc<WindowCache>,
117    // every join/union split gets an increment to distinguish between schema state
118    pub branch_idx: usize,
119    pub flags: RelaxedCell<u8>,
120    pub ext_contexts: Arc<Vec<DataFrame>>,
121    node_timer: Option<NodeTimer>,
122    stop: Arc<RelaxedCell<bool>>,
123}
124
125impl ExecutionState {
126    pub fn new() -> Self {
127        let mut flags: StateFlags = Default::default();
128        if verbose() {
129            flags |= StateFlags::VERBOSE;
130        }
131        Self {
132            df_cache: Default::default(),
133            schema_cache: Default::default(),
134            window_cache: Default::default(),
135            branch_idx: 0,
136            flags: RelaxedCell::from(StateFlags::init().as_u8()),
137            ext_contexts: Default::default(),
138            node_timer: None,
139            stop: Arc::new(RelaxedCell::from(false)),
140        }
141    }
142
143    /// Toggle this to measure execution times.
144    pub fn time_nodes(&mut self, start: std::time::Instant) {
145        self.node_timer = Some(NodeTimer::new(start))
146    }
147    pub fn has_node_timer(&self) -> bool {
148        self.node_timer.is_some()
149    }
150
151    pub fn finish_timer(self) -> PolarsResult<DataFrame> {
152        self.node_timer.unwrap().finish()
153    }
154
155    // Timings should be a list of (start, end, name) where the start
156    // and end are raw durations since the query start as nanoseconds.
157    pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
158        for &(start, end, ref name) in timings {
159            self.node_timer.as_ref().unwrap().store_duration(
160                Duration::from_nanos(start),
161                Duration::from_nanos(end),
162                name.to_string(),
163            );
164        }
165    }
166
167    // This is wrong when the U64 overflows which will never happen.
168    pub fn should_stop(&self) -> PolarsResult<()> {
169        try_raise_keyboard_interrupt();
170        polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
171        Ok(())
172    }
173
174    pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
175        self.stop.clone()
176    }
177
178    pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
179        match &self.node_timer {
180            None => func(),
181            Some(timer) => {
182                let start = std::time::Instant::now();
183                let out = func();
184                let end = std::time::Instant::now();
185
186                timer.store(start, end, name.as_ref().to_string());
187                out
188            },
189        }
190    }
191
192    /// Partially clones and partially clears state
193    /// This should be used when splitting a node, like a join or union
194    pub fn split(&self) -> Self {
195        Self {
196            df_cache: self.df_cache.clone(),
197            schema_cache: Default::default(),
198            window_cache: Default::default(),
199            branch_idx: self.branch_idx,
200            flags: self.flags.clone(),
201            ext_contexts: self.ext_contexts.clone(),
202            node_timer: self.node_timer.clone(),
203            stop: self.stop.clone(),
204        }
205    }
206
207    pub fn set_schema(&self, schema: SchemaRef) {
208        let mut lock = self.schema_cache.write().unwrap();
209        *lock = Some(schema);
210    }
211
212    /// Clear the schema. Typically at the end of a projection.
213    pub fn clear_schema_cache(&self) {
214        let mut lock = self.schema_cache.write().unwrap();
215        *lock = None;
216    }
217
218    /// Get the schema.
219    pub fn get_schema(&self) -> Option<SchemaRef> {
220        let lock = self.schema_cache.read().unwrap();
221        lock.clone()
222    }
223
224    pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
225        if self.verbose() {
226            eprintln!("CACHE SET: cache id: {id}");
227        }
228
229        let value = Arc::new(CachedValue {
230            remaining_hits: AtomicI64::new(cache_hits as i64),
231            df,
232        });
233
234        let prev = self.df_cache.write().unwrap().insert(*id, value);
235        assert!(prev.is_none(), "duplicate set cache: {id}");
236    }
237
238    pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
239        let guard = self.df_cache.read().unwrap();
240        let value = guard.get(id).expect("cache not prefilled");
241        let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
242        if remaining < 0 {
243            panic!("cache used more times than expected: {id}");
244        }
245        if self.verbose() {
246            eprintln!("CACHE HIT: cache id: {id}");
247        }
248        if remaining == 1 {
249            drop(guard);
250            let value = self.df_cache.write().unwrap().remove(id).unwrap();
251            if self.verbose() {
252                eprintln!("CACHE DROP: cache id: {id}");
253            }
254            Arc::into_inner(value).unwrap().df
255        } else {
256            value.df.clone()
257        }
258    }
259
260    /// Clear the cache used by the Window expressions
261    pub fn clear_window_expr_cache(&self) {
262        self.window_cache.clear();
263    }
264
265    fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
266        let flags: StateFlags = self.flags.load().into();
267        let flags = f(flags);
268        self.flags.store(flags.as_u8());
269    }
270
271    /// Indicates that window expression's [`GroupTuples`] may be cached.
272    pub fn cache_window(&self) -> bool {
273        let flags: StateFlags = self.flags.load().into();
274        flags.contains(StateFlags::CACHE_WINDOW_EXPR)
275    }
276
277    /// Indicates that window expression's [`GroupTuples`] may be cached.
278    pub fn has_window(&self) -> bool {
279        let flags: StateFlags = self.flags.load().into();
280        flags.contains(StateFlags::HAS_WINDOW)
281    }
282
283    /// More verbose logging
284    pub fn verbose(&self) -> bool {
285        let flags: StateFlags = self.flags.load().into();
286        flags.contains(StateFlags::VERBOSE)
287    }
288
289    pub fn remove_cache_window_flag(&mut self) {
290        self.set_flags(&|mut flags| {
291            flags.remove(StateFlags::CACHE_WINDOW_EXPR);
292            flags
293        });
294    }
295
296    pub fn insert_cache_window_flag(&mut self) {
297        self.set_flags(&|mut flags| {
298            flags.insert(StateFlags::CACHE_WINDOW_EXPR);
299            flags
300        });
301    }
302    // this will trigger some conservative
303    pub fn insert_has_window_function_flag(&mut self) {
304        self.set_flags(&|mut flags| {
305            flags.insert(StateFlags::HAS_WINDOW);
306            flags
307        });
308    }
309}
310
311impl Default for ExecutionState {
312    fn default() -> Self {
313        ExecutionState::new()
314    }
315}