polars_expr/state/
execution_state.rs

1use std::borrow::Cow;
2use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering};
3use std::sync::{Mutex, RwLock};
4
5use bitflags::bitflags;
6use once_cell::sync::OnceCell;
7use polars_core::config::verbose;
8use polars_core::error::check_signals;
9use polars_core::prelude::*;
10use polars_ops::prelude::ChunkJoinOptIds;
11
12use super::NodeTimer;
13
14pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
15
16#[derive(Default)]
17pub struct WindowCache {
18    groups: RwLock<PlHashMap<String, GroupPositions>>,
19    join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
20    map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
21}
22
23impl WindowCache {
24    pub(crate) fn clear(&self) {
25        let mut g = self.groups.write().unwrap();
26        g.clear();
27        let mut g = self.join_tuples.write().unwrap();
28        g.clear();
29    }
30
31    pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
32        let g = self.groups.read().unwrap();
33        g.get(key).cloned()
34    }
35
36    pub fn insert_groups(&self, key: String, groups: GroupPositions) {
37        let mut g = self.groups.write().unwrap();
38        g.insert(key, groups);
39    }
40
41    pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
42        let g = self.join_tuples.read().unwrap();
43        g.get(key).cloned()
44    }
45
46    pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
47        let mut g = self.join_tuples.write().unwrap();
48        g.insert(key, join_tuples);
49    }
50
51    pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
52        let g = self.map_idx.read().unwrap();
53        g.get(key).cloned()
54    }
55
56    pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
57        let mut g = self.map_idx.write().unwrap();
58        g.insert(key, idx);
59    }
60}
61
62bitflags! {
63    #[repr(transparent)]
64    #[derive(Copy, Clone)]
65    pub(super) struct StateFlags: u8 {
66        /// More verbose logging
67        const VERBOSE = 0x01;
68        /// Indicates that window expression's [`GroupTuples`] may be cached.
69        const CACHE_WINDOW_EXPR = 0x02;
70        /// Indicates the expression has a window function
71        const HAS_WINDOW = 0x04;
72        /// If set, the expression is evaluated in the
73        /// streaming engine.
74        const IN_STREAMING = 0x08;
75    }
76}
77
78impl Default for StateFlags {
79    fn default() -> Self {
80        StateFlags::CACHE_WINDOW_EXPR
81    }
82}
83
84impl StateFlags {
85    fn init() -> Self {
86        let verbose = verbose();
87        let mut flags: StateFlags = Default::default();
88        if verbose {
89            flags |= StateFlags::VERBOSE;
90        }
91        flags
92    }
93    fn as_u8(self) -> u8 {
94        unsafe { std::mem::transmute(self) }
95    }
96}
97
98impl From<u8> for StateFlags {
99    fn from(value: u8) -> Self {
100        unsafe { std::mem::transmute(value) }
101    }
102}
103
104type CachedValue = Arc<(AtomicI64, OnceCell<DataFrame>)>;
105
106/// State/ cache that is maintained during the Execution of the physical plan.
107pub struct ExecutionState {
108    // cached by a `.cache` call and kept in memory for the duration of the plan.
109    df_cache: Arc<Mutex<PlHashMap<usize, CachedValue>>>,
110    pub schema_cache: RwLock<Option<SchemaRef>>,
111    /// Used by Window Expressions to cache intermediate state
112    pub window_cache: Arc<WindowCache>,
113    // every join/union split gets an increment to distinguish between schema state
114    pub branch_idx: usize,
115    pub flags: AtomicU8,
116    pub ext_contexts: Arc<Vec<DataFrame>>,
117    node_timer: Option<NodeTimer>,
118    stop: Arc<AtomicBool>,
119}
120
121impl ExecutionState {
122    pub fn new() -> Self {
123        let mut flags: StateFlags = Default::default();
124        if verbose() {
125            flags |= StateFlags::VERBOSE;
126        }
127        Self {
128            df_cache: Default::default(),
129            schema_cache: Default::default(),
130            window_cache: Default::default(),
131            branch_idx: 0,
132            flags: AtomicU8::new(StateFlags::init().as_u8()),
133            ext_contexts: Default::default(),
134            node_timer: None,
135            stop: Arc::new(AtomicBool::new(false)),
136        }
137    }
138
139    /// Toggle this to measure execution times.
140    pub fn time_nodes(&mut self) {
141        self.node_timer = Some(NodeTimer::new())
142    }
143    pub fn has_node_timer(&self) -> bool {
144        self.node_timer.is_some()
145    }
146
147    pub fn finish_timer(self) -> PolarsResult<DataFrame> {
148        self.node_timer.unwrap().finish()
149    }
150
151    // This is wrong when the U64 overflows which will never happen.
152    pub fn should_stop(&self) -> PolarsResult<()> {
153        check_signals()?;
154        polars_ensure!(!self.stop.load(Ordering::Relaxed), ComputeError: "query interrupted");
155        Ok(())
156    }
157
158    pub fn cancel_token(&self) -> Arc<AtomicBool> {
159        self.stop.clone()
160    }
161
162    pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
163        match &self.node_timer {
164            None => func(),
165            Some(timer) => {
166                let start = std::time::Instant::now();
167                let out = func();
168                let end = std::time::Instant::now();
169
170                timer.store(start, end, name.as_ref().to_string());
171                out
172            },
173        }
174    }
175
176    /// Partially clones and partially clears state
177    /// This should be used when splitting a node, like a join or union
178    pub fn split(&self) -> Self {
179        Self {
180            df_cache: self.df_cache.clone(),
181            schema_cache: Default::default(),
182            window_cache: Default::default(),
183            branch_idx: self.branch_idx,
184            flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
185            ext_contexts: self.ext_contexts.clone(),
186            node_timer: self.node_timer.clone(),
187            stop: self.stop.clone(),
188        }
189    }
190
191    pub fn set_schema(&self, schema: SchemaRef) {
192        let mut lock = self.schema_cache.write().unwrap();
193        *lock = Some(schema);
194    }
195
196    /// Clear the schema. Typically at the end of a projection.
197    pub fn clear_schema_cache(&self) {
198        let mut lock = self.schema_cache.write().unwrap();
199        *lock = None;
200    }
201
202    /// Get the schema.
203    pub fn get_schema(&self) -> Option<SchemaRef> {
204        let lock = self.schema_cache.read().unwrap();
205        lock.clone()
206    }
207
208    pub fn get_df_cache(&self, key: usize, cache_hits: u32) -> CachedValue {
209        let mut guard = self.df_cache.lock().unwrap();
210        guard
211            .entry(key)
212            .or_insert_with(|| Arc::new((AtomicI64::new(cache_hits as i64), OnceCell::new())))
213            .clone()
214    }
215
216    pub fn remove_df_cache(&self, key: usize) {
217        let mut guard = self.df_cache.lock().unwrap();
218        let _ = guard.remove(&key).unwrap();
219    }
220
221    /// Clear the cache used by the Window expressions
222    pub fn clear_window_expr_cache(&self) {
223        self.window_cache.clear();
224    }
225
226    fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
227        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
228        let flags = f(flags);
229        self.flags.store(flags.as_u8(), Ordering::Relaxed);
230    }
231
232    /// Indicates that window expression's [`GroupTuples`] may be cached.
233    pub fn cache_window(&self) -> bool {
234        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
235        flags.contains(StateFlags::CACHE_WINDOW_EXPR)
236    }
237
238    /// Indicates that window expression's [`GroupTuples`] may be cached.
239    pub fn has_window(&self) -> bool {
240        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
241        flags.contains(StateFlags::HAS_WINDOW)
242    }
243
244    /// More verbose logging
245    pub fn verbose(&self) -> bool {
246        let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
247        flags.contains(StateFlags::VERBOSE)
248    }
249
250    pub fn remove_cache_window_flag(&mut self) {
251        self.set_flags(&|mut flags| {
252            flags.remove(StateFlags::CACHE_WINDOW_EXPR);
253            flags
254        });
255    }
256
257    pub fn insert_cache_window_flag(&mut self) {
258        self.set_flags(&|mut flags| {
259            flags.insert(StateFlags::CACHE_WINDOW_EXPR);
260            flags
261        });
262    }
263    // this will trigger some conservative
264    pub fn insert_has_window_function_flag(&mut self) {
265        self.set_flags(&|mut flags| {
266            flags.insert(StateFlags::HAS_WINDOW);
267            flags
268        });
269    }
270}
271
272impl Default for ExecutionState {
273    fn default() -> Self {
274        ExecutionState::new()
275    }
276}
277
278impl Clone for ExecutionState {
279    /// clones, but clears no state.
280    fn clone(&self) -> Self {
281        Self {
282            df_cache: self.df_cache.clone(),
283            schema_cache: self.schema_cache.read().unwrap().clone().into(),
284            window_cache: self.window_cache.clone(),
285            branch_idx: self.branch_idx,
286            flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
287            ext_contexts: self.ext_contexts.clone(),
288            node_timer: self.node_timer.clone(),
289            stop: self.stop.clone(),
290        }
291    }
292}