1use std::{
2 collections::{hash_map::Entry, HashMap},
3 fmt::{self, Debug, Display},
4 future::Future,
5 mem,
6 num::NonZeroU32,
7 str::FromStr,
8 sync::Arc,
9};
10
11#[cfg(feature = "tls")]
12use rustls_pki_types::{CertificateDer, PrivateKeyDer, ServerName};
13
14use tokio::{net::ToSocketAddrs, sync::Mutex};
15
16use crate::{
17 capabilities::{Base, Capabilities},
18 message::{
19 rpc::{
20 self,
21 operation::{Builder, CloseSession},
22 IntoResult,
23 },
24 ClientHello, ClientMsg, ReadError, ServerHello, ServerMsg,
25 },
26 transport::Transport,
27 Error,
28};
29
30#[cfg(feature = "tls")]
31use crate::transport::Tls;
32
33#[cfg(feature = "ssh")]
34use crate::transport::{Password, Ssh};
35
36#[cfg(feature = "junos")]
37use crate::transport::JunosLocal;
38
39#[allow(clippy::module_name_repetitions)]
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct SessionId(NonZeroU32);
43
44impl SessionId {
45 pub(crate) fn new(n: u32) -> Result<Self, Error> {
46 NonZeroU32::new(n)
47 .ok_or(Error::InvalidSessionId { session_id: n })
48 .map(Self)
49 }
50}
51
52impl FromStr for SessionId {
53 type Err = ReadError;
54
55 fn from_str(s: &str) -> Result<Self, Self::Err> {
56 Ok(Self(s.parse().map_err(Self::Err::SessionIdParse)?))
57 }
58}
59
60impl Display for SessionId {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 Display::fmt(&self.0, f)
63 }
64}
65
66#[derive(Debug)]
74pub struct Session<T: Transport> {
75 transport_tx: Arc<Mutex<T::SendHandle>>,
76 transport_rx: Arc<Mutex<T::RecvHandle>>,
77 context: Context,
78 last_message_id: rpc::MessageId,
79 requests: Arc<Mutex<HashMap<rpc::MessageId, OutstandingRequest>>>,
80}
81
82#[derive(Debug)]
84pub struct Context {
85 session_id: SessionId,
86 protocol_version: Base,
87 client_capabilities: Capabilities,
88 server_capabilities: Capabilities,
89}
90
91impl Context {
92 const fn new(
93 session_id: SessionId,
94 protocol_version: Base,
95 client_capabilities: Capabilities,
96 server_capabilities: Capabilities,
97 ) -> Self {
98 Self {
99 session_id,
100 protocol_version,
101 client_capabilities,
102 server_capabilities,
103 }
104 }
105
106 #[must_use]
108 pub const fn session_id(&self) -> SessionId {
109 self.session_id
110 }
111
112 #[must_use]
114 pub const fn protocol_version(&self) -> Base {
115 self.protocol_version
116 }
117
118 #[must_use]
120 pub const fn client_capabilities(&self) -> &Capabilities {
121 &self.client_capabilities
122 }
123
124 #[must_use]
126 pub const fn server_capabilities(&self) -> &Capabilities {
127 &self.server_capabilities
128 }
129}
130
131#[derive(Debug)]
132enum OutstandingRequest {
133 Pending,
134 Ready(rpc::PartialReply),
135 Complete,
136}
137
138impl OutstandingRequest {
139 #[tracing::instrument(level = "trace")]
140 fn take(&mut self) -> Result<Option<rpc::PartialReply>, Error> {
141 match mem::replace(self, Self::Complete) {
142 mut pending @ Self::Pending => {
143 mem::swap(self, &mut pending);
144 Ok(None)
145 }
146 Self::Complete => Err(Error::RequestComplete),
147 Self::Ready(reply) => Ok(Some(reply)),
148 }
149 }
150}
151
152#[cfg(feature = "ssh")]
153impl Session<Ssh> {
154 #[tracing::instrument(level = "debug")]
156 pub async fn ssh<A>(addr: A, username: String, password: Password) -> Result<Self, Error>
157 where
158 A: ToSocketAddrs + Send + Debug,
159 {
160 tracing::info!("starting ssh transport");
161 let transport = Ssh::connect(addr, username, password).await?;
162 Self::new(transport).await
163 }
164}
165
166#[cfg(feature = "tls")]
167impl Session<Tls> {
168 #[tracing::instrument(skip(ca_cert, client_cert, client_key), level = "debug")]
170 pub async fn tls<A, S>(
171 addr: A,
172 server_name: S,
173 ca_cert: CertificateDer<'_>,
174 client_cert: CertificateDer<'static>,
175 client_key: PrivateKeyDer<'static>,
176 ) -> Result<Self, Error>
177 where
178 A: ToSocketAddrs + Debug + Send,
179 S: TryInto<ServerName<'static>> + Debug + Send,
180 Error: From<S::Error>,
181 {
182 tracing::info!("starting tls transport");
183 let transport = Tls::connect(addr, server_name, ca_cert, client_cert, client_key).await?;
184 Self::new(transport).await
185 }
186}
187
188#[cfg(feature = "junos")]
189impl Session<JunosLocal> {
190 #[tracing::instrument(level = "debug")]
192 pub async fn junos_local() -> Result<Self, Error> {
193 tracing::info!("starting local junos transport");
194 let transport = JunosLocal::connect().await?;
195 Self::new(transport).await
196 }
197}
198
199impl<T: Transport> Session<T> {
200 #[tracing::instrument(skip(transport), level = "trace")]
201 async fn new(transport: T) -> Result<Self, Error> {
202 let client_hello = ClientHello::default();
203 let (mut tx, mut rx) = transport.split();
204 let ((), server_hello) =
205 tokio::try_join!(client_hello.send(&mut tx), ServerHello::recv(&mut rx))?;
206 let transport_tx = Arc::new(Mutex::new(tx));
207 let transport_rx = Arc::new(Mutex::new(rx));
208 let session_id = server_hello.session_id();
209 let server_capabilities = server_hello.capabilities();
210 let client_capabilities = client_hello.capabilities();
211 let protocol_version = client_capabilities.highest_common_version(&server_capabilities)?;
212 let context = Context::new(
213 session_id,
214 protocol_version,
215 client_capabilities,
216 server_capabilities,
217 );
218 let requests = Arc::new(Mutex::new(HashMap::default()));
219 Ok(Self {
220 transport_tx,
221 transport_rx,
222 context,
223 requests,
224 last_message_id: rpc::MessageId::default(),
225 })
226 }
227
228 #[must_use]
230 pub const fn context(&self) -> &Context {
231 &self.context
232 }
233
234 #[tracing::instrument(skip(self, build_fn), level = "debug")]
259 pub async fn rpc<O, F>(
260 &mut self,
261 build_fn: F,
262 ) -> Result<impl Future<Output = Result<<O::Reply as IntoResult>::Ok, Error>>, Error>
263 where
264 O: rpc::Operation,
265 F: FnOnce(O::Builder<'_>) -> Result<O, Error> + Send,
266 {
267 let message_id = self.last_message_id.increment();
268 let request = O::new(&self.context, build_fn)
269 .map(|operation| rpc::Request::new(message_id, operation))?;
270 #[allow(clippy::significant_drop_in_scrutinee)]
271 match self.requests.lock().await.entry(message_id) {
272 Entry::Occupied(_) => return Err(Error::MessageIdCollision { message_id }),
273 Entry::Vacant(entry) => {
274 request.send(&mut *self.transport_tx.lock().await).await?;
275 _ = entry.insert(OutstandingRequest::Pending);
276 }
277 };
278 let requests = self.requests.clone();
279 let rx = self.transport_rx.clone();
280 Ok(Self::recv::<O>(message_id, requests, rx))
281 }
282
283 #[tracing::instrument(skip(requests, rx), level = "debug")]
284 async fn recv<O>(
285 message_id: rpc::MessageId,
286 requests: Arc<Mutex<HashMap<rpc::MessageId, OutstandingRequest>>>,
287 rx: Arc<Mutex<<T as Transport>::RecvHandle>>,
288 ) -> Result<<O::Reply as IntoResult>::Ok, Error>
289 where
290 O: rpc::Operation,
291 {
292 loop {
296 let mut rx_guard = rx.lock().await;
297 tracing::trace!(?requests);
298 tracing::debug!("checking for ready response");
299 if let Some(partial) = requests
300 .lock()
301 .await
302 .get_mut(&message_id)
303 .ok_or(Error::RequestNotFound { message_id })?
304 .take()?
305 {
306 tracing::debug!("found ready response");
307 let reply: rpc::Reply<O> = partial.try_into()?;
308 break reply.into_result();
309 };
310 tracing::debug!("response to {message_id:?} not yet ready");
311 let reply = rpc::PartialReply::recv(&mut *rx_guard).await?;
312 #[allow(clippy::significant_drop_in_scrutinee)]
313 match requests
314 .lock()
315 .await
316 .get_mut(&reply.message_id())
317 .ok_or_else(|| Error::RequestNotFound {
318 message_id: reply.message_id(),
319 })? {
320 OutstandingRequest::Complete => break Err(Error::RequestComplete),
321 OutstandingRequest::Ready(_) => {
322 break Err(Error::MessageIdCollision {
323 message_id: reply.message_id(),
324 })
325 }
326 pending @ OutstandingRequest::Pending => {
327 tracing::debug!("storing response to {:?}", reply.message_id());
328 _ = mem::replace(pending, OutstandingRequest::Ready(reply));
329 }
330 };
331 drop(rx_guard);
332 }
333 }
334
335 #[tracing::instrument(skip(self), level = "debug")]
337 pub async fn close(mut self) -> Result<impl Future<Output = Result<(), Error>>, Error> {
338 self.rpc::<CloseSession, _>(Builder::finish)
339 .await
340 .map(|fut| async move { fut.await.map(|()| drop(self)) })
341 }
342}