1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
//! Rust representation of `UDF_INIT`

#![allow(clippy::useless_conversion, clippy::unnecessary_cast)]

use std::cell::UnsafeCell;
use std::ffi::c_ulong;
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "logging-debug")]
use std::{any::type_name, mem::size_of};

use udf_sys::UDF_INIT;

#[cfg(feature = "logging-debug")]
use crate::udf_log;
use crate::{Init, UdfState};

/// Helpful constants related to the `max_length` parameter
///
/// These can be helpful when calling [`UdfCfg::set_max_len()`]
#[repr(u32)]
#[non_exhaustive]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum MaxLenOptions {
    /// The default max length for integers is 21
    IntDefault = 21,

    /// The default max length of a real value is 13 plus the result of
    /// [`UdfCfg::get_decimals()`]
    RealBase = 13,

    /// A `blob` can be up to 65 KiB.
    Blob = 1 << 16,

    /// A `mediumblob` can be up to 16 MiB.
    MediumBlob = 1 << 24,
}

/// A collection of SQL arguments
///
/// This is rusty wrapper around SQL's `UDF_INIT` struct, providing methods to
/// easily and safely work with arguments.
#[repr(transparent)]
pub struct UdfCfg<S: UdfState>(pub(crate) UnsafeCell<UDF_INIT>, PhantomData<S>);

impl<S: UdfState> UdfCfg<S> {
    /// Create an `ArgList` type on a `UDF_ARGS` struct
    ///
    /// # Safety
    ///
    /// The caller must guarantee that `ptr` is valid and remains valid for the
    /// lifetime of the returned value
    #[inline]
    pub(crate) unsafe fn from_raw_ptr<'p>(ptr: *const UDF_INIT) -> &'p Self {
        &*ptr.cast()
    }

    /// Consume a box and store its pointer in this `UDF_INIT`
    ///
    /// This takes a boxed object, turns it into a pointer, and stores that
    /// pointer in this struct. After calling this function, [`retrieve_box`]
    /// _must_ be called to free the memory!
    pub(crate) fn store_box<T>(&self, b: Box<T>) {
        let box_ptr = Box::into_raw(b);

        // Note: if T is zero-sized, this will print `0x1` for the address
        #[cfg(feature = "logging-debug")]
        udf_log!(
            Debug: "{box_ptr:p} {} bytes udf->server control transfer ({})",
            size_of::<T>(),type_name::<T>()
        );

        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        // here
        unsafe { (*self.0.get()).ptr = box_ptr.cast() };
    }

    /// Given this struct's `ptr` field is a boxed object, turn that pointer
    /// back into a box
    ///
    /// # Safety
    ///
    /// T _must_ be the type of this struct's pointer, likely created with
    /// [`store_box`]
    pub(crate) unsafe fn retrieve_box<T>(&self) -> Box<T> {
        let box_ptr = (*self.0.get()).ptr.cast::<T>();

        #[cfg(feature = "logging-debug")]
        udf_log!(
            Debug: "{box_ptr:p} {} bytes server->udf control transfer ({})",
            size_of::<T>(),type_name::<T>()
        );

        Box::from_raw(box_ptr)
    }

    /// Retrieve the setting for whether this UDF may return `null`
    ///
    /// This defaults to true if any argument is nullable, false otherwise
    #[inline]
    pub fn get_maybe_null(&self) -> bool {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).maybe_null }
    }

    /// Retrieve the setting for number of decimal places
    ///
    /// This defaults to the longest number of digits of any argument, or 31 if
    /// there is no fixed number
    #[inline]
    pub fn get_decimals(&self) -> u32 {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).decimals as u32 }
    }

    /// Set the number of decimals this function returns
    ///
    /// This can be changed at any point in the UDF (init or process)
    #[inline]
    pub fn set_decimals(&self, v: u32) {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).decimals = v.into() }
    }

    /// Retrieve the current maximum length setting for this in-progress UDF
    #[inline]
    pub fn get_max_len(&self) -> u64 {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).max_length as u64 }
    }

    /// Get the current `const_item` value
    #[inline]
    pub fn get_is_const(&self) -> bool {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).const_item }
    }
}

