polars_expr/state/
execution_state.rs

1use std::borrow::Cow;
2use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering};
3use std::sync::{Mutex, OnceLock, RwLock};
4use std::time::Duration;
5
6use bitflags::bitflags;
7use polars_core::config::verbose;
8use polars_core::prelude::*;
9use polars_ops::prelude::ChunkJoinOptIds;
10use polars_utils::unique_id::UniqueId;
11
12use super::NodeTimer;
13
14pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
15
16#[derive(Default)]
17pub struct WindowCache {
18    groups: RwLock<PlHashMap<String, GroupPositions>>,
19    join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
20    map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
21}
22
23impl WindowCache {
24    pub(crate) fn clear(&self) {
25        let mut g = self.groups.write().unwrap();
26        g.clear();
27        let mut g = self.join_tuples.write().unwrap();
28        g.clear();
29    }
30
31    pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
32        let g = self.groups.read().unwrap();
33        g.get(key).cloned()
34    }
35
36    pub fn insert_groups(&self, key: String, groups: GroupPositions) {
37        let mut g = self.groups.write().unwrap();
38        g.insert(key, groups);
39    }
40
41    pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
42        let g = self.join_tuples.read().unwrap();
43        g.get(key).cloned()
44    }
45
46    pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
47        let mut g = self.join_tuples.write().unwrap();
48        g.insert(key, join_tuples);
49    }
50
51    pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
52        let g = self.map_idx.read().unwrap();
53        g.get(key).cloned()
54    }
55
56    pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
57        let mut g = self.map_idx.write().unwrap();
58        g.insert(key, idx);
59    }
60}
61
62bitflags! {
63    #[repr(transparent)]
64    #[derive(Copy, Clone)]
65    pub(super) struct StateFlags: u8 {
66        /// More verbose logging
67        const VERBOSE = 0x01;
68        /// Indicates that window expression's [`GroupTuples`] may be cached.
69        const CACHE_WINDOW_EXPR = 0x02;
70        /// Indicates the expression has a window function
71        const HAS_WINDOW = 0x04;
72    }
73}
74
75impl Default for StateFlags {
76    fn default() -> Self {
77        StateFlags::CACHE_WINDOW_EXPR
78    }
79}
80
81impl StateFlags {
82    fn init() -> Self {
83        let verbose = verbose();
84        let mut flags: StateFlags = Default::default();
85        if verbose {
86            flags |= StateFlags::VERBOSE;
87        }
88        flags
89    }
90    fn as_u8(self) -> u8 {
91        unsafe { std::mem::transmute(self) }
92    }
93}
94
95impl From<u8> for StateFlags {
96    fn from(value: u8) -> Self {
97        unsafe { std::mem::transmute(value) }
98    }
99}
100
101type CachedValue = Arc<(AtomicI64, OnceLock<DataFrame>)>;
102
103/// State/ cache that is maintained during the Execution of the physical plan.
104pub struct ExecutionState {
105    // cached by a `.cache` call and kept in memory for the duration of the plan.
106    df_cache: Arc<RwLock<PlHashMap<UniqueId, CachedValue>>>,
107    pub schema_cache: RwLock<Option<SchemaRef>>,
108    /// Used by Window Expressions to cache intermediate state
109    pub window_cache: Arc<WindowCache>,
110    // every join/union split gets an increment to distinguish between schema state
111    pub branch_idx: usize,
112    pub flags: AtomicU8,
113    pub ext_contexts: Arc<Vec<DataFrame>>,
114    node_timer: Option<NodeTimer>,
115    stop: Arc<AtomicBool>,
116}
117
118impl ExecutionState {
119    pub fn new() -> Self {
120        let mut flags: StateFlags = Default::default();
121        if verbose() {
122            flags |= StateFlags::VERBOSE;
123        }
124        Self {
125            df_cache: Default::default(),
126            schema_cache: Default::default(),
127            window_cache: Default::default(),
128            branch_idx: 0,
129            flags: AtomicU8::new(StateFlags::init().as_u8()),
130            ext_contexts: Default::default(),
131            node_timer: None,
132            stop: Arc::new(AtomicBool::new(false)),
133        }
134    }
135
136    /// Toggle this to measure execution times.
137    pub fn time_nodes(&mut self, start: std::time::Instant) {
138        self.node_timer = Some(NodeTimer::new(start))
139    }
140    pub fn has_node_timer(&self) -> bool {
141        self.node_timer.is_some()
142    }
143
144    pub fn finish_timer(self) -> PolarsResult<DataFrame> {
145        self.node_timer.unwrap().finish()
146    }
147
148    // Timings should be a list of (start, end, name) where the start
149    // and end are raw durations since the query start as nanoseconds.
150    pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
151        for &(start, end, ref name) in timings {
152            self.node_timer.as_ref().unwrap().store_duration(
153                Duration::from_nanos(start),
154                Duration::from_nanos(end),
155                name.to_string(),
156            );
157        }
158    }
159
160    // This is wrong when the U64 overflows which will never happen.
161    pub fn should_stop(&self) -> PolarsResult<()> {
162        try_raise_keyboard_interrupt();
163        polars_ensure!(!self.stop.load(Ordering::Relaxed), ComputeError: "query interrupted");
164        Ok(())
165    }
166
167    pub fn cancel_token(&self) -> Arc<AtomicBool> {
168        self.stop.clone()
169    }
170
171    pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
172        match &self.node_timer {
173            None => func(),
174            Some(timer) => {
175                let start = std::time::Instant::now();
176                let out = func();
177                let end = std::time::Instant::now();
178
179                timer.store(start, end, name.as_ref().to_string());
180                out
181            },
182        }
183    }
184
185    /// Partially clones and partially clears state
186    /// This should be used when splitting a node, like a join or union
187    pub fn split(&self) -> Self {
188        Self {
189            df_cache: self.df_cache.clone(),
190            schema_cache: Default::default(),
191            window_cache: Default::default(),
192            branch_idx: self.branch_idx,
193            flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
194            ext_contexts: self.ext_contexts.clone(),
195            node_timer: self.node_timer.clone(),
196            stop: self.stop.clone(),
197        }
198    }
199
200    pub fn set_schema(&self, schema: SchemaRef) {
201        let mut lock = self.schema_cache.write().unwrap();
202        *lock = Some(schema);
203    }
204
205    /// Clear the schema. Typically at the end of a projection.
206    pub fn clear_schema_cache(&self) {
207        let mut lock = self.schema_cache.write().unwrap();
208        *lock = None;
209    }
210
211    /// Get the schema.
212    pub fn get_schema(&self) -> Option<SchemaRef> {
213        let lock = self.schema_cache.read().unwrap();
214        lock.clone()
215    }
216
217    pub fn get_df_cache(&self, key: &UniqueId, cache_hits: u32) -> CachedValue {
218        let guard = self.df_cache.read().unwrap();
219
220        match guard.get(key) {
221            Some(v) => v.clone(),
222            None => {
223                drop(guard);
224                let mut guard = self.df_cache.write().unwrap();
225
226                guard
227                    .entry(key.clone())
228                    .or_insert_with(|| {
229                        Arc::new((AtomicI64::new(cache_hits as i64), OnceLock::new()))
230                    })
231                    .clone()
232            },
233        }
234    }
235
236    pub fn remove_df_cache(&self, key: &UniqueId) {
237        let mut guard = self.df_cache.write().unwrap();
238        let _ = guard.remove(key).unwrap();
239    }
240
241    /// Clear the cache used by the Window expressions
242    pub fn clear_window_expr_cache(&self) {
243        self.window_cache.clear();
244    }
245
246    fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
247        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
248        let flags = f(flags);
249        self.flags.store(flags.as_u8(), Ordering::Relaxed);
250    }
251
252    /// Indicates that window expression's [`GroupTuples`] may be cached.
253    pub fn cache_window(&self) -> bool {
254        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
255        flags.contains(StateFlags::CACHE_WINDOW_EXPR)
256    }
257
258    /// Indicates that window expression's [`GroupTuples`] may be cached.
259    pub fn has_window(&self) -> bool {
260        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
261        flags.contains(StateFlags::HAS_WINDOW)
262    }
263
264    /// More verbose logging
265    pub fn verbose(&self) -> bool {
266        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
267        flags.contains(StateFlags::VERBOSE)
268    }
269
270    pub fn remove_cache_window_flag(&mut self) {
271        self.set_flags(&|mut flags| {
272            flags.remove(StateFlags::CACHE_WINDOW_EXPR);
273            flags
274        });
275    }
276
277    pub fn insert_cache_window_flag(&mut self) {
278        self.set_flags(&|mut flags| {
279            flags.insert(StateFlags::CACHE_WINDOW_EXPR);
280            flags
281        });
282    }
283    // this will trigger some conservative
284    pub fn insert_has_window_function_flag(&mut self) {
285        self.set_flags(&|mut flags| {
286            flags.insert(StateFlags::HAS_WINDOW);
287            flags
288        });
289    }
290}
291
292impl Default for ExecutionState {
293    fn default() -> Self {
294        ExecutionState::new()
295    }
296}
297
298impl Clone for ExecutionState {
299    /// clones, but clears no state.
300    fn clone(&self) -> Self {
301        Self {
302            df_cache: self.df_cache.clone(),
303            schema_cache: self.schema_cache.read().unwrap().clone().into(),
304            window_cache: self.window_cache.clone(),
305            branch_idx: self.branch_idx,
306            flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
307            ext_contexts: self.ext_contexts.clone(),
308            node_timer: self.node_timer.clone(),
309            stop: self.stop.clone(),
310        }
311    }
312}