polars_expr/state/
execution_state.rs1use std::borrow::Cow;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::{Mutex, RwLock};
4use std::time::Duration;
5
6use bitflags::bitflags;
7use polars_core::config::verbose;
8use polars_core::prelude::*;
9use polars_ops::prelude::ChunkJoinOptIds;
10use polars_utils::relaxed_cell::RelaxedCell;
11use polars_utils::unique_id::UniqueId;
12
13use super::NodeTimer;
14
15pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
16
17#[derive(Default)]
18pub struct WindowCache {
19 groups: RwLock<PlHashMap<String, GroupPositions>>,
20 join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
21 map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
22}
23
24impl WindowCache {
25 pub(crate) fn clear(&self) {
26 let Self {
27 groups,
28 join_tuples,
29 map_idx,
30 } = self;
31 groups.write().unwrap().clear();
32 join_tuples.write().unwrap().clear();
33 map_idx.write().unwrap().clear();
34 }
35
36 pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
37 let g = self.groups.read().unwrap();
38 g.get(key).cloned()
39 }
40
41 pub fn insert_groups(&self, key: String, groups: GroupPositions) {
42 let mut g = self.groups.write().unwrap();
43 g.insert(key, groups);
44 }
45
46 pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
47 let g = self.join_tuples.read().unwrap();
48 g.get(key).cloned()
49 }
50
51 pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
52 let mut g = self.join_tuples.write().unwrap();
53 g.insert(key, join_tuples);
54 }
55
56 pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
57 let g = self.map_idx.read().unwrap();
58 g.get(key).cloned()
59 }
60
61 pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
62 let mut g = self.map_idx.write().unwrap();
63 g.insert(key, idx);
64 }
65}
66
67bitflags! {
68 #[repr(transparent)]
69 #[derive(Copy, Clone)]
70 pub(super) struct StateFlags: u8 {
71 const VERBOSE = 0x01;
73 const CACHE_WINDOW_EXPR = 0x02;
75 const HAS_WINDOW = 0x04;
77 }
78}
79
80impl Default for StateFlags {
81 fn default() -> Self {
82 StateFlags::CACHE_WINDOW_EXPR
83 }
84}
85
86impl StateFlags {
87 fn init() -> Self {
88 let verbose = verbose();
89 let mut flags: StateFlags = Default::default();
90 if verbose {
91 flags |= StateFlags::VERBOSE;
92 }
93 flags
94 }
95 fn as_u8(self) -> u8 {
96 unsafe { std::mem::transmute(self) }
97 }
98}
99
100impl From<u8> for StateFlags {
101 fn from(value: u8) -> Self {
102 unsafe { std::mem::transmute(value) }
103 }
104}
105
106struct CachedValue {
107 remaining_hits: AtomicI64,
110 df: DataFrame,
111}
112
113#[derive(Clone)]
115pub struct ExecutionState {
116 df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
118 pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
119 pub window_cache: Arc<WindowCache>,
121 pub branch_idx: usize,
123 pub flags: RelaxedCell<u8>,
124 pub ext_contexts: Arc<Vec<DataFrame>>,
125 node_timer: Option<NodeTimer>,
126 stop: Arc<RelaxedCell<bool>>,
127}
128
129impl ExecutionState {
130 pub fn new() -> Self {
131 let mut flags: StateFlags = Default::default();
132 if verbose() {
133 flags |= StateFlags::VERBOSE;
134 }
135 Self {
136 df_cache: Default::default(),
137 schema_cache: Default::default(),
138 window_cache: Default::default(),
139 branch_idx: 0,
140 flags: RelaxedCell::from(StateFlags::init().as_u8()),
141 ext_contexts: Default::default(),
142 node_timer: None,
143 stop: Arc::new(RelaxedCell::from(false)),
144 }
145 }
146
147 pub fn time_nodes(&mut self, start: std::time::Instant) {
149 self.node_timer = Some(NodeTimer::new(start))
150 }
151 pub fn has_node_timer(&self) -> bool {
152 self.node_timer.is_some()
153 }
154
155 pub fn finish_timer(self) -> PolarsResult<DataFrame> {
156 self.node_timer.unwrap().finish()
157 }
158
159 pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
162 for &(start, end, ref name) in timings {
163 self.node_timer.as_ref().unwrap().store_duration(
164 Duration::from_nanos(start),
165 Duration::from_nanos(end),
166 name.to_string(),
167 );
168 }
169 }
170
171 pub fn should_stop(&self) -> PolarsResult<()> {
173 try_raise_keyboard_interrupt();
174 polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
175 Ok(())
176 }
177
178 pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
179 self.stop.clone()
180 }
181
182 pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
183 match &self.node_timer {
184 None => func(),
185 Some(timer) => {
186 let start = std::time::Instant::now();
187 let out = func();
188 let end = std::time::Instant::now();
189
190 timer.store(start, end, name.as_ref().to_string());
191 out
192 },
193 }
194 }
195
196 pub fn split(&self) -> Self {
199 Self {
200 df_cache: self.df_cache.clone(),
201 schema_cache: Default::default(),
202 window_cache: Default::default(),
203 branch_idx: self.branch_idx,
204 flags: self.flags.clone(),
205 ext_contexts: self.ext_contexts.clone(),
206 node_timer: self.node_timer.clone(),
207 stop: self.stop.clone(),
208 }
209 }
210
211 pub fn set_schema(&self, schema: SchemaRef) {
212 let mut lock = self.schema_cache.write().unwrap();
213 *lock = Some(schema);
214 }
215
216 pub fn clear_schema_cache(&self) {
218 let mut lock = self.schema_cache.write().unwrap();
219 *lock = None;
220 }
221
222 pub fn get_schema(&self) -> Option<SchemaRef> {
224 let lock = self.schema_cache.read().unwrap();
225 lock.clone()
226 }
227
228 pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
229 if self.verbose() {
230 eprintln!("CACHE SET: cache id: {id}");
231 }
232
233 let value = Arc::new(CachedValue {
234 remaining_hits: AtomicI64::new(cache_hits as i64),
235 df,
236 });
237
238 let prev = self.df_cache.write().unwrap().insert(*id, value);
239 assert!(prev.is_none(), "duplicate set cache: {id}");
240 }
241
242 pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
243 let guard = self.df_cache.read().unwrap();
244 let value = guard.get(id).expect("cache not prefilled");
245 let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
246 if remaining < 0 {
247 panic!("cache used more times than expected: {id}");
248 }
249 if self.verbose() {
250 eprintln!("CACHE HIT: cache id: {id}");
251 }
252 if remaining == 1 {
253 drop(guard);
254 let value = self.df_cache.write().unwrap().remove(id).unwrap();
255 if self.verbose() {
256 eprintln!("CACHE DROP: cache id: {id}");
257 }
258 Arc::into_inner(value).unwrap().df
259 } else {
260 value.df.clone()
261 }
262 }
263
264 pub fn clear_window_expr_cache(&self) {
266 self.window_cache.clear();
267 }
268
269 fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
270 let flags: StateFlags = self.flags.load().into();
271 let flags = f(flags);
272 self.flags.store(flags.as_u8());
273 }
274
275 pub fn cache_window(&self) -> bool {
277 let flags: StateFlags = self.flags.load().into();
278 flags.contains(StateFlags::CACHE_WINDOW_EXPR)
279 }
280
281 pub fn has_window(&self) -> bool {
283 let flags: StateFlags = self.flags.load().into();
284 flags.contains(StateFlags::HAS_WINDOW)
285 }
286
287 pub fn verbose(&self) -> bool {
289 let flags: StateFlags = self.flags.load().into();
290 flags.contains(StateFlags::VERBOSE)
291 }
292
293 pub fn remove_cache_window_flag(&mut self) {
294 self.set_flags(&|mut flags| {
295 flags.remove(StateFlags::CACHE_WINDOW_EXPR);
296 flags
297 });
298 }
299
300 pub fn insert_cache_window_flag(&mut self) {
301 self.set_flags(&|mut flags| {
302 flags.insert(StateFlags::CACHE_WINDOW_EXPR);
303 flags
304 });
305 }
306 pub fn insert_has_window_function_flag(&mut self) {
308 self.set_flags(&|mut flags| {
309 flags.insert(StateFlags::HAS_WINDOW);
310 flags
311 });
312 }
313}
314
315impl Default for ExecutionState {
316 fn default() -> Self {
317 ExecutionState::new()
318 }
319}