sqlx-data-params 0.1.0

Data parameter utilities for SQLx-Data - advanced pagination (Serial/Slice/Cursor), dynamic filtering, sorting, and type-safe query parameters for database operations
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
use crate::FilterValue;
use crate::{IntoParams, Params};

#[cfg(feature = "json")]
use serde::{Deserialize, Serialize};

// ================================================================================================
// CORE TYPES
// ================================================================================================

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
pub struct CursorEntry {
    pub value: CursorValue,
}

/// Client-facing cursor data - contains only the serializable data that goes to the client
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
#[derive(Default)]
pub struct Cursor {
    #[cfg_attr(feature = "json", serde(skip_serializing_if = "Vec::is_empty"))]
    pub entries: Vec<CursorEntry>,
    #[cfg_attr(feature = "json", serde(skip_serializing_if = "Option::is_none"))]
    pub version: Option<u8>,
    #[cfg_attr(feature = "json", serde(skip_serializing_if = "Option::is_none"))]
    pub fingerprint: Option<u64>,
}

/// Internal cursor params with metadata - contains the cursor data plus internal processing metadata
#[derive(Clone, Debug, PartialEq)]
#[derive(Default)]
pub struct CursorParams {
    /// Internal direction metadata
    pub direction: Option<CursorDirection>,
    /// After, Before, and decoded cursor data - used when building queries
    pub values: Vec<FilterValue>,
    /// Optional error message if cursor processing failed
    pub error: Option<String>,
}

#[derive(Clone, Debug, PartialEq)]
pub enum CursorDirection {
    After,
    Before,
}

#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "json", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "json", serde(untagged))]
pub enum CursorValue {
    Int(i64),
    UInt(u64),
    Float(f64),
    Bool(bool),
    String(String),
}

impl From<i64> for CursorValue {
    fn from(value: i64) -> Self {
        CursorValue::Int(value)
    }
}

impl From<u64> for CursorValue {
    fn from(value: u64) -> Self {
        CursorValue::UInt(value)
    }
}

impl From<f64> for CursorValue {
    fn from(value: f64) -> Self {
        CursorValue::Float(value)
    }
}

impl From<bool> for CursorValue {
    fn from(value: bool) -> Self {
        CursorValue::Bool(value)
    }
}

impl From<String> for CursorValue {
    fn from(value: String) -> Self {
        CursorValue::String(value)
    }
}

impl From<&str> for CursorValue {
    fn from(value: &str) -> Self {
        CursorValue::String(value.to_string())
    }
}

impl From<i32> for CursorValue {
    fn from(value: i32) -> Self {
        CursorValue::Int(value as i64)
    }
}

impl From<u32> for CursorValue {
    fn from(value: u32) -> Self {
        CursorValue::UInt(value as u64)
    }
}

impl From<f32> for CursorValue {
    fn from(value: f32) -> Self {
        CursorValue::Float(value as f64)
    }
}

pub type Result<T, E = CursorError> = ::std::result::Result<T, E>;

// ================================================================================================
// ERROR HANDLING
// ================================================================================================

#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
pub type SqlxError = sqlx_data_integration::Error;

#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
enum CursorErrorKind {
    #[error(transparent)]
    #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
    Sqlx(#[from] SqlxError),

    #[error("Field '{0}' not allowed for cursor pagination")]
    InvalidField(String),

    #[error("Data is empty")]
    EmptyData,

    #[error("Encoding cursor failed: {0}")]
    EncodeError(String),

    #[error("Decoding cursor failed: {0}")]
    DecodeError(String),
}

#[derive(Debug)]
pub struct CursorError(CursorErrorKind);

impl CursorError {
    /// Create an InvalidField error with automatic type conversion based on features
    pub fn invalid_field(field: impl Into<String>) -> Self {
        Self(CursorErrorKind::InvalidField(field.into()))
    }

    pub fn empty_data() -> Self {
        Self(CursorErrorKind::EmptyData)
    }

    pub fn encode_error(msg: impl Into<String>) -> Self {
        Self(CursorErrorKind::EncodeError(msg.into()))
    }

    pub fn decode_error(msg: impl Into<String>) -> Self {
        Self(CursorErrorKind::DecodeError(msg.into()))
    }
}

// Convert CursorError to sqlx_data_integration::Error when database features are enabled
#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
impl From<SqlxError> for CursorError {
    fn from(e: SqlxError) -> Self {
        Self(CursorErrorKind::Sqlx(e))
    }
}

impl std::fmt::Display for CursorError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

impl std::error::Error for CursorError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        self.0.source()
    }
}

