1use std::{
21 borrow::Cow,
22 convert::TryFrom,
23 future::Future,
24 str::FromStr,
25 sync::{
26 atomic::{AtomicUsize, Ordering},
27 Arc,
28 },
29 time::Duration,
30};
31
32use base64::Engine;
33use futures::StreamExt;
34use http_types::Method;
35use serde::{Deserialize, Serialize};
36use tide::{http::Mime, sse::Sender, Request, Response, Server, StatusCode};
37use tokio::{task::JoinHandle, time::timeout};
38use zenoh::{
39 bytes::{Encoding, ZBytes},
40 internal::{
41 bail,
42 plugins::{RunningPluginTrait, ZenohPlugin},
43 runtime::DynamicRuntime,
44 zerror,
45 },
46 key_expr::{keyexpr, KeyExpr},
47 query::{Parameters, QueryConsolidation, Reply, Selector, ZenohParameters},
48 sample::{Sample, SampleKind},
49 session::Session,
50 Result as ZResult,
51};
52use zenoh_plugin_trait::{plugin_long_version, plugin_version, Plugin, PluginControl};
53
54mod config;
55pub use config::Config;
56use zenoh::query::ReplyError;
57
58const GIT_VERSION: &str = git_version::git_version!(prefix = "v", cargo_prefix = "v");
59lazy_static::lazy_static! {
60 static ref LONG_VERSION: String = format!("{} built with {}", GIT_VERSION, env!("RUSTC_VERSION"));
61}
62const RAW_KEY: &str = "_raw";
63
64lazy_static::lazy_static! {
65 static ref WORKER_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_WORK_THREAD_NUM);
66 static ref MAX_BLOCK_THREAD_NUM: AtomicUsize = AtomicUsize::new(config::DEFAULT_MAX_BLOCK_THREAD_NUM);
67 static ref TOKIO_RUNTIME: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread()
69 .worker_threads(WORKER_THREAD_NUM.load(Ordering::SeqCst))
70 .max_blocking_threads(MAX_BLOCK_THREAD_NUM.load(Ordering::SeqCst))
71 .enable_all()
72 .build()
73 .expect("Unable to create runtime");
74}
75
76#[inline(always)]
77pub(crate) fn blockon_runtime<F: Future>(task: F) -> F::Output {
78 match tokio::runtime::Handle::try_current() {
80 Ok(rt) => {
81 tokio::task::block_in_place(|| rt.block_on(task))
83 }
84 Err(_) => {
85 tokio::task::block_in_place(|| TOKIO_RUNTIME.block_on(task))
87 }
88 }
89}
90
91pub(crate) fn spawn_runtime<F>(task: F) -> JoinHandle<F::Output>
92where
93 F: Future + Send + 'static,
94 F::Output: Send + 'static,
95{
96 match tokio::runtime::Handle::try_current() {
98 Ok(rt) => {
99 rt.spawn(task)
101 }
102 Err(_) => {
103 TOKIO_RUNTIME.spawn(task)
105 }
106 }
107}
108
109#[derive(Serialize, Deserialize)]
110struct JSONSample {
111 key: String,
112 value: serde_json::Value,
113 encoding: String,
114 timestamp: Option<String>,
115}
116
117pub fn base64_encode(data: &[u8]) -> String {
118 use base64::engine::general_purpose;
119 general_purpose::STANDARD.encode(data)
120}
121
122fn payload_to_json(payload: &ZBytes, encoding: &Encoding) -> serde_json::Value {
123 if payload.is_empty() {
124 return serde_json::Value::Null;
125 }
126 match encoding {
127 &Encoding::APPLICATION_JSON | &Encoding::TEXT_JSON | &Encoding::TEXT_JSON5 => {
129 let bytes = payload.to_bytes();
130 serde_json::from_slice(&bytes).unwrap_or_else(|e| {
131 tracing::warn!(
132 "Encoding is JSON but data is not JSON, converting to base64, Error: {e:?}"
133 );
134 serde_json::Value::String(base64_encode(&bytes))
135 })
136 }
137 &Encoding::TEXT_PLAIN | &Encoding::ZENOH_STRING => serde_json::Value::String(
138 String::from_utf8(payload.to_bytes().into_owned()).unwrap_or_else(|e| {
139 tracing::warn!(
140 "Encoding is String but data is not String, converting to base64, Error: {e:?}"
141 );
142 base64_encode(e.as_bytes())
143 }),
144 ),
145 _ => serde_json::Value::String(base64_encode(&payload.to_bytes())),
147 }
148}
149
150fn sample_to_json(sample: &Sample) -> JSONSample {
151 JSONSample {
152 key: sample.key_expr().as_str().to_string(),
153 value: payload_to_json(sample.payload(), sample.encoding()),
154 encoding: sample.encoding().to_string(),
155 timestamp: sample.timestamp().map(|ts| ts.to_string()),
156 }
157}
158
159fn result_to_json(sample: Result<&Sample, &ReplyError>) -> JSONSample {
160 match sample {
161 Ok(sample) => sample_to_json(sample),
162 Err(err) => JSONSample {
163 key: "ERROR".into(),
164 value: payload_to_json(err.payload(), err.encoding()),
165 encoding: err.encoding().to_string(),
166 timestamp: None,
167 },
168 }
169}
170
171async fn to_json(results: flume::Receiver<Reply>) -> String {
172 let values = results
173 .stream()
174 .filter_map(move |reply| async move { Some(result_to_json(reply.result())) })
175 .collect::<Vec<JSONSample>>()
176 .await;
177
178 serde_json::to_string(&values).unwrap_or("[]".into())
179}
180
181async fn to_json_response(results: flume::Receiver<Reply>) -> Response {
182 response(StatusCode::Ok, "application/json", &to_json(results).await)
183}
184
185fn sample_to_html(sample: &Sample) -> String {
186 format!(
187 "<dt>{}</dt>\n<dd>{}</dd>\n",
188 sample.key_expr().as_str(),
189 sample.payload().try_to_string().unwrap_or_default()
190 )
191}
192
193fn result_to_html(sample: Result<&Sample, &ReplyError>) -> String {
194 match sample {
195 Ok(sample) => sample_to_html(sample),
196 Err(err) => {
197 format!(
198 "<dt>ERROR</dt>\n<dd>{}</dd>\n",
199 err.payload().try_to_string().unwrap_or_default()
200 )
201 }
202 }
203}
204
205async fn to_html(results: flume::Receiver<Reply>) -> String {
206 let values = results
207 .stream()
208 .filter_map(move |reply| async move { Some(result_to_html(reply.result())) })
209 .collect::<Vec<String>>()
210 .await
211 .join("\n");
212 format!("<dl>\n{values}\n</dl>\n")
213}
214
215async fn to_html_response(results: flume::Receiver<Reply>) -> Response {
216 response(StatusCode::Ok, "text/html", &to_html(results).await)
217}
218
219async fn to_raw_response(results: flume::Receiver<Reply>) -> Response {
220 match results.recv_async().await {
221 Ok(reply) => match reply.result() {
222 Ok(sample) => response(
223 StatusCode::Ok,
224 Cow::from(sample.encoding()).as_ref(),
225 &sample.payload().to_bytes(),
226 ),
227 Err(value) => response(
228 StatusCode::Ok,
229 Cow::from(value.encoding()).as_ref(),
230 &value.payload().to_bytes(),
231 ),
232 },
233 Err(_) => response(StatusCode::Ok, "", ""),
234 }
235}
236
237fn method_to_kind(method: Method) -> SampleKind {
238 match method {
239 Method::Put => SampleKind::Put,
240 Method::Delete => SampleKind::Delete,
241 _ => SampleKind::default(),
242 }
243}
244
245fn response<'a, S: Into<&'a str> + std::fmt::Debug>(
246 status: StatusCode,
247 content_type: S,
248 body: &(impl AsRef<[u8]> + ?Sized),
249) -> Response {
250 let body = body.as_ref();
251 let mut content_type = content_type.into().to_string();
252 tracing::trace!(
253 "Outgoing Response: {status} - {content_type:?} - body: {body}",
254 body = std::str::from_utf8(body).unwrap_or_default(),
255 );
256 let mut builder = Response::builder(status)
257 .header("content-length", body.len().to_string())
258 .header("Access-Control-Allow-Origin", "*")
259 .body(body);
260 for chunk in content_type.split(";") {
261 if let Some((_, header)) = chunk.split_once("content-encoding=") {
262 builder = builder.header("content-encoding", header);
263 content_type = content_type.replace(&[";", chunk].concat(), "");
264 break;
265 }
266 }
267 if let Ok(mime) = Mime::from_str(&content_type) {
268 builder = builder.content_type(mime);
269 }
270 builder.build()
271}
272
273#[cfg(feature = "dynamic_plugin")]
274zenoh_plugin_trait::declare_plugin!(RestPlugin);
275
276pub struct RestPlugin {}
277
278impl ZenohPlugin for RestPlugin {}
279
280impl Plugin for RestPlugin {
281 type StartArgs = DynamicRuntime;
282 type Instance = zenoh::internal::plugins::RunningPlugin;
283 const DEFAULT_NAME: &'static str = "rest";
284 const PLUGIN_VERSION: &'static str = plugin_version!();
285 const PLUGIN_LONG_VERSION: &'static str = plugin_long_version!();
286
287 fn start(
288 name: &str,
289 runtime: &Self::StartArgs,
290 ) -> ZResult<zenoh::internal::plugins::RunningPlugin> {
291 zenoh::init_log_from_env_or("error");
295 tracing::debug!("REST plugin {}", LONG_VERSION.as_str());
296
297 let plugin_conf = runtime
298 .get_config()
299 .get_plugin_config(name)
300 .map_err(|_| zerror!("Plugin `{}`: missing config", name))?;
301
302 let conf: Config = serde_json::from_value(plugin_conf)
303 .map_err(|e| zerror!("Plugin `{}` configuration error: {}", name, e))?;
304 WORKER_THREAD_NUM.store(conf.work_thread_num, Ordering::SeqCst);
305 MAX_BLOCK_THREAD_NUM.store(conf.max_block_thread_num, Ordering::SeqCst);
306
307 let task = run(runtime.clone(), conf.clone());
308 let task =
309 blockon_runtime(async { timeout(Duration::from_millis(1), spawn_runtime(task)).await });
310
311 if let Ok(Ok(Err(e))) = task {
314 bail!("REST server failed within 1ms: {e}")
315 }
316
317 Ok(Box::new(RunningPlugin(conf)))
318 }
319}
320
321struct RunningPlugin(Config);
322
323impl PluginControl for RunningPlugin {}
324
325impl RunningPluginTrait for RunningPlugin {
326 fn adminspace_getter<'a>(
327 &'a self,
328 key_expr: &'a KeyExpr<'a>,
329 plugin_status_key: &str,
330 ) -> ZResult<Vec<zenoh::internal::plugins::Response>> {
331 let mut responses = Vec::new();
332 let mut key = String::from(plugin_status_key);
333 with_extended_string(&mut key, &["/version"], |key| {
334 if keyexpr::new(key.as_str()).unwrap().intersects(key_expr) {
335 responses.push(zenoh::internal::plugins::Response::new(
336 key.clone(),
337 GIT_VERSION.into(),
338 ))
339 }
340 });
341 with_extended_string(&mut key, &["/port"], |port_key| {
342 if keyexpr::new(port_key.as_str())
343 .unwrap()
344 .intersects(key_expr)
345 {
346 responses.push(zenoh::internal::plugins::Response::new(
347 port_key.clone(),
348 (&self.0).into(),
349 ))
350 }
351 });
352 Ok(responses)
353 }
354}
355
356fn with_extended_string<R, F: FnMut(&mut String) -> R>(
357 prefix: &mut String,
358 suffixes: &[&str],
359 mut closure: F,
360) -> R {
361 let prefix_len = prefix.len();
362 for suffix in suffixes {
363 prefix.push_str(suffix);
364 }
365 let result = closure(prefix);
366 prefix.truncate(prefix_len);
367 result
368}
369
370async fn query(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
371 tracing::trace!("Incoming GET request: {:?}", req);
372
373 let first_accept = match req.header("accept") {
374 Some(accept) => accept[0]
375 .to_string()
376 .split(';')
377 .next()
378 .unwrap()
379 .split(',')
380 .next()
381 .unwrap()
382 .to_string(),
383 None => "application/json".to_string(),
384 };
385 if first_accept == "text/event-stream" {
386 Ok(tide::sse::upgrade(
387 req,
388 move |req: Request<(Arc<Session>, String)>, sender: Sender| async move {
389 let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
390 Ok(ke) => ke.into_owned(),
391 Err(e) => {
392 return Err(tide::Error::new(
393 tide::StatusCode::BadRequest,
394 anyhow::anyhow!("{}", e),
395 ))
396 }
397 };
398 spawn_runtime(async move {
399 tracing::debug!("Subscribe to {} for SSE stream", key_expr);
400 let sender = &sender;
401 let sub = req.state().0.declare_subscriber(&key_expr).await.unwrap();
402 loop {
403 let sample = sub.recv_async().await.unwrap();
404 let json_sample =
405 serde_json::to_string(&sample_to_json(&sample)).unwrap_or("{}".into());
406
407 match timeout(
408 std::time::Duration::new(10, 0),
409 sender.send(&sample.kind().to_string(), json_sample, None),
410 )
411 .await
412 {
413 Ok(Ok(_)) => {}
414 Ok(Err(e)) => {
415 tracing::debug!("SSE error ({})! Unsubscribe and terminate", e);
416 if let Err(e) = sub.undeclare().await {
417 tracing::error!("Error undeclaring subscriber: {}", e);
418 }
419 break;
420 }
421 Err(_) => {
422 tracing::debug!("SSE timeout! Unsubscribe and terminate",);
423 if let Err(e) = sub.undeclare().await {
424 tracing::error!("Error undeclaring subscriber: {}", e);
425 }
426 break;
427 }
428 }
429 }
430 });
431 Ok(())
432 },
433 ))
434 } else {
435 let body = req.body_bytes().await.unwrap_or_default();
436 let url = req.url();
437 let key_expr = match path_to_key_expr(url.path(), &req.state().1) {
438 Ok(ke) => ke,
439 Err(e) => {
440 return Ok(response(
441 StatusCode::BadRequest,
442 "text/plain",
443 &e.to_string(),
444 ))
445 }
446 };
447 let query_part = url.query();
448 let parameters = Parameters::from(query_part.unwrap_or_default());
449 let consolidation = if parameters.time_range().is_some() {
450 QueryConsolidation::from(zenoh::query::ConsolidationMode::None)
451 } else {
452 QueryConsolidation::from(zenoh::query::ConsolidationMode::Latest)
453 };
454 let raw = parameters.contains_key(RAW_KEY);
455 let mut query = req
456 .state()
457 .0
458 .get(Selector::borrowed(&key_expr, ¶meters))
459 .consolidation(consolidation)
460 .with(flume::unbounded());
461 if !body.is_empty() {
462 let encoding: Encoding = req
463 .content_type()
464 .map(|m| Encoding::from(m.to_string()))
465 .unwrap_or_default();
466 query = query.payload(body).encoding(encoding);
467 }
468 match query.await {
469 Ok(receiver) => {
470 if raw {
471 Ok(to_raw_response(receiver).await)
472 } else if first_accept == "text/html" {
473 Ok(to_html_response(receiver).await)
474 } else {
475 Ok(to_json_response(receiver).await)
476 }
477 }
478 Err(e) => Ok(response(
479 StatusCode::InternalServerError,
480 "text/plain",
481 &e.to_string(),
482 )),
483 }
484 }
485}
486
487async fn write(mut req: Request<(Arc<Session>, String)>) -> tide::Result<Response> {
488 tracing::trace!("Incoming PUT request: {:?}", req);
489 match req.body_bytes().await {
490 Ok(bytes) => {
491 let key_expr = match path_to_key_expr(req.url().path(), &req.state().1) {
492 Ok(ke) => ke,
493 Err(e) => {
494 return Ok(response(
495 StatusCode::BadRequest,
496 "text/plain",
497 &e.to_string(),
498 ))
499 }
500 };
501
502 let encoding: Encoding = req
503 .content_type()
504 .map(|m| Encoding::from(m.to_string()))
505 .unwrap_or_default();
506
507 let session = &req.state().0;
509 let res = match method_to_kind(req.method()) {
510 SampleKind::Put => session.put(&key_expr, bytes).encoding(encoding).await,
511 SampleKind::Delete => session.delete(&key_expr).await,
512 };
513 match res {
514 Ok(_) => Ok(Response::new(StatusCode::Ok)),
515 Err(e) => Ok(response(
516 StatusCode::InternalServerError,
517 "text/plain",
518 &e.to_string(),
519 )),
520 }
521 }
522 Err(e) => Ok(response(
523 StatusCode::NoContent,
524 "text/plain",
525 &e.to_string(),
526 )),
527 }
528}
529
530pub async fn run(runtime: DynamicRuntime, conf: Config) -> ZResult<()> {
531 zenoh::init_log_from_env_or("error");
535
536 let zid = runtime.zid().to_string();
537 let session = zenoh::session::init(runtime).await.unwrap();
538
539 let mut app = Server::with_state((Arc::new(session), zid));
540 app.with(
541 tide::security::CorsMiddleware::new()
542 .allow_methods(
543 "GET, POST, PUT, PATCH, DELETE"
544 .parse::<http_types::headers::HeaderValue>()
545 .unwrap(),
546 )
547 .allow_origin(tide::security::Origin::from("*"))
548 .allow_credentials(false),
549 );
550
551 app.at("/")
552 .get(query)
553 .post(query)
554 .put(write)
555 .patch(write)
556 .delete(write);
557 app.at("*")
558 .get(query)
559 .post(query)
560 .put(write)
561 .patch(write)
562 .delete(write);
563
564 if let Err(e) = app.listen(conf.http_port).await {
565 tracing::error!("Unable to start http server for REST: {:?}", e);
566 return Err(e.into());
567 }
568 Ok(())
569}
570
571fn path_to_key_expr<'a>(path: &'a str, zid: &str) -> ZResult<KeyExpr<'a>> {
572 let path = path.strip_prefix('/').unwrap_or(path);
573 if path == "@/local" {
574 KeyExpr::try_from(format!("@/{zid}"))
575 } else if let Some(suffix) = path.strip_prefix("@/local/") {
576 KeyExpr::try_from(format!("@/{zid}/{suffix}"))
577 } else {
578 KeyExpr::try_from(path)
579 }
580}