1use super::{NodeEndpoint, NodeId, ProxyError, Result};
7use crate::backend::{BackendClient, BackendConfig};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
15pub struct CursorState {
16 pub name: String,
18 pub session_id: Uuid,
20 pub query: String,
22 pub parameters: Vec<CursorParam>,
24 pub total_rows: Option<u64>,
26 pub position: u64,
28 pub scrollable: bool,
30 pub with_hold: bool,
32 pub direction: CursorDirection,
34 pub fetch_size: u32,
36 pub created_at: chrono::DateTime<chrono::Utc>,
38 pub last_fetch: Option<chrono::DateTime<chrono::Utc>>,
40 pub closed: bool,
42}
43
44#[derive(Debug, Clone)]
46pub struct CursorParam {
47 pub index: u32,
49 pub value: Vec<u8>,
51 pub type_name: String,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum CursorDirection {
58 Forward,
60 Backward,
62 Both,
64}
65
66#[derive(Debug, Clone)]
68pub struct CursorRestoreResult {
69 pub name: String,
71 pub success: bool,
73 pub recreated: bool,
75 pub rows_skipped: u64,
77 pub duration_ms: u64,
79 pub error: Option<String>,
81}
82
83pub struct CursorRestore {
85 cursors: Arc<RwLock<HashMap<String, CursorState>>>,
87 session_cursors: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
89 max_cursors_per_session: usize,
91 enabled: bool,
93 backend_template: Option<BackendConfig>,
98 endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
100}
101
102impl CursorRestore {
103 pub fn new() -> Self {
105 Self {
106 cursors: Arc::new(RwLock::new(HashMap::new())),
107 session_cursors: Arc::new(RwLock::new(HashMap::new())),
108 max_cursors_per_session: 100,
109 enabled: true,
110 backend_template: None,
111 endpoints: Arc::new(RwLock::new(HashMap::new())),
112 }
113 }
114
115 pub fn with_max_cursors(mut self, max: usize) -> Self {
117 self.max_cursors_per_session = max;
118 self
119 }
120
121 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
124 self.backend_template = Some(template);
125 self
126 }
127
128 pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
131 self.endpoints.write().await.insert(node_id, endpoint);
132 }
133
134 fn build_config(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
135 self.backend_template.as_ref().map(|t| {
136 let mut c = t.clone();
137 c.host = endpoint.host.clone();
138 c.port = endpoint.port;
139 c
140 })
141 }
142
143 pub fn set_enabled(&mut self, enabled: bool) {
145 self.enabled = enabled;
146 }
147
148 pub async fn save_cursor(&self, state: CursorState) -> Result<()> {
150 if !self.enabled {
151 return Ok(());
152 }
153
154 let session_id = state.session_id;
155 let cursor_name = state.name.clone();
156
157 {
159 let session_cursors = self.session_cursors.read().await;
160 if let Some(cursors) = session_cursors.get(&session_id) {
161 if cursors.len() >= self.max_cursors_per_session && !cursors.contains(&cursor_name)
162 {
163 return Err(ProxyError::CursorRestore(format!(
164 "Maximum cursors ({}) per session exceeded",
165 self.max_cursors_per_session
166 )));
167 }
168 }
169 }
170
171 self.cursors
173 .write()
174 .await
175 .insert(cursor_name.clone(), state);
176
177 self.session_cursors
179 .write()
180 .await
181 .entry(session_id)
182 .or_default()
183 .push(cursor_name.clone());
184
185 tracing::debug!("Saved cursor state: {}", cursor_name);
186
187 Ok(())
188 }
189
190 pub async fn update_position(&self, cursor_name: &str, new_position: u64) -> Result<()> {
192 if !self.enabled {
193 return Ok(());
194 }
195
196 let mut cursors = self.cursors.write().await;
197 let cursor = cursors.get_mut(cursor_name).ok_or_else(|| {
198 ProxyError::CursorRestore(format!("Cursor '{}' not found", cursor_name))
199 })?;
200
201 cursor.position = new_position;
202 cursor.last_fetch = Some(chrono::Utc::now());
203
204 Ok(())
205 }
206
207 pub async fn close_cursor(&self, cursor_name: &str) -> Result<()> {
209 let mut cursors = self.cursors.write().await;
210
211 if let Some(cursor) = cursors.get_mut(cursor_name) {
212 cursor.closed = true;
213
214 let session_id = cursor.session_id;
216 drop(cursors);
217
218 self.session_cursors
219 .write()
220 .await
221 .entry(session_id)
222 .and_modify(|v| v.retain(|n| n != cursor_name));
223
224 self.cursors.write().await.remove(cursor_name);
225
226 tracing::debug!("Closed cursor: {}", cursor_name);
227 }
228
229 Ok(())
230 }
231
232 pub async fn get_cursor(&self, cursor_name: &str) -> Option<CursorState> {
234 self.cursors.read().await.get(cursor_name).cloned()
235 }
236
237 pub async fn get_session_cursors(&self, session_id: &Uuid) -> Vec<CursorState> {
239 let session_cursors = self.session_cursors.read().await;
240 let cursor_names = match session_cursors.get(session_id) {
241 Some(names) => names.clone(),
242 None => return vec![],
243 };
244 drop(session_cursors);
245
246 let cursors = self.cursors.read().await;
247 cursor_names
248 .iter()
249 .filter_map(|name| cursors.get(name).cloned())
250 .collect()
251 }
252
253 pub async fn restore_cursor(
255 &self,
256 cursor_name: &str,
257 target_node: NodeId,
258 ) -> Result<CursorRestoreResult> {
259 let start = std::time::Instant::now();
260
261 let cursor = self.get_cursor(cursor_name).await.ok_or_else(|| {
262 ProxyError::CursorRestore(format!("Cursor '{}' not found", cursor_name))
263 })?;
264
265 if cursor.closed {
266 return Err(ProxyError::CursorRestore(format!(
267 "Cursor '{}' is already closed",
268 cursor_name
269 )));
270 }
271
272 let rows_to_skip = cursor.position;
279 let result = self
280 .recreate_cursor(&cursor, target_node, rows_to_skip)
281 .await;
282
283 let duration_ms = start.elapsed().as_millis() as u64;
284
285 match result {
286 Ok(()) => {
287 tracing::info!(
288 "Restored cursor '{}' on node {:?}, skipped {} rows in {}ms",
289 cursor_name,
290 target_node,
291 rows_to_skip,
292 duration_ms
293 );
294
295 Ok(CursorRestoreResult {
296 name: cursor_name.to_string(),
297 success: true,
298 recreated: true,
299 rows_skipped: rows_to_skip,
300 duration_ms,
301 error: None,
302 })
303 }
304 Err(e) => {
305 tracing::error!("Failed to restore cursor '{}': {}", cursor_name, e);
306
307 Ok(CursorRestoreResult {
308 name: cursor_name.to_string(),
309 success: false,
310 recreated: false,
311 rows_skipped: 0,
312 duration_ms,
313 error: Some(e.to_string()),
314 })
315 }
316 }
317 }
318
319 async fn recreate_cursor(
341 &self,
342 cursor: &CursorState,
343 target_node: NodeId,
344 skip_rows: u64,
345 ) -> Result<()> {
346 let endpoint = self.endpoints.read().await.get(&target_node).cloned();
347 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
348 Some(c) => c,
349 None => {
350 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
351 return Ok(());
352 }
353 };
354
355 let mut client = BackendClient::connect(&cfg)
356 .await
357 .map_err(|e| ProxyError::CursorRestore(format!("connect: {}", e)))?;
358
359 let interpolated_query = interpolate_cursor_params(&cursor.query, &cursor.parameters)?;
361
362 let scroll = match cursor.direction {
363 CursorDirection::Forward => "NO SCROLL",
364 CursorDirection::Backward | CursorDirection::Both => "SCROLL",
365 };
366 let with_hold = if cursor.with_hold { "WITH HOLD" } else { "" };
367
368 if !cursor.with_hold {
369 client
371 .execute("BEGIN")
372 .await
373 .map_err(|e| ProxyError::CursorRestore(format!("BEGIN: {}", e)))?;
374 }
375
376 let declare = format!(
377 "DECLARE {} {} CURSOR {} FOR {}",
378 quote_ident(&cursor.name),
379 scroll,
380 with_hold,
381 interpolated_query
382 );
383 client
384 .execute(&declare)
385 .await
386 .map_err(|e| ProxyError::CursorRestore(format!("DECLARE: {}", e)))?;
387
388 if skip_rows > 0 {
389 let move_sql = format!(
390 "MOVE FORWARD {} IN {}",
391 skip_rows,
392 quote_ident(&cursor.name)
393 );
394 client
395 .execute(&move_sql)
396 .await
397 .map_err(|e| ProxyError::CursorRestore(format!("MOVE: {}", e)))?;
398 }
399
400 client.close().await;
401 Ok(())
402 }
403
404 pub async fn restore_session_cursors(
406 &self,
407 session_id: &Uuid,
408 target_node: NodeId,
409 ) -> Vec<CursorRestoreResult> {
410 let cursors = self.get_session_cursors(session_id).await;
411 let mut results = Vec::new();
412
413 for cursor in cursors {
414 if !cursor.closed {
415 match self.restore_cursor(&cursor.name, target_node).await {
416 Ok(result) => results.push(result),
417 Err(e) => results.push(CursorRestoreResult {
418 name: cursor.name,
419 success: false,
420 recreated: false,
421 rows_skipped: 0,
422 duration_ms: 0,
423 error: Some(e.to_string()),
424 }),
425 }
426 }
427 }
428
429 results
430 }
431
432 pub async fn clear_session(&self, session_id: &Uuid) {
434 let cursor_names = {
436 let mut session_cursors = self.session_cursors.write().await;
437 session_cursors.remove(session_id).unwrap_or_default()
438 };
439
440 let mut cursors = self.cursors.write().await;
442 for name in cursor_names {
443 cursors.remove(&name);
444 }
445
446 tracing::debug!("Cleared cursors for session {:?}", session_id);
447 }
448
449 pub async fn stats(&self) -> CursorRestoreStats {
451 let cursors = self.cursors.read().await;
452 let sessions = self.session_cursors.read().await;
453
454 CursorRestoreStats {
455 total_cursors: cursors.len(),
456 active_cursors: cursors.values().filter(|c| !c.closed).count(),
457 sessions_with_cursors: sessions.len(),
458 enabled: self.enabled,
459 }
460 }
461}
462
463impl Default for CursorRestore {
464 fn default() -> Self {
465 Self::new()
466 }
467}
468
469fn quote_ident(name: &str) -> String {
472 let mut out = String::with_capacity(name.len() + 2);
473 out.push('"');
474 for ch in name.chars() {
475 if ch == '"' {
476 out.push_str("\"\"");
477 } else {
478 out.push(ch);
479 }
480 }
481 out.push('"');
482 out
483}
484
485fn interpolate_cursor_params(query: &str, params: &[CursorParam]) -> Result<String> {
489 let mut sorted: Vec<&CursorParam> = params.iter().collect();
491 sorted.sort_by_key(|p| p.index);
492 for (i, p) in sorted.iter().enumerate() {
493 if p.index as usize != i + 1 {
494 return Err(ProxyError::CursorRestore(format!(
495 "cursor params are not dense 1..N (got index {} at position {})",
496 p.index,
497 i + 1
498 )));
499 }
500 }
501
502 let literals: Vec<String> = sorted
504 .iter()
505 .map(|p| {
506 match std::str::from_utf8(&p.value) {
508 Ok(s) => {
509 let mut out = String::with_capacity(s.len() + 2);
510 out.push('\'');
511 for ch in s.chars() {
512 if ch == '\'' {
513 out.push_str("''");
514 } else {
515 out.push(ch);
516 }
517 }
518 out.push('\'');
519 out
520 }
521 Err(_) => {
522 let mut out = String::with_capacity(2 + p.value.len() * 2);
523 out.push_str("'\\x");
524 for byte in &p.value {
525 out.push_str(&format!("{:02x}", byte));
526 }
527 out.push('\'');
528 out
529 }
530 }
531 })
532 .collect();
533
534 let bytes = query.as_bytes();
536 let mut out = String::with_capacity(query.len());
537 let mut i = 0;
538 let mut in_string = false;
539 let mut quote = 0u8;
540 while i < bytes.len() {
541 let b = bytes[i];
542 if in_string {
543 out.push(b as char);
544 if b == quote {
545 if i + 1 < bytes.len() && bytes[i + 1] == quote {
546 out.push(quote as char);
547 i += 2;
548 continue;
549 }
550 in_string = false;
551 }
552 i += 1;
553 continue;
554 }
555 if b == b'\'' || b == b'"' {
556 in_string = true;
557 quote = b;
558 out.push(b as char);
559 i += 1;
560 continue;
561 }
562 if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
563 let mut j = i + 1;
564 while j < bytes.len() && bytes[j].is_ascii_digit() {
565 j += 1;
566 }
567 let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
568 .unwrap()
569 .parse()
570 .map_err(|_| {
571 ProxyError::CursorRestore(format!(
572 "invalid parameter reference near byte {}",
573 i
574 ))
575 })?;
576 if idx == 0 || idx > literals.len() {
577 return Err(ProxyError::CursorRestore(format!(
578 "parameter ${} out of range (have {})",
579 idx,
580 literals.len()
581 )));
582 }
583 out.push_str(&literals[idx - 1]);
584 i = j;
585 continue;
586 }
587 out.push(b as char);
588 i += 1;
589 }
590 Ok(out)
591}
592
593#[derive(Debug, Clone)]
595pub struct CursorRestoreStats {
596 pub total_cursors: usize,
598 pub active_cursors: usize,
600 pub sessions_with_cursors: usize,
602 pub enabled: bool,
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609
610 fn make_cursor_state(name: &str, session_id: Uuid) -> CursorState {
611 CursorState {
612 name: name.to_string(),
613 session_id,
614 query: "SELECT * FROM users".to_string(),
615 parameters: vec![],
616 total_rows: Some(1000),
617 position: 0,
618 scrollable: false,
619 with_hold: false,
620 direction: CursorDirection::Forward,
621 fetch_size: 100,
622 created_at: chrono::Utc::now(),
623 last_fetch: None,
624 closed: false,
625 }
626 }
627
628 #[tokio::test]
629 async fn test_save_cursor() {
630 let restore = CursorRestore::new();
631 let session_id = Uuid::new_v4();
632 let state = make_cursor_state("test_cursor", session_id);
633
634 restore.save_cursor(state).await.unwrap();
635
636 let cursor = restore.get_cursor("test_cursor").await;
637 assert!(cursor.is_some());
638 assert_eq!(cursor.unwrap().name, "test_cursor");
639 }
640
641 #[tokio::test]
642 async fn test_update_position() {
643 let restore = CursorRestore::new();
644 let session_id = Uuid::new_v4();
645 let state = make_cursor_state("test_cursor", session_id);
646
647 restore.save_cursor(state).await.unwrap();
648 restore.update_position("test_cursor", 500).await.unwrap();
649
650 let cursor = restore.get_cursor("test_cursor").await.unwrap();
651 assert_eq!(cursor.position, 500);
652 assert!(cursor.last_fetch.is_some());
653 }
654
655 #[tokio::test]
656 async fn test_close_cursor() {
657 let restore = CursorRestore::new();
658 let session_id = Uuid::new_v4();
659 let state = make_cursor_state("test_cursor", session_id);
660
661 restore.save_cursor(state).await.unwrap();
662 restore.close_cursor("test_cursor").await.unwrap();
663
664 assert!(restore.get_cursor("test_cursor").await.is_none());
665 }
666
667 #[tokio::test]
668 async fn test_get_session_cursors() {
669 let restore = CursorRestore::new();
670 let session_id = Uuid::new_v4();
671
672 restore
673 .save_cursor(make_cursor_state("cursor1", session_id))
674 .await
675 .unwrap();
676 restore
677 .save_cursor(make_cursor_state("cursor2", session_id))
678 .await
679 .unwrap();
680
681 let cursors = restore.get_session_cursors(&session_id).await;
682 assert_eq!(cursors.len(), 2);
683 }
684
685 #[tokio::test]
686 async fn test_clear_session() {
687 let restore = CursorRestore::new();
688 let session_id = Uuid::new_v4();
689
690 restore
691 .save_cursor(make_cursor_state("cursor1", session_id))
692 .await
693 .unwrap();
694 restore
695 .save_cursor(make_cursor_state("cursor2", session_id))
696 .await
697 .unwrap();
698
699 restore.clear_session(&session_id).await;
700
701 let cursors = restore.get_session_cursors(&session_id).await;
702 assert!(cursors.is_empty());
703 }
704
705 #[tokio::test]
706 async fn test_restore_cursor() {
707 let restore = CursorRestore::new();
708 let session_id = Uuid::new_v4();
709 let mut state = make_cursor_state("test_cursor", session_id);
710 state.position = 500;
711
712 restore.save_cursor(state).await.unwrap();
713
714 let target = NodeId::new();
715 let result = restore.restore_cursor("test_cursor", target).await.unwrap();
716
717 assert!(result.success);
718 assert!(result.recreated);
719 assert_eq!(result.rows_skipped, 500);
720 }
721
722 #[tokio::test]
723 async fn test_stats() {
724 let restore = CursorRestore::new();
725 let session_id = Uuid::new_v4();
726
727 restore
728 .save_cursor(make_cursor_state("cursor1", session_id))
729 .await
730 .unwrap();
731
732 let stats = restore.stats().await;
733 assert_eq!(stats.total_cursors, 1);
734 assert_eq!(stats.active_cursors, 1);
735 assert_eq!(stats.sessions_with_cursors, 1);
736 }
737
738 #[test]
739 fn test_quote_ident_doubles_embedded_quotes() {
740 assert_eq!(quote_ident("users"), "\"users\"");
741 assert_eq!(quote_ident(r#"my"cursor"#), r#""my""cursor""#);
742 }
743
744 #[test]
745 fn test_interpolate_cursor_params_no_params() {
746 let out = interpolate_cursor_params("SELECT * FROM users", &[]).unwrap();
747 assert_eq!(out, "SELECT * FROM users");
748 }
749
750 #[test]
751 fn test_interpolate_cursor_params_utf8() {
752 let params = vec![
753 CursorParam {
754 index: 1,
755 value: b"alice".to_vec(),
756 type_name: "text".into(),
757 },
758 CursorParam {
759 index: 2,
760 value: b"42".to_vec(),
761 type_name: "int4".into(),
762 },
763 ];
764 let out =
765 interpolate_cursor_params("SELECT * FROM t WHERE name = $1 AND age = $2", ¶ms)
766 .unwrap();
767 assert_eq!(out, "SELECT * FROM t WHERE name = 'alice' AND age = '42'");
768 }
769
770 #[test]
771 fn test_interpolate_cursor_params_escapes_quote() {
772 let params = vec![CursorParam {
773 index: 1,
774 value: b"o'brien".to_vec(),
775 type_name: "text".into(),
776 }];
777 let out = interpolate_cursor_params("SELECT $1", ¶ms).unwrap();
778 assert_eq!(out, "SELECT 'o''brien'");
779 }
780
781 #[test]
782 fn test_interpolate_cursor_params_binary_hex() {
783 let params = vec![CursorParam {
784 index: 1,
785 value: vec![0xDE, 0xAD, 0xBE, 0xEF],
786 type_name: "bytea".into(),
787 }];
788 let out = interpolate_cursor_params("SELECT $1", ¶ms).unwrap();
789 assert!(out.starts_with("SELECT '") && out.ends_with('\''));
793 }
794
795 #[test]
796 fn test_interpolate_cursor_params_missing_index_rejected() {
797 let params = vec![CursorParam {
798 index: 2, value: b"x".to_vec(),
800 type_name: "text".into(),
801 }];
802 let err = interpolate_cursor_params("SELECT $1", ¶ms).unwrap_err();
803 assert!(matches!(err, ProxyError::CursorRestore(_)));
804 }
805
806 #[test]
807 fn test_interpolate_cursor_params_out_of_range() {
808 let params = vec![CursorParam {
809 index: 1,
810 value: b"a".to_vec(),
811 type_name: "text".into(),
812 }];
813 let err = interpolate_cursor_params("SELECT $2", ¶ms).unwrap_err();
814 assert!(matches!(err, ProxyError::CursorRestore(_)));
815 }
816}