#[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
impl From<CursorError> for sqlx_data_integration::Error {
    fn from(err: CursorError) -> Self {
        match err.0 {
            CursorErrorKind::Sqlx(e) => e.into(),
            other => sqlx_data_integration::Error::Decode(other.into()),
        }
    }
}

// ================================================================================================
// IMPLEMENTATIONS
// ================================================================================================



impl Cursor {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn new_multi(entries: Vec<CursorEntry>) -> Self {
        Self { entries, version: None, fingerprint: None }
    }

    pub fn and_field(mut self, value: impl Into<CursorValue>) -> Self {
        self.entries.push(CursorEntry {
            value: value.into(),
        });
        self
    }
}

impl CursorParams {
    pub fn new(value: FilterValue, direction: CursorDirection) -> Self {
        Self {
            values: vec![value],
            direction: Some(direction),
            error: None,
        }
    }

    pub fn from_values(values: Vec<FilterValue>, direction: CursorDirection) -> Self {
        Self {
            values,
            direction: Some(direction),
            error: None,
        }
    }

    pub fn with_error(direction: CursorDirection, error: impl Into<String>) -> Self {
        Self {
            values: vec![],
            direction: Some(direction),
            error: Some(error.into()),
        }
    }

    pub fn and_field(mut self, value: FilterValue) -> Self {
        self.values.push(value);
        self
    }

    /// Access to the cursor values
    pub fn values(&self) -> &[FilterValue] {
        &self.values
    }

    /// Check if cursor has values
    pub fn is_empty(&self) -> bool {
        self.values.is_empty()
    }

    /// Get the number of values
    pub fn len(&self) -> usize {
        self.values.len()
    }

    /// Check if this cursor has an error
    pub fn has_error(&self) -> bool {
        self.error.is_some()
    }

    /// Get the error message if any
    pub fn error(&self) -> Option<&str> {
        self.error.as_deref()
    }

    /// Generate cursor from a specific item in the data
    fn generate_cursor<T: CursorSecureExtract>(
        data: &[T],
        has_more: bool,
        sorting_params: &crate::sort::SortingParams,
        get_item: impl FnOnce(&[T]) -> Option<&T>,
    ) -> Result<Option<Cursor>> {
        if !has_more || data.is_empty() {
            return Ok(None);
        }

        // Extract field names from sorting parameters
        let fields: Vec<String> = sorting_params
            .sorts()
            .iter()
            .map(|s| s.field.clone())
            .collect();

        if fields.is_empty() {
            return Err(CursorError::invalid_field(
                "Cursor pagination requires ORDER BY fields",
            ));
        }

        let item = get_item(data).ok_or(CursorError::empty_data())?;

        let values = item.extract_whitelisted_fields(&fields)?;

        if values.len() != fields.len() {
            return Err(CursorError::invalid_field(
                "Cursor fields mismatch with sorting params",
            ));
        }

        let entries: Vec<CursorEntry> = values
            .into_iter()
            .map(|value| CursorEntry { value })
            .collect();

        Ok(Some(Cursor::new_multi(entries)))
    }

    /// Generate next cursor from the last item in data
    pub fn generate_next_cursor<T: CursorSecureExtract>(
        &self,
        data: &[T],
        has_next: bool,
        sorting_params: &crate::sort::SortingParams,
    ) -> Result<Option<String>> {
        let cursor = Self::generate_cursor(data, has_next, sorting_params, |data| data.last())?;
        match cursor {
            Some(c) => Ok(Some(T::encode(&c)?)),
            None => Ok(None),
        }
    }

    /// Generate prev cursor from the first item in data
    pub fn generate_prev_cursor<T: CursorSecureExtract>(
        &self,
        data: &[T],
        has_prev: bool,
        sorting_params: &crate::sort::SortingParams,
    ) -> Result<Option<String>> {
        let cursor = Self::generate_cursor(data, has_prev, sorting_params, |data| data.first())?;
        match cursor {
            Some(c) => Ok(Some(T::encode(&c)?)),
            None => Ok(None),
        }
    }

    
}

