use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::RwLock;
use nodedb_types::DatabaseId;
use super::state::{PgSession, TransactionState};
pub struct SessionStore {
sessions: RwLock<HashMap<SocketAddr, PgSession>>,
}
impl Default for SessionStore {
fn default() -> Self {
Self::new()
}
}
impl SessionStore {
pub fn new() -> Self {
Self {
sessions: RwLock::new(HashMap::new()),
}
}
pub fn ensure_session(&self, addr: SocketAddr) {
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
sessions.entry(addr).or_insert_with(PgSession::new);
}
pub fn remove(&self, addr: &SocketAddr) {
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
sessions.remove(addr);
}
pub fn all_sessions(&self) -> Vec<(String, String)> {
let sessions = self.sessions.read().unwrap_or_else(|p| p.into_inner());
sessions
.iter()
.map(|(addr, session)| {
let tx = match session.tx_state {
TransactionState::Idle => "idle",
TransactionState::InBlock => "in_transaction",
TransactionState::Failed => "failed",
};
(addr.to_string(), tx.to_string())
})
.collect()
}
pub fn count(&self) -> usize {
let sessions = self.sessions.read().unwrap_or_else(|p| p.into_inner());
sessions.len()
}
pub fn get_cached_plan<F>(
&self,
addr: &SocketAddr,
sql: &str,
current_version: F,
) -> Option<(
Vec<crate::control::planner::physical::PhysicalTask>,
crate::control::planner::descriptor_set::DescriptorVersionSet,
)>
where
F: Fn(&nodedb_cluster::DescriptorId) -> Option<u64>,
{
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
sessions
.get_mut(addr)
.and_then(|s| s.plan_cache.get(sql, current_version))
}
pub fn put_cached_plan(
&self,
addr: &SocketAddr,
sql: &str,
tasks: Vec<crate::control::planner::physical::PhysicalTask>,
versions: crate::control::planner::descriptor_set::DescriptorVersionSet,
) {
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
if let Some(session) = sessions.get_mut(addr) {
session.plan_cache.put(sql, tasks, versions);
}
}
pub fn get_current_database(&self, addr: &SocketAddr) -> Option<DatabaseId> {
let sessions = self.sessions.read().unwrap_or_else(|p| p.into_inner());
sessions.get(addr)?.current_database
}
pub fn set_current_database(&self, addr: &SocketAddr, db_id: DatabaseId) {
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
if let Some(session) = sessions.get_mut(addr) {
session.current_database = Some(db_id);
}
}
pub fn reset_for_database_switch(&self, addr: &SocketAddr, new_db: DatabaseId) {
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
if let Some(session) = sessions.get_mut(addr) {
session.tx_state = TransactionState::Idle;
session.tx_buffer.clear();
session.tx_snapshot_lsn = None;
session.tx_read_set.clear();
session.savepoints.clear();
session.pending_offset_commits.clear();
session.pending_notifies.clear();
session.prepared_stmts.clear();
session.plan_cache.clear();
session.current_database = Some(new_db);
}
}
pub(super) fn read_session<R>(
&self,
addr: &SocketAddr,
f: impl FnOnce(&PgSession) -> R,
) -> Option<R> {
let sessions = self.sessions.read().unwrap_or_else(|p| p.into_inner());
sessions.get(addr).map(f)
}
pub(super) fn write_session<R>(
&self,
addr: &SocketAddr,
f: impl FnOnce(&mut PgSession) -> R,
) -> Option<R> {
let mut sessions = self.sessions.write().unwrap_or_else(|p| p.into_inner());
sessions.get_mut(addr).map(f)
}
}