1use super::capabilities::Capabilities;
2#[cfg(feature = "dangerous_configuration")]
3use super::dangerous::DangerousClientBuilder;
4use crate::auth::AuthInterceptor;
5use crate::error::GinmiError;
6use crate::gen::gnmi::g_nmi_client::GNmiClient;
7use crate::gen::gnmi::CapabilityRequest;
8use hyper::body::Bytes;
9use std::str::FromStr;
10use tonic::codegen::{Body, InterceptedService, StdError};
11use tonic::metadata::AsciiMetadataValue;
12use tonic::transport::{Certificate, Channel, ClientTlsConfig, Uri};
13
14#[derive(Debug, Clone)]
17pub struct Client<T> {
18 pub(crate) inner: GNmiClient<T>,
19}
20
21impl<'a> Client<InterceptedService<Channel, AuthInterceptor>> {
22 pub fn builder(target: &'a str) -> ClientBuilder<'a> {
24 ClientBuilder::new(target)
25 }
26}
27
28impl<T> Client<T>
29where
30 T: tonic::client::GrpcService<tonic::body::BoxBody>,
31 T::Error: Into<StdError>,
32 T::ResponseBody: Body<Data = Bytes> + Send + 'static,
33 <T::ResponseBody as Body>::Error: Into<StdError> + Send,
34{
35 pub async fn capabilities(&mut self) -> Result<Capabilities, GinmiError> {
54 let req = CapabilityRequest::default();
55 let res = self.inner.capabilities(req).await?;
56 Ok(Capabilities(res.into_inner()))
57 }
58}
59
60#[derive(Debug, Copy, Clone)]
61pub struct Credentials<'a> {
62 pub(crate) username: &'a str,
63 pub(crate) password: &'a str,
64}
65
66#[derive(Debug, Clone)]
70pub struct ClientBuilder<'a> {
71 pub(crate) target: &'a str,
72 pub(crate) creds: Option<Credentials<'a>>,
73 tls_settings: Option<ClientTlsConfig>,
74}
75
76impl<'a> ClientBuilder<'a> {
77 pub fn new(target: &'a str) -> Self {
78 Self {
79 target,
80 creds: None,
81 tls_settings: None,
82 }
83 }
84
85 pub fn credentials(mut self, username: &'a str, password: &'a str) -> Self {
87 self.creds = Some(Credentials { username, password });
88 self
89 }
90
91 pub fn tls(mut self, ca_certificate: impl AsRef<[u8]>, domain_name: impl Into<String>) -> Self {
93 let cert = Certificate::from_pem(ca_certificate);
94 let settings = ClientTlsConfig::new()
95 .ca_certificate(cert)
96 .domain_name(domain_name);
97 self.tls_settings = Some(settings);
98 self
99 }
100
101 #[cfg(feature = "dangerous_configuration")]
102 #[cfg_attr(docsrs, doc(cfg(feature = "dangerous_configuration")))]
103 pub fn dangerous(self) -> DangerousClientBuilder<'a> {
105 DangerousClientBuilder::from(self)
106 }
107
108 pub async fn build(
116 self,
117 ) -> Result<Client<InterceptedService<Channel, AuthInterceptor>>, GinmiError> {
118 let uri = match Uri::from_str(self.target) {
119 Ok(u) => u,
120 Err(e) => return Err(GinmiError::InvalidUriError(e.to_string())),
121 };
122
123 let mut endpoint = Channel::builder(uri);
124
125 if self.tls_settings.is_some() {
126 endpoint = endpoint.tls_config(self.tls_settings.unwrap())?;
127 }
128
129 let channel = endpoint.connect().await?;
130 let (username, password) = match self.creds {
131 Some(c) => (
132 Some(AsciiMetadataValue::from_str(c.username)?),
133 Some(AsciiMetadataValue::from_str(c.password)?),
134 ),
135 None => (None, None),
136 };
137
138 Ok(Client {
139 inner: GNmiClient::with_interceptor(channel, AuthInterceptor::new(username, password)),
140 })
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[tokio::test]
149 async fn invalid_uri() {
150 let client = Client::<InterceptedService<Channel, AuthInterceptor>>::builder("$$$$")
151 .build()
152 .await;
153 assert!(client.is_err());
154 }
155
156 #[tokio::test]
157 async fn invalid_tls_settings() {
158 let client = Client::builder("https://test:57400")
159 .tls("invalid cert", "invalid domain")
160 .build()
161 .await;
162 assert!(client.is_err());
163 }
164}