picodata_plugin/transport/rpc/
server.rs

1use super::Request;
2use super::Response;
3use crate::internal::ffi;
4use crate::plugin::interface::PicoContext;
5use crate::plugin::interface::ServiceId;
6use crate::transport::context::Context;
7use crate::transport::context::FfiSafeContext;
8use crate::util::FfiSafeBytes;
9use crate::util::FfiSafeStr;
10use std::mem::MaybeUninit;
11use std::ptr::NonNull;
12use tarantool::error::BoxError;
13use tarantool::error::TarantoolErrorCode;
14
15////////////////////////////////////////////////////////////////////////////////
16// RouteBuilder
17////////////////////////////////////////////////////////////////////////////////
18
19/// A helper struct for declaring RPC endpoints.
20#[derive(Debug, Clone)]
21pub struct RouteBuilder<'a> {
22    plugin: &'a str,
23    service: &'a str,
24    version: &'a str,
25    path: Option<&'a str>,
26}
27
28impl<'a> RouteBuilder<'a> {
29    /// A route is required to contain information about the service from which
30    /// it is being registered. Currently it's only possible to automatically
31    /// extract this information from [`PicoContext`].
32    #[inline(always)]
33    pub fn from_pico_context(context: &'a PicoContext) -> Self {
34        Self {
35            plugin: context.plugin_name(),
36            service: context.service_name(),
37            version: context.plugin_version(),
38            path: None,
39        }
40    }
41
42    /// A route is required to contain information about the service from which
43    /// it is being registered. Use this method to specify this info explicitly.
44    /// Consider using [`RouteBuilder::from_pico_context`] instead.
45    ///
46    /// # Safety
47    /// The caller must make sure that the info is the actual service info of
48    /// the currently running service.
49    #[inline(always)]
50    pub unsafe fn from_service_info(plugin: &'a str, service: &'a str, version: &'a str) -> Self {
51        Self {
52            plugin,
53            service,
54            version,
55            path: None,
56        }
57    }
58
59    /// Specify a route path. The path must start with a `'/'` character.
60    #[inline]
61    pub fn path(mut self, path: &'a str) -> Self {
62        if let Some(old) = self.path.take() {
63            #[rustfmt::skip]
64            tarantool::say_warn!("RouteBuilder path is silently changed from {old:?} to {path:?}");
65        }
66        self.path = Some(path);
67        self
68    }
69
70    /// Register the RPC endpoint with the currently chosen parameters and the
71    /// provided handler.
72    ///
73    /// Note that `f` must implement `Fn`. This is required by rust's semantics
74    /// to allow the RPC handlers to yield. If a handler yields then another
75    /// concurrent RPC request may result in the same handler being executed,
76    /// so we must not hold any `&mut` references in those closures (other than
77    /// ones allowed by rust semantics, see official reference on undefined
78    /// behaviour [here]).
79    ///
80    /// [here]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
81    #[track_caller]
82    pub fn register<F>(self, f: F) -> Result<(), BoxError>
83    where
84        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
85    {
86        let Some(path) = self.path else {
87            #[rustfmt::skip]
88            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, "path must be specified for RPC endpoint"));
89        };
90
91        let identifier = PackedServiceIdentifier::pack(
92            path.into(),
93            self.plugin.into(),
94            self.service.into(),
95            self.version.into(),
96        )?;
97        let handler = FfiRpcHandler::new(identifier, f);
98        if let Err(e) = register_rpc_handler(handler) {
99            // Note: recreating the error to capture the caller's source location
100            #[rustfmt::skip]
101            return Err(BoxError::new(e.error_code(), e.message()));
102        }
103
104        Ok(())
105    }
106}
107
108impl<'a> From<&'a PicoContext> for RouteBuilder<'a> {
109    #[inline(always)]
110    fn from(context: &'a PicoContext) -> Self {
111        Self::from_pico_context(context)
112    }
113}
114
115////////////////////////////////////////////////////////////////////////////////
116// ffi wrappers
117////////////////////////////////////////////////////////////////////////////////
118
119/// **For internal use**.
120#[inline]
121fn register_rpc_handler(handler: FfiRpcHandler) -> Result<(), BoxError> {
122    // This is safe.
123    let rc = unsafe { ffi::pico_ffi_register_rpc_handler(handler) };
124    if rc == -1 {
125        return Err(BoxError::last());
126    }
127
128    return Ok(());
129}
130
131type RpcHandlerCallback = extern "C" fn(
132    handler: *const FfiRpcHandler,
133    input: FfiSafeBytes,
134    context: *const FfiSafeContext,
135    output: *mut FfiSafeBytes,
136) -> std::ffi::c_int;
137
138/// **For internal use**.
139///
140/// Use [`RouteBuilder`] instead.
141#[repr(C)]
142pub struct FfiRpcHandler {
143    /// The result data must either be statically allocated, or allocated using
144    /// the fiber region allocator (see [`box_region_alloc`]).
145    ///
146    /// [`box_region_alloc`]: tarantool::ffi::tarantool::box_region_alloc
147    callback: RpcHandlerCallback,
148    drop: extern "C" fn(*mut FfiRpcHandler),
149
150    /// The pointer to the closure object.
151    ///
152    /// Note that the pointer must be `mut` because we will at some point drop the data pointed to by it.
153    /// But when calling the closure, the `const` pointer should be used.
154    closure_pointer: *mut (),
155
156    /// This data is owned by this struct (freed on drop).
157    pub identifier: PackedServiceIdentifier,
158}
159
160impl Drop for FfiRpcHandler {
161    #[inline(always)]
162    fn drop(&mut self) {
163        (self.drop)(self)
164    }
165}
166
167impl FfiRpcHandler {
168    fn new<F>(identifier: PackedServiceIdentifier, f: F) -> Self
169    where
170        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
171    {
172        let closure = Box::new(f);
173        let closure_pointer: *mut F = Box::into_raw(closure);
174
175        FfiRpcHandler {
176            callback: Self::trampoline::<F>,
177            drop: Self::drop_handler::<F>,
178            closure_pointer: closure_pointer.cast(),
179
180            identifier,
181        }
182    }
183
184    extern "C" fn trampoline<F>(
185        handler: *const FfiRpcHandler,
186        input: FfiSafeBytes,
187        context: *const FfiSafeContext,
188        output: *mut FfiSafeBytes,
189    ) -> std::ffi::c_int
190    where
191        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
192    {
193        // This is safe. To verify see `register_rpc_handler` above.
194        let closure_pointer: *const F = unsafe { (*handler).closure_pointer.cast::<F>() };
195        let closure = unsafe { &*closure_pointer };
196        let input = unsafe { input.as_bytes() };
197        let context = unsafe { &*context };
198        let mut context = Context::new(context);
199
200        let request = Request::from_bytes(input);
201        let result = (|| {
202            let response = closure(request, &mut context)?;
203            response.to_region_slice()
204        })();
205
206        match result {
207            Ok(region_slice) => {
208                // This is safe. To verify see `FfiRpcHandler::call` bellow.
209                unsafe { std::ptr::write(output, region_slice.into()) }
210
211                return 0;
212            }
213            Err(e) => {
214                e.set_last();
215                return -1;
216            }
217        }
218    }
219
220    extern "C" fn drop_handler<F>(handler: *mut FfiRpcHandler) {
221        unsafe {
222            let closure_pointer: *mut F = (*handler).closure_pointer.cast::<F>();
223            let closure = Box::from_raw(closure_pointer);
224            drop(closure);
225
226            if cfg!(debug_assertions) {
227                // Overwrite the pointer with garbage so that we fail loudly is case of a bug
228                (*handler).closure_pointer = 0xcccccccccccccccc_u64 as _;
229            }
230
231            (*handler).identifier.drop();
232        }
233    }
234
235    #[inline(always)]
236    pub fn identity(&self) -> usize {
237        self.callback as *const RpcHandlerCallback as _
238    }
239
240    #[inline(always)]
241    pub fn call(&self, input: &[u8], context: &FfiSafeContext) -> Result<&'static [u8], ()> {
242        let mut output = MaybeUninit::uninit();
243
244        let rc = (self.callback)(self, input.into(), context, output.as_mut_ptr());
245        if rc == -1 {
246            // Actual error is passed through tarantool. Can't return BoxError
247            // here, because tarantool-module version may be different in picodata.
248            return Err(());
249        }
250
251        // This is safe. To verify see `trampoline` above.
252        let result = unsafe { output.assume_init().as_bytes() };
253
254        Ok(result)
255    }
256}
257
258/// **For internal use**.
259///
260/// Use [`RouteBuilder`] instead.
261///
262/// This struct stores an RPC route identifier in the following packed format:
263/// `{plugin}.{service}{path}{version}`. This format allows for efficient
264/// extraction of the RPC route identifier for purposes of logging (note that
265/// version is not displayed), and also losslessly stores info about all the
266/// parts of the idenifier.
267///
268/// This represnetation also adds a constraint on the maximum length of the
269/// plugin name, service name and path, each one of them must not be longer than
270/// 65535 bytes (which is obviously engough for anybody).
271#[repr(C)]
272#[derive(Debug, Default, Clone, Copy)]
273pub struct PackedServiceIdentifier {
274    pub storage: FfiSafeStr,
275    pub plugin_len: u16,
276    pub service_len: u16,
277    pub path_len: u16,
278    pub version_len: u16,
279}
280
281impl PackedServiceIdentifier {
282    pub(crate) fn pack(
283        path: &str,
284        plugin: &str,
285        service: &str,
286        version: &str,
287    ) -> Result<Self, BoxError> {
288        let Ok(plugin_len) = plugin.len().try_into() else {
289            #[rustfmt::skip]
290            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("plugin name length must not exceed 65535, got {}", plugin.len())));
291        };
292        let Ok(service_len) = service.len().try_into() else {
293            #[rustfmt::skip]
294            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("service name length must not exceed 65535, got {}", service.len())));
295        };
296        let Ok(path_len) = path.len().try_into() else {
297            #[rustfmt::skip]
298            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("route path length must not exceed 65535, got {}", path.len())));
299        };
300        let Ok(version_len) = version.len().try_into() else {
301            #[rustfmt::skip]
302            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("version string length must not exceed 65535, got {}", version.len())));
303        };
304
305        let total_string_len = plugin_len
306            // For an extra '.' between plugin and service names
307            + 1
308            + service_len
309            + path_len
310            + version_len;
311        let mut string_storage = Vec::with_capacity(total_string_len as _);
312        string_storage.extend_from_slice(plugin.as_bytes());
313        string_storage.push(b'.');
314        string_storage.extend_from_slice(service.as_bytes());
315        string_storage.extend_from_slice(path.as_bytes());
316        string_storage.extend_from_slice(version.as_bytes());
317
318        let start = string_storage.as_mut_ptr();
319        let capacity = string_storage.capacity();
320
321        // Safety: vec has an allocated buffer, so the pointer cannot be null.
322        // Also a concatenation of utf8 strings is always a utf8 string.
323        let storage =
324            unsafe { FfiSafeStr::from_raw_parts(NonNull::new_unchecked(start), capacity) };
325
326        // Self now owns this data and will be freeing it in it's `drop`.
327        std::mem::forget(string_storage);
328
329        Ok(Self {
330            storage,
331            plugin_len,
332            service_len,
333            path_len,
334            version_len,
335        })
336    }
337
338    #[allow(unreachable_code)]
339    pub(crate) fn drop(&mut self) {
340        let (pointer, capacity) = self.storage.into_raw_parts();
341        if capacity == 0 {
342            #[cfg(debug_assertions)]
343            unreachable!("drop should only be called once");
344            return;
345        }
346
347        // Note: we pretend the original Vec was filled to capacity which
348        // may or may not be true, there might have been some unitialized
349        // data at the end. But this doesn't matter in this case because we
350        // just want to drop the data, and only the capacity matters.
351        // Safety: safe because drop only happens once and the next time the
352        // pointer will be set to null.
353        unsafe {
354            let string_storage = Vec::from_raw_parts(pointer, capacity, capacity);
355            drop(string_storage);
356        }
357        // Overwrite with len = 0, to guard against double free
358        self.storage = FfiSafeStr::from("");
359    }
360
361    #[inline(always)]
362    fn storage_slice(&self, start: u16, len: u16) -> &str {
363        // SAFETY: data is alive for the lifetime of `&self`, and borrow checker does it's thing
364        let storage = unsafe { self.storage.as_str() };
365        let end = (start + len) as usize;
366        &storage[start as usize..end]
367    }
368
369    #[inline(always)]
370    pub fn plugin(&self) -> &str {
371        self.storage_slice(0, self.plugin_len)
372    }
373
374    #[inline(always)]
375    pub fn service(&self) -> &str {
376        self.storage_slice(self.plugin_len + 1, self.service_len)
377    }
378
379    #[inline(always)]
380    pub fn service_id(&self) -> ServiceId {
381        ServiceId::new(self.plugin(), self.service(), self.version())
382    }
383
384    #[inline(always)]
385    pub fn path(&self) -> &str {
386        self.storage_slice(self.plugin_len + 1 + self.service_len, self.path_len)
387    }
388
389    #[inline(always)]
390    pub fn route_repr(&self) -> &str {
391        self.storage_slice(0, self.plugin_len + 1 + self.service_len + self.path_len)
392    }
393
394    /// Returns plugin version.
395    #[inline(always)]
396    pub fn version(&self) -> &str {
397        self.storage_slice(
398            self.plugin_len + 1 + self.service_len + self.path_len,
399            self.version_len,
400        )
401    }
402}
403
404impl std::fmt::Display for PackedServiceIdentifier {
405    #[inline(always)]
406    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
407        write!(
408            f,
409            "{}.{}:v{}{}",
410            self.plugin(),
411            self.service(),
412            self.version(),
413            self.path()
414        )
415    }
416}