zenoh_plugin_rest/
lib.rs

1//
2// Copyright (c) 2023 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14
15//! ⚠️ WARNING ⚠️
16//!
17//! This crate is intended for Zenoh's internal use.
18//!
19//! [Click here for Zenoh's documentation](https://docs.rs/zenoh/latest/zenoh)
20use std::{
21    borrow::Cow,
22    convert::TryFrom,
23    future::Future,
24    str::FromStr,
25    sync::{
26        atomic::{AtomicUsize, Ordering},
27        Arc,
28    },
29    time::Duration,
30};
31
32use base64::Engine;
33use futures::StreamExt;
34use http_types::Method;
35use serde::{Deserialize, Serialize};
36use tide::{http::Mime, sse::Sender, Request, Response, Server, StatusCode};
37use tokio::{task::JoinHandle, time::timeout};
38use zenoh::{
39    bytes::{Encoding, ZBytes},
40    internal::{
41        bail,
42        plugins::{RunningPluginTrait, ZenohPlugin},
43        runtime::DynamicRuntime,
44        zerror,
45    },
46    key_expr::{keyexpr, KeyExpr},
47    query::{Parameters, QueryConsolidation, Reply, Selector, ZenohParameters},
48    sample::{Sample, SampleKind},
49    session::Session,
50    Result as ZResult,
51};
52use zenoh_plugin_trait::{plugin_long_version, plugin_version, Plugin, PluginControl};
53
54mod config;
55pub use config::Config;
56use zenoh::query::ReplyError;
57
58const GIT_VERSION: &str = git_version::git_version!(prefix = "v", cargo_prefix = "v");
59lazy_static::lazy_static! {
60    static ref LONG_VERSION: String = format!("{} built with {}", GIT_VERSION, env!("RUSTC_VERSION"));
61}
62const RAW_KEY: &str = "_raw";
63
64lazy_static::lazy_static! {
65    static ref WORKER_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_WORK_THREAD_NUM);
66    static ref MAX_BLOCK_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_MAX_BLOCK_THREAD_NUM);
67    // The global runtime is used in the dynamic plugins, which we can't get the current runtime
68    static ref TOKIO_RUNTIME: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread()
69               .worker_threads(WORKER_THREAD_NUM.load(Ordering::SeqCst))
70               .max_blocking_threads(MAX_BLOCK_THREAD_NUM.load(Ordering::SeqCst))
71               .enable_all()
72               .build()
73               .expect("Unable to create runtime");
74}
75
76#[inline(always)]
77pub(crate) fn blockon_runtime<F: Future>(task: F) -> F::Output {
78    // Check whether able to get the current runtime
79    match tokio::runtime::Handle::try_current() {
80        Ok(rt) => {
81            // Able to get the current runtime (standalone binary), use the current runtime
82            tokio::task::block_in_place(|| rt.block_on(task))
83        }
84        Err(_) => {
85            // Unable to get the current runtime (dynamic plugins), reuse the global runtime
86            tokio::task::block_in_place(|| TOKIO_RUNTIME.block_on(task))
87        }
88    }
89}
90
91pub(crate) fn spawn_runtime<F>(task: F) -> JoinHandle<F::Output>
92where
93    F: Future + Send + 'static,
94    F::Output: Send + 'static,
95{
96    // Check whether able to get the current runtime
97    match tokio::runtime::Handle::try_current() {
98        Ok(rt) => {
99            // Able to get the current runtime (standalone binary), spawn on the current runtime
100            rt.spawn(task)
101        }
102        Err(_) => {
103            // Unable to get the current runtime (dynamic plugins), spawn on the global runtime
104            TOKIO_RUNTIME.spawn(task)
105        }
106    }
107}
108
109#[derive(Serialize, Deserialize)]
110struct JSONSample {
111    key: String,
112    value: serde_json::Value,
113    encoding: String,
114    timestamp: Option<String>,
115}
116
117pub fn base64_encode(data: &[u8]) -> String {
118    use base64::engine::general_purpose;
119    general_purpose::STANDARD.encode(data)
120}
121
122fn payload_to_json(payload: &ZBytes, encoding: &Encoding) -> serde_json::Value {
123    if payload.is_empty() {
124        return serde_json::Value::Null;
125    }
126    match encoding {
127        // If it is a JSON try to deserialize as json, if it fails fallback to base64
128        &Encoding::APPLICATION_JSON | &Encoding::TEXT_JSON | &Encoding::TEXT_JSON5 => {
129            let bytes = payload.to_bytes();
130            serde_json::from_slice(&bytes).unwrap_or_else(|e| {
131                tracing::warn!(
132                    "Encoding is JSON but data is not JSON, converting to base64, Error: {e:?}"
133                );
134                serde_json::Value::String(base64_encode(&bytes))
135            })
136        }
137        &Encoding::TEXT_PLAIN | &Encoding::ZENOH_STRING => serde_json::Value::String(
138            String::from_utf8(payload.to_bytes().into_owned()).unwrap_or_else(|e| {
139                tracing::warn!(
140                    "Encoding is String but data is not String, converting to base64, Error: {e:?}"
141                );
142                base64_encode(e.as_bytes())
143            }),
144        ),
145        // otherwise convert to JSON string
146        _ => serde_json::Value::String(base64_encode(&payload.to_bytes())),
147    }
148}
149
150fn sample_to_json(sample: &Sample) -> JSONSample {
151    JSONSample {
152        key: sample.key_expr().as_str().to_string(),
153        value: payload_to_json(sample.payload(), sample.encoding()),
154        encoding: sample.encoding().to_string(),
155        timestamp: sample.timestamp().map(|ts| ts.to_string()),
156    }
157}
158
159fn result_to_json(sample: Result<&Sample, &ReplyError>) -> JSONSample {
160    match sample {
161        Ok(sample) => sample_to_json(sample),
162        Err(err) => JSONSample {
163            key: "ERROR".into(),
164            value: payload_to_json(err.payload(), err.encoding()),
165            encoding: err.encoding().to_string(),
166            timestamp: None,
167        },
168    }
169}
170
171async fn to_json(results: flume::Receiver<Reply>) -> String {
172    let values = results
173        .stream()
174        .filter_map(move |reply| async move { Some(result_to_json(reply.result())) })
175        .collect::<Vec<JSONSample>>()
176        .await;
177
178    serde_json::to_string(&values).unwrap_or("[]".into())
179}
180
181async fn to_json_response(results: flume::Receiver<Reply>) -> Response {
182    response(StatusCode::Ok, "application/json", &to_json(results).await)
183}
184
185fn sample_to_html(sample: &Sample) -> String {
186    format!(
187        "<dt>{}</dt>\n<dd>{}</dd>\n",
188        sample.key_expr().as_str(),
189        sample.payload().try_to_string().unwrap_or_default()
190    )
191}
192
193fn result_to_html(sample: Result<&Sample, &ReplyError>) -> String {
194    match sample {
195        Ok(sample) => sample_to_html(sample),
196        Err(err) => {
197            format!(
198                "<dt>ERROR</dt>\n<dd>{}</dd>\n",
199                err.payload().try_to_string().unwrap_or_default()
200            )
201        }
202    }
203}
204
205async fn to_html(results: flume::Receiver<Reply>) -> String {
206    let values = results
207        .stream()
208        .filter_map(move |reply| async move { Some(result_to_html(reply.result())) })
209        .collect::<Vec<String>>()
210        .await
211        .join("\n");
212    format!("<dl>\n{values}\n</dl>\n")
213}
214
215async fn to_html_response(results: flume::Receiver<Reply>) -> Response {
216    response(StatusCode::Ok, "text/html", &to_html(results).await)
217}
218
219async fn to_raw_response(results: flume::Receiver<Reply>) -> Response {
220    match results.recv_async().await {
221        Ok(reply) => match reply.result() {
222            Ok(sample) => response(
223                StatusCode::Ok,
224                Cow::from(sample.encoding()).as_ref(),
225                &sample.payload().to_bytes(),
226            ),
227            Err(value) => response(
228                StatusCode::Ok,
229                Cow::from(value.encoding()).as_ref(),
230                &value.payload().to_bytes(),
231            ),
232        },
233        Err(_) => response(StatusCode::Ok, "", ""),
234    }
235}
236
237fn method_to_kind(method: Method) -> SampleKind {
238    match method {
239        Method::Put => SampleKind::Put,
240        Method::Delete => SampleKind::Delete,
241        _ => SampleKind::default(),
242    }
243}
244
245fn response<'a, S: Into<&'a str> + std::fmt::Debug>(
246    status: StatusCode,
247    content_type: S,
248    body: &(impl AsRef<[u8]> + ?Sized),
249) -> Response {
250    let body = body.as_ref();
251    let mut content_type = content_type.into().to_string();
252    tracing::trace!(
253        "Outgoing Response: {status} - {content_type:?} - body: {body}",
254        body = std::str::from_utf8(body).unwrap_or_default(),
255    );
256    let mut builder = Response::builder(status)
257        .header("content-length", body.len().to_string())
258        .header("Access-Control-Allow-Origin", "*")
259        .body(body);
260    for chunk in content_type.split(";") {
261        if let Some((_, header)) = chunk.split_once("content-encoding=") {
262            builder = builder.header("content-encoding", header);
263            content_type = content_type.replace(&[";", chunk].concat(), "");
264            break;
265        }
266    }
267    if let Ok(mime) = Mime::from_str(&content_type) {
268        builder = builder.content_type(mime);
269    }
270    builder.build()
271}
272
273#[cfg(feature = "dynamic_plugin")]
274zenoh_plugin_trait::declare_plugin!(RestPlugin);
275
276pub struct RestPlugin {}
277
278impl ZenohPlugin for RestPlugin {}
279
280impl Plugin for RestPlugin {
281    type StartArgs = DynamicRuntime;
282    type Instance = zenoh::internal::plugins::RunningPlugin;
283    const DEFAULT_NAME: &'static str = "rest";
284    const PLUGIN_VERSION: &'static str = plugin_version!();
285    const PLUGIN_LONG_VERSION: &'static str = plugin_long_version!();
286
287    fn start(
288        name: &str,
289        runtime: &Self::StartArgs,
290    ) -> ZResult<zenoh::internal::plugins::RunningPlugin> {
291        // Try to initiate login.
292        // Required in case of dynamic lib, otherwise no logs.
293        // But cannot be done twice in case of static link.
294        zenoh::init_log_from_env_or("error");
295        tracing::debug!("REST plugin {}", LONG_VERSION.as_str());
296
297        let plugin_conf = runtime
298            .get_config()
299            .get_plugin_config(name)
300            .map_err(|_| zerror!("Plugin `{}`: missing config", name))?;
301
302        let conf: Config = serde_json::from_value(plugin_conf)
303            .map_err(|e| zerror!("Plugin `{}` configuration error: {}", name, e))?;
304        WORKER_THREAD_NUM.store(conf.work_thread_num, Ordering::SeqCst);
305        MAX_BLOCK_THREAD_NUM.store(conf.max_block_thread_num, Ordering::SeqCst);
306
307        let task = run(runtime.clone(), conf.clone());
308        let task =
309            blockon_runtime(async { timeout(Duration::from_millis(1), spawn_runtime(task)).await });
310
311        // The spawn task (spawn_runtime(task)).await) should not return immediately. The server should block inside.
312        // If it returns immediately (for example, address already in use), we can get the error inside Ok
313        if let Ok(Ok(Err(e))) = task {
314            bail!("REST server failed within 1ms: {e}")
315        }
316
317        Ok(Box::new(RunningPlugin(conf)))
318    }
319}
320
321struct RunningPlugin(Config);
322
323impl PluginControl for RunningPlugin {}
324
325impl RunningPluginTrait for RunningPlugin {
326    fn adminspace_getter<'a>(
327        &'a self,
328        key_expr: &'a KeyExpr<'a>,
329        plugin_status_key: &str,
330    ) -> ZResult<Vec<zenoh::internal::plugins::Response>> {
331        let mut responses = Vec::new();
332        let mut key = String::from(plugin_status_key);
333        with_extended_string(&mut key, &["/version"], |key| {
334            if keyexpr::new(key.as_str()).unwrap().intersects(key_expr) {
335                responses.push(zenoh::internal::plugins::Response::new(
336                    key.clone(),
337                    GIT_VERSION.into(),
338                ))
339            }
340        });
341        with_extended_string(&mut key, &["/port"], |port_key| {
342            if keyexpr::new(port_key.as_str())
343                .unwrap()
344                .intersects(key_expr)
345            {
346                responses.push(zenoh::internal::plugins::Response::new(
347                    port_key.clone(),
348                    (&self.0).into(),
349                ))
350            }
351        });
352        Ok(responses)
353    }
354}
355
356fn with_extended_string<R, F: FnMut(&mut String) -> R>(
357    prefix: &mut String,
358    suffixes: &[&str],
359    mut closure: F,
360) -> R {
361    let prefix_len = prefix.len();
362    for suffix in suffixes {
363        prefix.push_str(suffix);
364    }
365    let result = closure(prefix);
366    prefix.truncate(prefix_len);
367    result
368}
369
370async fn query(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
371    tracing::trace!("Incoming GET request: {:?}", req);
372
373    let first_accept = match req.header("accept") {
374        Some(accept) => accept[0]
375            .to_string()
376            .split(';')
377            .next()
378            .unwrap()
379            .split(',')
380            .next()
381            .unwrap()
382            .to_string(),
383        None => "application/json".to_string(),
384    };
385    if first_accept == "text/event-stream" {
386        Ok(tide::sse::upgrade(
387            req,
388            move |req: Request<(Arc<Session>, String)>, sender: Sender| async move {
389                let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
390                    Ok(ke) => ke.into_owned(),
391                    Err(e) => {
392                        return Err(tide::Error::new(
393                            tide::StatusCode::BadRequest,
394                            anyhow::anyhow!("{}", e),
395                        ))
396                    }
397                };
398                spawn_runtime(async move {
399                    tracing::debug!("Subscribe to {} for SSE stream", key_expr);
400                    let sender = &sender;
401                    let sub = req.state().0.declare_subscriber(&key_expr).await.unwrap();
402                    loop {
403                        let sample = sub.recv_async().await.unwrap();
404                        let json_sample =
405                            serde_json::to_string(&sample_to_json(&sample)).unwrap_or("{}".into());
406
407                        match timeout(
408                            std::time::Duration::new(10, 0),
409                            sender.send(&sample.kind().to_string(), json_sample, None),
410                        )
411                        .await
412                        {
413                            Ok(Ok(_)) => {}
414                            Ok(Err(e)) => {
415                                tracing::debug!("SSE error ({})! Unsubscribe and terminate", e);
416                                if let Err(e) = sub.undeclare().await {
417                                    tracing::error!("Error undeclaring subscriber: {}", e);
418                                }
419                                break;
420                            }
421                            Err(_) => {
422                                tracing::debug!("SSE timeout! Unsubscribe and terminate",);
423                                if let Err(e) = sub.undeclare().await {
424                                    tracing::error!("Error undeclaring subscriber: {}", e);
425                                }
426                                break;
427                            }
428                        }
429                    }
430                });
431                Ok(())
432            },
433        ))
434    } else {
435        let body = req.body_bytes().await.unwrap_or_default();
436        let url = req.url();
437        let key_expr = match path_to_key_expr(url.path(), &req.state().1) {
438            Ok(ke) => ke,
439            Err(e) => {
440                return Ok(response(
441                    StatusCode::BadRequest,
442                    "text/plain",
443                    &e.to_string(),
444                ))
445            }
446        };
447        let query_part = url.query();
448        let parameters = Parameters::from(query_part.unwrap_or_default());
449        let consolidation = if parameters.time_range().is_some() {
450            QueryConsolidation::from(zenoh::query::ConsolidationMode::None)
451        } else {
452            QueryConsolidation::from(zenoh::query::ConsolidationMode::Latest)
453        };
454        let raw = parameters.contains_key(RAW_KEY);
455        let mut query = req
456            .state()
457            .0
458            .get(Selector::borrowed(&key_expr, &parameters))
459            .consolidation(consolidation)
460            .with(flume::unbounded());
461        if !body.is_empty() {
462            let encoding: Encoding = req
463                .content_type()
464                .map(|m| Encoding::from(m.to_string()))
465                .unwrap_or_default();
466            query = query.payload(body).encoding(encoding);
467        }
468        match query.await {
469            Ok(receiver) => {
470                if raw {
471                    Ok(to_raw_response(receiver).await)
472                } else if first_accept == "text/html" {
473                    Ok(to_html_response(receiver).await)
474                } else {
475                    Ok(to_json_response(receiver).await)
476                }
477            }
478            Err(e) => Ok(response(
479                StatusCode::InternalServerError,
480                "text/plain",
481                &e.to_string(),
482            )),
483        }
484    }
485}
486
487async fn write(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
488    tracing::trace!("Incoming PUT request: {:?}", req);
489    match req.body_bytes().await {
490        Ok(bytes) => {
491            let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
492                Ok(ke) => ke,
493                Err(e) => {
494                    return Ok(response(
495                        StatusCode::BadRequest,
496                        "text/plain",
497                        &e.to_string(),
498                    ))
499                }
500            };
501
502            let encoding: Encoding = req
503                .content_type()
504                .map(|m| Encoding::from(m.to_string()))
505                .unwrap_or_default();
506
507            // @TODO: Define the right congestion control value
508            let session = &req.state().0;
509            let res = match method_to_kind(req.method()) {
510                SampleKind::Put => session.put(&key_expr, bytes).encoding(encoding).await,
511                SampleKind::Delete => session.delete(&key_expr).await,
512            };
513            match res {
514                Ok(_) => Ok(Response::new(StatusCode::Ok)),
515                Err(e) => Ok(response(
516                    StatusCode::InternalServerError,
517                    "text/plain",
518                    &e.to_string(),
519                )),
520            }
521        }
522        Err(e) => Ok(response(
523            StatusCode::NoContent,
524            "text/plain",
525            &e.to_string(),
526        )),
527    }
528}
529
530pub async fn run(runtime: DynamicRuntime, conf: Config) -> ZResult<()> {
531    // Try to initiate login.
532    // Required in case of dynamic lib, otherwise no logs.
533    // But cannot be done twice in case of static link.
534    zenoh::init_log_from_env_or("error");
535
536    let zid = runtime.zid().to_string();
537    let session = zenoh::session::init(runtime).await.unwrap();
538
539    let mut app = Server::with_state((Arc::new(session), zid));
540    app.with(
541        tide::security::CorsMiddleware::new()
542            .allow_methods(
543                "GET, POST, PUT, PATCH, DELETE"
544                    .parse::<http_types::headers::HeaderValue>()
545                    .unwrap(),
546            )
547            .allow_origin(tide::security::Origin::from("*"))
548            .allow_credentials(false),
549    );
550
551    app.at("/")
552        .get(query)
553        .post(query)
554        .put(write)
555        .patch(write)
556        .delete(write);
557    app.at("*")
558        .get(query)
559        .post(query)
560        .put(write)
561        .patch(write)
562        .delete(write);
563
564    if let Err(e) = app.listen(conf.http_port).await {
565        tracing::error!("Unable to start http server for REST: {:?}", e);
566        return Err(e.into());
567    }
568    Ok(())
569}
570
571fn path_to_key_expr<'a>(path: &'a str, zid: &str) -> ZResult<KeyExpr<'a>> {
572    let path = path.strip_prefix('/').unwrap_or(path);
573    if path == "@/local" {
574        KeyExpr::try_from(format!("@/{zid}"))
575    } else if let Some(suffix) = path.strip_prefix("@/local/") {
576        KeyExpr::try_from(format!("@/{zid}/{suffix}"))
577    } else {
578        KeyExpr::try_from(path)
579    }
580}