1use std::borrow::Cow;
2use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering};
3use std::sync::{Mutex, OnceLock, RwLock};
4use std::time::Duration;
5
6use bitflags::bitflags;
7use polars_core::config::verbose;
8use polars_core::prelude::*;
9use polars_ops::prelude::ChunkJoinOptIds;
10use polars_utils::unique_id::UniqueId;
11
12use super::NodeTimer;
13
14pub type JoinTuplesCache = Arc<Mutex<PlHashMap<String, ChunkJoinOptIds>>>;
15
16#[derive(Default)]
17pub struct WindowCache {
18 groups: RwLock<PlHashMap<String, GroupPositions>>,
19 join_tuples: RwLock<PlHashMap<String, Arc<ChunkJoinOptIds>>>,
20 map_idx: RwLock<PlHashMap<String, Arc<IdxCa>>>,
21}
22
23impl WindowCache {
24 pub(crate) fn clear(&self) {
25 let mut g = self.groups.write().unwrap();
26 g.clear();
27 let mut g = self.join_tuples.write().unwrap();
28 g.clear();
29 }
30
31 pub fn get_groups(&self, key: &str) -> Option<GroupPositions> {
32 let g = self.groups.read().unwrap();
33 g.get(key).cloned()
34 }
35
36 pub fn insert_groups(&self, key: String, groups: GroupPositions) {
37 let mut g = self.groups.write().unwrap();
38 g.insert(key, groups);
39 }
40
41 pub fn get_join(&self, key: &str) -> Option<Arc<ChunkJoinOptIds>> {
42 let g = self.join_tuples.read().unwrap();
43 g.get(key).cloned()
44 }
45
46 pub fn insert_join(&self, key: String, join_tuples: Arc<ChunkJoinOptIds>) {
47 let mut g = self.join_tuples.write().unwrap();
48 g.insert(key, join_tuples);
49 }
50
51 pub fn get_map(&self, key: &str) -> Option<Arc<IdxCa>> {
52 let g = self.map_idx.read().unwrap();
53 g.get(key).cloned()
54 }
55
56 pub fn insert_map(&self, key: String, idx: Arc<IdxCa>) {
57 let mut g = self.map_idx.write().unwrap();
58 g.insert(key, idx);
59 }
60}
61
62bitflags! {
63 #[repr(transparent)]
64 #[derive(Copy, Clone)]
65 pub(super) struct StateFlags: u8 {
66 const VERBOSE = 0x01;
68 const CACHE_WINDOW_EXPR = 0x02;
70 const HAS_WINDOW = 0x04;
72 }
73}
74
75impl Default for StateFlags {
76 fn default() -> Self {
77 StateFlags::CACHE_WINDOW_EXPR
78 }
79}
80
81impl StateFlags {
82 fn init() -> Self {
83 let verbose = verbose();
84 let mut flags: StateFlags = Default::default();
85 if verbose {
86 flags |= StateFlags::VERBOSE;
87 }
88 flags
89 }
90 fn as_u8(self) -> u8 {
91 unsafe { std::mem::transmute(self) }
92 }
93}
94
95impl From<u8> for StateFlags {
96 fn from(value: u8) -> Self {
97 unsafe { std::mem::transmute(value) }
98 }
99}
100
101type CachedValue = Arc<(AtomicI64, OnceLock<DataFrame>)>;
102
103pub struct ExecutionState {
105 df_cache: Arc<RwLock<PlHashMap<UniqueId, CachedValue>>>,
107 pub schema_cache: RwLock<Option<SchemaRef>>,
108 pub window_cache: Arc<WindowCache>,
110 pub branch_idx: usize,
112 pub flags: AtomicU8,
113 pub ext_contexts: Arc<Vec<DataFrame>>,
114 node_timer: Option<NodeTimer>,
115 stop: Arc<AtomicBool>,
116}
117
118impl ExecutionState {
119 pub fn new() -> Self {
120 let mut flags: StateFlags = Default::default();
121 if verbose() {
122 flags |= StateFlags::VERBOSE;
123 }
124 Self {
125 df_cache: Default::default(),
126 schema_cache: Default::default(),
127 window_cache: Default::default(),
128 branch_idx: 0,
129 flags: AtomicU8::new(StateFlags::init().as_u8()),
130 ext_contexts: Default::default(),
131 node_timer: None,
132 stop: Arc::new(AtomicBool::new(false)),
133 }
134 }
135
136 pub fn time_nodes(&mut self, start: std::time::Instant) {
138 self.node_timer = Some(NodeTimer::new(start))
139 }
140 pub fn has_node_timer(&self) -> bool {
141 self.node_timer.is_some()
142 }
143
144 pub fn finish_timer(self) -> PolarsResult<DataFrame> {
145 self.node_timer.unwrap().finish()
146 }
147
148 pub fn record_raw_timings(&self, timings: &[(u64, u64, String)]) {
151 for &(start, end, ref name) in timings {
152 self.node_timer.as_ref().unwrap().store_duration(
153 Duration::from_nanos(start),
154 Duration::from_nanos(end),
155 name.to_string(),
156 );
157 }
158 }
159
160 pub fn should_stop(&self) -> PolarsResult<()> {
162 try_raise_keyboard_interrupt();
163 polars_ensure!(!self.stop.load(Ordering::Relaxed), ComputeError: "query interrupted");
164 Ok(())
165 }
166
167 pub fn cancel_token(&self) -> Arc<AtomicBool> {
168 self.stop.clone()
169 }
170
171 pub fn record<T, F: FnOnce() -> T>(&self, func: F, name: Cow<'static, str>) -> T {
172 match &self.node_timer {
173 None => func(),
174 Some(timer) => {
175 let start = std::time::Instant::now();
176 let out = func();
177 let end = std::time::Instant::now();
178
179 timer.store(start, end, name.as_ref().to_string());
180 out
181 },
182 }
183 }
184
185 pub fn split(&self) -> Self {
188 Self {
189 df_cache: self.df_cache.clone(),
190 schema_cache: Default::default(),
191 window_cache: Default::default(),
192 branch_idx: self.branch_idx,
193 flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
194 ext_contexts: self.ext_contexts.clone(),
195 node_timer: self.node_timer.clone(),
196 stop: self.stop.clone(),
197 }
198 }
199
200 pub fn set_schema(&self, schema: SchemaRef) {
201 let mut lock = self.schema_cache.write().unwrap();
202 *lock = Some(schema);
203 }
204
205 pub fn clear_schema_cache(&self) {
207 let mut lock = self.schema_cache.write().unwrap();
208 *lock = None;
209 }
210
211 pub fn get_schema(&self) -> Option<SchemaRef> {
213 let lock = self.schema_cache.read().unwrap();
214 lock.clone()
215 }
216
217 pub fn get_df_cache(&self, key: &UniqueId, cache_hits: u32) -> CachedValue {
218 let guard = self.df_cache.read().unwrap();
219
220 match guard.get(key) {
221 Some(v) => v.clone(),
222 None => {
223 drop(guard);
224 let mut guard = self.df_cache.write().unwrap();
225
226 guard
227 .entry(key.clone())
228 .or_insert_with(|| {
229 Arc::new((AtomicI64::new(cache_hits as i64), OnceLock::new()))
230 })
231 .clone()
232 },
233 }
234 }
235
236 pub fn remove_df_cache(&self, key: &UniqueId) {
237 let mut guard = self.df_cache.write().unwrap();
238 let _ = guard.remove(key).unwrap();
239 }
240
241 pub fn clear_window_expr_cache(&self) {
243 self.window_cache.clear();
244 }
245
246 fn set_flags(&self, f: &dyn Fn(StateFlags) -> StateFlags) {
247 let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
248 let flags = f(flags);
249 self.flags.store(flags.as_u8(), Ordering::Relaxed);
250 }
251
252 pub fn cache_window(&self) -> bool {
254 let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
255 flags.contains(StateFlags::CACHE_WINDOW_EXPR)
256 }
257
258 pub fn has_window(&self) -> bool {
260 let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
261 flags.contains(StateFlags::HAS_WINDOW)
262 }
263
264 pub fn verbose(&self) -> bool {
266 let flags: StateFlags = self.flags.load(Ordering::Relaxed).into();
267 flags.contains(StateFlags::VERBOSE)
268 }
269
270 pub fn remove_cache_window_flag(&mut self) {
271 self.set_flags(&|mut flags| {
272 flags.remove(StateFlags::CACHE_WINDOW_EXPR);
273 flags
274 });
275 }
276
277 pub fn insert_cache_window_flag(&mut self) {
278 self.set_flags(&|mut flags| {
279 flags.insert(StateFlags::CACHE_WINDOW_EXPR);
280 flags
281 });
282 }
283 pub fn insert_has_window_function_flag(&mut self) {
285 self.set_flags(&|mut flags| {
286 flags.insert(StateFlags::HAS_WINDOW);
287 flags
288 });
289 }
290}
291
292impl Default for ExecutionState {
293 fn default() -> Self {
294 ExecutionState::new()
295 }
296}
297
298impl Clone for ExecutionState {
299 fn clone(&self) -> Self {
301 Self {
302 df_cache: self.df_cache.clone(),
303 schema_cache: self.schema_cache.read().unwrap().clone().into(),
304 window_cache: self.window_cache.clone(),
305 branch_idx: self.branch_idx,
306 flags: AtomicU8::new(self.flags.load(Ordering::Relaxed)),
307 ext_contexts: self.ext_contexts.clone(),
308 node_timer: self.node_timer.clone(),
309 stop: self.stop.clone(),
310 }
311 }
312}