valkey_module/context/
blocked.rs

1use crate::redismodule::AUTH_HANDLED;
2use crate::{raw, Context, ValkeyError, ValkeyString};
3use std::os::raw::{c_int, c_void};
4
5// Callback types for handling blocked client operations
6// Currently supports authentication reply callback for block_client_on_auth
7#[derive(Debug)]
8pub enum ReplyCallback<T> {
9    Auth(fn(&Context, ValkeyString, ValkeyString, Option<&T>) -> Result<c_int, ValkeyError>),
10}
11
12#[derive(Debug)]
13struct BlockedClientPrivateData<T: 'static> {
14    reply_callback: Option<ReplyCallback<T>>,
15    free_callback: Option<FreePrivateDataCallback<T>>,
16    data: Option<Box<T>>,
17}
18
19// Callback type for freeing private data associated with a blocked client
20type FreePrivateDataCallback<T> = fn(&Context, T);
21
22pub struct BlockedClient<T: 'static = ()> {
23    pub(crate) inner: *mut raw::RedisModuleBlockedClient,
24    reply_callback: Option<ReplyCallback<T>>,
25    free_callback: Option<FreePrivateDataCallback<T>>,
26    data: Option<Box<T>>,
27}
28
29#[allow(dead_code)]
30unsafe extern "C" fn auth_reply_wrapper<T: 'static>(
31    ctx: *mut raw::RedisModuleCtx,
32    username: *mut raw::RedisModuleString,
33    password: *mut raw::RedisModuleString,
34    err: *mut *mut raw::RedisModuleString,
35) -> c_int {
36    let context = Context::new(ctx);
37    let ctx_ptr = std::ptr::NonNull::new_unchecked(ctx);
38    let username = ValkeyString::new(Some(ctx_ptr), username);
39    let password = ValkeyString::new(Some(ctx_ptr), password);
40
41    let module_private_data = context.get_blocked_client_private_data();
42    if module_private_data.is_null() {
43        panic!("[auth_reply_wrapper] Module private data is null; this should not happen!");
44    }
45
46    let user_private_data = &*(module_private_data as *const BlockedClientPrivateData<T>);
47
48    let cb = match user_private_data.reply_callback.as_ref() {
49        Some(ReplyCallback::Auth(cb)) => cb,
50        None => panic!("[auth_reply_wrapper] Reply callback is null; this should not happen!"),
51    };
52
53    let data_ref = user_private_data.data.as_deref();
54
55    match cb(&context, username, password, data_ref) {
56        Ok(result) => result,
57        Err(error) => {
58            let error_msg = ValkeyString::create_and_retain(&error.to_string());
59            *err = error_msg.inner;
60            AUTH_HANDLED
61        }
62    }
63}
64
65#[allow(dead_code)]
66unsafe extern "C" fn free_callback_wrapper<T: 'static>(
67    ctx: *mut raw::RedisModuleCtx,
68    module_private_data: *mut c_void,
69) {
70    let context = Context::new(ctx);
71
72    if module_private_data.is_null() {
73        panic!("[free_callback_wrapper] Module private data is null; this should not happen!");
74    }
75
76    let user_private_data = Box::from_raw(module_private_data as *mut BlockedClientPrivateData<T>);
77
78    // Execute free_callback only if both callback and data exist
79    // Note: free_callback can exist without data - this is a valid state
80    if let Some(free_cb) = user_private_data.free_callback {
81        if let Some(data) = user_private_data.data {
82            free_cb(&context, *data);
83        }
84    }
85}
86
87// We need to be able to send the inner pointer to another thread
88unsafe impl<T> Send for BlockedClient<T> {}
89
90impl<T> BlockedClient<T> {
91    pub(crate) fn new(inner: *mut raw::RedisModuleBlockedClient) -> Self {
92        Self {
93            inner,
94            reply_callback: None,
95            free_callback: None,
96            data: None,
97        }
98    }
99
100    #[allow(dead_code)]
101    pub(crate) fn with_auth_callback(
102        inner: *mut raw::RedisModuleBlockedClient,
103        auth_reply_callback: fn(
104            &Context,
105            ValkeyString,
106            ValkeyString,
107            Option<&T>,
108        ) -> Result<c_int, ValkeyError>,
109        free_callback: Option<FreePrivateDataCallback<T>>,
110    ) -> Self
111    where
112        T: 'static,
113    {
114        Self {
115            inner,
116            reply_callback: Some(ReplyCallback::Auth(auth_reply_callback)),
117            free_callback,
118            data: None,
119        }
120    }
121
122    /// Sets private data for the blocked client.
123    ///
124    /// # Arguments
125    /// * `data` - The private data to store
126    ///
127    /// # Returns
128    /// * `Ok(())` - If the private data was successfully set
129    /// * `Err(ValkeyError)` - If setting the private data failed (e.g., no free callback)
130    pub fn set_blocked_private_data(&mut self, data: T) -> Result<(), ValkeyError> {
131        if self.free_callback.is_none() {
132            return Err(ValkeyError::Str(
133                "Cannot set private data without a free callback - this would leak memory",
134            ));
135        }
136        self.data = Some(Box::new(data));
137        Ok(())
138    }
139
140    /// Aborts the blocked client operation
141    ///
142    /// # Returns
143    /// * `Ok(())` - If the blocked client was successfully aborted
144    /// * `Err(ValkeyError)` - If the abort operation failed
145    pub fn abort(mut self) -> Result<(), ValkeyError> {
146        unsafe {
147            // Clear references to data and callbacks
148            self.data = None;
149            self.reply_callback = None;
150            self.free_callback = None;
151
152            if raw::RedisModule_AbortBlock.unwrap()(self.inner) == raw::REDISMODULE_OK as c_int {
153                // Prevent the normal Drop from running
154                self.inner = std::ptr::null_mut();
155                Ok(())
156            } else {
157                Err(ValkeyError::Str("Failed to abort blocked client"))
158            }
159        }
160    }
161}
162
163impl<T: 'static> Drop for BlockedClient<T> {
164    fn drop(&mut self) {
165        if !self.inner.is_null() {
166            let callback_data_ptr = if self.reply_callback.is_some() || self.free_callback.is_some()
167            {
168                Box::into_raw(Box::new(BlockedClientPrivateData {
169                    reply_callback: self.reply_callback.take(),
170                    free_callback: self.free_callback.take(),
171                    data: self.data.take(),
172                })) as *mut c_void
173            } else {
174                std::ptr::null_mut()
175            };
176
177            unsafe {
178                raw::RedisModule_UnblockClient.unwrap()(self.inner, callback_data_ptr);
179            }
180        }
181    }
182}
183
184impl Context {
185    #[must_use]
186    pub fn block_client(&self) -> BlockedClient {
187        let blocked_client = unsafe {
188            raw::RedisModule_BlockClient.unwrap()(
189                self.ctx, // ctx
190                None,     // reply_func
191                None,     // timeout_func
192                None, 0,
193            )
194        };
195
196        BlockedClient::new(blocked_client)
197    }
198
199    /// Blocks a client during authentication and registers callbacks
200    ///
201    /// Wrapper around ValkeyModule_BlockClientOnAuth. Used for asynchronous authentication
202    /// processing.
203    ///
204    /// # Arguments
205    /// * `auth_reply_callback` - Callback executed when authentication completes
206    /// * `free_callback` - Optional callback for cleaning up private data
207    ///
208    /// # Returns
209    /// * `BlockedClient<T>` - Handle to manage the blocked client
210    #[must_use]
211    #[cfg(all(any(
212        feature = "min-redis-compatibility-version-7-2",
213        feature = "min-valkey-compatibility-version-8-0"
214    ),))]
215    pub fn block_client_on_auth<T: 'static + Send>(
216        &self,
217        auth_reply_callback: fn(
218            &Context,
219            ValkeyString,
220            ValkeyString,
221            Option<&T>,
222        ) -> Result<c_int, ValkeyError>,
223        free_callback: Option<FreePrivateDataCallback<T>>,
224    ) -> BlockedClient<T> {
225        unsafe {
226            let blocked_client = raw::RedisModule_BlockClientOnAuth.unwrap()(
227                self.ctx,
228                Some(auth_reply_wrapper::<T>),
229                Some(free_callback_wrapper::<T>),
230            );
231
232            BlockedClient::with_auth_callback(blocked_client, auth_reply_callback, free_callback)
233        }
234    }
235
236    /// Retrieves the private data associated with a blocked client in the current context.
237    /// This is an internal function used primarily by reply callbacks to access user-provided data.
238    ///
239    /// # Safety
240    /// This function returns a raw pointer that must be properly cast to the expected type.
241    /// The caller must ensure the pointer is not null before dereferencing.
242    ///
243    /// # Implementation Detail
244    /// Wraps the Valkey Module C API function `ValkeyModule_GetBlockedClientPrivateData`
245    pub(crate) fn get_blocked_client_private_data(&self) -> *mut c_void {
246        unsafe { raw::RedisModule_GetBlockedClientPrivateData.unwrap()(self.ctx) }
247    }
248}