/// Implementations of actions on a `UdfCfg` that are only possible during
/// initialization
impl UdfCfg<Init> {
    /// Set whether or not this function may return null
    #[inline]
    pub fn set_maybe_null(&self, v: bool) {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).maybe_null = v };
    }

    /// Set the maximum possible length of this UDF's result
    ///
    /// This is mostly relevant for String and Decimal return types. See
    /// [`MaxLenOptions`] for possible defaults, including `BLOB` sizes.
    #[inline]
    pub fn set_max_len(&self, v: u64) {
        // Need to try_into because ulong is 64 bits in GNU but 32 bits MSVC
        let set: c_ulong = v.try_into().unwrap_or(c_ulong::MAX);
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).max_length = set };
    }

    /// Set a new `const_item` value
    ///
    /// Set this to true if your function always returns the same values with
    /// the same arguments
    #[inline]
    pub fn set_is_const(&self, v: bool) {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        unsafe { (*self.0.get()).const_item = v };
    }
}

impl<T: UdfState> Debug for UdfCfg<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        // SAFETY: unsafe when called from different threads, but we are `!Sync`
        // here
        let base = unsafe { &*self.0.get() };
        f.debug_struct("UdfCfg")
            .field("maybe_null", &base.maybe_null)
            .field("decimals", &base.decimals)
            .field("max_len", &base.max_length)
            .field("is_const", &base.const_item)
            .field("ptr", &base.ptr)
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;
    use std::mem::{align_of, size_of};

    use super::*;
    use crate::mock::MockUdfCfg;
    use crate::{Init, Process};

    // Verify no size issues
    #[test]
    fn cfg_init_size() {
        assert_eq!(
            size_of::<UDF_INIT>(),
            size_of::<UdfCfg<Init>>(),
            concat!("Size of: ", stringify!(UDF_INIT))
        );
        assert_eq!(
            align_of::<UDF_INIT>(),
            align_of::<UdfCfg<Init>>(),
            concat!("Alignment of ", stringify!(UDF_INIT))
        );
    }

    #[test]
    fn cfg_proc_size() {
        assert_eq!(
            size_of::<UDF_INIT>(),
            size_of::<UdfCfg<Process>>(),
            concat!("Size of: ", stringify!(UDF_INIT))
        );
        assert_eq!(
            align_of::<UDF_INIT>(),
            align_of::<UdfCfg<Process>>(),
            concat!("Alignment of ", stringify!(UDF_INIT))
        );
    }

    #[test]
    fn test_box_load_store() {
        // Verify store & retrieve on a box works
        #[derive(PartialEq, Debug, Clone)]
        struct X {
            s: String,
            map: HashMap<i64, f64>,
        }

        let mut map = HashMap::new();
        map.insert(930_984_098, 4_525_435_435.900_981);
        map.insert(12_341_234, -234.090_909_092);
        map.insert(-23_412_343_453, 838_383.6);

        let stored = X {
            s: "This is a string".to_owned(),
            map,
        };

        let mut m = MockUdfCfg::new();
        let cfg = m.as_init();
        cfg.store_box(Box::new(stored.clone()));

        let loaded: X = unsafe { *cfg.retrieve_box() };
        assert_eq!(stored, loaded);
    }

    #[test]
    fn maybe_null() {
        let mut m = MockUdfCfg::new();

        *m.maybe_null() = false;
        assert!(!m.as_init().get_maybe_null());
        *m.maybe_null() = true;
        assert!(m.as_init().get_maybe_null());
    }

    #[test]
    fn decimals() {
        let mut m = MockUdfCfg::new();

        *m.decimals() = 1234;
        assert_eq!(m.as_init().get_decimals(), 1234);
        *m.decimals() = 0;
        assert_eq!(m.as_init().get_decimals(), 0);
        *m.decimals() = 1;
        assert_eq!(m.as_init().get_decimals(), 1);

        m.as_init().set_decimals(4);
        assert_eq!(*m.decimals(), 4);
    }
    #[test]
    fn max_len() {
        let mut m = MockUdfCfg::new();

        *m.max_len() = 1234;
        assert_eq!(m.as_init().get_max_len(), 1234);
        *m.max_len() = 0;
        assert_eq!(m.as_init().get_max_len(), 0);
        *m.max_len() = 1;
        assert_eq!(m.as_init().get_max_len(), 1);
    }
    #[test]
    fn test_const() {
        let mut m = MockUdfCfg::new();

        *m.is_const() = false;
        assert!(!m.as_init().get_is_const());
        *m.is_const() = true;
        assert!(m.as_init().get_is_const());
    }
}