use std::collections::HashMap;
use std::net::SocketAddr;
use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use super::store::SessionStore;
pub struct TempTableEntry {
pub schema: SchemaRef,
pub on_commit: OnCommitAction,
pub batches: Vec<RecordBatch>,
}
impl std::fmt::Debug for TempTableEntry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TempTableEntry")
.field("schema", &self.schema)
.field("on_commit", &self.on_commit)
.field("batches", &self.batches.len())
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OnCommitAction {
PreserveRows,
DeleteRows,
Drop,
}
pub struct TempTableRegistry {
tables: HashMap<String, TempTableEntry>,
}
impl TempTableRegistry {
pub fn new() -> Self {
Self {
tables: HashMap::new(),
}
}
pub fn register(&mut self, name: String, entry: TempTableEntry) {
self.tables.insert(name, entry);
}
pub fn exists(&self, name: &str) -> bool {
self.tables.contains_key(name)
}
pub fn get(&self, name: &str) -> Option<&TempTableEntry> {
self.tables.get(name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut TempTableEntry> {
self.tables.get_mut(name)
}
pub fn remove(&mut self, name: &str) -> bool {
self.tables.remove(name).is_some()
}
pub fn names(&self) -> Vec<String> {
self.tables.keys().cloned().collect()
}
pub fn on_commit(&mut self) -> Vec<String> {
let mut to_drop = Vec::new();
for (name, entry) in &self.tables {
if entry.on_commit == OnCommitAction::Drop {
to_drop.push(name.clone());
}
}
for name in &to_drop {
self.tables.remove(name);
}
to_drop
}
pub fn clear(&mut self) {
self.tables.clear();
}
}
impl Default for TempTableRegistry {
fn default() -> Self {
Self::new()
}
}
impl SessionStore {
pub fn register_temp_table(&self, addr: &SocketAddr, name: String, entry: TempTableEntry) {
self.write_session(addr, |session| {
session.temp_tables.register(name, entry);
});
}
pub fn has_temp_table(&self, addr: &SocketAddr, name: &str) -> bool {
self.read_session(addr, |s| s.temp_tables.exists(name))
.unwrap_or(false)
}
pub fn remove_temp_table(&self, addr: &SocketAddr, name: &str) -> bool {
self.write_session(addr, |session| session.temp_tables.remove(name))
.unwrap_or(false)
}
pub fn temp_table_names(&self, addr: &SocketAddr) -> Vec<String> {
self.read_session(addr, |s| s.temp_tables.names())
.unwrap_or_default()
}
}