use super::{NodeEndpoint, NodeId, ProxyError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
fn quote_session_ident(name: &str) -> String {
let mut out = String::with_capacity(name.len() + 2);
out.push('"');
for ch in name.chars() {
if ch == '"' {
out.push_str("\"\"");
} else {
out.push(ch);
}
}
out.push('"');
out
}
#[derive(Debug, Clone)]
pub struct SessionState {
pub session_id: Uuid,
pub user: String,
pub database: String,
pub application_name: Option<String>,
pub client_encoding: String,
pub server_encoding: String,
pub timezone: String,
pub search_path: Vec<String>,
pub datestyle: String,
pub intervalstyle: String,
pub custom_parameters: HashMap<String, String>,
pub temp_tables: Vec<TempTableInfo>,
pub prepared_statements: HashMap<String, PreparedStatementInfo>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_activity: chrono::DateTime<chrono::Utc>,
pub original_node: NodeId,
}
#[derive(Debug, Clone)]
pub struct TempTableInfo {
pub name: String,
pub schema: String,
pub columns: Vec<ColumnDef>,
pub has_data: bool,
pub row_count: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct ColumnDef {
pub name: String,
pub data_type: String,
pub nullable: bool,
pub default_expr: Option<String>,
}
#[derive(Debug, Clone)]
pub struct PreparedStatementInfo {
pub name: String,
pub query: String,
pub param_types: Vec<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
impl SessionState {
pub fn new(session_id: Uuid, user: String, database: String, node: NodeId) -> Self {
Self {
session_id,
user,
database,
application_name: None,
client_encoding: "UTF8".to_string(),
server_encoding: "UTF8".to_string(),
timezone: "UTC".to_string(),
search_path: vec!["public".to_string()],
datestyle: "ISO, MDY".to_string(),
intervalstyle: "postgres".to_string(),
custom_parameters: HashMap::new(),
temp_tables: Vec::new(),
prepared_statements: HashMap::new(),
created_at: chrono::Utc::now(),
last_activity: chrono::Utc::now(),
original_node: node,
}
}
pub fn set_parameter(&mut self, name: String, value: String) {
match name.to_lowercase().as_str() {
"timezone" => self.timezone = value,
"search_path" => {
self.search_path = value.split(',').map(|s| s.trim().to_string()).collect()
}
"client_encoding" => self.client_encoding = value,
"datestyle" => self.datestyle = value,
"intervalstyle" => self.intervalstyle = value,
"application_name" => self.application_name = Some(value),
_ => {
self.custom_parameters.insert(name, value);
}
}
self.last_activity = chrono::Utc::now();
}
pub fn get_parameter(&self, name: &str) -> Option<String> {
match name.to_lowercase().as_str() {
"timezone" => Some(self.timezone.clone()),
"search_path" => Some(self.search_path.join(", ")),
"client_encoding" => Some(self.client_encoding.clone()),
"server_encoding" => Some(self.server_encoding.clone()),
"datestyle" => Some(self.datestyle.clone()),
"intervalstyle" => Some(self.intervalstyle.clone()),
"application_name" => self.application_name.clone(),
_ => self.custom_parameters.get(name).cloned(),
}
}
pub fn add_prepared_statement(&mut self, info: PreparedStatementInfo) {
self.prepared_statements.insert(info.name.clone(), info);
self.last_activity = chrono::Utc::now();
}
pub fn remove_prepared_statement(&mut self, name: &str) {
self.prepared_statements.remove(name);
}
pub fn add_temp_table(&mut self, info: TempTableInfo) {
self.temp_tables.push(info);
self.last_activity = chrono::Utc::now();
}
pub fn generate_restore_statements(&self) -> Vec<String> {
let mut statements = Vec::new();
statements.push(format!("SET timezone TO '{}'", self.timezone));
statements.push(format!(
"SET search_path TO {}",
self.search_path.join(", ")
));
statements.push(format!("SET client_encoding TO '{}'", self.client_encoding));
statements.push(format!("SET datestyle TO '{}'", self.datestyle));
statements.push(format!("SET intervalstyle TO '{}'", self.intervalstyle));
if let Some(ref app_name) = self.application_name {
statements.push(format!("SET application_name TO '{}'", app_name));
}
for (name, value) in &self.custom_parameters {
statements.push(format!("SET {} TO '{}'", name, value));
}
for prep in self.prepared_statements.values() {
if prep.param_types.is_empty() {
statements.push(format!("PREPARE {} AS {}", prep.name, prep.query));
} else {
statements.push(format!(
"PREPARE {} ({}) AS {}",
prep.name,
prep.param_types.join(", "),
prep.query
));
}
}
statements
}
}
#[derive(Debug, Clone)]
pub struct SessionMigrateResult {
pub session_id: Uuid,
pub success: bool,
pub target_node: NodeId,
pub parameters_restored: usize,
pub prepared_statements_restored: usize,
pub temp_tables_migrated: usize,
pub temp_tables_failed: usize,
pub duration_ms: u64,
pub error: Option<String>,
}
pub struct SessionMigrate {
sessions: Arc<RwLock<HashMap<Uuid, SessionState>>>,
enabled: bool,
migrate_temp_tables: bool,
max_sessions: usize,
backend_template: Option<crate::backend::BackendConfig>,
endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
}
impl SessionMigrate {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
enabled: true,
migrate_temp_tables: false, max_sessions: 10000,
backend_template: None,
endpoints: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_max_sessions(mut self, max: usize) -> Self {
self.max_sessions = max;
self
}
pub fn with_backend_template(
mut self,
template: crate::backend::BackendConfig,
) -> Self {
self.backend_template = Some(template);
self
}
pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
self.endpoints.write().await.insert(node_id, endpoint);
}
fn build_config(
&self,
endpoint: &NodeEndpoint,
) -> Option<crate::backend::BackendConfig> {
self.backend_template.as_ref().map(|t| {
let mut c = t.clone();
c.host = endpoint.host.clone();
c.port = endpoint.port;
c
})
}
pub fn with_temp_table_migration(mut self, enabled: bool) -> Self {
self.migrate_temp_tables = enabled;
self
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub async fn register_session(&self, state: SessionState) -> Result<()> {
if !self.enabled {
return Ok(());
}
let session_id = state.session_id;
{
let sessions = self.sessions.read().await;
if sessions.len() >= self.max_sessions && !sessions.contains_key(&session_id) {
return Err(ProxyError::SessionMigration(format!(
"Maximum sessions ({}) exceeded",
self.max_sessions
)));
}
}
self.sessions.write().await.insert(session_id, state);
tracing::debug!("Registered session {:?}", session_id);
Ok(())
}
pub async fn set_parameter(
&self,
session_id: Uuid,
name: String,
value: String,
) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut sessions = self.sessions.write().await;
let session = sessions.get_mut(&session_id).ok_or_else(|| {
ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
})?;
session.set_parameter(name, value);
Ok(())
}
pub async fn add_prepared_statement(
&self,
session_id: Uuid,
info: PreparedStatementInfo,
) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut sessions = self.sessions.write().await;
let session = sessions.get_mut(&session_id).ok_or_else(|| {
ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
})?;
session.add_prepared_statement(info);
Ok(())
}
pub async fn remove_prepared_statement(&self, session_id: Uuid, name: &str) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut sessions = self.sessions.write().await;
if let Some(session) = sessions.get_mut(&session_id) {
session.remove_prepared_statement(name);
}
Ok(())
}
pub async fn add_temp_table(&self, session_id: Uuid, info: TempTableInfo) -> Result<()> {
if !self.enabled {
return Ok(());
}
let mut sessions = self.sessions.write().await;
let session = sessions.get_mut(&session_id).ok_or_else(|| {
ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
})?;
session.add_temp_table(info);
Ok(())
}
pub async fn get_session(&self, session_id: &Uuid) -> Option<SessionState> {
self.sessions.read().await.get(session_id).cloned()
}
pub async fn close_session(&self, session_id: &Uuid) {
self.sessions.write().await.remove(session_id);
tracing::debug!("Closed session {:?}", session_id);
}
pub async fn migrate_session(
&self,
session_id: Uuid,
target_node: NodeId,
) -> Result<SessionMigrateResult> {
let start = std::time::Instant::now();
let session = self.get_session(&session_id).await.ok_or_else(|| {
ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
})?;
let statements = session.generate_restore_statements();
let mut parameters_restored = 0;
let mut prepared_statements_restored = 0;
for stmt in &statements {
match self.execute_statement(target_node, stmt).await {
Ok(()) => {
if stmt.starts_with("SET ") {
parameters_restored += 1;
} else if stmt.starts_with("PREPARE ") {
prepared_statements_restored += 1;
}
}
Err(e) => {
tracing::warn!("Failed to execute restore statement: {} - {}", stmt, e);
}
}
}
let mut temp_tables_migrated = 0;
let mut temp_tables_failed = 0;
if self.migrate_temp_tables {
for table in &session.temp_tables {
match self.migrate_temp_table(target_node, table).await {
Ok(()) => temp_tables_migrated += 1,
Err(e) => {
temp_tables_failed += 1;
tracing::warn!(
"Failed to migrate temp table {}: {}",
table.name,
e
);
}
}
}
}
{
let mut sessions = self.sessions.write().await;
if let Some(s) = sessions.get_mut(&session_id) {
s.original_node = target_node;
s.last_activity = chrono::Utc::now();
}
}
let duration_ms = start.elapsed().as_millis() as u64;
tracing::info!(
"Migrated session {:?} to node {:?}: {} params, {} prepared, {}ms",
session_id,
target_node,
parameters_restored,
prepared_statements_restored,
duration_ms
);
Ok(SessionMigrateResult {
session_id,
success: true,
target_node,
parameters_restored,
prepared_statements_restored,
temp_tables_migrated,
temp_tables_failed,
duration_ms,
error: None,
})
}
async fn execute_statement(&self, node: NodeId, stmt: &str) -> Result<()> {
let endpoint = self.endpoints.read().await.get(&node).cloned();
let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
Some(c) => c,
None => {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
return Ok(());
}
};
let mut client = crate::backend::BackendClient::connect(&cfg)
.await
.map_err(|e| ProxyError::SessionMigration(format!("connect: {}", e)))?;
let outcome = client.execute(stmt).await;
client.close().await;
outcome
.map(|_| ())
.map_err(|e| ProxyError::SessionMigration(format!("execute: {}", e)))
}
async fn migrate_temp_table(
&self,
node: NodeId,
table: &TempTableInfo,
) -> Result<()> {
let endpoint = self.endpoints.read().await.get(&node).cloned();
let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
Some(c) => c,
None => {
tracing::debug!(
table = %table.name,
"migrate_temp_table: skeleton path (no backend template)"
);
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
return Ok(());
}
};
let mut stmt = String::with_capacity(64 + table.name.len());
stmt.push_str("CREATE TEMP TABLE IF NOT EXISTS ");
stmt.push_str("e_session_ident(&table.name));
stmt.push_str(" (");
for (i, col) in table.columns.iter().enumerate() {
if i > 0 {
stmt.push_str(", ");
}
stmt.push_str("e_session_ident(&col.name));
stmt.push(' ');
stmt.push_str(&col.data_type);
if !col.nullable {
stmt.push_str(" NOT NULL");
}
if let Some(default) = &col.default_expr {
stmt.push_str(" DEFAULT ");
stmt.push_str(default);
}
}
stmt.push(')');
let mut client = crate::backend::BackendClient::connect(&cfg)
.await
.map_err(|e| ProxyError::SessionMigration(format!("connect: {}", e)))?;
let outcome = client.execute(&stmt).await;
client.close().await;
outcome.map(|_| ()).map_err(|e| {
ProxyError::SessionMigration(format!("create temp table: {}", e))
})?;
if table.has_data {
tracing::warn!(
table = %table.name,
"temp table has data but migration intentionally does not copy it — route writes through the journal and use failover replay"
);
}
Ok(())
}
pub async fn stats(&self) -> SessionMigrateStats {
let sessions = self.sessions.read().await;
let total_prepared: usize = sessions
.values()
.map(|s| s.prepared_statements.len())
.sum();
let total_temp_tables: usize = sessions.values().map(|s| s.temp_tables.len()).sum();
SessionMigrateStats {
active_sessions: sessions.len(),
total_prepared_statements: total_prepared,
total_temp_tables,
enabled: self.enabled,
temp_table_migration_enabled: self.migrate_temp_tables,
}
}
}
impl Default for SessionMigrate {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SessionMigrateStats {
pub active_sessions: usize,
pub total_prepared_statements: usize,
pub total_temp_tables: usize,
pub enabled: bool,
pub temp_table_migration_enabled: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_state_new() {
let session_id = Uuid::new_v4();
let node_id = NodeId::new();
let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), node_id);
assert_eq!(state.user, "user");
assert_eq!(state.database, "db");
assert_eq!(state.timezone, "UTC");
assert_eq!(state.search_path, vec!["public"]);
}
#[test]
fn test_set_get_parameter() {
let mut state = SessionState::new(
Uuid::new_v4(),
"user".to_string(),
"db".to_string(),
NodeId::new(),
);
state.set_parameter("timezone".to_string(), "America/New_York".to_string());
assert_eq!(state.get_parameter("timezone"), Some("America/New_York".to_string()));
state.set_parameter("custom_param".to_string(), "custom_value".to_string());
assert_eq!(state.get_parameter("custom_param"), Some("custom_value".to_string()));
}
#[test]
fn test_generate_restore_statements() {
let mut state = SessionState::new(
Uuid::new_v4(),
"user".to_string(),
"db".to_string(),
NodeId::new(),
);
state.set_parameter("timezone".to_string(), "UTC".to_string());
state.add_prepared_statement(PreparedStatementInfo {
name: "my_query".to_string(),
query: "SELECT * FROM users WHERE id = $1".to_string(),
param_types: vec!["integer".to_string()],
created_at: chrono::Utc::now(),
});
let statements = state.generate_restore_statements();
assert!(statements.iter().any(|s| s.contains("timezone")));
assert!(statements.iter().any(|s| s.contains("PREPARE my_query")));
}
#[tokio::test]
async fn test_register_session() {
let migrate = SessionMigrate::new();
let session_id = Uuid::new_v4();
let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
migrate.register_session(state).await.unwrap();
let session = migrate.get_session(&session_id).await;
assert!(session.is_some());
}
#[tokio::test]
async fn test_set_parameter() {
let migrate = SessionMigrate::new();
let session_id = Uuid::new_v4();
let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
migrate.register_session(state).await.unwrap();
migrate
.set_parameter(session_id, "timezone".to_string(), "Europe/London".to_string())
.await
.unwrap();
let session = migrate.get_session(&session_id).await.unwrap();
assert_eq!(session.timezone, "Europe/London");
}
#[tokio::test]
async fn test_migrate_session() {
let migrate = SessionMigrate::new();
let session_id = Uuid::new_v4();
let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
migrate.register_session(state).await.unwrap();
let target = NodeId::new();
let result = migrate.migrate_session(session_id, target).await.unwrap();
assert!(result.success);
assert!(result.parameters_restored > 0);
}
#[tokio::test]
async fn test_close_session() {
let migrate = SessionMigrate::new();
let session_id = Uuid::new_v4();
let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
migrate.register_session(state).await.unwrap();
migrate.close_session(&session_id).await;
assert!(migrate.get_session(&session_id).await.is_none());
}
#[tokio::test]
async fn test_stats() {
let migrate = SessionMigrate::new();
let session_id = Uuid::new_v4();
let mut state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
state.add_prepared_statement(PreparedStatementInfo {
name: "ps1".to_string(),
query: "SELECT 1".to_string(),
param_types: vec![],
created_at: chrono::Utc::now(),
});
migrate.register_session(state).await.unwrap();
let stats = migrate.stats().await;
assert_eq!(stats.active_sessions, 1);
assert_eq!(stats.total_prepared_statements, 1);
}
}