// ================================================================================================
// SECURITY TRAITS
// ================================================================================================

/// **Security-First Cursor Field Whitelist Trait**
///
/// This trait enforces a whitelist-based security model for cursor pagination fields.
/// Implementors MUST explicitly whitelist each allowed field to prevent field injection attacks.
pub trait CursorSecureExtract {
    /// **SECURITY CRITICAL**: Extract values ONLY for explicitly whitelisted cursor fields.
    ///
    /// **THIS IS A SECURITY WHITELIST** - Only return values for fields you explicitly allow.
    /// **ALWAYS** return `Err` for any field not in your whitelist to prevent field injection.
    ///
    /// # Security Model
    ///
    /// This method acts as the primary defense against field injection attacks via cursor pagination.
    /// Even if malicious field names are injected through `from_encoded()` or other vectors,
    /// this whitelist ensures only safe, predefined fields can be accessed.
    ///
    /// # Implementation Requirements
    ///
    /// - **MUST** use explicit `match field.as_str()` with hardcoded field names
    /// - **MUST** return `Err` for the default case (`_`)
    /// - **NEVER** use dynamic field resolution or reflection
    /// - **ONLY** allow fields that are safe for cursor-based ordering
    ///
    /// # Example
    /// ```rust, ignore
    /// # use sqlx_data_params::{CursorSecureExtract, CursorValue, CursorError, SqlxError};
    /// type Result<T> = ::std::result::Result<T, SqlxError>;
    /// struct User {
    ///     id: i64,
    ///     name: String,
    ///     email: String,
    ///     password_hash: String, // ← NEVER include sensitive fields!
    /// }
    ///
    /// impl CursorSecureExtract for User {
    ///     #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
    ///     fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>> {
    ///         let mut values = Vec::with_capacity(fields.len());
    ///         for field in fields {
    ///             // 🛡️ SECURITY WHITELIST: Only these fields are allowed
    ///             match field.as_str() {
    ///                 "id" => values.push(self.id.into()),           // ✅ Safe: Primary key
    ///                 "name" => values.push(self.name.clone().into()), // ✅ Safe: Public field
    ///                 "email" => values.push(self.email.clone().into()), // ✅ Safe: Public field
    ///                 // password_hash is NOT in whitelist - cannot be accessed via cursor
    ///                 _ => return Err(CursorError::invalid_field(field.clone()).into()), // 🚫 REJECT: All non-whitelisted fields
    ///             }
    ///         }
    ///         Ok(values)
    ///     }
    ///
    ///     #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))]
    ///     fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>> {
    ///         let mut values = Vec::with_capacity(fields.len());
    ///         for field in fields {
    ///             // 🛡️ SECURITY WHITELIST: Only these fields are allowed
    ///             match field.as_str() {
    ///                 "id" => values.push(self.id.into()),           // ✅ Safe: Primary key
    ///                 "name" => values.push(self.name.clone().into()), // ✅ Safe: Public field
    ///                 "email" => values.push(self.email.clone().into()), // ✅ Safe: Public field
    ///                 // password_hash is NOT in whitelist - cannot be accessed via cursor
    ///                 _ => return Err(CursorError::invalid_field(field.clone())), // 🚫 REJECT: All non-whitelisted fields
    ///             }
    ///         }
    ///         Ok(values)
    ///     }
    /// }
    /// ```
    ///
    /// # Security Benefits
    ///
    /// - **Field Injection Prevention**: Malicious fields from `from_encoded()` are rejected
    /// - **Data Exposure Control**: Sensitive fields cannot be accessed via cursor pagination
    /// - **Explicit Security Model**: Developers must consciously choose which fields to expose
    /// - **Defense in Depth**: Multiple layers protect against various attack vectors
    #[cfg(any(feature = "sqlite", feature = "postgres", feature = "mysql"))]
    fn extract_whitelisted_fields(
        &self,
        fields: &[String],
    ) -> Result<Vec<CursorValue>, sqlx_data_integration::Error>;

    #[cfg(not(any(feature = "sqlite", feature = "postgres", feature = "mysql")))]
    fn extract_whitelisted_fields(&self, fields: &[String]) -> Result<Vec<CursorValue>>;

