polars_expr/state/
execution_state.rs1use std::borrow::Cow;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::sync::{Mutex, RwLock};
4use std::time::Duration;
5
6use bitflags::bitflags;
7use polars_core::config::verbose;
8use polars_core::prelude::*;
9use polars_ops::prelude::ChunkJoinOptIds;
10use polars_utils::relaxed_cell::RelaxedCell;
11use polars_utils::unique_id::UniqueId;
12
13use super::NodeTimer;
14
15pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
16
17#[derive(Default)]
18pub struct WindowCache {
19 groups: RwLock<PlHashMap<String, GroupPositions>>,
20 join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
21 map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
22}
23
24impl WindowCache {
25 pub(crate) fn clear(&self) {
26 let mut g = self.groups.write().unwrap();
27 g.clear();
28 let mut g = self.join_tuples.write().unwrap();
29 g.clear();
30 }
31
32 pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
33 let g = self.groups.read().unwrap();
34 g.get(key).cloned()
35 }
36
37 pub fn insert_groups(&self, key: String, groups: GroupPositions) {
38 let mut g = self.groups.write().unwrap();
39 g.insert(key, groups);
40 }
41
42 pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
43 let g = self.join_tuples.read().unwrap();
44 g.get(key).cloned()
45 }
46
47 pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
48 let mut g = self.join_tuples.write().unwrap();
49 g.insert(key, join_tuples);
50 }
51
52 pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
53 let g = self.map_idx.read().unwrap();
54 g.get(key).cloned()
55 }
56
57 pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
58 let mut g = self.map_idx.write().unwrap();
59 g.insert(key, idx);
60 }
61}
62
63bitflags! {
64 #[repr(transparent)]
65 #[derive(Copy, Clone)]
66 pub(super) struct StateFlags: u8 {
67 const VERBOSE = 0x01;
69 const CACHE_WINDOW_EXPR = 0x02;
71 const HAS_WINDOW = 0x04;
73 }
74}
75
76impl Default for StateFlags {
77 fn default() -> Self {
78 StateFlags::CACHE_WINDOW_EXPR
79 }
80}
81
82impl StateFlags {
83 fn init() -> Self {
84 let verbose = verbose();
85 let mut flags: StateFlags = Default::default();
86 if verbose {
87 flags |= StateFlags::VERBOSE;
88 }
89 flags
90 }
91 fn as_u8(self) -> u8 {
92 unsafe { std::mem::transmute(self) }
93 }
94}
95
96impl From<u8> for StateFlags {
97 fn from(value: u8) -> Self {
98 unsafe { std::mem::transmute(value) }
99 }
100}
101
102struct CachedValue {
103 remaining_hits: AtomicI64,
106 df: DataFrame,
107}
108
109#[derive(Clone)]
111pub struct ExecutionState {
112 df_cache: Arc<RwLock<PlHashMap<UniqueId, Arc<CachedValue>>>>,
114 pub schema_cache: Arc<RwLock<Option<SchemaRef>>>,
115 pub window_cache: Arc<WindowCache>,
117 pub branch_idx: usize,
119 pub flags: RelaxedCell<u8>,
120 pub ext_contexts: Arc<Vec<DataFrame>>,
121 node_timer: Option<NodeTimer>,
122 stop: Arc<RelaxedCell<bool>>,
123}
124
125impl ExecutionState {
126 pub fn new() -> Self {
127 let mut flags: StateFlags = Default::default();
128 if verbose() {
129 flags |= StateFlags::VERBOSE;
130 }
131 Self {
132 df_cache: Default::default(),
133 schema_cache: Default::default(),
134 window_cache: Default::default(),
135 branch_idx: 0,
136 flags: RelaxedCell::from(StateFlags::init().as_u8()),
137 ext_contexts: Default::default(),
138 node_timer: None,
139 stop: Arc::new(RelaxedCell::from(false)),
140 }
141 }
142
143 pub fn time_nodes(&mut self, start: std::time::Instant) {
145 self.node_timer = Some(NodeTimer::new(start))
146 }
147 pub fn has_node_timer(&self) -> bool {
148 self.node_timer.is_some()
149 }
150
151 pub fn finish_timer(self) -> PolarsResult<DataFrame> {
152 self.node_timer.unwrap().finish()
153 }
154
155 pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
158 for &(start, end, ref name) in timings {
159 self.node_timer.as_ref().unwrap().store_duration(
160 Duration::from_nanos(start),
161 Duration::from_nanos(end),
162 name.to_string(),
163 );
164 }
165 }
166
167 pub fn should_stop(&self) -> PolarsResult<()> {
169 try_raise_keyboard_interrupt();
170 polars_ensure!(!self.stop.load(), ComputeError: "query interrupted");
171 Ok(())
172 }
173
174 pub fn cancel_token(&self) -> Arc<RelaxedCell<bool>> {
175 self.stop.clone()
176 }
177
178 pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
179 match &self.node_timer {
180 None => func(),
181 Some(timer) => {
182 let start = std::time::Instant::now();
183 let out = func();
184 let end = std::time::Instant::now();
185
186 timer.store(start, end, name.as_ref().to_string());
187 out
188 },
189 }
190 }
191
192 pub fn split(&self) -> Self {
195 Self {
196 df_cache: self.df_cache.clone(),
197 schema_cache: Default::default(),
198 window_cache: Default::default(),
199 branch_idx: self.branch_idx,
200 flags: self.flags.clone(),
201 ext_contexts: self.ext_contexts.clone(),
202 node_timer: self.node_timer.clone(),
203 stop: self.stop.clone(),
204 }
205 }
206
207 pub fn set_schema(&self, schema: SchemaRef) {
208 let mut lock = self.schema_cache.write().unwrap();
209 *lock = Some(schema);
210 }
211
212 pub fn clear_schema_cache(&self) {
214 let mut lock = self.schema_cache.write().unwrap();
215 *lock = None;
216 }
217
218 pub fn get_schema(&self) -> Option<SchemaRef> {
220 let lock = self.schema_cache.read().unwrap();
221 lock.clone()
222 }
223
224 pub fn set_df_cache(&self, id: &UniqueId, df: DataFrame, cache_hits: u32) {
225 if self.verbose() {
226 eprintln!("CACHE SET: cache id: {id}");
227 }
228
229 let value = Arc::new(CachedValue {
230 remaining_hits: AtomicI64::new(cache_hits as i64),
231 df,
232 });
233
234 let prev = self.df_cache.write().unwrap().insert(*id, value);
235 assert!(prev.is_none(), "duplicate set cache: {id}");
236 }
237
238 pub fn get_df_cache(&self, id: &UniqueId) -> DataFrame {
239 let guard = self.df_cache.read().unwrap();
240 let value = guard.get(id).expect("cache not prefilled");
241 let remaining = value.remaining_hits.fetch_sub(1, Ordering::Relaxed);
242 if remaining < 0 {
243 panic!("cache used more times than expected: {id}");
244 }
245 if self.verbose() {
246 eprintln!("CACHE HIT: cache id: {id}");
247 }
248 if remaining == 1 {
249 drop(guard);
250 let value = self.df_cache.write().unwrap().remove(id).unwrap();
251 if self.verbose() {
252 eprintln!("CACHE DROP: cache id: {id}");
253 }
254 Arc::into_inner(value).unwrap().df
255 } else {
256 value.df.clone()
257 }
258 }
259
260 pub fn clear_window_expr_cache(&self) {
262 self.window_cache.clear();
263 }
264
265 fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
266 let flags: StateFlags = self.flags.load().into();
267 let flags = f(flags);
268 self.flags.store(flags.as_u8());
269 }
270
271 pub fn cache_window(&self) -> bool {
273 let flags: StateFlags = self.flags.load().into();
274 flags.contains(StateFlags::CACHE_WINDOW_EXPR)
275 }
276
277 pub fn has_window(&self) -> bool {
279 let flags: StateFlags = self.flags.load().into();
280 flags.contains(StateFlags::HAS_WINDOW)
281 }
282
283 pub fn verbose(&self) -> bool {
285 let flags: StateFlags = self.flags.load().into();
286 flags.contains(StateFlags::VERBOSE)
287 }
288
289 pub fn remove_cache_window_flag(&mut self) {
290 self.set_flags(&|mut flags| {
291 flags.remove(StateFlags::CACHE_WINDOW_EXPR);
292 flags
293 });
294 }
295
296 pub fn insert_cache_window_flag(&mut self) {
297 self.set_flags(&|mut flags| {
298 flags.insert(StateFlags::CACHE_WINDOW_EXPR);
299 flags
300 });
301 }
302 pub fn insert_has_window_function_flag(&mut self) {
304 self.set_flags(&|mut flags| {
305 flags.insert(StateFlags::HAS_WINDOW);
306 flags
307 });
308 }
309}
310
311impl Default for ExecutionState {
312 fn default() -> Self {
313 ExecutionState::new()
314 }
315}