1use super::{NodeEndpoint, NodeId, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12fn quote_session_ident(name: &str) -> String {
16 let mut out = String::with_capacity(name.len() + 2);
17 out.push('"');
18 for ch in name.chars() {
19 if ch == '"' {
20 out.push_str("\"\"");
21 } else {
22 out.push(ch);
23 }
24 }
25 out.push('"');
26 out
27}
28
29#[derive(Debug, Clone)]
31pub struct SessionState {
32 pub session_id: Uuid,
34 pub user: String,
36 pub database: String,
38 pub application_name: Option<String>,
40 pub client_encoding: String,
42 pub server_encoding: String,
44 pub timezone: String,
46 pub search_path: Vec<String>,
48 pub datestyle: String,
50 pub intervalstyle: String,
52 pub custom_parameters: HashMap<String, String>,
54 pub temp_tables: Vec<TempTableInfo>,
56 pub prepared_statements: HashMap<String, PreparedStatementInfo>,
58 pub created_at: chrono::DateTime<chrono::Utc>,
60 pub last_activity: chrono::DateTime<chrono::Utc>,
62 pub original_node: NodeId,
64}
65
66#[derive(Debug, Clone)]
68pub struct TempTableInfo {
69 pub name: String,
71 pub schema: String,
73 pub columns: Vec<ColumnDef>,
75 pub has_data: bool,
77 pub row_count: Option<u64>,
79}
80
81#[derive(Debug, Clone)]
83pub struct ColumnDef {
84 pub name: String,
86 pub data_type: String,
88 pub nullable: bool,
90 pub default_expr: Option<String>,
92}
93
94#[derive(Debug, Clone)]
96pub struct PreparedStatementInfo {
97 pub name: String,
99 pub query: String,
101 pub param_types: Vec<String>,
103 pub created_at: chrono::DateTime<chrono::Utc>,
105}
106
107impl SessionState {
108 pub fn new(session_id: Uuid, user: String, database: String, node: NodeId) -> Self {
110 Self {
111 session_id,
112 user,
113 database,
114 application_name: None,
115 client_encoding: "UTF8".to_string(),
116 server_encoding: "UTF8".to_string(),
117 timezone: "UTC".to_string(),
118 search_path: vec!["public".to_string()],
119 datestyle: "ISO, MDY".to_string(),
120 intervalstyle: "postgres".to_string(),
121 custom_parameters: HashMap::new(),
122 temp_tables: Vec::new(),
123 prepared_statements: HashMap::new(),
124 created_at: chrono::Utc::now(),
125 last_activity: chrono::Utc::now(),
126 original_node: node,
127 }
128 }
129
130 pub fn set_parameter(&mut self, name: String, value: String) {
132 match name.to_lowercase().as_str() {
134 "timezone" => self.timezone = value,
135 "search_path" => {
136 self.search_path = value.split(',').map(|s| s.trim().to_string()).collect()
137 }
138 "client_encoding" => self.client_encoding = value,
139 "datestyle" => self.datestyle = value,
140 "intervalstyle" => self.intervalstyle = value,
141 "application_name" => self.application_name = Some(value),
142 _ => {
143 self.custom_parameters.insert(name, value);
144 }
145 }
146 self.last_activity = chrono::Utc::now();
147 }
148
149 pub fn get_parameter(&self, name: &str) -> Option<String> {
151 match name.to_lowercase().as_str() {
152 "timezone" => Some(self.timezone.clone()),
153 "search_path" => Some(self.search_path.join(", ")),
154 "client_encoding" => Some(self.client_encoding.clone()),
155 "server_encoding" => Some(self.server_encoding.clone()),
156 "datestyle" => Some(self.datestyle.clone()),
157 "intervalstyle" => Some(self.intervalstyle.clone()),
158 "application_name" => self.application_name.clone(),
159 _ => self.custom_parameters.get(name).cloned(),
160 }
161 }
162
163 pub fn add_prepared_statement(&mut self, info: PreparedStatementInfo) {
165 self.prepared_statements.insert(info.name.clone(), info);
166 self.last_activity = chrono::Utc::now();
167 }
168
169 pub fn remove_prepared_statement(&mut self, name: &str) {
171 self.prepared_statements.remove(name);
172 }
173
174 pub fn add_temp_table(&mut self, info: TempTableInfo) {
176 self.temp_tables.push(info);
177 self.last_activity = chrono::Utc::now();
178 }
179
180 pub fn generate_restore_statements(&self) -> Vec<String> {
182 let mut statements = Vec::new();
183
184 statements.push(format!("SET timezone TO '{}'", self.timezone));
186 statements.push(format!(
187 "SET search_path TO {}",
188 self.search_path.join(", ")
189 ));
190 statements.push(format!("SET client_encoding TO '{}'", self.client_encoding));
191 statements.push(format!("SET datestyle TO '{}'", self.datestyle));
192 statements.push(format!("SET intervalstyle TO '{}'", self.intervalstyle));
193
194 if let Some(ref app_name) = self.application_name {
195 statements.push(format!("SET application_name TO '{}'", app_name));
196 }
197
198 for (name, value) in &self.custom_parameters {
200 statements.push(format!("SET {} TO '{}'", name, value));
201 }
202
203 for prep in self.prepared_statements.values() {
205 if prep.param_types.is_empty() {
206 statements.push(format!("PREPARE {} AS {}", prep.name, prep.query));
207 } else {
208 statements.push(format!(
209 "PREPARE {} ({}) AS {}",
210 prep.name,
211 prep.param_types.join(", "),
212 prep.query
213 ));
214 }
215 }
216
217 statements
218 }
219}
220
221#[derive(Debug, Clone)]
223pub struct SessionMigrateResult {
224 pub session_id: Uuid,
226 pub success: bool,
228 pub target_node: NodeId,
230 pub parameters_restored: usize,
232 pub prepared_statements_restored: usize,
234 pub temp_tables_migrated: usize,
236 pub temp_tables_failed: usize,
238 pub duration_ms: u64,
240 pub error: Option<String>,
242}
243
244pub struct SessionMigrate {
246 sessions: Arc<RwLock<HashMap<Uuid, SessionState>>>,
248 enabled: bool,
250 migrate_temp_tables: bool,
252 max_sessions: usize,
254 backend_template: Option<crate::backend::BackendConfig>,
258 endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
260}
261
262impl SessionMigrate {
263 pub fn new() -> Self {
265 Self {
266 sessions: Arc::new(RwLock::new(HashMap::new())),
267 enabled: true,
268 migrate_temp_tables: false, max_sessions: 10000,
270 backend_template: None,
271 endpoints: Arc::new(RwLock::new(HashMap::new())),
272 }
273 }
274
275 pub fn with_max_sessions(mut self, max: usize) -> Self {
277 self.max_sessions = max;
278 self
279 }
280
281 pub fn with_backend_template(
284 mut self,
285 template: crate::backend::BackendConfig,
286 ) -> Self {
287 self.backend_template = Some(template);
288 self
289 }
290
291 pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
293 self.endpoints.write().await.insert(node_id, endpoint);
294 }
295
296 fn build_config(
297 &self,
298 endpoint: &NodeEndpoint,
299 ) -> Option<crate::backend::BackendConfig> {
300 self.backend_template.as_ref().map(|t| {
301 let mut c = t.clone();
302 c.host = endpoint.host.clone();
303 c.port = endpoint.port;
304 c
305 })
306 }
307
308 pub fn with_temp_table_migration(mut self, enabled: bool) -> Self {
310 self.migrate_temp_tables = enabled;
311 self
312 }
313
314 pub fn set_enabled(&mut self, enabled: bool) {
316 self.enabled = enabled;
317 }
318
319 pub async fn register_session(&self, state: SessionState) -> Result<()> {
321 if !self.enabled {
322 return Ok(());
323 }
324
325 let session_id = state.session_id;
326
327 {
329 let sessions = self.sessions.read().await;
330 if sessions.len() >= self.max_sessions && !sessions.contains_key(&session_id) {
331 return Err(ProxyError::SessionMigration(format!(
332 "Maximum sessions ({}) exceeded",
333 self.max_sessions
334 )));
335 }
336 }
337
338 self.sessions.write().await.insert(session_id, state);
339 tracing::debug!("Registered session {:?}", session_id);
340
341 Ok(())
342 }
343
344 pub async fn set_parameter(
346 &self,
347 session_id: Uuid,
348 name: String,
349 value: String,
350 ) -> Result<()> {
351 if !self.enabled {
352 return Ok(());
353 }
354
355 let mut sessions = self.sessions.write().await;
356 let session = sessions.get_mut(&session_id).ok_or_else(|| {
357 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
358 })?;
359
360 session.set_parameter(name, value);
361 Ok(())
362 }
363
364 pub async fn add_prepared_statement(
366 &self,
367 session_id: Uuid,
368 info: PreparedStatementInfo,
369 ) -> Result<()> {
370 if !self.enabled {
371 return Ok(());
372 }
373
374 let mut sessions = self.sessions.write().await;
375 let session = sessions.get_mut(&session_id).ok_or_else(|| {
376 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
377 })?;
378
379 session.add_prepared_statement(info);
380 Ok(())
381 }
382
383 pub async fn remove_prepared_statement(&self, session_id: Uuid, name: &str) -> Result<()> {
385 if !self.enabled {
386 return Ok(());
387 }
388
389 let mut sessions = self.sessions.write().await;
390 if let Some(session) = sessions.get_mut(&session_id) {
391 session.remove_prepared_statement(name);
392 }
393 Ok(())
394 }
395
396 pub async fn add_temp_table(&self, session_id: Uuid, info: TempTableInfo) -> Result<()> {
398 if !self.enabled {
399 return Ok(());
400 }
401
402 let mut sessions = self.sessions.write().await;
403 let session = sessions.get_mut(&session_id).ok_or_else(|| {
404 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
405 })?;
406
407 session.add_temp_table(info);
408 Ok(())
409 }
410
411 pub async fn get_session(&self, session_id: &Uuid) -> Option<SessionState> {
413 self.sessions.read().await.get(session_id).cloned()
414 }
415
416 pub async fn close_session(&self, session_id: &Uuid) {
418 self.sessions.write().await.remove(session_id);
419 tracing::debug!("Closed session {:?}", session_id);
420 }
421
422 pub async fn migrate_session(
424 &self,
425 session_id: Uuid,
426 target_node: NodeId,
427 ) -> Result<SessionMigrateResult> {
428 let start = std::time::Instant::now();
429
430 let session = self.get_session(&session_id).await.ok_or_else(|| {
431 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
432 })?;
433
434 let statements = session.generate_restore_statements();
436
437 let mut parameters_restored = 0;
439 let mut prepared_statements_restored = 0;
440
441 for stmt in &statements {
442 match self.execute_statement(target_node, stmt).await {
443 Ok(()) => {
444 if stmt.starts_with("SET ") {
445 parameters_restored += 1;
446 } else if stmt.starts_with("PREPARE ") {
447 prepared_statements_restored += 1;
448 }
449 }
450 Err(e) => {
451 tracing::warn!("Failed to execute restore statement: {} - {}", stmt, e);
452 }
453 }
454 }
455
456 let mut temp_tables_migrated = 0;
458 let mut temp_tables_failed = 0;
459
460 if self.migrate_temp_tables {
461 for table in &session.temp_tables {
462 match self.migrate_temp_table(target_node, table).await {
463 Ok(()) => temp_tables_migrated += 1,
464 Err(e) => {
465 temp_tables_failed += 1;
466 tracing::warn!(
467 "Failed to migrate temp table {}: {}",
468 table.name,
469 e
470 );
471 }
472 }
473 }
474 }
475
476 {
478 let mut sessions = self.sessions.write().await;
479 if let Some(s) = sessions.get_mut(&session_id) {
480 s.original_node = target_node;
481 s.last_activity = chrono::Utc::now();
482 }
483 }
484
485 let duration_ms = start.elapsed().as_millis() as u64;
486
487 tracing::info!(
488 "Migrated session {:?} to node {:?}: {} params, {} prepared, {}ms",
489 session_id,
490 target_node,
491 parameters_restored,
492 prepared_statements_restored,
493 duration_ms
494 );
495
496 Ok(SessionMigrateResult {
497 session_id,
498 success: true,
499 target_node,
500 parameters_restored,
501 prepared_statements_restored,
502 temp_tables_migrated,
503 temp_tables_failed,
504 duration_ms,
505 error: None,
506 })
507 }
508
509 async fn execute_statement(&self, node: NodeId, stmt: &str) -> Result<()> {
516 let endpoint = self.endpoints.read().await.get(&node).cloned();
517 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
518 Some(c) => c,
519 None => {
520 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
521 return Ok(());
522 }
523 };
524
525 let mut client = crate::backend::BackendClient::connect(&cfg)
526 .await
527 .map_err(|e| ProxyError::SessionMigration(format!("connect: {}", e)))?;
528 let outcome = client.execute(stmt).await;
529 client.close().await;
530 outcome
531 .map(|_| ())
532 .map_err(|e| ProxyError::SessionMigration(format!("execute: {}", e)))
533 }
534
535 async fn migrate_temp_table(
544 &self,
545 node: NodeId,
546 table: &TempTableInfo,
547 ) -> Result<()> {
548 let endpoint = self.endpoints.read().await.get(&node).cloned();
549 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
550 Some(c) => c,
551 None => {
552 tracing::debug!(
553 table = %table.name,
554 "migrate_temp_table: skeleton path (no backend template)"
555 );
556 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
557 return Ok(());
558 }
559 };
560
561 let mut stmt = String::with_capacity(64 + table.name.len());
566 stmt.push_str("CREATE TEMP TABLE IF NOT EXISTS ");
567 stmt.push_str("e_session_ident(&table.name));
568 stmt.push_str(" (");
569 for (i, col) in table.columns.iter().enumerate() {
570 if i > 0 {
571 stmt.push_str(", ");
572 }
573 stmt.push_str("e_session_ident(&col.name));
574 stmt.push(' ');
575 stmt.push_str(&col.data_type);
576 if !col.nullable {
577 stmt.push_str(" NOT NULL");
578 }
579 if let Some(default) = &col.default_expr {
580 stmt.push_str(" DEFAULT ");
581 stmt.push_str(default);
582 }
583 }
584 stmt.push(')');
585
586 let mut client = crate::backend::BackendClient::connect(&cfg)
587 .await
588 .map_err(|e| ProxyError::SessionMigration(format!("connect: {}", e)))?;
589 let outcome = client.execute(&stmt).await;
590 client.close().await;
591 outcome.map(|_| ()).map_err(|e| {
592 ProxyError::SessionMigration(format!("create temp table: {}", e))
593 })?;
594
595 if table.has_data {
596 tracing::warn!(
597 table = %table.name,
598 "temp table has data but migration intentionally does not copy it — route writes through the journal and use failover replay"
599 );
600 }
601 Ok(())
602 }
603
604 pub async fn stats(&self) -> SessionMigrateStats {
606 let sessions = self.sessions.read().await;
607
608 let total_prepared: usize = sessions
609 .values()
610 .map(|s| s.prepared_statements.len())
611 .sum();
612
613 let total_temp_tables: usize = sessions.values().map(|s| s.temp_tables.len()).sum();
614
615 SessionMigrateStats {
616 active_sessions: sessions.len(),
617 total_prepared_statements: total_prepared,
618 total_temp_tables,
619 enabled: self.enabled,
620 temp_table_migration_enabled: self.migrate_temp_tables,
621 }
622 }
623}
624
625impl Default for SessionMigrate {
626 fn default() -> Self {
627 Self::new()
628 }
629}
630
631#[derive(Debug, Clone)]
633pub struct SessionMigrateStats {
634 pub active_sessions: usize,
636 pub total_prepared_statements: usize,
638 pub total_temp_tables: usize,
640 pub enabled: bool,
642 pub temp_table_migration_enabled: bool,
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649
650 #[test]
651 fn test_session_state_new() {
652 let session_id = Uuid::new_v4();
653 let node_id = NodeId::new();
654 let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), node_id);
655
656 assert_eq!(state.user, "user");
657 assert_eq!(state.database, "db");
658 assert_eq!(state.timezone, "UTC");
659 assert_eq!(state.search_path, vec!["public"]);
660 }
661
662 #[test]
663 fn test_set_get_parameter() {
664 let mut state = SessionState::new(
665 Uuid::new_v4(),
666 "user".to_string(),
667 "db".to_string(),
668 NodeId::new(),
669 );
670
671 state.set_parameter("timezone".to_string(), "America/New_York".to_string());
672 assert_eq!(state.get_parameter("timezone"), Some("America/New_York".to_string()));
673
674 state.set_parameter("custom_param".to_string(), "custom_value".to_string());
675 assert_eq!(state.get_parameter("custom_param"), Some("custom_value".to_string()));
676 }
677
678 #[test]
679 fn test_generate_restore_statements() {
680 let mut state = SessionState::new(
681 Uuid::new_v4(),
682 "user".to_string(),
683 "db".to_string(),
684 NodeId::new(),
685 );
686
687 state.set_parameter("timezone".to_string(), "UTC".to_string());
688 state.add_prepared_statement(PreparedStatementInfo {
689 name: "my_query".to_string(),
690 query: "SELECT * FROM users WHERE id = $1".to_string(),
691 param_types: vec!["integer".to_string()],
692 created_at: chrono::Utc::now(),
693 });
694
695 let statements = state.generate_restore_statements();
696
697 assert!(statements.iter().any(|s| s.contains("timezone")));
698 assert!(statements.iter().any(|s| s.contains("PREPARE my_query")));
699 }
700
701 #[tokio::test]
702 async fn test_register_session() {
703 let migrate = SessionMigrate::new();
704 let session_id = Uuid::new_v4();
705 let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
706
707 migrate.register_session(state).await.unwrap();
708
709 let session = migrate.get_session(&session_id).await;
710 assert!(session.is_some());
711 }
712
713 #[tokio::test]
714 async fn test_set_parameter() {
715 let migrate = SessionMigrate::new();
716 let session_id = Uuid::new_v4();
717 let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
718
719 migrate.register_session(state).await.unwrap();
720 migrate
721 .set_parameter(session_id, "timezone".to_string(), "Europe/London".to_string())
722 .await
723 .unwrap();
724
725 let session = migrate.get_session(&session_id).await.unwrap();
726 assert_eq!(session.timezone, "Europe/London");
727 }
728
729 #[tokio::test]
730 async fn test_migrate_session() {
731 let migrate = SessionMigrate::new();
732 let session_id = Uuid::new_v4();
733 let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
734
735 migrate.register_session(state).await.unwrap();
736
737 let target = NodeId::new();
738 let result = migrate.migrate_session(session_id, target).await.unwrap();
739
740 assert!(result.success);
741 assert!(result.parameters_restored > 0);
742 }
743
744 #[tokio::test]
745 async fn test_close_session() {
746 let migrate = SessionMigrate::new();
747 let session_id = Uuid::new_v4();
748 let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
749
750 migrate.register_session(state).await.unwrap();
751 migrate.close_session(&session_id).await;
752
753 assert!(migrate.get_session(&session_id).await.is_none());
754 }
755
756 #[tokio::test]
757 async fn test_stats() {
758 let migrate = SessionMigrate::new();
759 let session_id = Uuid::new_v4();
760 let mut state = SessionState::new(session_id, "user".to_string(), "db".to_string(), NodeId::new());
761
762 state.add_prepared_statement(PreparedStatementInfo {
763 name: "ps1".to_string(),
764 query: "SELECT 1".to_string(),
765 param_types: vec![],
766 created_at: chrono::Utc::now(),
767 });
768
769 migrate.register_session(state).await.unwrap();
770
771 let stats = migrate.stats().await;
772 assert_eq!(stats.active_sessions, 1);
773 assert_eq!(stats.total_prepared_statements, 1);
774 }
775}