Skip to main content

polars_expr/state/
execution_state.rs

1use std::borrow::Cow;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::{Mutex, RwLock};
4use std::time::Duration;
5
6use arrow::bitmap::Bitmap;
7use bitflags::bitflags;
8use polars_core::config::verbose;
9use polars_core::prelude::*;
10use polars_ops::prelude::ChunkJoinOptIds;
11use polars_utils::relaxed_cell::RelaxedCell;
12use polars_utils::unique_id::UniqueId;
13
14use super::NodeTimer;
15use crate::prelude::AggregationContext;
16
17pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
18
19#[derive(Default)]
20pub struct WindowCache {
21    groups: RwLock<PlHashMap<String, GroupPositions>>,
22    join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
23    map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
24}
25
26impl WindowCache {
27    pub(crate) fn clear(&self) {
28        let Self {
29            groups,
30            join_tuples,
31            map_idx,
32        } = self;
33        groups.write().unwrap().clear();
34        join_tuples.write().unwrap().clear();
35        map_idx.write().unwrap().clear();
36    }
37
38    pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
39        let g = self.groups.read().unwrap();
40        g.get(key).cloned()
41    }
42
43    pub fn insert_groups(&self, key: String, groups: GroupPositions) {
44        let mut g = self.groups.write().unwrap();
45        g.insert(key, groups);
46    }
47
48    pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
49        let g = self.join_tuples.read().unwrap();
50        g.get(key).cloned()
51    }
52
53    pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
54        let mut g = self.join_tuples.write().unwrap();
55        g.insert(key, join_tuples);
56    }
57
58    pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
59        let g = self.map_idx.read().unwrap();
60        g.get(key).cloned()
61    }
62
63    pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
64        let mut g = self.map_idx.write().unwrap();
65        g.insert(key, idx);
66    }
67}
68
69bitflags! {
70    #[repr(transparent)]
71    #[derive(Copy, Clone)]
72    pub(super) struct StateFlags: u8 {
73        /// More verbose logging
74        const VERBOSE = 0x01;
75        /// Indicates that window expression's [`GroupTuples`] may be cached.
76        const CACHE_WINDOW_EXPR = 0x02;
77        /// Indicates the expression has a window function
78        const HAS_WINDOW = 0x04;
79    }
80}
81
82impl Default for StateFlags {
83    fn default() -> Self {
84        StateFlags::CACHE_WINDOW_EXPR
85    }
86}
87
88impl StateFlags {
89    fn init() -> Self {
90        let verbose = verbose();
91        let mut flags: StateFlags = Default::default();
92        if verbose {
93            flags |= StateFlags::VERBOSE;
94        }
95        flags
96    }
97    fn as_u8(self) -> u8 {
98        unsafe { std::mem::transmute(self) }
99    }
100}
101
102impl From<u8> for StateFlags {
103    fn from(value: u8) -> Self {
104        unsafe { std::mem::transmute(value) }
105    }
106}
107
108struct CachedValue {
109    /// The number of times the cache will still be read.
110    /// Zero means that there will be no more reads and the cache can be dropped.
111    remaining_hits: AtomicI64,
112    df: DataFrame,
113}
114
115/// State/ cache that is maintained during the Execution of the physical plan.
116#[derive(Clone)]
117pub struct ExecutionState {
118    // cached by a `.cache` call and kept in memory for the duration of the plan.
119    df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
120    pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
121    /// Used by Window Expressions to cache intermediate state
122    pub window_cache: Arc<WindowCache>,
123    // every join/union split gets an increment to distinguish between schema state
124    pub branch_idx: usize,
125    pub flags: RelaxedCell<u8>,
126    #[cfg(feature = "dtype-struct")]
127    pub with_fields: Option<Arc<StructChunked>>,
128    #[cfg(feature = "dtype-struct")]
129    pub with_fields_ac: Option<Arc<AggregationContext<'static>>>,
130    pub ext_contexts: Arc<Vec<DataFrame>>,
131    pub element: Arc<Option<(Column, Option<Bitmap>)>>,
132    node_timer: Option<NodeTimer>,
133    stop: Arc<RelaxedCell<bool>>,
134}
135
136impl ExecutionState {
137    pub fn new() -> Self {
138        let mut flags: StateFlags = Default::default();
139        if verbose() {
140            flags |= StateFlags::VERBOSE;
141        }
142        Self {
143            df_cache: Default::default(),
144            schema_cache: Default::default(),
145            window_cache: Default::default(),
146            branch_idx: 0,
147            flags: RelaxedCell::from(StateFlags::init().as_u8()),
148            #[cfg(feature = "dtype-struct")]
149            with_fields: Default::default(),
150            #[cfg(feature = "dtype-struct")]
151            with_fields_ac: Default::default(),
152            ext_contexts: Default::default(),
153            element: Default::default(),
154            node_timer: None,
155            stop: Arc::new(RelaxedCell::from(false)),
156        }
157    }
158
159    /// Toggle this to measure execution times.
160    pub fn time_nodes(&mut self, start: std::time::Instant) {
161        self.node_timer = Some(NodeTimer::new(start))
162    }
163    pub fn has_node_timer(&self) -> bool {
164        self.node_timer.is_some()
165    }
166
167    pub fn finish_timer(self) -> PolarsResult<DataFrame> {
168        self.node_timer.unwrap().finish()
169    }
170
171    // Timings should be a list of (start, end, name) where the start
172    // and end are raw durations since the query start as nanoseconds.
173    pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
174        for &(start, end, ref name) in timings {
175            self.node_timer.as_ref().unwrap().store_duration(
176                Duration::from_nanos(start),
177                Duration::from_nanos(end),
178                name.to_string(),
179            );
180        }
181    }
182
183    // This is wrong when the U64 overflows which will never happen.
184    pub fn should_stop(&self) -> PolarsResult<()> {
185        try_raise_keyboard_interrupt();
186        polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
187        Ok(())
188    }
189
190    pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
191        self.stop.clone()
192    }
193
194    pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
195        match &self.node_timer {
196            None => func(),
197            Some(timer) => {
198                let start = std::time::Instant::now();
199                let out = func();
200                let end = std::time::Instant::now();
201
202                timer.store(start, end, name.as_ref().to_string());
203                out
204            },
205        }
206    }
207
208    /// Partially clones and partially clears state
209    /// This should be used when splitting a node, like a join or union
210    pub fn split(&self) -> Self {
211        Self {
212            df_cache: self.df_cache.clone(),
213            schema_cache: Default::default(),
214            window_cache: Default::default(),
215            branch_idx: self.branch_idx,
216            flags: self.flags.clone(),
217            ext_contexts: self.ext_contexts.clone(),
218            // Retain input values for `pl.element` in Eval context
219            element: self.element.clone(),
220            #[cfg(feature = "dtype-struct")]
221            with_fields: self.with_fields.clone(),
222            #[cfg(feature = "dtype-struct")]
223            with_fields_ac: self.with_fields_ac.clone(),
224            node_timer: self.node_timer.clone(),
225            stop: self.stop.clone(),
226        }
227    }
228
229    pub fn set_schema(&self, schema: SchemaRef) {
230        let mut lock = self.schema_cache.write().unwrap();
231        *lock = Some(schema);
232    }
233
234    /// Clear the schema. Typically at the end of a projection.
235    pub fn clear_schema_cache(&self) {
236        let mut lock = self.schema_cache.write().unwrap();
237        *lock = None;
238    }
239
240    /// Get the schema.
241    pub fn get_schema(&self) -> Option<SchemaRef> {
242        let lock = self.schema_cache.read().unwrap();
243        lock.clone()
244    }
245
246    pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
247        if self.verbose() {
248            eprintln!("CACHE SET: cache id: {id}");
249        }
250
251        let value = Arc::new(CachedValue {
252            remaining_hits: AtomicI64::new(cache_hits as i64),
253            df,
254        });
255
256        let prev = self.df_cache.write().unwrap().insert(*id, value);
257        assert!(prev.is_none(), "duplicate set cache: {id}");
258    }
259
260    pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
261        let guard = self.df_cache.read().unwrap();
262        let value = guard.get(id).expect("cache not prefilled");
263        let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
264        if remaining < 0 {
265            panic!("cache used more times than expected: {id}");
266        }
267        if self.verbose() {
268            eprintln!("CACHE HIT: cache id: {id}");
269        }
270        if remaining == 1 {
271            drop(guard);
272            let value = self.df_cache.write().unwrap().remove(id).unwrap();
273            if self.verbose() {
274                eprintln!("CACHE DROP: cache id: {id}");
275            }
276            Arc::into_inner(value).unwrap().df
277        } else {
278            value.df.clone()
279        }
280    }
281
282    /// Clear the cache used by the Window expressions
283    pub fn clear_window_expr_cache(&self) {
284        self.window_cache.clear();
285    }
286
287    fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
288        let flags: StateFlags = self.flags.load().into();
289        let flags = f(flags);
290        self.flags.store(flags.as_u8());
291    }
292
293    /// Indicates that window expression's [`GroupTuples`] may be cached.
294    pub fn cache_window(&self) -> bool {
295        let flags: StateFlags = self.flags.load().into();
296        flags.contains(StateFlags::CACHE_WINDOW_EXPR)
297    }
298
299    /// Indicates that window expression's [`GroupTuples`] may be cached.
300    pub fn has_window(&self) -> bool {
301        let flags: StateFlags = self.flags.load().into();
302        flags.contains(StateFlags::HAS_WINDOW)
303    }
304
305    /// More verbose logging
306    pub fn verbose(&self) -> bool {
307        let flags: StateFlags = self.flags.load().into();
308        flags.contains(StateFlags::VERBOSE)
309    }
310
311    pub fn remove_cache_window_flag(&mut self) {
312        self.set_flags(&|mut flags| {
313            flags.remove(StateFlags::CACHE_WINDOW_EXPR);
314            flags
315        });
316    }
317
318    pub fn insert_cache_window_flag(&mut self) {
319        self.set_flags(&|mut flags| {
320            flags.insert(StateFlags::CACHE_WINDOW_EXPR);
321            flags
322        });
323    }
324    // this will trigger some conservative
325    pub fn insert_has_window_function_flag(&mut self) {
326        self.set_flags(&|mut flags| {
327            flags.insert(StateFlags::HAS_WINDOW);
328            flags
329        });
330    }
331}
332
333impl Default for ExecutionState {
334    fn default() -> Self {
335        ExecutionState::new()
336    }
337}