picodata_plugin/transport/rpc/
server.rs

1use super::Request;
2use super::Response;
3use crate::internal::ffi;
4use crate::plugin::interface::PicoContext;
5use crate::plugin::interface::ServiceId;
6use crate::transport::context::Context;
7use crate::transport::context::FfiSafeContext;
8#[allow(unused_imports)]
9use crate::transport::rpc::client::RequestBuilder;
10use crate::util::FfiSafeBytes;
11use crate::util::FfiSafeStr;
12use std::mem::MaybeUninit;
13use std::ptr::NonNull;
14use tarantool::error::BoxError;
15use tarantool::error::TarantoolErrorCode;
16
17////////////////////////////////////////////////////////////////////////////////
18// RouteBuilder
19////////////////////////////////////////////////////////////////////////////////
20
21/// A helper struct for declaring RPC endpoints.
22///
23/// See also [`RequestBuilder`] for the client side of the RPC communication.
24#[derive(Debug, Clone)]
25pub struct RouteBuilder<'a> {
26    plugin: &'a str,
27    service: &'a str,
28    version: &'a str,
29    path: Option<&'a str>,
30}
31
32impl<'a> RouteBuilder<'a> {
33    /// A route is required to contain information about the service from which
34    /// it is being registered. Currently it's only possible to automatically
35    /// extract this information from [`PicoContext`].
36    #[inline(always)]
37    pub fn from_pico_context(context: &'a PicoContext) -> Self {
38        Self {
39            plugin: context.plugin_name(),
40            service: context.service_name(),
41            version: context.plugin_version(),
42            path: None,
43        }
44    }
45
46    /// A route is required to contain information about the service from which
47    /// it is being registered. Use this method to specify this info explicitly.
48    /// Consider using [`RouteBuilder::from_pico_context`] instead.
49    ///
50    /// # Safety
51    /// The caller must make sure that the info is the actual service info of
52    /// the currently running service.
53    #[inline(always)]
54    pub unsafe fn from_service_info(plugin: &'a str, service: &'a str, version: &'a str) -> Self {
55        Self {
56            plugin,
57            service,
58            version,
59            path: None,
60        }
61    }
62
63    /// Specify a route path. The path must start with a `'/'` character.
64    #[inline]
65    pub fn path(mut self, path: &'a str) -> Self {
66        if let Some(old) = self.path.take() {
67            #[rustfmt::skip]
68            tarantool::say_warn!("RouteBuilder path is silently changed from {old:?} to {path:?}");
69        }
70        self.path = Some(path);
71        self
72    }
73
74    /// Register the RPC endpoint with the currently chosen parameters and the
75    /// provided handler.
76    ///
77    /// Note that `f` must implement `Fn`. This is required by rust's semantics
78    /// to allow the RPC handlers to yield. If a handler yields then another
79    /// concurrent RPC request may result in the same handler being executed,
80    /// so we must not hold any `&mut` references in those closures (other than
81    /// ones allowed by rust semantics, see official reference on undefined
82    /// behaviour [here]).
83    ///
84    /// Use [`RequestBuilder::send`] to invoke the RPC endpoint registerred with
85    /// this method.
86    ///
87    /// # Local execution
88    ///
89    /// Note that the RPC handler may be invoked locally if the caller specifies
90    /// the request target which matches the current instance. In that case
91    /// the handler is invoked in the same process without yielding from the
92    /// fiber. A special named field `"call_was_local": true` is added to the
93    /// [`Context`] argument of the handler if the call was local. Note that
94    /// this field may or may not be missing in case of non-local call.
95    ///
96    /// [here]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
97    #[track_caller]
98    pub fn register<F>(self, f: F) -> Result<(), BoxError>
99    where
100        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
101    {
102        let Some(path) = self.path else {
103            #[rustfmt::skip]
104            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, "path must be specified for RPC endpoint"));
105        };
106
107        let identifier = PackedServiceIdentifier::pack(
108            path.into(),
109            self.plugin.into(),
110            self.service.into(),
111            self.version.into(),
112        )?;
113        let handler = FfiRpcHandler::new(identifier, f);
114        if let Err(e) = register_rpc_handler(handler) {
115            // Note: recreating the error to capture the caller's source location
116            #[rustfmt::skip]
117            return Err(BoxError::new(e.error_code(), e.message()));
118        }
119
120        Ok(())
121    }
122}
123
124impl<'a> From<&'a PicoContext> for RouteBuilder<'a> {
125    #[inline(always)]
126    fn from(context: &'a PicoContext) -> Self {
127        Self::from_pico_context(context)
128    }
129}
130
131////////////////////////////////////////////////////////////////////////////////
132// ffi wrappers
133////////////////////////////////////////////////////////////////////////////////
134
135/// **For internal use**.
136#[inline]
137fn register_rpc_handler(handler: FfiRpcHandler) -> Result<(), BoxError> {
138    // This is safe.
139    let rc = unsafe { ffi::pico_ffi_register_rpc_handler(handler) };
140    if rc == -1 {
141        return Err(BoxError::last());
142    }
143
144    return Ok(());
145}
146
147type RpcHandlerCallback = extern "C" fn(
148    handler: *const FfiRpcHandler,
149    input: FfiSafeBytes,
150    context: *const FfiSafeContext,
151    output: *mut FfiSafeBytes,
152) -> std::ffi::c_int;
153
154/// **For internal use**.
155///
156/// Use [`RouteBuilder`] instead.
157#[repr(C)]
158pub struct FfiRpcHandler {
159    /// The result data must either be statically allocated, or allocated using
160    /// the fiber region allocator (see [`box_region_alloc`]).
161    ///
162    /// [`box_region_alloc`]: tarantool::ffi::tarantool::box_region_alloc
163    callback: RpcHandlerCallback,
164    drop: extern "C" fn(*mut FfiRpcHandler),
165
166    /// The pointer to the closure object.
167    ///
168    /// Note that the pointer must be `mut` because we will at some point drop the data pointed to by it.
169    /// But when calling the closure, the `const` pointer should be used.
170    closure_pointer: *mut (),
171
172    /// This data is owned by this struct (freed on drop).
173    pub identifier: PackedServiceIdentifier,
174}
175
176impl Drop for FfiRpcHandler {
177    #[inline(always)]
178    fn drop(&mut self) {
179        (self.drop)(self)
180    }
181}
182
183impl FfiRpcHandler {
184    fn new<F>(identifier: PackedServiceIdentifier, f: F) -> Self
185    where
186        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
187    {
188        let closure = Box::new(f);
189        let closure_pointer: *mut F = Box::into_raw(closure);
190
191        FfiRpcHandler {
192            callback: Self::trampoline::<F>,
193            drop: Self::drop_handler::<F>,
194            closure_pointer: closure_pointer.cast(),
195
196            identifier,
197        }
198    }
199
200    extern "C" fn trampoline<F>(
201        handler: *const FfiRpcHandler,
202        input: FfiSafeBytes,
203        context: *const FfiSafeContext,
204        output: *mut FfiSafeBytes,
205    ) -> std::ffi::c_int
206    where
207        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
208    {
209        // This is safe. To verify see `register_rpc_handler` above.
210        let closure_pointer: *const F = unsafe { (*handler).closure_pointer.cast::<F>() };
211        let closure = unsafe { &*closure_pointer };
212        let input = unsafe { input.as_bytes() };
213        let context = unsafe { &*context };
214        let mut context = Context::new(context);
215
216        let request = Request::from_bytes(input);
217        let result = (|| {
218            let response = closure(request, &mut context)?;
219            response.to_region_slice()
220        })();
221
222        match result {
223            Ok(region_slice) => {
224                // This is safe. To verify see `FfiRpcHandler::call` bellow.
225                unsafe { std::ptr::write(output, region_slice.into()) }
226
227                return 0;
228            }
229            Err(e) => {
230                e.set_last();
231                return -1;
232            }
233        }
234    }
235
236    extern "C" fn drop_handler<F>(handler: *mut FfiRpcHandler) {
237        unsafe {
238            let closure_pointer: *mut F = (*handler).closure_pointer.cast::<F>();
239            let closure = Box::from_raw(closure_pointer);
240            drop(closure);
241
242            if cfg!(debug_assertions) {
243                // Overwrite the pointer with garbage so that we fail loudly is case of a bug
244                (*handler).closure_pointer = 0xcccccccccccccccc_u64 as _;
245            }
246
247            (*handler).identifier.drop();
248        }
249    }
250
251    #[inline(always)]
252    pub fn identity(&self) -> usize {
253        self.callback as *const RpcHandlerCallback as _
254    }
255
256    #[inline(always)]
257    pub fn call(&self, input: &[u8], context: &FfiSafeContext) -> Result<&'static [u8], ()> {
258        let mut output = MaybeUninit::uninit();
259
260        let rc = (self.callback)(self, input.into(), context, output.as_mut_ptr());
261        if rc == -1 {
262            // Actual error is passed through tarantool. Can't return BoxError
263            // here, because tarantool-module version may be different in picodata.
264            return Err(());
265        }
266
267        // This is safe. To verify see `trampoline` above.
268        let result = unsafe { output.assume_init().as_bytes() };
269
270        Ok(result)
271    }
272}
273
274/// **For internal use**.
275///
276/// Use [`RouteBuilder`] instead.
277///
278/// This struct stores an RPC route identifier in the following packed format:
279/// `{plugin}.{service}{path}{version}`. This format allows for efficient
280/// extraction of the RPC route identifier for purposes of logging (note that
281/// version is not displayed), and also losslessly stores info about all the
282/// parts of the idenifier.
283///
284/// This represnetation also adds a constraint on the maximum length of the
285/// plugin name, service name and path, each one of them must not be longer than
286/// 65535 bytes (which is obviously engough for anybody).
287#[repr(C)]
288#[derive(Debug, Default, Clone, Copy)]
289pub struct PackedServiceIdentifier {
290    pub storage: FfiSafeStr,
291    pub plugin_len: u16,
292    pub service_len: u16,
293    pub path_len: u16,
294    pub version_len: u16,
295}
296
297impl PackedServiceIdentifier {
298    pub(crate) fn pack(
299        path: &str,
300        plugin: &str,
301        service: &str,
302        version: &str,
303    ) -> Result<Self, BoxError> {
304        let Ok(plugin_len) = plugin.len().try_into() else {
305            #[rustfmt::skip]
306            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("plugin name length must not exceed 65535, got {}", plugin.len())));
307        };
308        let Ok(service_len) = service.len().try_into() else {
309            #[rustfmt::skip]
310            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("service name length must not exceed 65535, got {}", service.len())));
311        };
312        let Ok(path_len) = path.len().try_into() else {
313            #[rustfmt::skip]
314            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("route path length must not exceed 65535, got {}", path.len())));
315        };
316        let Ok(version_len) = version.len().try_into() else {
317            #[rustfmt::skip]
318            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("version string length must not exceed 65535, got {}", version.len())));
319        };
320
321        let total_string_len = plugin_len
322            // For an extra '.' between plugin and service names
323            + 1
324            + service_len
325            + path_len
326            + version_len;
327        let mut string_storage = Vec::with_capacity(total_string_len as _);
328        string_storage.extend_from_slice(plugin.as_bytes());
329        string_storage.push(b'.');
330        string_storage.extend_from_slice(service.as_bytes());
331        string_storage.extend_from_slice(path.as_bytes());
332        string_storage.extend_from_slice(version.as_bytes());
333
334        let start = string_storage.as_mut_ptr();
335        let capacity = string_storage.capacity();
336
337        // Safety: vec has an allocated buffer, so the pointer cannot be null.
338        // Also a concatenation of utf8 strings is always a utf8 string.
339        let storage =
340            unsafe { FfiSafeStr::from_raw_parts(NonNull::new_unchecked(start), capacity) };
341
342        // Self now owns this data and will be freeing it in it's `drop`.
343        std::mem::forget(string_storage);
344
345        Ok(Self {
346            storage,
347            plugin_len,
348            service_len,
349            path_len,
350            version_len,
351        })
352    }
353
354    #[allow(unreachable_code)]
355    pub(crate) fn drop(&mut self) {
356        let (pointer, capacity) = self.storage.into_raw_parts();
357        if capacity == 0 {
358            #[cfg(debug_assertions)]
359            unreachable!("drop should only be called once");
360            return;
361        }
362
363        // Note: we pretend the original Vec was filled to capacity which
364        // may or may not be true, there might have been some unitialized
365        // data at the end. But this doesn't matter in this case because we
366        // just want to drop the data, and only the capacity matters.
367        // Safety: safe because drop only happens once and the next time the
368        // pointer will be set to null.
369        unsafe {
370            let string_storage = Vec::from_raw_parts(pointer, capacity, capacity);
371            drop(string_storage);
372        }
373        // Overwrite with len = 0, to guard against double free
374        self.storage = FfiSafeStr::from("");
375    }
376
377    #[inline(always)]
378    fn storage_slice(&self, start: u16, len: u16) -> &str {
379        // SAFETY: data is alive for the lifetime of `&self`, and borrow checker does it's thing
380        let storage = unsafe { self.storage.as_str() };
381        let end = (start + len) as usize;
382        &storage[start as usize..end]
383    }
384
385    #[inline(always)]
386    pub fn plugin(&self) -> &str {
387        self.storage_slice(0, self.plugin_len)
388    }
389
390    #[inline(always)]
391    pub fn service(&self) -> &str {
392        self.storage_slice(self.plugin_len + 1, self.service_len)
393    }
394
395    #[inline(always)]
396    pub fn service_id(&self) -> ServiceId {
397        ServiceId::new(self.plugin(), self.service(), self.version())
398    }
399
400    #[inline(always)]
401    pub fn path(&self) -> &str {
402        self.storage_slice(self.plugin_len + 1 + self.service_len, self.path_len)
403    }
404
405    #[inline(always)]
406    pub fn route_repr(&self) -> &str {
407        self.storage_slice(0, self.plugin_len + 1 + self.service_len + self.path_len)
408    }
409
410    /// Returns plugin version.
411    #[inline(always)]
412    pub fn version(&self) -> &str {
413        self.storage_slice(
414            self.plugin_len + 1 + self.service_len + self.path_len,
415            self.version_len,
416        )
417    }
418}
419
420impl std::fmt::Display for PackedServiceIdentifier {
421    #[inline(always)]
422    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
423        write!(
424            f,
425            "{}.{}:v{}{}",
426            self.plugin(),
427            self.service(),
428            self.version(),
429            self.path()
430        )
431    }
432}