1use super::{NodeEndpoint, NodeId, ProxyError, Result};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::RwLock;
10use uuid::Uuid;
11
12fn quote_session_ident(name: &str) -> String {
16 let mut out = String::with_capacity(name.len() + 2);
17 out.push('"');
18 for ch in name.chars() {
19 if ch == '"' {
20 out.push_str("\"\"");
21 } else {
22 out.push(ch);
23 }
24 }
25 out.push('"');
26 out
27}
28
29#[derive(Debug, Clone)]
31pub struct SessionState {
32 pub session_id: Uuid,
34 pub user: String,
36 pub database: String,
38 pub application_name: Option<String>,
40 pub client_encoding: String,
42 pub server_encoding: String,
44 pub timezone: String,
46 pub search_path: Vec<String>,
48 pub datestyle: String,
50 pub intervalstyle: String,
52 pub custom_parameters: HashMap<String, String>,
54 pub temp_tables: Vec<TempTableInfo>,
56 pub prepared_statements: HashMap<String, PreparedStatementInfo>,
58 pub created_at: chrono::DateTime<chrono::Utc>,
60 pub last_activity: chrono::DateTime<chrono::Utc>,
62 pub original_node: NodeId,
64}
65
66#[derive(Debug, Clone)]
68pub struct TempTableInfo {
69 pub name: String,
71 pub schema: String,
73 pub columns: Vec<ColumnDef>,
75 pub has_data: bool,
77 pub row_count: Option<u64>,
79}
80
81#[derive(Debug, Clone)]
83pub struct ColumnDef {
84 pub name: String,
86 pub data_type: String,
88 pub nullable: bool,
90 pub default_expr: Option<String>,
92}
93
94#[derive(Debug, Clone)]
96pub struct PreparedStatementInfo {
97 pub name: String,
99 pub query: String,
101 pub param_types: Vec<String>,
103 pub created_at: chrono::DateTime<chrono::Utc>,
105}
106
107impl SessionState {
108 pub fn new(session_id: Uuid, user: String, database: String, node: NodeId) -> Self {
110 Self {
111 session_id,
112 user,
113 database,
114 application_name: None,
115 client_encoding: "UTF8".to_string(),
116 server_encoding: "UTF8".to_string(),
117 timezone: "UTC".to_string(),
118 search_path: vec!["public".to_string()],
119 datestyle: "ISO, MDY".to_string(),
120 intervalstyle: "postgres".to_string(),
121 custom_parameters: HashMap::new(),
122 temp_tables: Vec::new(),
123 prepared_statements: HashMap::new(),
124 created_at: chrono::Utc::now(),
125 last_activity: chrono::Utc::now(),
126 original_node: node,
127 }
128 }
129
130 pub fn set_parameter(&mut self, name: String, value: String) {
132 match name.to_lowercase().as_str() {
134 "timezone" => self.timezone = value,
135 "search_path" => {
136 self.search_path = value.split(',').map(|s| s.trim().to_string()).collect()
137 }
138 "client_encoding" => self.client_encoding = value,
139 "datestyle" => self.datestyle = value,
140 "intervalstyle" => self.intervalstyle = value,
141 "application_name" => self.application_name = Some(value),
142 _ => {
143 self.custom_parameters.insert(name, value);
144 }
145 }
146 self.last_activity = chrono::Utc::now();
147 }
148
149 pub fn get_parameter(&self, name: &str) -> Option<String> {
151 match name.to_lowercase().as_str() {
152 "timezone" => Some(self.timezone.clone()),
153 "search_path" => Some(self.search_path.join(", ")),
154 "client_encoding" => Some(self.client_encoding.clone()),
155 "server_encoding" => Some(self.server_encoding.clone()),
156 "datestyle" => Some(self.datestyle.clone()),
157 "intervalstyle" => Some(self.intervalstyle.clone()),
158 "application_name" => self.application_name.clone(),
159 _ => self.custom_parameters.get(name).cloned(),
160 }
161 }
162
163 pub fn add_prepared_statement(&mut self, info: PreparedStatementInfo) {
165 self.prepared_statements.insert(info.name.clone(), info);
166 self.last_activity = chrono::Utc::now();
167 }
168
169 pub fn remove_prepared_statement(&mut self, name: &str) {
171 self.prepared_statements.remove(name);
172 }
173
174 pub fn add_temp_table(&mut self, info: TempTableInfo) {
176 self.temp_tables.push(info);
177 self.last_activity = chrono::Utc::now();
178 }
179
180 pub fn generate_restore_statements(&self) -> Vec<String> {
182 let mut statements = Vec::new();
183
184 statements.push(format!("SET timezone TO '{}'", self.timezone));
186 statements.push(format!(
187 "SET search_path TO {}",
188 self.search_path.join(", ")
189 ));
190 statements.push(format!("SET client_encoding TO '{}'", self.client_encoding));
191 statements.push(format!("SET datestyle TO '{}'", self.datestyle));
192 statements.push(format!("SET intervalstyle TO '{}'", self.intervalstyle));
193
194 if let Some(ref app_name) = self.application_name {
195 statements.push(format!("SET application_name TO '{}'", app_name));
196 }
197
198 for (name, value) in &self.custom_parameters {
200 statements.push(format!("SET {} TO '{}'", name, value));
201 }
202
203 for prep in self.prepared_statements.values() {
205 if prep.param_types.is_empty() {
206 statements.push(format!("PREPARE {} AS {}", prep.name, prep.query));
207 } else {
208 statements.push(format!(
209 "PREPARE {} ({}) AS {}",
210 prep.name,
211 prep.param_types.join(", "),
212 prep.query
213 ));
214 }
215 }
216
217 statements
218 }
219}
220
221#[derive(Debug, Clone)]
223pub struct SessionMigrateResult {
224 pub session_id: Uuid,
226 pub success: bool,
228 pub target_node: NodeId,
230 pub parameters_restored: usize,
232 pub prepared_statements_restored: usize,
234 pub temp_tables_migrated: usize,
236 pub temp_tables_failed: usize,
238 pub duration_ms: u64,
240 pub error: Option<String>,
242}
243
244pub struct SessionMigrate {
246 sessions: Arc<RwLock<HashMap<Uuid, SessionState>>>,
248 enabled: bool,
250 migrate_temp_tables: bool,
252 max_sessions: usize,
254 backend_template: Option<crate::backend::BackendConfig>,
258 endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
260}
261
262impl SessionMigrate {
263 pub fn new() -> Self {
265 Self {
266 sessions: Arc::new(RwLock::new(HashMap::new())),
267 enabled: true,
268 migrate_temp_tables: false, max_sessions: 10000,
270 backend_template: None,
271 endpoints: Arc::new(RwLock::new(HashMap::new())),
272 }
273 }
274
275 pub fn with_max_sessions(mut self, max: usize) -> Self {
277 self.max_sessions = max;
278 self
279 }
280
281 pub fn with_backend_template(mut self, template: crate::backend::BackendConfig) -> Self {
284 self.backend_template = Some(template);
285 self
286 }
287
288 pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
290 self.endpoints.write().await.insert(node_id, endpoint);
291 }
292
293 fn build_config(&self, endpoint: &NodeEndpoint) -> Option<crate::backend::BackendConfig> {
294 self.backend_template.as_ref().map(|t| {
295 let mut c = t.clone();
296 c.host = endpoint.host.clone();
297 c.port = endpoint.port;
298 c
299 })
300 }
301
302 pub fn with_temp_table_migration(mut self, enabled: bool) -> Self {
304 self.migrate_temp_tables = enabled;
305 self
306 }
307
308 pub fn set_enabled(&mut self, enabled: bool) {
310 self.enabled = enabled;
311 }
312
313 pub async fn register_session(&self, state: SessionState) -> Result<()> {
315 if !self.enabled {
316 return Ok(());
317 }
318
319 let session_id = state.session_id;
320
321 {
323 let sessions = self.sessions.read().await;
324 if sessions.len() >= self.max_sessions && !sessions.contains_key(&session_id) {
325 return Err(ProxyError::SessionMigration(format!(
326 "Maximum sessions ({}) exceeded",
327 self.max_sessions
328 )));
329 }
330 }
331
332 self.sessions.write().await.insert(session_id, state);
333 tracing::debug!("Registered session {:?}", session_id);
334
335 Ok(())
336 }
337
338 pub async fn set_parameter(&self, session_id: Uuid, name: String, value: String) -> Result<()> {
340 if !self.enabled {
341 return Ok(());
342 }
343
344 let mut sessions = self.sessions.write().await;
345 let session = sessions.get_mut(&session_id).ok_or_else(|| {
346 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
347 })?;
348
349 session.set_parameter(name, value);
350 Ok(())
351 }
352
353 pub async fn add_prepared_statement(
355 &self,
356 session_id: Uuid,
357 info: PreparedStatementInfo,
358 ) -> Result<()> {
359 if !self.enabled {
360 return Ok(());
361 }
362
363 let mut sessions = self.sessions.write().await;
364 let session = sessions.get_mut(&session_id).ok_or_else(|| {
365 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
366 })?;
367
368 session.add_prepared_statement(info);
369 Ok(())
370 }
371
372 pub async fn remove_prepared_statement(&self, session_id: Uuid, name: &str) -> Result<()> {
374 if !self.enabled {
375 return Ok(());
376 }
377
378 let mut sessions = self.sessions.write().await;
379 if let Some(session) = sessions.get_mut(&session_id) {
380 session.remove_prepared_statement(name);
381 }
382 Ok(())
383 }
384
385 pub async fn add_temp_table(&self, session_id: Uuid, info: TempTableInfo) -> Result<()> {
387 if !self.enabled {
388 return Ok(());
389 }
390
391 let mut sessions = self.sessions.write().await;
392 let session = sessions.get_mut(&session_id).ok_or_else(|| {
393 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
394 })?;
395
396 session.add_temp_table(info);
397 Ok(())
398 }
399
400 pub async fn get_session(&self, session_id: &Uuid) -> Option<SessionState> {
402 self.sessions.read().await.get(session_id).cloned()
403 }
404
405 pub async fn close_session(&self, session_id: &Uuid) {
407 self.sessions.write().await.remove(session_id);
408 tracing::debug!("Closed session {:?}", session_id);
409 }
410
411 pub async fn migrate_session(
413 &self,
414 session_id: Uuid,
415 target_node: NodeId,
416 ) -> Result<SessionMigrateResult> {
417 let start = std::time::Instant::now();
418
419 let session = self.get_session(&session_id).await.ok_or_else(|| {
420 ProxyError::SessionMigration(format!("Session {:?} not found", session_id))
421 })?;
422
423 let statements = session.generate_restore_statements();
425
426 let mut parameters_restored = 0;
428 let mut prepared_statements_restored = 0;
429
430 for stmt in &statements {
431 match self.execute_statement(target_node, stmt).await {
432 Ok(()) => {
433 if stmt.starts_with("SET ") {
434 parameters_restored += 1;
435 } else if stmt.starts_with("PREPARE ") {
436 prepared_statements_restored += 1;
437 }
438 }
439 Err(e) => {
440 tracing::warn!("Failed to execute restore statement: {} - {}", stmt, e);
441 }
442 }
443 }
444
445 let mut temp_tables_migrated = 0;
447 let mut temp_tables_failed = 0;
448
449 if self.migrate_temp_tables {
450 for table in &session.temp_tables {
451 match self.migrate_temp_table(target_node, table).await {
452 Ok(()) => temp_tables_migrated += 1,
453 Err(e) => {
454 temp_tables_failed += 1;
455 tracing::warn!("Failed to migrate temp table {}: {}", table.name, e);
456 }
457 }
458 }
459 }
460
461 {
463 let mut sessions = self.sessions.write().await;
464 if let Some(s) = sessions.get_mut(&session_id) {
465 s.original_node = target_node;
466 s.last_activity = chrono::Utc::now();
467 }
468 }
469
470 let duration_ms = start.elapsed().as_millis() as u64;
471
472 tracing::info!(
473 "Migrated session {:?} to node {:?}: {} params, {} prepared, {}ms",
474 session_id,
475 target_node,
476 parameters_restored,
477 prepared_statements_restored,
478 duration_ms
479 );
480
481 Ok(SessionMigrateResult {
482 session_id,
483 success: true,
484 target_node,
485 parameters_restored,
486 prepared_statements_restored,
487 temp_tables_migrated,
488 temp_tables_failed,
489 duration_ms,
490 error: None,
491 })
492 }
493
494 async fn execute_statement(&self, node: NodeId, stmt: &str) -> Result<()> {
501 let endpoint = self.endpoints.read().await.get(&node).cloned();
502 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
503 Some(c) => c,
504 None => {
505 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
506 return Ok(());
507 }
508 };
509
510 let mut client = crate::backend::BackendClient::connect(&cfg)
511 .await
512 .map_err(|e| ProxyError::SessionMigration(format!("connect: {}", e)))?;
513 let outcome = client.execute(stmt).await;
514 client.close().await;
515 outcome
516 .map(|_| ())
517 .map_err(|e| ProxyError::SessionMigration(format!("execute: {}", e)))
518 }
519
520 async fn migrate_temp_table(&self, node: NodeId, table: &TempTableInfo) -> Result<()> {
529 let endpoint = self.endpoints.read().await.get(&node).cloned();
530 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
531 Some(c) => c,
532 None => {
533 tracing::debug!(
534 table = %table.name,
535 "migrate_temp_table: skeleton path (no backend template)"
536 );
537 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
538 return Ok(());
539 }
540 };
541
542 let mut stmt = String::with_capacity(64 + table.name.len());
547 stmt.push_str("CREATE TEMP TABLE IF NOT EXISTS ");
548 stmt.push_str("e_session_ident(&table.name));
549 stmt.push_str(" (");
550 for (i, col) in table.columns.iter().enumerate() {
551 if i > 0 {
552 stmt.push_str(", ");
553 }
554 stmt.push_str("e_session_ident(&col.name));
555 stmt.push(' ');
556 stmt.push_str(&col.data_type);
557 if !col.nullable {
558 stmt.push_str(" NOT NULL");
559 }
560 if let Some(default) = &col.default_expr {
561 stmt.push_str(" DEFAULT ");
562 stmt.push_str(default);
563 }
564 }
565 stmt.push(')');
566
567 let mut client = crate::backend::BackendClient::connect(&cfg)
568 .await
569 .map_err(|e| ProxyError::SessionMigration(format!("connect: {}", e)))?;
570 let outcome = client.execute(&stmt).await;
571 client.close().await;
572 outcome
573 .map(|_| ())
574 .map_err(|e| ProxyError::SessionMigration(format!("create temp table: {}", e)))?;
575
576 if table.has_data {
577 tracing::warn!(
578 table = %table.name,
579 "temp table has data but migration intentionally does not copy it — route writes through the journal and use failover replay"
580 );
581 }
582 Ok(())
583 }
584
585 pub async fn stats(&self) -> SessionMigrateStats {
587 let sessions = self.sessions.read().await;
588
589 let total_prepared: usize = sessions.values().map(|s| s.prepared_statements.len()).sum();
590
591 let total_temp_tables: usize = sessions.values().map(|s| s.temp_tables.len()).sum();
592
593 SessionMigrateStats {
594 active_sessions: sessions.len(),
595 total_prepared_statements: total_prepared,
596 total_temp_tables,
597 enabled: self.enabled,
598 temp_table_migration_enabled: self.migrate_temp_tables,
599 }
600 }
601}
602
603impl Default for SessionMigrate {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609#[derive(Debug, Clone)]
611pub struct SessionMigrateStats {
612 pub active_sessions: usize,
614 pub total_prepared_statements: usize,
616 pub total_temp_tables: usize,
618 pub enabled: bool,
620 pub temp_table_migration_enabled: bool,
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn test_session_state_new() {
630 let session_id = Uuid::new_v4();
631 let node_id = NodeId::new();
632 let state = SessionState::new(session_id, "user".to_string(), "db".to_string(), node_id);
633
634 assert_eq!(state.user, "user");
635 assert_eq!(state.database, "db");
636 assert_eq!(state.timezone, "UTC");
637 assert_eq!(state.search_path, vec!["public"]);
638 }
639
640 #[test]
641 fn test_set_get_parameter() {
642 let mut state = SessionState::new(
643 Uuid::new_v4(),
644 "user".to_string(),
645 "db".to_string(),
646 NodeId::new(),
647 );
648
649 state.set_parameter("timezone".to_string(), "America/New_York".to_string());
650 assert_eq!(
651 state.get_parameter("timezone"),
652 Some("America/New_York".to_string())
653 );
654
655 state.set_parameter("custom_param".to_string(), "custom_value".to_string());
656 assert_eq!(
657 state.get_parameter("custom_param"),
658 Some("custom_value".to_string())
659 );
660 }
661
662 #[test]
663 fn test_generate_restore_statements() {
664 let mut state = SessionState::new(
665 Uuid::new_v4(),
666 "user".to_string(),
667 "db".to_string(),
668 NodeId::new(),
669 );
670
671 state.set_parameter("timezone".to_string(), "UTC".to_string());
672 state.add_prepared_statement(PreparedStatementInfo {
673 name: "my_query".to_string(),
674 query: "SELECT * FROM users WHERE id = $1".to_string(),
675 param_types: vec!["integer".to_string()],
676 created_at: chrono::Utc::now(),
677 });
678
679 let statements = state.generate_restore_statements();
680
681 assert!(statements.iter().any(|s| s.contains("timezone")));
682 assert!(statements.iter().any(|s| s.contains("PREPARE my_query")));
683 }
684
685 #[tokio::test]
686 async fn test_register_session() {
687 let migrate = SessionMigrate::new();
688 let session_id = Uuid::new_v4();
689 let state = SessionState::new(
690 session_id,
691 "user".to_string(),
692 "db".to_string(),
693 NodeId::new(),
694 );
695
696 migrate.register_session(state).await.unwrap();
697
698 let session = migrate.get_session(&session_id).await;
699 assert!(session.is_some());
700 }
701
702 #[tokio::test]
703 async fn test_set_parameter() {
704 let migrate = SessionMigrate::new();
705 let session_id = Uuid::new_v4();
706 let state = SessionState::new(
707 session_id,
708 "user".to_string(),
709 "db".to_string(),
710 NodeId::new(),
711 );
712
713 migrate.register_session(state).await.unwrap();
714 migrate
715 .set_parameter(
716 session_id,
717 "timezone".to_string(),
718 "Europe/London".to_string(),
719 )
720 .await
721 .unwrap();
722
723 let session = migrate.get_session(&session_id).await.unwrap();
724 assert_eq!(session.timezone, "Europe/London");
725 }
726
727 #[tokio::test]
728 async fn test_migrate_session() {
729 let migrate = SessionMigrate::new();
730 let session_id = Uuid::new_v4();
731 let state = SessionState::new(
732 session_id,
733 "user".to_string(),
734 "db".to_string(),
735 NodeId::new(),
736 );
737
738 migrate.register_session(state).await.unwrap();
739
740 let target = NodeId::new();
741 let result = migrate.migrate_session(session_id, target).await.unwrap();
742
743 assert!(result.success);
744 assert!(result.parameters_restored > 0);
745 }
746
747 #[tokio::test]
748 async fn test_close_session() {
749 let migrate = SessionMigrate::new();
750 let session_id = Uuid::new_v4();
751 let state = SessionState::new(
752 session_id,
753 "user".to_string(),
754 "db".to_string(),
755 NodeId::new(),
756 );
757
758 migrate.register_session(state).await.unwrap();
759 migrate.close_session(&session_id).await;
760
761 assert!(migrate.get_session(&session_id).await.is_none());
762 }
763
764 #[tokio::test]
765 async fn test_stats() {
766 let migrate = SessionMigrate::new();
767 let session_id = Uuid::new_v4();
768 let mut state = SessionState::new(
769 session_id,
770 "user".to_string(),
771 "db".to_string(),
772 NodeId::new(),
773 );
774
775 state.add_prepared_statement(PreparedStatementInfo {
776 name: "ps1".to_string(),
777 query: "SELECT 1".to_string(),
778 param_types: vec![],
779 created_at: chrono::Utc::now(),
780 });
781
782 migrate.register_session(state).await.unwrap();
783
784 let stats = migrate.stats().await;
785 assert_eq!(stats.active_sessions, 1);
786 assert_eq!(stats.total_prepared_statements, 1);
787 }
788}