1use thiserror::Error;
88
89#[cfg(feature = "async")]
90pub mod async_api;
91mod client;
92mod protocol;
93
94static SYSNAME: &str = "ar";
95
96#[cfg(feature = "blocking")]
97pub use client::{HttpClient, UdsClient};
98
99#[derive(Debug, Error)]
101pub enum Error {
102 #[error("protocol error: {0}")]
104 ProtocolError(#[from] protocol::Error),
105 #[error("communication error: {0}")]
107 ClientError(#[from] client::Error),
108}
109
110#[derive(Default)]
112pub enum ResultFormat {
113 #[default]
115 Json,
116 Text,
118}
119
120#[derive(Debug, PartialEq, Eq)]
122pub enum Response {
123 Error {
125 message: String,
127 code: i64,
129 errors: Vec<String>,
133 },
134 Result(Vec<String>),
135}
136
137pub trait Runner {
139 fn run<S: AsRef<str>>(self, cmds: &[S], format: ResultFormat) -> Result<Response, Error>;
148}
149
150impl<T: client::Requester> Runner for T {
151 fn run<S: AsRef<str>>(self, cmds: &[S], format: ResultFormat) -> Result<Response, Error> {
152 let request = protocol::make_run_request(cmds, format);
153 let response = self.do_request(request)?;
154 protocol::parse_response(&response).map_err(|e| e.into())
155 }
156}
157
158#[doc(hidden)]
159pub struct UdsClientBuilder {
160 sysname: String,
161 socket_name: Option<String>,
162}
163
164impl ClientBuilder<UdsClientBuilder> {
165 pub fn set_sysname(mut self, sysname: String) -> Self {
167 self.0.sysname = sysname;
168 self
169 }
170
171 pub fn set_socket_name(mut self, socket_name: String) -> Self {
173 self.0.socket_name = Some(socket_name);
174 self
175 }
176
177 #[cfg(feature = "blocking")]
178 pub fn build_blocking(self) -> Result<UdsClient, Error> {
180 let socket_name = self
181 .0
182 .socket_name
183 .unwrap_or_else(|| protocol::make_socket_name(SYSNAME));
184 UdsClient::connect(self.0.sysname, socket_name).map_err(|e| e.into())
185 }
186}
187
188#[doc(hidden)]
189pub struct UseHttp(());
190#[doc(hidden)]
191pub struct UseHttps {
192 insecure: bool,
193}
194#[doc(hidden)]
195pub struct HttpClientBuilder<T> {
196 hostname: String,
197 auth: Option<(String, String)>,
198 timeout: std::time::Duration,
199 https: T,
200}
201
202impl<T> ClientBuilder<HttpClientBuilder<T>> {
203 pub fn set_authentication(mut self, username: String, password: String) -> Self {
205 self.0.auth = Some((username, password));
206 self
207 }
208
209 pub fn set_timeout(mut self, timeout: std::time::Duration) -> Self {
212 self.0.timeout = timeout;
213 self
214 }
215
216 pub fn enable_https(self) -> ClientBuilder<HttpClientBuilder<UseHttps>> {
218 ClientBuilder(HttpClientBuilder {
219 hostname: self.0.hostname,
220 auth: self.0.auth,
221 timeout: self.0.timeout,
222 https: UseHttps { insecure: false },
223 })
224 }
225}
226
227impl ClientBuilder<HttpClientBuilder<UseHttp>> {
228 #[cfg(feature = "blocking")]
229 pub fn build_blocking(self) -> HttpClient {
231 HttpClient::new_http(self.0.hostname, self.0.auth, self.0.timeout)
232 }
233}
234
235impl ClientBuilder<HttpClientBuilder<UseHttps>> {
236 pub fn set_insecure(mut self, value: bool) -> Self {
239 self.0.https.insecure = value;
240 self
241 }
242
243 #[cfg(feature = "blocking")]
244 pub fn build_blocking(self) -> HttpClient {
246 HttpClient::new_https(
247 self.0.hostname,
248 self.0.auth,
249 self.0.timeout,
250 self.0.https.insecure,
251 )
252 }
253}
254
255pub struct ClientBuilder<T>(T);
257
258impl ClientBuilder<()> {
259 pub fn unix_socket() -> ClientBuilder<UdsClientBuilder> {
261 ClientBuilder(UdsClientBuilder {
262 sysname: SYSNAME.to_string(),
263 socket_name: None,
264 })
265 }
266
267 pub fn http(hostname: String) -> ClientBuilder<HttpClientBuilder<UseHttp>> {
269 ClientBuilder(HttpClientBuilder {
270 hostname,
271 auth: None,
272 timeout: std::time::Duration::from_secs(30),
273 https: UseHttp(()),
274 })
275 }
276}
277
278#[deprecated(since = "0.2.0", note = "please use the `ClientBuilder`")]
288#[cfg(feature = "blocking")]
289pub fn eapi_run<T: AsRef<str>>(
290 sysname: Option<&str>,
291 commands: &[T],
292 format: ResultFormat,
293) -> Result<Response, Error> {
294 let mut builder = ClientBuilder::unix_socket();
295 if let Some(sysname) = sysname {
296 builder = builder.set_sysname(sysname.to_owned());
297 }
298 builder.build_blocking()?.run(commands, format)
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use nix::sys;
305 use std::convert::Infallible;
306 use std::io::IoSliceMut;
307 use std::os::unix::io::RawFd;
308
309 fn rcv_string(socket: RawFd) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
310 let mut len = [0; 4];
311 sys::socket::recv(socket, &mut len, sys::socket::MsgFlags::MSG_WAITALL)?;
312 let len = i32::from_le_bytes(len) as usize;
313
314 let mut buf = vec![0; len];
315 sys::socket::recv(socket, &mut buf, sys::socket::MsgFlags::MSG_WAITALL)?;
316
317 Ok(buf)
318 }
319
320 fn rcv_fd<const N: usize>(
321 socket: RawFd,
322 ) -> Result<Vec<RawFd>, Box<dyn std::error::Error + Send + Sync>> {
323 let mut buf = [0];
324 let mut iov = [IoSliceMut::new(&mut buf)];
325 let mut cmsg_buf: Vec<u8> = nix::cmsg_space!([RawFd; N]);
326
327 let mut result = Vec::with_capacity(N);
328 loop {
329 let rcv = sys::socket::recvmsg::<()>(
330 socket,
331 &mut iov,
332 Some(&mut cmsg_buf),
333 sys::socket::MsgFlags::empty(),
334 )?;
335
336 for cmsg in rcv.cmsgs() {
337 if let sys::socket::ControlMessageOwned::ScmRights(fds) = cmsg {
338 result.extend(fds)
339 } else {
340 return Err("didn't receive SCM_RIGHTS message".into());
341 }
342 }
343 if result.len() == N {
344 break Ok(result);
345 }
346 }
347 }
348
349 pub fn run_uds_server<T: AsRef<str>>(
350 socket_name: T,
351 sysname: &str,
352 response: &str,
353 ) -> (
354 std::sync::Arc<std::sync::Barrier>,
355 std::thread::JoinHandle<Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>>,
356 ) {
357 let socket_name = socket_name.as_ref().to_string();
358 let sysname = sysname.to_string();
359 let response = response.to_string();
360
361 let barrier = std::sync::Arc::new(std::sync::Barrier::new(2));
362 let ready = barrier.clone();
363 let handle = std::thread::spawn(move || {
364 let socket = sys::socket::socket(
365 sys::socket::AddressFamily::Unix,
366 sys::socket::SockType::Stream,
367 sys::socket::SockFlag::empty(),
368 None,
369 )?;
370
371 sys::socket::bind(socket, &sys::socket::UnixAddr::new(socket_name.as_str())?)?;
372 sys::socket::listen(socket, 1)?;
373 ready.wait();
374
375 let stream = sys::socket::accept(socket)?;
376
377 let signal = rcv_fd::<1>(stream)?[0];
378
379 if rcv_string(stream)? != protocol::make_args(sysname) {
380 return Err("received invalid args".into());
381 }
382
383 if rcv_string(stream)? != protocol::make_env()? {
384 return Err("received invalid env".into());
385 }
386
387 if rcv_string(stream)? != "0".as_bytes() {
388 return Err("received invalid uid".into());
389 }
390
391 if rcv_string(stream)? != "0".as_bytes() {
392 return Err("received invalid gid".into());
393 }
394
395 if rcv_string(stream)? != "".as_bytes() {
396 return Err("received invalid terminal name".into());
397 }
398
399 let mut buf = [0];
400 sys::socket::recv(stream, &mut buf, sys::socket::MsgFlags::empty())?;
401 if buf[0] != b'c' {
402 return Err("received invalid mode".into());
403 }
404
405 let sockets = rcv_fd::<3>(stream)?;
406 let resp_socket = sockets[0];
407 let req_socket = sockets[1];
408 let stats = sockets[2];
409
410 nix::unistd::close(stream)?;
411
412 let request = rcv_string(req_socket)?;
413 sys::socket::send(
414 resp_socket,
415 response.as_bytes(),
416 sys::socket::MsgFlags::empty(),
417 )?;
418
419 nix::unistd::close(signal)?;
420 nix::unistd::close(req_socket)?;
421 nix::unistd::close(resp_socket)?;
422 nix::unistd::close(stats)?;
423
424 nix::unistd::close(socket)?;
425
426 Ok(request)
427 });
428
429 (barrier, handle)
430 }
431
432 fn run_http_server(
433 response: &str,
434 ) -> (
435 u16,
436 tokio::sync::oneshot::Sender<()>,
437 tokio::sync::mpsc::Receiver<(Vec<u8>, Vec<u8>)>,
438 ) {
439 let rt = tokio::runtime::Runtime::new().unwrap();
440 let (sender, receiver) = tokio::sync::mpsc::channel(1);
441 let (tx_shut, rx_shut) = tokio::sync::oneshot::channel::<()>();
442
443 let addr = ([127, 0, 0, 1], 0).into();
444 let incoming = {
445 let _guard = rt.enter();
446 hyper::server::conn::AddrIncoming::bind(&addr).unwrap()
447 };
448 let port = incoming.local_addr().port();
449 let response = response.to_string();
450
451 std::thread::spawn(move || {
452 rt.block_on(async move {
453 let make_service = hyper::service::make_service_fn(move |_conn| {
454 let response = response.clone();
455 let sender = sender.clone();
456 async move {
457 Ok::<_, Infallible>(hyper::service::service_fn(
458 move |req: hyper::Request<hyper::Body>| {
459 let auth = req
460 .headers()
461 .get("Authorization")
462 .unwrap_or(&hyper::header::HeaderValue::from_static(""))
463 .as_bytes()
464 .to_vec();
465 let response = response.clone();
466 let sender = sender.clone();
467 async move {
468 let body =
469 hyper::body::to_bytes(req.into_body()).await?.to_vec();
470 sender.send((auth, body)).await.unwrap();
471 Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::from(
472 response,
473 )))
474 }
475 },
476 ))
477 }
478 });
479
480 let server = hyper::server::Server::builder(incoming)
481 .serve(make_service)
482 .with_graceful_shutdown(async {
483 rx_shut.await.ok();
484 });
485 server.await.unwrap();
486 })
487 });
488
489 (port, tx_shut, receiver)
490 }
491
492 #[test]
493 fn test_uds_ok() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
494 let tmp_dir = tempfile::tempdir()?;
495 let socket_name = tmp_dir
496 .path()
497 .join(SYSNAME)
498 .to_str()
499 .ok_or("can't convert path to string")?
500 .to_string();
501
502 let response = r#"{
503 "jsonrpc": "2.0",
504 "result": ["test1", "test2", {"a": "b"}],
505 "id": "1"
506 }"#;
507 let (ready, handle) = run_uds_server(&socket_name, SYSNAME, response);
508 ready.wait();
509 let result = ClientBuilder::unix_socket()
510 .set_sysname(SYSNAME.to_owned())
511 .set_socket_name(socket_name)
512 .build_blocking()?
513 .run(&["show run", "show int", "show clock"], ResultFormat::Json)?;
514 let request = match handle.join() {
515 Ok(r) => r?,
516 Err(e) => std::panic::resume_unwind(e),
517 };
518 let expected = serde_json::json!({
519 "jsonrpc": "2.0",
520 "method": "runCmds",
521 "params": {
522 "version": "latest",
523 "cmds": ["show run", "show int", "show clock"],
524 "format": "json",
525 },
526 "id": "1"
527 })
528 .to_string();
529 assert_eq!(request, expected.as_bytes());
530 assert_eq!(
531 result,
532 Response::Result(vec![
533 "\"test1\"".to_string(),
534 "\"test2\"".to_string(),
535 "{\"a\":\"b\"}".to_string()
536 ])
537 );
538
539 Ok(())
540 }
541
542 #[test]
543 fn test_uds_error() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
544 let tmp_dir = tempfile::tempdir()?;
545 let socket_name = tmp_dir
546 .path()
547 .join(SYSNAME)
548 .to_str()
549 .ok_or("can't convert path to string")?
550 .to_string();
551
552 let response = r#"{
553 "jsonrpc": "2.0",
554 "error": {
555 "message": "error message",
556 "code": 3,
557 "data": ["a", "b"]
558 },
559 "id": "1"
560 }"#;
561 let (ready, handle) = run_uds_server(&socket_name, SYSNAME, response);
562 ready.wait();
563 let result = ClientBuilder::unix_socket()
564 .set_sysname(SYSNAME.to_owned())
565 .set_socket_name(socket_name)
566 .build_blocking()?
567 .run(&["show run", "show int", "show clock"], ResultFormat::Json)?;
568 let request = match handle.join() {
569 Ok(r) => r?,
570 Err(e) => std::panic::resume_unwind(e),
571 };
572 let expected = serde_json::json!({
573 "jsonrpc": "2.0",
574 "method": "runCmds",
575 "params": {
576 "version": "latest",
577 "cmds": ["show run", "show int", "show clock"],
578 "format": "json",
579 },
580 "id": "1"
581 })
582 .to_string();
583 assert_eq!(request, expected.as_bytes());
584 assert_eq!(
585 result,
586 Response::Error {
587 message: "error message".to_string(),
588 code: 3,
589 errors: vec!["\"a\"".to_string(), "\"b\"".to_string()]
590 }
591 );
592
593 Ok(())
594 }
595
596 #[test]
597 fn test_http_ok() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
598 let response = r#"{
599 "jsonrpc": "2.0",
600 "result": ["test1", "test2", {"a": "b"}],
601 "id": "1"
602 }"#;
603 let (port, shutdown, mut receiver) = run_http_server(response);
604 let result = ClientBuilder::http("localhost:".to_owned() + &port.to_string())
605 .set_authentication("admin".to_owned(), "pass".to_owned())
606 .build_blocking()
607 .run(&["show run", "show int", "show clock"], ResultFormat::Json)?;
608 let request = receiver.blocking_recv().unwrap();
609 let expected = serde_json::json!({
610 "jsonrpc": "2.0",
611 "method": "runCmds",
612 "params": {
613 "version": "latest",
614 "cmds": ["show run", "show int", "show clock"],
615 "format": "json",
616 },
617 "id": "1"
618 })
619 .to_string();
620 assert_eq!(request.0, "Basic YWRtaW46cGFzcw==".as_bytes());
621 assert_eq!(request.1, expected.as_bytes());
622 assert_eq!(
623 result,
624 Response::Result(vec![
625 "\"test1\"".to_string(),
626 "\"test2\"".to_string(),
627 "{\"a\":\"b\"}".to_string()
628 ])
629 );
630
631 let _ = shutdown.send(());
632
633 Ok(())
634 }
635
636 #[test]
637 fn test_http_error() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
638 let response = r#"{
639 "jsonrpc": "2.0",
640 "error": {
641 "message": "error message",
642 "code": 3,
643 "data": ["a", "b"]
644 },
645 "id": "1"
646 }"#;
647 let (port, shutdown, mut receiver) = run_http_server(response);
648 let result = ClientBuilder::http("localhost:".to_owned() + &port.to_string())
649 .set_authentication("admin".to_owned(), "pass".to_owned())
650 .build_blocking()
651 .run(&["show run", "show int", "show clock"], ResultFormat::Json)?;
652 let request = receiver.blocking_recv().unwrap();
653 let expected = serde_json::json!({
654 "jsonrpc": "2.0",
655 "method": "runCmds",
656 "params": {
657 "version": "latest",
658 "cmds": ["show run", "show int", "show clock"],
659 "format": "json",
660 },
661 "id": "1"
662 })
663 .to_string();
664 assert_eq!(request.0, "Basic YWRtaW46cGFzcw==".as_bytes());
665 assert_eq!(request.1, expected.as_bytes());
666 assert_eq!(
667 result,
668 Response::Error {
669 message: "error message".to_string(),
670 code: 3,
671 errors: vec!["\"a\"".to_string(), "\"b\"".to_string()]
672 }
673 );
674
675 let _ = shutdown.send(());
676
677 Ok(())
678 }
679}