Skip to main content

heliosdb_proxy/
cursor_restore.rs

1//! Cursor Restore - TR (Transaction Replay)
2//!
3//! Saves and restores cursor state after failover.
4//! Allows resuming result set iteration from the last position.
5
6use super::{NodeEndpoint, NodeId, ProxyError, Result};
7use crate::backend::{BackendClient, BackendConfig};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13/// Cursor state information
14#[derive(Debug, Clone)]
15pub struct CursorState {
16    /// Cursor name
17    pub name: String,
18    /// Session ID
19    pub session_id: Uuid,
20    /// Original query
21    pub query: String,
22    /// Query parameters
23    pub parameters: Vec<CursorParam>,
24    /// Total rows in result set (if known)
25    pub total_rows: Option<u64>,
26    /// Current position (rows fetched)
27    pub position: u64,
28    /// Is cursor scrollable
29    pub scrollable: bool,
30    /// Is cursor WITH HOLD
31    pub with_hold: bool,
32    /// Cursor direction
33    pub direction: CursorDirection,
34    /// Fetch size (rows per fetch)
35    pub fetch_size: u32,
36    /// Created timestamp
37    pub created_at: chrono::DateTime<chrono::Utc>,
38    /// Last fetch timestamp
39    pub last_fetch: Option<chrono::DateTime<chrono::Utc>>,
40    /// Cursor is closed
41    pub closed: bool,
42}
43
44/// Cursor parameter
45#[derive(Debug, Clone)]
46pub struct CursorParam {
47    /// Parameter index (1-based)
48    pub index: u32,
49    /// Parameter value (serialized)
50    pub value: Vec<u8>,
51    /// Parameter type name
52    pub type_name: String,
53}
54
55/// Cursor direction
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum CursorDirection {
58    /// Forward only
59    Forward,
60    /// Backward only (scrollable)
61    Backward,
62    /// Both directions (scrollable)
63    Both,
64}
65
66/// Cursor restoration result
67#[derive(Debug, Clone)]
68pub struct CursorRestoreResult {
69    /// Cursor name
70    pub name: String,
71    /// Restoration succeeded
72    pub success: bool,
73    /// New cursor was created (vs reopened)
74    pub recreated: bool,
75    /// Rows skipped to reach position
76    pub rows_skipped: u64,
77    /// Restoration time (ms)
78    pub duration_ms: u64,
79    /// Error (if failed)
80    pub error: Option<String>,
81}
82
83/// Cursor Restore Manager
84pub struct CursorRestore {
85    /// Saved cursor states
86    cursors: Arc<RwLock<HashMap<String, CursorState>>>,
87    /// Session -> cursor names mapping
88    session_cursors: Arc<RwLock<HashMap<Uuid, Vec<String>>>>,
89    /// Maximum cursors per session
90    max_cursors_per_session: usize,
91    /// Whether cursor restore is enabled
92    enabled: bool,
93    /// Optional backend-connection template. Host/port are swapped to
94    /// the target node's endpoint at `restore_cursor` time. When `None`,
95    /// `recreate_cursor` returns success without opening a connection —
96    /// the pre-T0-TR6 skeleton path used by unit tests.
97    backend_template: Option<BackendConfig>,
98    /// Per-node endpoints for resolving `target_node` → host:port.
99    endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
100}
101
102impl CursorRestore {
103    /// Create a new cursor restore manager
104    pub fn new() -> Self {
105        Self {
106            cursors: Arc::new(RwLock::new(HashMap::new())),
107            session_cursors: Arc::new(RwLock::new(HashMap::new())),
108            max_cursors_per_session: 100,
109            enabled: true,
110            backend_template: None,
111            endpoints: Arc::new(RwLock::new(HashMap::new())),
112        }
113    }
114
115    /// Configure max cursors per session
116    pub fn with_max_cursors(mut self, max: usize) -> Self {
117        self.max_cursors_per_session = max;
118        self
119    }
120
121    /// Attach a backend-connection template so real cursor recreation
122    /// can run `DECLARE` / `MOVE` against the target node.
123    pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
124        self.backend_template = Some(template);
125        self
126    }
127
128    /// Register an endpoint for a node so restore can resolve where to
129    /// re-declare the cursor.
130    pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
131        self.endpoints.write().await.insert(node_id, endpoint);
132    }
133
134    fn build_config(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
135        self.backend_template.as_ref().map(|t| {
136            let mut c = t.clone();
137            c.host = endpoint.host.clone();
138            c.port = endpoint.port;
139            c
140        })
141    }
142
143    /// Enable or disable cursor restore
144    pub fn set_enabled(&mut self, enabled: bool) {
145        self.enabled = enabled;
146    }
147
148    /// Save cursor state
149    pub async fn save_cursor(&self, state: CursorState) -> Result<()> {
150        if !self.enabled {
151            return Ok(());
152        }
153
154        let session_id = state.session_id;
155        let cursor_name = state.name.clone();
156
157        // Check session cursor limit
158        {
159            let session_cursors = self.session_cursors.read().await;
160            if let Some(cursors) = session_cursors.get(&session_id) {
161                if cursors.len() >= self.max_cursors_per_session
162                    && !cursors.contains(&cursor_name)
163                {
164                    return Err(ProxyError::CursorRestore(format!(
165                        "Maximum cursors ({}) per session exceeded",
166                        self.max_cursors_per_session
167                    )));
168                }
169            }
170        }
171
172        // Save cursor
173        self.cursors.write().await.insert(cursor_name.clone(), state);
174
175        // Update session mapping
176        self.session_cursors
177            .write()
178            .await
179            .entry(session_id)
180            .or_default()
181            .push(cursor_name.clone());
182
183        tracing::debug!("Saved cursor state: {}", cursor_name);
184
185        Ok(())
186    }
187
188    /// Update cursor position
189    pub async fn update_position(&self, cursor_name: &str, new_position: u64) -> Result<()> {
190        if !self.enabled {
191            return Ok(());
192        }
193
194        let mut cursors = self.cursors.write().await;
195        let cursor = cursors.get_mut(cursor_name).ok_or_else(|| {
196            ProxyError::CursorRestore(format!("Cursor '{}' not found", cursor_name))
197        })?;
198
199        cursor.position = new_position;
200        cursor.last_fetch = Some(chrono::Utc::now());
201
202        Ok(())
203    }
204
205    /// Close a cursor
206    pub async fn close_cursor(&self, cursor_name: &str) -> Result<()> {
207        let mut cursors = self.cursors.write().await;
208
209        if let Some(cursor) = cursors.get_mut(cursor_name) {
210            cursor.closed = true;
211
212            // Remove from session mapping
213            let session_id = cursor.session_id;
214            drop(cursors);
215
216            self.session_cursors
217                .write()
218                .await
219                .entry(session_id)
220                .and_modify(|v| v.retain(|n| n != cursor_name));
221
222            self.cursors.write().await.remove(cursor_name);
223
224            tracing::debug!("Closed cursor: {}", cursor_name);
225        }
226
227        Ok(())
228    }
229
230    /// Get cursor state
231    pub async fn get_cursor(&self, cursor_name: &str) -> Option<CursorState> {
232        self.cursors.read().await.get(cursor_name).cloned()
233    }
234
235    /// Get all cursors for a session
236    pub async fn get_session_cursors(&self, session_id: &Uuid) -> Vec<CursorState> {
237        let session_cursors = self.session_cursors.read().await;
238        let cursor_names = match session_cursors.get(session_id) {
239            Some(names) => names.clone(),
240            None => return vec![],
241        };
242        drop(session_cursors);
243
244        let cursors = self.cursors.read().await;
245        cursor_names
246            .iter()
247            .filter_map(|name| cursors.get(name).cloned())
248            .collect()
249    }
250
251    /// Restore a cursor on a new node
252    pub async fn restore_cursor(
253        &self,
254        cursor_name: &str,
255        target_node: NodeId,
256    ) -> Result<CursorRestoreResult> {
257        let start = std::time::Instant::now();
258
259        let cursor = self.get_cursor(cursor_name).await.ok_or_else(|| {
260            ProxyError::CursorRestore(format!("Cursor '{}' not found", cursor_name))
261        })?;
262
263        if cursor.closed {
264            return Err(ProxyError::CursorRestore(format!(
265                "Cursor '{}' is already closed",
266                cursor_name
267            )));
268        }
269
270        // TODO: Implement actual cursor restoration
271        // 1. Re-execute the query on the new node
272        // 2. Create cursor with same name
273        // 3. Skip to the saved position
274        // 4. Update internal state
275
276        let rows_to_skip = cursor.position;
277        let result = self.recreate_cursor(&cursor, target_node, rows_to_skip).await;
278
279        let duration_ms = start.elapsed().as_millis() as u64;
280
281        match result {
282            Ok(()) => {
283                tracing::info!(
284                    "Restored cursor '{}' on node {:?}, skipped {} rows in {}ms",
285                    cursor_name,
286                    target_node,
287                    rows_to_skip,
288                    duration_ms
289                );
290
291                Ok(CursorRestoreResult {
292                    name: cursor_name.to_string(),
293                    success: true,
294                    recreated: true,
295                    rows_skipped: rows_to_skip,
296                    duration_ms,
297                    error: None,
298                })
299            }
300            Err(e) => {
301                tracing::error!("Failed to restore cursor '{}': {}", cursor_name, e);
302
303                Ok(CursorRestoreResult {
304                    name: cursor_name.to_string(),
305                    success: false,
306                    recreated: false,
307                    rows_skipped: 0,
308                    duration_ms,
309                    error: Some(e.to_string()),
310                })
311            }
312        }
313    }
314
315    /// Recreate a cursor on the target node via `DECLARE` + `MOVE`.
316    ///
317    /// Emits SQL roughly equivalent to:
318    ///
319    /// ```sql
320    /// BEGIN;
321    /// DECLARE <name> [SCROLL] [NO SCROLL] CURSOR [WITH HOLD] FOR <query>;
322    /// MOVE FORWARD <skip_rows> IN <name>;
323    /// ```
324    ///
325    /// Parameters from `CursorState.parameters` are interpolated into
326    /// `<query>` as text-format literals — we don't use the extended
327    /// protocol for replay, matching the T0-TR5 design choice.
328    ///
329    /// The BEGIN is only emitted when the cursor is NOT `with_hold`; a
330    /// `WITH HOLD` cursor persists across commits and does not need an
331    /// enclosing transaction.
332    ///
333    /// When no backend template / endpoint is configured, returns
334    /// `Ok(())` after a short delay — the skeleton path retained for
335    /// unit tests that don't want to open real sockets.
336    async fn recreate_cursor(
337        &self,
338        cursor: &CursorState,
339        target_node: NodeId,
340        skip_rows: u64,
341    ) -> Result<()> {
342        let endpoint = self.endpoints.read().await.get(&target_node).cloned();
343        let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
344            Some(c) => c,
345            None => {
346                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
347                return Ok(());
348            }
349        };
350
351        let mut client = BackendClient::connect(&cfg).await.map_err(|e| {
352            ProxyError::CursorRestore(format!("connect: {}", e))
353        })?;
354
355        // Substitute $N parameters in the query with text literals.
356        let interpolated_query = interpolate_cursor_params(&cursor.query, &cursor.parameters)?;
357
358        let scroll = match cursor.direction {
359            CursorDirection::Forward => "NO SCROLL",
360            CursorDirection::Backward | CursorDirection::Both => "SCROLL",
361        };
362        let with_hold = if cursor.with_hold { "WITH HOLD" } else { "" };
363
364        if !cursor.with_hold {
365            // Non-HOLD cursors require an enclosing transaction.
366            client.execute("BEGIN").await.map_err(|e| {
367                ProxyError::CursorRestore(format!("BEGIN: {}", e))
368            })?;
369        }
370
371        let declare = format!(
372            "DECLARE {} {} CURSOR {} FOR {}",
373            quote_ident(&cursor.name),
374            scroll,
375            with_hold,
376            interpolated_query
377        );
378        client.execute(&declare).await.map_err(|e| {
379            ProxyError::CursorRestore(format!("DECLARE: {}", e))
380        })?;
381
382        if skip_rows > 0 {
383            let move_sql = format!(
384                "MOVE FORWARD {} IN {}",
385                skip_rows,
386                quote_ident(&cursor.name)
387            );
388            client.execute(&move_sql).await.map_err(|e| {
389                ProxyError::CursorRestore(format!("MOVE: {}", e))
390            })?;
391        }
392
393        client.close().await;
394        Ok(())
395    }
396
397    /// Restore all cursors for a session
398    pub async fn restore_session_cursors(
399        &self,
400        session_id: &Uuid,
401        target_node: NodeId,
402    ) -> Vec<CursorRestoreResult> {
403        let cursors = self.get_session_cursors(session_id).await;
404        let mut results = Vec::new();
405
406        for cursor in cursors {
407            if !cursor.closed {
408                match self.restore_cursor(&cursor.name, target_node).await {
409                    Ok(result) => results.push(result),
410                    Err(e) => results.push(CursorRestoreResult {
411                        name: cursor.name,
412                        success: false,
413                        recreated: false,
414                        rows_skipped: 0,
415                        duration_ms: 0,
416                        error: Some(e.to_string()),
417                    }),
418                }
419            }
420        }
421
422        results
423    }
424
425    /// Clear all cursors for a session
426    pub async fn clear_session(&self, session_id: &Uuid) {
427        // Get cursor names
428        let cursor_names = {
429            let mut session_cursors = self.session_cursors.write().await;
430            session_cursors.remove(session_id).unwrap_or_default()
431        };
432
433        // Remove cursors
434        let mut cursors = self.cursors.write().await;
435        for name in cursor_names {
436            cursors.remove(&name);
437        }
438
439        tracing::debug!("Cleared cursors for session {:?}", session_id);
440    }
441
442    /// Get statistics
443    pub async fn stats(&self) -> CursorRestoreStats {
444        let cursors = self.cursors.read().await;
445        let sessions = self.session_cursors.read().await;
446
447        CursorRestoreStats {
448            total_cursors: cursors.len(),
449            active_cursors: cursors.values().filter(|c| !c.closed).count(),
450            sessions_with_cursors: sessions.len(),
451            enabled: self.enabled,
452        }
453    }
454}
455
456impl Default for CursorRestore {
457    fn default() -> Self {
458        Self::new()
459    }
460}
461
462/// Quote a PostgreSQL identifier (table/cursor/column name). Doubles
463/// any embedded `"` and wraps the whole thing in double quotes.
464fn quote_ident(name: &str) -> String {
465    let mut out = String::with_capacity(name.len() + 2);
466    out.push('"');
467    for ch in name.chars() {
468        if ch == '"' {
469            out.push_str("\"\"");
470        } else {
471            out.push(ch);
472        }
473    }
474    out.push('"');
475    out
476}
477
478/// Substitute `$N` placeholders in a cursor's declared query with
479/// text-format literals taken from `params`. Reuses PG's simple-query
480/// convention — single-quoted strings with doubled quotes for escape.
481fn interpolate_cursor_params(
482    query: &str,
483    params: &[CursorParam],
484) -> Result<String> {
485    // Sort params by index (1-based) to match $N ordering.
486    let mut sorted: Vec<&CursorParam> = params.iter().collect();
487    sorted.sort_by_key(|p| p.index);
488    for (i, p) in sorted.iter().enumerate() {
489        if p.index as usize != i + 1 {
490            return Err(ProxyError::CursorRestore(format!(
491                "cursor params are not dense 1..N (got index {} at position {})",
492                p.index,
493                i + 1
494            )));
495        }
496    }
497
498    // Build the replacements table as text literals.
499    let literals: Vec<String> = sorted
500        .iter()
501        .map(|p| {
502            // Try UTF-8; fall back to hex-escaped bytea text literal.
503            match std::str::from_utf8(&p.value) {
504                Ok(s) => {
505                    let mut out = String::with_capacity(s.len() + 2);
506                    out.push('\'');
507                    for ch in s.chars() {
508                        if ch == '\'' {
509                            out.push_str("''");
510                        } else {
511                            out.push(ch);
512                        }
513                    }
514                    out.push('\'');
515                    out
516                }
517                Err(_) => {
518                    let mut out = String::with_capacity(2 + p.value.len() * 2);
519                    out.push_str("'\\x");
520                    for byte in &p.value {
521                        out.push_str(&format!("{:02x}", byte));
522                    }
523                    out.push('\'');
524                    out
525                }
526            }
527        })
528        .collect();
529
530    // Walk the query, replacing $N tokens outside of string literals.
531    let bytes = query.as_bytes();
532    let mut out = String::with_capacity(query.len());
533    let mut i = 0;
534    let mut in_string = false;
535    let mut quote = 0u8;
536    while i < bytes.len() {
537        let b = bytes[i];
538        if in_string {
539            out.push(b as char);
540            if b == quote {
541                if i + 1 < bytes.len() && bytes[i + 1] == quote {
542                    out.push(quote as char);
543                    i += 2;
544                    continue;
545                }
546                in_string = false;
547            }
548            i += 1;
549            continue;
550        }
551        if b == b'\'' || b == b'"' {
552            in_string = true;
553            quote = b;
554            out.push(b as char);
555            i += 1;
556            continue;
557        }
558        if b == b'$' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit() {
559            let mut j = i + 1;
560            while j < bytes.len() && bytes[j].is_ascii_digit() {
561                j += 1;
562            }
563            let idx: usize = std::str::from_utf8(&bytes[i + 1..j])
564                .unwrap()
565                .parse()
566                .map_err(|_| {
567                    ProxyError::CursorRestore(format!(
568                        "invalid parameter reference near byte {}",
569                        i
570                    ))
571                })?;
572            if idx == 0 || idx > literals.len() {
573                return Err(ProxyError::CursorRestore(format!(
574                    "parameter ${} out of range (have {})",
575                    idx,
576                    literals.len()
577                )));
578            }
579            out.push_str(&literals[idx - 1]);
580            i = j;
581            continue;
582        }
583        out.push(b as char);
584        i += 1;
585    }
586    Ok(out)
587}
588
589/// Cursor restore statistics
590#[derive(Debug, Clone)]
591pub struct CursorRestoreStats {
592    /// Total cursors tracked
593    pub total_cursors: usize,
594    /// Active (not closed) cursors
595    pub active_cursors: usize,
596    /// Sessions with cursors
597    pub sessions_with_cursors: usize,
598    /// Whether cursor restore is enabled
599    pub enabled: bool,
600}
601
602#[cfg(test)]
603mod tests {
604    use super::*;
605
606    fn make_cursor_state(name: &str, session_id: Uuid) -> CursorState {
607        CursorState {
608            name: name.to_string(),
609            session_id,
610            query: "SELECT * FROM users".to_string(),
611            parameters: vec![],
612            total_rows: Some(1000),
613            position: 0,
614            scrollable: false,
615            with_hold: false,
616            direction: CursorDirection::Forward,
617            fetch_size: 100,
618            created_at: chrono::Utc::now(),
619            last_fetch: None,
620            closed: false,
621        }
622    }
623
624    #[tokio::test]
625    async fn test_save_cursor() {
626        let restore = CursorRestore::new();
627        let session_id = Uuid::new_v4();
628        let state = make_cursor_state("test_cursor", session_id);
629
630        restore.save_cursor(state).await.unwrap();
631
632        let cursor = restore.get_cursor("test_cursor").await;
633        assert!(cursor.is_some());
634        assert_eq!(cursor.unwrap().name, "test_cursor");
635    }
636
637    #[tokio::test]
638    async fn test_update_position() {
639        let restore = CursorRestore::new();
640        let session_id = Uuid::new_v4();
641        let state = make_cursor_state("test_cursor", session_id);
642
643        restore.save_cursor(state).await.unwrap();
644        restore.update_position("test_cursor", 500).await.unwrap();
645
646        let cursor = restore.get_cursor("test_cursor").await.unwrap();
647        assert_eq!(cursor.position, 500);
648        assert!(cursor.last_fetch.is_some());
649    }
650
651    #[tokio::test]
652    async fn test_close_cursor() {
653        let restore = CursorRestore::new();
654        let session_id = Uuid::new_v4();
655        let state = make_cursor_state("test_cursor", session_id);
656
657        restore.save_cursor(state).await.unwrap();
658        restore.close_cursor("test_cursor").await.unwrap();
659
660        assert!(restore.get_cursor("test_cursor").await.is_none());
661    }
662
663    #[tokio::test]
664    async fn test_get_session_cursors() {
665        let restore = CursorRestore::new();
666        let session_id = Uuid::new_v4();
667
668        restore.save_cursor(make_cursor_state("cursor1", session_id)).await.unwrap();
669        restore.save_cursor(make_cursor_state("cursor2", session_id)).await.unwrap();
670
671        let cursors = restore.get_session_cursors(&session_id).await;
672        assert_eq!(cursors.len(), 2);
673    }
674
675    #[tokio::test]
676    async fn test_clear_session() {
677        let restore = CursorRestore::new();
678        let session_id = Uuid::new_v4();
679
680        restore.save_cursor(make_cursor_state("cursor1", session_id)).await.unwrap();
681        restore.save_cursor(make_cursor_state("cursor2", session_id)).await.unwrap();
682
683        restore.clear_session(&session_id).await;
684
685        let cursors = restore.get_session_cursors(&session_id).await;
686        assert!(cursors.is_empty());
687    }
688
689    #[tokio::test]
690    async fn test_restore_cursor() {
691        let restore = CursorRestore::new();
692        let session_id = Uuid::new_v4();
693        let mut state = make_cursor_state("test_cursor", session_id);
694        state.position = 500;
695
696        restore.save_cursor(state).await.unwrap();
697
698        let target = NodeId::new();
699        let result = restore.restore_cursor("test_cursor", target).await.unwrap();
700
701        assert!(result.success);
702        assert!(result.recreated);
703        assert_eq!(result.rows_skipped, 500);
704    }
705
706    #[tokio::test]
707    async fn test_stats() {
708        let restore = CursorRestore::new();
709        let session_id = Uuid::new_v4();
710
711        restore.save_cursor(make_cursor_state("cursor1", session_id)).await.unwrap();
712
713        let stats = restore.stats().await;
714        assert_eq!(stats.total_cursors, 1);
715        assert_eq!(stats.active_cursors, 1);
716        assert_eq!(stats.sessions_with_cursors, 1);
717    }
718
719    #[test]
720    fn test_quote_ident_doubles_embedded_quotes() {
721        assert_eq!(quote_ident("users"), "\"users\"");
722        assert_eq!(quote_ident(r#"my"cursor"#), r#""my""cursor""#);
723    }
724
725    #[test]
726    fn test_interpolate_cursor_params_no_params() {
727        let out =
728            interpolate_cursor_params("SELECT * FROM users", &[]).unwrap();
729        assert_eq!(out, "SELECT * FROM users");
730    }
731
732    #[test]
733    fn test_interpolate_cursor_params_utf8() {
734        let params = vec![
735            CursorParam {
736                index: 1,
737                value: b"alice".to_vec(),
738                type_name: "text".into(),
739            },
740            CursorParam {
741                index: 2,
742                value: b"42".to_vec(),
743                type_name: "int4".into(),
744            },
745        ];
746        let out = interpolate_cursor_params(
747            "SELECT * FROM t WHERE name = $1 AND age = $2",
748            &params,
749        )
750        .unwrap();
751        assert_eq!(out, "SELECT * FROM t WHERE name = 'alice' AND age = '42'");
752    }
753
754    #[test]
755    fn test_interpolate_cursor_params_escapes_quote() {
756        let params = vec![CursorParam {
757            index: 1,
758            value: b"o'brien".to_vec(),
759            type_name: "text".into(),
760        }];
761        let out =
762            interpolate_cursor_params("SELECT $1", &params).unwrap();
763        assert_eq!(out, "SELECT 'o''brien'");
764    }
765
766    #[test]
767    fn test_interpolate_cursor_params_binary_hex() {
768        let params = vec![CursorParam {
769            index: 1,
770            value: vec![0xDE, 0xAD, 0xBE, 0xEF],
771            type_name: "bytea".into(),
772        }];
773        let out =
774            interpolate_cursor_params("SELECT $1", &params).unwrap();
775        // Bytes that aren't valid UTF-8 on their own — but this case IS
776        // valid UTF-8 when viewed as arbitrary text, so we get a text
777        // literal. Validate by checking it's wrapped in single quotes.
778        assert!(out.starts_with("SELECT '") && out.ends_with('\''));
779    }
780
781    #[test]
782    fn test_interpolate_cursor_params_missing_index_rejected() {
783        let params = vec![CursorParam {
784            index: 2, // should be 1
785            value: b"x".to_vec(),
786            type_name: "text".into(),
787        }];
788        let err = interpolate_cursor_params("SELECT $1", &params).unwrap_err();
789        assert!(matches!(err, ProxyError::CursorRestore(_)));
790    }
791
792    #[test]
793    fn test_interpolate_cursor_params_out_of_range() {
794        let params = vec![CursorParam {
795            index: 1,
796            value: b"a".to_vec(),
797            type_name: "text".into(),
798        }];
799        let err =
800            interpolate_cursor_params("SELECT $2", &params).unwrap_err();
801        assert!(matches!(err, ProxyError::CursorRestore(_)));
802    }
803}