1use super::{NodeEndpoint, NodeId, ProxyError, Result};
7use crate::backend::{BackendClient, BackendConfig};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
15pub struct CursorState {
16 pub name: String,
18 pub session_id: Uuid,
20 pub query: String,
22 pub parameters: Vec<CursorParam>,
24 pub total_rows: Option<u64>,
26 pub position: u64,
28 pub scrollable: bool,
30 pub with_hold: bool,
32 pub direction: CursorDirection,
34 pub fetch_size: u32,
36 pub created_at: chrono::DateTime<chrono::Utc>,
38 pub last_fetch: Option<chrono::DateTime<chrono::Utc>>,
40 pub closed: bool,
42}
43
44#[derive(Debug, Clone)]
46pub struct CursorParam {
47 pub index: u32,
49 pub value: Vec<u8>,
51 pub type_name: String,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum CursorDirection {
58 Forward,
60 Backward,
62 Both,
64}
65
66#[derive(Debug, Clone)]
68pub struct CursorRestoreResult {
69 pub name: String,
71 pub success: bool,
73 pub recreated: bool,
75 pub rows_skipped: u64,
77 pub duration_ms: u64,
79 pub error: Option<String>,
81}
82
83pub struct CursorRestore {
85 cursors: Arc<RwLock<HashMap<String, CursorState>>>,
87 session_cursors: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
89 max_cursors_per_session: usize,
91 enabled: bool,
93 backend_template: Option<BackendConfig>,
98 endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
100}
101
102impl CursorRestore {
103 pub fn new() -> Self {
105 Self {
106 cursors: Arc::new(RwLock::new(HashMap::new())),
107 session_cursors: Arc::new(RwLock::new(HashMap::new())),
108 max_cursors_per_session: 100,
109 enabled: true,
110 backend_template: None,
111 endpoints: Arc::new(RwLock::new(HashMap::new())),
112 }
113 }
114
115 pub fn with_max_cursors(mut self, max: usize) -> Self {
117 self.max_cursors_per_session = max;
118 self
119 }
120
121 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
124 self.backend_template = Some(template);
125 self
126 }
127
128 pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
131 self.endpoints.write().await.insert(node_id, endpoint);
132 }
133
134 fn build_config(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
135 self.backend_template.as_ref().map(|t| {
136 let mut c = t.clone();
137 c.host = endpoint.host.clone();
138 c.port = endpoint.port;
139 c
140 })
141 }
142
143 pub fn set_enabled(&mut self, enabled: bool) {
145 self.enabled = enabled;
146 }
147
148 pub async fn save_cursor(&self, state: CursorState) -> Result<()> {
150 if !self.enabled {
151 return Ok(());
152 }
153
154 let session_id = state.session_id;
155 let cursor_name = state.name.clone();
156
157 {
159 let session_cursors = self.session_cursors.read().await;
160 if let Some(cursors) = session_cursors.get(&session_id) {
161 if cursors.len() >= self.max_cursors_per_session
162 && !cursors.contains(&cursor_name)
163 {
164 return Err(ProxyError::CursorRestore(format!(
165 "Maximum cursors ({}) per session exceeded",
166 self.max_cursors_per_session
167 )));
168 }
169 }
170 }
171
172 self.cursors.write().await.insert(cursor_name.clone(), state);
174
175 self.session_cursors
177 .write()
178 .await
179 .entry(session_id)
180 .or_default()
181 .push(cursor_name.clone());
182
183 tracing::debug!("Saved cursor state: {}", cursor_name);
184
185 Ok(())
186 }
187
188 pub async fn update_position(&self, cursor_name: &str, new_position: u64) -> Result<()> {
190 if !self.enabled {
191 return Ok(());
192 }
193
194 let mut cursors = self.cursors.write().await;
195 let cursor = cursors.get_mut(cursor_name).ok_or_else(|| {
196 ProxyError::CursorRestore(format!("Cursor '{}' not found", cursor_name))
197 })?;
198
199 cursor.position = new_position;
200 cursor.last_fetch = Some(chrono::Utc::now());
201
202 Ok(())
203 }
204
205 pub async fn close_cursor(&self, cursor_name: &str) -> Result<()> {
207 let mut cursors = self.cursors.write().await;
208
209 if let Some(cursor) = cursors.get_mut(cursor_name) {
210 cursor.closed = true;
211
212 let session_id = cursor.session_id;
214 drop(cursors);
215
216 self.session_cursors
217 .write()
218 .await
219 .entry(session_id)
220 .and_modify(|v| v.retain(|n| n != cursor_name));
221
222 self.cursors.write().await.remove(cursor_name);
223
224 tracing::debug!("Closed cursor: {}", cursor_name);
225 }
226
227 Ok(())
228 }
229
230 pub async fn get_cursor(&self, cursor_name: &str) -> Option<CursorState> {
232 self.cursors.read().await.get(cursor_name).cloned()
233 }
234
235 pub async fn get_session_cursors(&self, session_id: &Uuid) -> Vec<CursorState> {
237 let session_cursors = self.session_cursors.read().await;
238 let cursor_names = match session_cursors.get(session_id) {
239 Some(names) => names.clone(),
240 None => return vec![],
241 };
242 drop(session_cursors);
243
244 let cursors = self.cursors.read().await;
245 cursor_names
246 .iter()
247 .filter_map(|name| cursors.get(name).cloned())
248 .collect()
249 }
250
251 pub async fn restore_cursor(
253 &self,
254 cursor_name: &str,
255 target_node: NodeId,
256 ) -> Result<CursorRestoreResult> {
257 let start = std::time::Instant::now();
258
259 let cursor = self.get_cursor(cursor_name).await.ok_or_else(|| {
260 ProxyError::CursorRestore(format!("Cursor '{}' not found", cursor_name))
261 })?;
262
263 if cursor.closed {
264 return Err(ProxyError::CursorRestore(format!(
265 "Cursor '{}' is already closed",
266 cursor_name
267 )));
268 }
269
270 let rows_to_skip = cursor.position;
277 let result = self.recreate_cursor(&cursor, target_node, rows_to_skip).await;
278
279 let duration_ms = start.elapsed().as_millis() as u64;
280
281 match result {
282 Ok(()) => {
283 tracing::info!(
284 "Restored cursor '{}' on node {:?}, skipped {} rows in {}ms",
285 cursor_name,
286 target_node,
287 rows_to_skip,
288 duration_ms
289 );
290
291 Ok(CursorRestoreResult {
292 name: cursor_name.to_string(),
293 success: true,
294 recreated: true,
295 rows_skipped: rows_to_skip,
296 duration_ms,
297 error: None,
298 })
299 }
300 Err(e) => {
301 tracing::error!("Failed to restore cursor '{}': {}", cursor_name, e);
302
303 Ok(CursorRestoreResult {
304 name: cursor_name.to_string(),
305 success: false,
306 recreated: false,
307 rows_skipped: 0,
308 duration_ms,
309 error: Some(e.to_string()),
310 })
311 }
312 }
313 }
314
315 async fn recreate_cursor(
337 &self,
338 cursor: &CursorState,
339 target_node: NodeId,
340 skip_rows: u64,
341 ) -> Result<()> {
342 let endpoint = self.endpoints.read().await.get(&target_node).cloned();
343 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
344 Some(c) => c,
345 None => {
346 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
347 return Ok(());
348 }
349 };
350
351 let mut client = BackendClient::connect(&cfg).await.map_err(|e| {
352 ProxyError::CursorRestore(format!("connect: {}", e))
353 })?;
354
355 let interpolated_query = interpolate_cursor_params(&cursor.query, &cursor.parameters)?;
357
358 let scroll = match cursor.direction {
359 CursorDirection::Forward => "NO SCROLL",
360 CursorDirection::Backward | CursorDirection::Both => "SCROLL",
361 };
362 let with_hold = if cursor.with_hold { "WITH HOLD" } else { "" };
363
364 if !cursor.with_hold {
365 client.execute("BEGIN").await.map_err(|e| {
367 ProxyError::CursorRestore(format!("BEGIN: {}", e))
368 })?;
369 }
370
371 let declare = format!(
372 "DECLARE {} {} CURSOR {} FOR {}",
373 quote_ident(&cursor.name),
374 scroll,
375 with_hold,
376 interpolated_query
377 );
378 client.execute(&declare).await.map_err(|e| {
379 ProxyError::CursorRestore(format!("DECLARE: {}", e))
380 })?;
381
382 if skip_rows > 0 {
383 let move_sql = format!(
384 "MOVE FORWARD {} IN {}",
385 skip_rows,
386 quote_ident(&cursor.name)
387 );
388 client.execute(&move_sql).await.map_err(|e| {
389 ProxyError::CursorRestore(format!("MOVE: {}", e))
390 })?;
391 }
392
393 client.close().await;
394 Ok(())
395 }
396
397 pub async fn restore_session_cursors(
399 &self,
400 session_id: &Uuid,
401 target_node: NodeId,
402 ) -> Vec<CursorRestoreResult> {
403 let cursors = self.get_session_cursors(session_id).await;
404 let mut results = Vec::new();
405
406 for cursor in cursors {
407 if !cursor.closed {
408 match self.restore_cursor(&cursor.name, target_node).await {
409 Ok(result) => results.push(result),
410 Err(e) => results.push(CursorRestoreResult {
411 name: cursor.name,
412 success: false,
413 recreated: false,
414 rows_skipped: 0,
415 duration_ms: 0,
416 error: Some(e.to_string()),
417 }),
418 }
419 }
420 }
421
422 results
423 }
424
425 pub async fn clear_session(&self, session_id: &Uuid) {
427 let cursor_names = {
429 let mut session_cursors = self.session_cursors.write().await;
430 session_cursors.remove(session_id).unwrap_or_default()
431 };
432
433 let mut cursors = self.cursors.write().await;
435 for name in cursor_names {
436 cursors.remove(&name);
437 }
438
439 tracing::debug!("Cleared cursors for session {:?}", session_id);
440 }
441
442 pub async fn stats(&self) -> CursorRestoreStats {
444 let cursors = self.cursors.read().await;
445 let sessions = self.session_cursors.read().await;
446
447 CursorRestoreStats {
448 total_cursors: cursors.len(),
449 active_cursors: cursors.values().filter(|c| !c.closed).count(),
450 sessions_with_cursors: sessions.len(),
451 enabled: self.enabled,
452 }
453 }
454}
455
456impl Default for CursorRestore {
457 fn default() -> Self {
458 Self::new()
459 }
460}
461
462fn quote_ident(name: &str) -> String {
465 let mut out = String::with_capacity(name.len() + 2);
466 out.push('"');
467 for ch in name.chars() {
468 if ch == '"' {
469 out.push_str("\"\"");
470 } else {
471 out.push(ch);
472 }
473 }
474 out.push('"');
475 out
476}
477
478fn interpolate_cursor_params(
482 query: &str,
483 params: &[CursorParam],
484) -> Result<String> {
485 let mut sorted: Vec<&CursorParam> = params.iter().collect();
487 sorted.sort_by_key(|p| p.index);
488 for (i, p) in sorted.iter().enumerate() {
489 if p.index as usize != i + 1 {
490 return Err(ProxyError::CursorRestore(format!(
491 "cursor params are not dense 1..N (got index {} at position {})",
492 p.index,
493 i + 1
494 )));
495 }
496 }
497
498 let literals: Vec<String> = sorted
500 .iter()
501 .map(|p| {
502 match std::str::from_utf8(&p.value) {
504 Ok(s) => {
505 let mut out = String::with_capacity(s.len() + 2);
506 out.push('\'');
507 for ch in s.chars() {
508 if ch == '\'' {
509 out.push_str("''");
510 } else {
511 out.push(ch);
512 }
513 }
514 out.push('\'');
515 out
516 }
517 Err(_) => {
518 let mut out = String::with_capacity(2 + p.value.len() * 2);
519 out.push_str("'\\x");
520 for byte in &p.value {
521 out.push_str(&format!("{:02x}", byte));
522 }
523 out.push('\'');
524 out
525 }
526 }
527 })
528 .collect();
529
530 let bytes = query.as_bytes();
532 let mut out = String::with_capacity(query.len());
533 let mut i = 0;
534 let mut in_string = false;
535 let mut quote = 0u8;
536 while i < bytes.len() {
537 let b = bytes[i];
538 if in_string {
539 out.push(b as char);
540 if b == quote {
541 if i + 1 < bytes.len() && bytes[i + 1] == quote {
542 out.push(quote as char);
543 i += 2;
544 continue;
545 }
546 in_string = false;
547 }
548 i += 1;
549 continue;
550 }
551 if b == b'\'' || b == b'"' {
552 in_string = true;
553 quote = b;
554 out.push(b as char);
555 i += 1;
556 continue;
557 }
558 if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
559 let mut j = i + 1;
560 while j < bytes.len() && bytes[j].is_ascii_digit() {
561 j += 1;
562 }
563 let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
564 .unwrap()
565 .parse()
566 .map_err(|_| {
567 ProxyError::CursorRestore(format!(
568 "invalid parameter reference near byte {}",
569 i
570 ))
571 })?;
572 if idx == 0 || idx > literals.len() {
573 return Err(ProxyError::CursorRestore(format!(
574 "parameter ${} out of range (have {})",
575 idx,
576 literals.len()
577 )));
578 }
579 out.push_str(&literals[idx - 1]);
580 i = j;
581 continue;
582 }
583 out.push(b as char);
584 i += 1;
585 }
586 Ok(out)
587}
588
589#[derive(Debug, Clone)]
591pub struct CursorRestoreStats {
592 pub total_cursors: usize,
594 pub active_cursors: usize,
596 pub sessions_with_cursors: usize,
598 pub enabled: bool,
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605
606 fn make_cursor_state(name: &str, session_id: Uuid) -> CursorState {
607 CursorState {
608 name: name.to_string(),
609 session_id,
610 query: "SELECT * FROM users".to_string(),
611 parameters: vec![],
612 total_rows: Some(1000),
613 position: 0,
614 scrollable: false,
615 with_hold: false,
616 direction: CursorDirection::Forward,
617 fetch_size: 100,
618 created_at: chrono::Utc::now(),
619 last_fetch: None,
620 closed: false,
621 }
622 }
623
624 #[tokio::test]
625 async fn test_save_cursor() {
626 let restore = CursorRestore::new();
627 let session_id = Uuid::new_v4();
628 let state = make_cursor_state("test_cursor", session_id);
629
630 restore.save_cursor(state).await.unwrap();
631
632 let cursor = restore.get_cursor("test_cursor").await;
633 assert!(cursor.is_some());
634 assert_eq!(cursor.unwrap().name, "test_cursor");
635 }
636
637 #[tokio::test]
638 async fn test_update_position() {
639 let restore = CursorRestore::new();
640 let session_id = Uuid::new_v4();
641 let state = make_cursor_state("test_cursor", session_id);
642
643 restore.save_cursor(state).await.unwrap();
644 restore.update_position("test_cursor", 500).await.unwrap();
645
646 let cursor = restore.get_cursor("test_cursor").await.unwrap();
647 assert_eq!(cursor.position, 500);
648 assert!(cursor.last_fetch.is_some());
649 }
650
651 #[tokio::test]
652 async fn test_close_cursor() {
653 let restore = CursorRestore::new();
654 let session_id = Uuid::new_v4();
655 let state = make_cursor_state("test_cursor", session_id);
656
657 restore.save_cursor(state).await.unwrap();
658 restore.close_cursor("test_cursor").await.unwrap();
659
660 assert!(restore.get_cursor("test_cursor").await.is_none());
661 }
662
663 #[tokio::test]
664 async fn test_get_session_cursors() {
665 let restore = CursorRestore::new();
666 let session_id = Uuid::new_v4();
667
668 restore.save_cursor(make_cursor_state("cursor1", session_id)).await.unwrap();
669 restore.save_cursor(make_cursor_state("cursor2", session_id)).await.unwrap();
670
671 let cursors = restore.get_session_cursors(&session_id).await;
672 assert_eq!(cursors.len(), 2);
673 }
674
675 #[tokio::test]
676 async fn test_clear_session() {
677 let restore = CursorRestore::new();
678 let session_id = Uuid::new_v4();
679
680 restore.save_cursor(make_cursor_state("cursor1", session_id)).await.unwrap();
681 restore.save_cursor(make_cursor_state("cursor2", session_id)).await.unwrap();
682
683 restore.clear_session(&session_id).await;
684
685 let cursors = restore.get_session_cursors(&session_id).await;
686 assert!(cursors.is_empty());
687 }
688
689 #[tokio::test]
690 async fn test_restore_cursor() {
691 let restore = CursorRestore::new();
692 let session_id = Uuid::new_v4();
693 let mut state = make_cursor_state("test_cursor", session_id);
694 state.position = 500;
695
696 restore.save_cursor(state).await.unwrap();
697
698 let target = NodeId::new();
699 let result = restore.restore_cursor("test_cursor", target).await.unwrap();
700
701 assert!(result.success);
702 assert!(result.recreated);
703 assert_eq!(result.rows_skipped, 500);
704 }
705
706 #[tokio::test]
707 async fn test_stats() {
708 let restore = CursorRestore::new();
709 let session_id = Uuid::new_v4();
710
711 restore.save_cursor(make_cursor_state("cursor1", session_id)).await.unwrap();
712
713 let stats = restore.stats().await;
714 assert_eq!(stats.total_cursors, 1);
715 assert_eq!(stats.active_cursors, 1);
716 assert_eq!(stats.sessions_with_cursors, 1);
717 }
718
719 #[test]
720 fn test_quote_ident_doubles_embedded_quotes() {
721 assert_eq!(quote_ident("users"), "\"users\"");
722 assert_eq!(quote_ident(r#"my"cursor"#), r#""my""cursor""#);
723 }
724
725 #[test]
726 fn test_interpolate_cursor_params_no_params() {
727 let out =
728 interpolate_cursor_params("SELECT * FROM users", &[]).unwrap();
729 assert_eq!(out, "SELECT * FROM users");
730 }
731
732 #[test]
733 fn test_interpolate_cursor_params_utf8() {
734 let params = vec![
735 CursorParam {
736 index: 1,
737 value: b"alice".to_vec(),
738 type_name: "text".into(),
739 },
740 CursorParam {
741 index: 2,
742 value: b"42".to_vec(),
743 type_name: "int4".into(),
744 },
745 ];
746 let out = interpolate_cursor_params(
747 "SELECT * FROM t WHERE name = $1 AND age = $2",
748 ¶ms,
749 )
750 .unwrap();
751 assert_eq!(out, "SELECT * FROM t WHERE name = 'alice' AND age = '42'");
752 }
753
754 #[test]
755 fn test_interpolate_cursor_params_escapes_quote() {
756 let params = vec![CursorParam {
757 index: 1,
758 value: b"o'brien".to_vec(),
759 type_name: "text".into(),
760 }];
761 let out =
762 interpolate_cursor_params("SELECT $1", ¶ms).unwrap();
763 assert_eq!(out, "SELECT 'o''brien'");
764 }
765
766 #[test]
767 fn test_interpolate_cursor_params_binary_hex() {
768 let params = vec![CursorParam {
769 index: 1,
770 value: vec![0xDE, 0xAD, 0xBE, 0xEF],
771 type_name: "bytea".into(),
772 }];
773 let out =
774 interpolate_cursor_params("SELECT $1", ¶ms).unwrap();
775 assert!(out.starts_with("SELECT '") && out.ends_with('\''));
779 }
780
781 #[test]
782 fn test_interpolate_cursor_params_missing_index_rejected() {
783 let params = vec![CursorParam {
784 index: 2, value: b"x".to_vec(),
786 type_name: "text".into(),
787 }];
788 let err = interpolate_cursor_params("SELECT $1", ¶ms).unwrap_err();
789 assert!(matches!(err, ProxyError::CursorRestore(_)));
790 }
791
792 #[test]
793 fn test_interpolate_cursor_params_out_of_range() {
794 let params = vec![CursorParam {
795 index: 1,
796 value: b"a".to_vec(),
797 type_name: "text".into(),
798 }];
799 let err =
800 interpolate_cursor_params("SELECT $2", ¶ms).unwrap_err();
801 assert!(matches!(err, ProxyError::CursorRestore(_)));
802 }
803}