    /// Encode cursor to string token
    ///
    /// Example implementation:
    /// ```rust,ignore
    /// fn encode(cursor: &Cursor) -> Result<String, sqlx_data_integration::Error> {
    ///     use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as BASE64};
    ///     let json_bytes = serde_json::to_vec(&cursor)
    ///         .map_err(|e| CursorError::encode_error(format!("JSON serialization failed: {}", e)))?;
    ///     Ok(BASE64.encode(json_bytes))
    /// }
    /// ```
    #[cfg(feature = "json")]
    fn encode(cursor: &Cursor) -> Result<String, sqlx_data_integration::Error>;

    /// Encode cursor to string token (JSON feature disabled)
    #[cfg(not(feature = "json"))]
    fn encode(_cursor: &Cursor) -> Result<String>;

    /// Decode string token to FilterValue vector
    ///
    /// Example implementation:
    /// ```rust,ignore
    /// fn decode(encoded: &str) -> Result<Vec<FilterValue>> {
    ///     use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD as BASE64};
    ///     let bytes = BASE64
    ///         .decode(encoded)
    ///         .map_err(|e| CursorError::decode_error(format!("Base64 decode failed: {}", e)))?;
    ///
    ///     let cursor: Cursor = serde_json::from_slice(&bytes).map_err(|e| {
    ///         CursorError::decode_error(format!("JSON deserialization failed: {}", e))
    ///     })?;
    ///
    ///     // Convert CursorValue to FilterValue
    ///     let filter_values: Vec<FilterValue> = cursor.entries.into_iter().map(|entry| {
    ///         match entry.value {
    ///             CursorValue::Int(v) => FilterValue::Int(v),
    ///             CursorValue::UInt(v) => FilterValue::UInt(v),
    ///             CursorValue::Float(v) => FilterValue::Float(v),
    ///             CursorValue::Bool(v) => FilterValue::Bool(v),
    ///             CursorValue::String(v) => v.into(), // Or Whatever conversion is appropriate
    ///         }
    ///     }).collect();
    ///
    ///     Ok(filter_values)
    /// }
    /// ```
    #[cfg(feature = "json")]
    fn decode(encoded: &str) -> Result<Vec<FilterValue>, sqlx_data_integration::Error>;

    /// Decode string token to FilterValue vector (JSON feature disabled)
    #[cfg(not(feature = "json"))]
    fn decode(_encoded: &str) -> Result<Vec<FilterValue>>;
}

// ================================================================================================
// PARAMS INTEGRATION
// ================================================================================================

impl IntoParams for CursorParams {
    fn into_params(self) -> Params {
        let per_page = 20; // Default value
        let pagination = crate::pagination::Pagination::Cursor(self);
        Params {
            filters: None,
            search: None,
            sort_by: None,
            pagination: Some(pagination),
            limit: Some(crate::pagination::LimitParam(per_page)),
            offset: None,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_cursor_builder_pattern() {
        let cursor =
            CursorParams::new(FilterValue::String("alice".into()), CursorDirection::Before)
                .and_field(FilterValue::Int(25))
                .and_field(FilterValue::Float(99.5));

        assert_eq!(cursor.len(), 3);
        assert_eq!(cursor.direction.unwrap(), CursorDirection::Before);
    }

    #[test]
    fn test_cursor_state_detection() {
        let cursor_with_data = CursorParams::new(FilterValue::Int(123), CursorDirection::After);
        assert!(!cursor_with_data.is_empty());
        assert!(!cursor_with_data.has_error());

        let cursor_with_error = CursorParams::with_error(CursorDirection::After, "decode failed");
        assert!(cursor_with_error.is_empty());
        assert!(cursor_with_error.has_error());
    }

    #[test]
    fn test_cursor_values() {
        let cursor = CursorParams::new(FilterValue::Int(123), CursorDirection::After)
            .and_field(FilterValue::String("test".into()));

        assert_eq!(cursor.len(), 2);
        assert_eq!(cursor.values().len(), 2);
        assert_eq!(cursor.direction, Some(CursorDirection::After));
    }

    #[test]
    fn test_error_workflow() {
        let cursor_ok = CursorParams::new(FilterValue::Int(123), CursorDirection::After);
        assert!(!cursor_ok.has_error());
        assert_eq!(cursor_ok.error(), None);

        let cursor_err = CursorParams::with_error(CursorDirection::Before, "Invalid token");
        assert!(cursor_err.has_error());
        assert_eq!(cursor_err.error(), Some("Invalid token"));
    }
}