picodata_plugin/transport/rpc/
server.rs

1use super::Request;
2use super::Response;
3use crate::internal::ffi;
4use crate::plugin::interface::PicoContext;
5use crate::plugin::interface::ServiceId;
6use crate::transport::context::Context;
7use crate::transport::context::FfiSafeContext;
8#[allow(unused_imports)]
9use crate::transport::rpc::client::RequestBuilder;
10use crate::util::FfiSafeBytes;
11use crate::util::FfiSafeStr;
12use std::mem::MaybeUninit;
13use std::ptr::NonNull;
14use tarantool::error::BoxError;
15use tarantool::error::TarantoolErrorCode;
16
17////////////////////////////////////////////////////////////////////////////////
18// RouteBuilder
19////////////////////////////////////////////////////////////////////////////////
20
21/// A helper struct for declaring RPC endpoints.
22///
23/// See also [`RequestBuilder`] for the client side of the RPC communication.
24#[derive(Debug, Clone)]
25pub struct RouteBuilder<'a> {
26    plugin: &'a str,
27    service: &'a str,
28    version: &'a str,
29    path: Option<&'a str>,
30}
31
32impl<'a> RouteBuilder<'a> {
33    /// A route is required to contain information about the service from which
34    /// it is being registered. Currently it's only possible to automatically
35    /// extract this information from [`PicoContext`].
36    #[inline(always)]
37    pub fn from_pico_context(context: &'a PicoContext) -> Self {
38        Self {
39            plugin: context.plugin_name(),
40            service: context.service_name(),
41            version: context.plugin_version(),
42            path: None,
43        }
44    }
45
46    /// A route is required to contain information about the service from which
47    /// it is being registered. Use this method to specify this info explicitly.
48    /// Consider using [`RouteBuilder::from_pico_context`] instead.
49    ///
50    /// # Safety
51    /// The caller must make sure that the info is the actual service info of
52    /// the currently running service.
53    #[inline(always)]
54    pub unsafe fn from_service_info(plugin: &'a str, service: &'a str, version: &'a str) -> Self {
55        Self {
56            plugin,
57            service,
58            version,
59            path: None,
60        }
61    }
62
63    /// Specify a route path. The path must start with a `'/'` character.
64    #[inline]
65    pub fn path(mut self, path: &'a str) -> Self {
66        if let Some(old) = self.path.take() {
67            #[rustfmt::skip]
68            tarantool::say_warn!("RouteBuilder path is silently changed from {old:?} to {path:?}");
69        }
70        self.path = Some(path);
71        self
72    }
73
74    /// Register the RPC endpoint with the currently chosen parameters and the
75    /// provided handler.
76    ///
77    /// Note that `f` must implement `Fn`. This is required by rust's semantics
78    /// to allow the RPC handlers to yield. If a handler yields then another
79    /// concurrent RPC request may result in the same handler being executed,
80    /// so we must not hold any `&mut` references in those closures (other than
81    /// ones allowed by rust semantics, see official reference on undefined
82    /// behaviour [here]).
83    ///
84    /// Use [`RequestBuilder::send`] to invoke the RPC endpoint registerred with
85    /// this method.
86    ///
87    /// # Local execution
88    ///
89    /// Note that the RPC handler may be invoked locally if the caller specifies
90    /// the request target which matches the current instance. In that case
91    /// the handler is invoked in the same process without yielding from the
92    /// fiber. A special named field `"call_was_local": true` is added to the
93    /// [`Context`] argument of the handler if the call was local. Note that
94    /// this field may or may not be missing in case of non-local call.
95    ///
96    /// [here]: <https://doc.rust-lang.org/reference/behavior-considered-undefined.html>
97    #[track_caller]
98    pub fn register<F>(self, f: F) -> Result<(), BoxError>
99    where
100        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
101    {
102        let Some(path) = self.path else {
103            #[rustfmt::skip]
104            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, "path must be specified for RPC endpoint"));
105        };
106
107        let identifier =
108            PackedServiceIdentifier::pack(path, self.plugin, self.service, self.version)?;
109        let handler = FfiRpcHandler::new(identifier, f);
110        if let Err(e) = register_rpc_handler(handler) {
111            // Note: recreating the error to capture the caller's source location
112            #[rustfmt::skip]
113            return Err(BoxError::new(e.error_code(), e.message()));
114        }
115
116        Ok(())
117    }
118}
119
120impl<'a> From<&'a PicoContext> for RouteBuilder<'a> {
121    #[inline(always)]
122    fn from(context: &'a PicoContext) -> Self {
123        Self::from_pico_context(context)
124    }
125}
126
127////////////////////////////////////////////////////////////////////////////////
128// ffi wrappers
129////////////////////////////////////////////////////////////////////////////////
130
131/// **For internal use**.
132#[inline]
133fn register_rpc_handler(handler: FfiRpcHandler) -> Result<(), BoxError> {
134    // This is safe.
135    let rc = unsafe { ffi::pico_ffi_register_rpc_handler(handler) };
136    if rc == -1 {
137        return Err(BoxError::last());
138    }
139
140    Ok(())
141}
142
143type RpcHandlerCallback = extern "C" fn(
144    handler: *const FfiRpcHandler,
145    input: FfiSafeBytes,
146    context: *const FfiSafeContext,
147    output: *mut FfiSafeBytes,
148) -> std::ffi::c_int;
149
150/// **For internal use**.
151///
152/// Use [`RouteBuilder`] instead.
153#[repr(C)]
154pub struct FfiRpcHandler {
155    /// The result data must either be statically allocated, or allocated using
156    /// the fiber region allocator (see [`box_region_alloc`]).
157    ///
158    /// [`box_region_alloc`]: tarantool::ffi::tarantool::box_region_alloc
159    callback: RpcHandlerCallback,
160    drop: extern "C" fn(*mut FfiRpcHandler),
161
162    /// The pointer to the closure object.
163    ///
164    /// Note that the pointer must be `mut` because we will at some point drop the data pointed to by it.
165    /// But when calling the closure, the `const` pointer should be used.
166    closure_pointer: *mut (),
167
168    /// This data is owned by this struct (freed on drop).
169    pub identifier: PackedServiceIdentifier,
170}
171
172impl Drop for FfiRpcHandler {
173    #[inline(always)]
174    fn drop(&mut self) {
175        (self.drop)(self)
176    }
177}
178
179impl FfiRpcHandler {
180    fn new<F>(identifier: PackedServiceIdentifier, f: F) -> Self
181    where
182        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
183    {
184        let closure = Box::new(f);
185        let closure_pointer: *mut F = Box::into_raw(closure);
186
187        FfiRpcHandler {
188            callback: Self::trampoline::<F>,
189            drop: Self::drop_handler::<F>,
190            closure_pointer: closure_pointer.cast(),
191
192            identifier,
193        }
194    }
195
196    extern "C" fn trampoline<F>(
197        handler: *const FfiRpcHandler,
198        input: FfiSafeBytes,
199        context: *const FfiSafeContext,
200        output: *mut FfiSafeBytes,
201    ) -> std::ffi::c_int
202    where
203        F: Fn(Request<'_>, &mut Context) -> Result<Response, BoxError> + 'static,
204    {
205        // This is safe. To verify see `register_rpc_handler` above.
206        let closure_pointer: *const F = unsafe { (*handler).closure_pointer.cast::<F>() };
207        let closure = unsafe { &*closure_pointer };
208        let input = unsafe { input.as_bytes() };
209        let context = unsafe { &*context };
210        let mut context = Context::new(context);
211
212        let request = Request::from_bytes(input);
213        let result = (|| {
214            let response = closure(request, &mut context)?;
215            response.to_region_slice()
216        })();
217
218        match result {
219            Ok(region_slice) => {
220                // This is safe. To verify see `FfiRpcHandler::call` bellow.
221                unsafe { std::ptr::write(output, region_slice.into()) }
222
223                0
224            }
225            Err(e) => {
226                e.set_last();
227                -1
228            }
229        }
230    }
231
232    extern "C" fn drop_handler<F>(handler: *mut FfiRpcHandler) {
233        unsafe {
234            let closure_pointer: *mut F = (*handler).closure_pointer.cast::<F>();
235            let closure = Box::from_raw(closure_pointer);
236            drop(closure);
237
238            if cfg!(debug_assertions) {
239                // Overwrite the pointer with garbage so that we fail loudly is case of a bug
240                (*handler).closure_pointer = 0xcccccccccccccccc_u64 as _;
241            }
242
243            (*handler).identifier.drop();
244        }
245    }
246
247    #[inline(always)]
248    pub fn identity(&self) -> usize {
249        self.callback as *const RpcHandlerCallback as _
250    }
251
252    #[inline(always)]
253    #[allow(clippy::result_unit_err)]
254    pub fn call(&self, input: &[u8], context: &FfiSafeContext) -> Result<&'static [u8], ()> {
255        let mut output = MaybeUninit::uninit();
256
257        let rc = (self.callback)(self, input.into(), context, output.as_mut_ptr());
258        if rc == -1 {
259            // Actual error is passed through tarantool. Can't return BoxError
260            // here, because tarantool-module version may be different in picodata.
261            return Err(());
262        }
263
264        // This is safe. To verify see `trampoline` above.
265        let result = unsafe { output.assume_init().as_bytes() };
266
267        Ok(result)
268    }
269}
270
271/// **For internal use**.
272///
273/// Use [`RouteBuilder`] instead.
274///
275/// This struct stores an RPC route identifier in the following packed format:
276/// `{plugin}.{service}{path}{version}`. This format allows for efficient
277/// extraction of the RPC route identifier for purposes of logging (note that
278/// version is not displayed), and also losslessly stores info about all the
279/// parts of the idenifier.
280///
281/// This represnetation also adds a constraint on the maximum length of the
282/// plugin name, service name and path, each one of them must not be longer than
283/// 65535 bytes (which is obviously engough for anybody).
284#[repr(C)]
285#[derive(Debug, Default, Clone, Copy)]
286pub struct PackedServiceIdentifier {
287    pub storage: FfiSafeStr,
288    pub plugin_len: u16,
289    pub service_len: u16,
290    pub path_len: u16,
291    pub version_len: u16,
292}
293
294impl PackedServiceIdentifier {
295    pub(crate) fn pack(
296        path: &str,
297        plugin: &str,
298        service: &str,
299        version: &str,
300    ) -> Result<Self, BoxError> {
301        let Ok(plugin_len) = plugin.len().try_into() else {
302            #[rustfmt::skip]
303            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("plugin name length must not exceed 65535, got {}", plugin.len())));
304        };
305        let Ok(service_len) = service.len().try_into() else {
306            #[rustfmt::skip]
307            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("service name length must not exceed 65535, got {}", service.len())));
308        };
309        let Ok(path_len) = path.len().try_into() else {
310            #[rustfmt::skip]
311            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("route path length must not exceed 65535, got {}", path.len())));
312        };
313        let Ok(version_len) = version.len().try_into() else {
314            #[rustfmt::skip]
315            return Err(BoxError::new(TarantoolErrorCode::IllegalParams, format!("version string length must not exceed 65535, got {}", version.len())));
316        };
317
318        let total_string_len = plugin_len
319            // For an extra '.' between plugin and service names
320            + 1
321            + service_len
322            + path_len
323            + version_len;
324        let mut string_storage = Vec::with_capacity(total_string_len as _);
325        string_storage.extend_from_slice(plugin.as_bytes());
326        string_storage.push(b'.');
327        string_storage.extend_from_slice(service.as_bytes());
328        string_storage.extend_from_slice(path.as_bytes());
329        string_storage.extend_from_slice(version.as_bytes());
330
331        let start = string_storage.as_mut_ptr();
332        let capacity = string_storage.capacity();
333
334        // Safety: vec has an allocated buffer, so the pointer cannot be null.
335        // Also a concatenation of utf8 strings is always a utf8 string.
336        let storage =
337            unsafe { FfiSafeStr::from_raw_parts(NonNull::new_unchecked(start), capacity) };
338
339        // Self now owns this data and will be freeing it in it's `drop`.
340        std::mem::forget(string_storage);
341
342        Ok(Self {
343            storage,
344            plugin_len,
345            service_len,
346            path_len,
347            version_len,
348        })
349    }
350
351    #[allow(unreachable_code)]
352    pub(crate) fn drop(&mut self) {
353        let (pointer, capacity) = self.storage.into_raw_parts();
354        if capacity == 0 {
355            #[cfg(debug_assertions)]
356            unreachable!("drop should only be called once");
357            return;
358        }
359
360        // Note: we pretend the original Vec was filled to capacity which
361        // may or may not be true, there might have been some unitialized
362        // data at the end. But this doesn't matter in this case because we
363        // just want to drop the data, and only the capacity matters.
364        // Safety: safe because drop only happens once and the next time the
365        // pointer will be set to null.
366        unsafe {
367            let string_storage = Vec::from_raw_parts(pointer, capacity, capacity);
368            drop(string_storage);
369        }
370        // Overwrite with len = 0, to guard against double free
371        self.storage = FfiSafeStr::from("");
372    }
373
374    #[inline(always)]
375    fn storage_slice(&self, start: u16, len: u16) -> &str {
376        // SAFETY: data is alive for the lifetime of `&self`, and borrow checker does it's thing
377        let storage = unsafe { self.storage.as_str() };
378        let end = (start + len) as usize;
379        &storage[start as usize..end]
380    }
381
382    #[inline(always)]
383    pub fn plugin(&self) -> &str {
384        self.storage_slice(0, self.plugin_len)
385    }
386
387    #[inline(always)]
388    pub fn service(&self) -> &str {
389        self.storage_slice(self.plugin_len + 1, self.service_len)
390    }
391
392    #[inline(always)]
393    pub fn service_id(&self) -> ServiceId {
394        ServiceId::new(self.plugin(), self.service(), self.version())
395    }
396
397    #[inline(always)]
398    pub fn path(&self) -> &str {
399        self.storage_slice(self.plugin_len + 1 + self.service_len, self.path_len)
400    }
401
402    #[inline(always)]
403    pub fn route_repr(&self) -> &str {
404        self.storage_slice(0, self.plugin_len + 1 + self.service_len + self.path_len)
405    }
406
407    /// Returns plugin version.
408    #[inline(always)]
409    pub fn version(&self) -> &str {
410        self.storage_slice(
411            self.plugin_len + 1 + self.service_len + self.path_len,
412            self.version_len,
413        )
414    }
415}
416
417impl std::fmt::Display for PackedServiceIdentifier {
418    #[inline(always)]
419    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
420        write!(
421            f,
422            "{}.{}:v{}{}",
423            self.plugin(),
424            self.service(),
425            self.version(),
426            self.path()
427        )
428    }
429}