1use std::sync::Arc;
2
3use tokio::io::{AsyncRead, AsyncWrite};
4use tokio::sync::mpsc::{Receiver, Sender};
5use tokio::sync::{Mutex, Notify};
6
7use crate::{ChannelId, ChannelOpenFailure, CryptoVec, Error, Pty, Sig};
8
9pub mod io;
10
11mod channel_ref;
12pub use channel_ref::ChannelRef;
13
14mod channel_stream;
15pub use channel_stream::ChannelStream;
16
17#[derive(Debug)]
18#[non_exhaustive]
19pub enum ChannelMsg {
21 Open {
22 id: ChannelId,
23 max_packet_size: u32,
24 window_size: u32,
25 },
26 Data {
27 data: CryptoVec,
28 },
29 ExtendedData {
30 data: CryptoVec,
31 ext: u32,
32 },
33 Eof,
34 Close,
35 RequestPty {
37 want_reply: bool,
38 term: String,
39 col_width: u32,
40 row_height: u32,
41 pix_width: u32,
42 pix_height: u32,
43 terminal_modes: Vec<(Pty, u32)>,
44 },
45 RequestShell {
47 want_reply: bool,
48 },
49 Exec {
51 want_reply: bool,
52 command: Vec<u8>,
53 },
54 Signal {
56 signal: Sig,
57 },
58 RequestSubsystem {
60 want_reply: bool,
61 name: String,
62 },
63 RequestX11 {
65 want_reply: bool,
66 single_connection: bool,
67 x11_authentication_protocol: String,
68 x11_authentication_cookie: String,
69 x11_screen_number: u32,
70 },
71 SetEnv {
73 want_reply: bool,
74 variable_name: String,
75 variable_value: String,
76 },
77 WindowChange {
79 col_width: u32,
80 row_height: u32,
81 pix_width: u32,
82 pix_height: u32,
83 },
84 AgentForward {
86 want_reply: bool,
87 },
88
89 XonXoff {
91 client_can_do: bool,
92 },
93 ExitStatus {
95 exit_status: u32,
96 },
97 ExitSignal {
99 signal_name: Sig,
100 core_dumped: bool,
101 error_message: String,
102 lang_tag: String,
103 },
104 WindowAdjusted {
106 new_size: u32,
107 },
108 Success,
110 Failure,
112 OpenFailure(ChannelOpenFailure),
113}
114
115#[derive(Clone, Debug)]
116pub(crate) struct WindowSizeRef {
117 value: Arc<Mutex<u32>>,
118 notifier: Arc<Notify>,
119}
120
121impl WindowSizeRef {
122 pub(crate) fn new(initial: u32) -> Self {
123 let notifier = Arc::new(Notify::new());
124 Self {
125 value: Arc::new(Mutex::new(initial)),
126 notifier,
127 }
128 }
129
130 pub(crate) async fn update(&self, value: u32) {
131 *self.value.lock().await = value;
132 self.notifier.notify_one();
133 }
134
135 pub(crate) fn subscribe(&self) -> Arc<Notify> {
136 Arc::clone(&self.notifier)
137 }
138}
139
140pub struct ChannelReadHalf {
144 pub(crate) receiver: Receiver<ChannelMsg>,
145}
146
147impl std::fmt::Debug for ChannelReadHalf {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 f.debug_struct("ChannelReadHalf").finish()
150 }
151}
152
153impl ChannelReadHalf {
154 pub async fn wait(&mut self) -> Option<ChannelMsg> {
156 self.receiver.recv().await
157 }
158
159 pub fn make_reader(&mut self) -> impl AsyncRead + '_ {
162 self.make_reader_ext(None)
163 }
164
165 pub fn make_reader_ext(&mut self, ext: Option<u32>) -> impl AsyncRead + '_ {
168 io::ChannelRx::new(self, ext)
169 }
170}
171
172pub struct ChannelWriteHalf<Send: From<(ChannelId, ChannelMsg)>> {
176 pub(crate) id: ChannelId,
177 pub(crate) sender: Sender<Send>,
178 pub(crate) max_packet_size: u32,
179 pub(crate) window_size: WindowSizeRef,
180}
181
182impl<S: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for ChannelWriteHalf<S> {
183 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
184 f.debug_struct("ChannelWriteHalf")
185 .field("id", &self.id)
186 .finish()
187 }
188}
189
190impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> ChannelWriteHalf<S> {
191 pub async fn writable_packet_size(&self) -> usize {
194 self.max_packet_size
195 .min(*self.window_size.value.lock().await) as usize
196 }
197
198 pub fn id(&self) -> ChannelId {
199 self.id
200 }
201
202 #[allow(clippy::too_many_arguments)] pub async fn request_pty(
205 &self,
206 want_reply: bool,
207 term: &str,
208 col_width: u32,
209 row_height: u32,
210 pix_width: u32,
211 pix_height: u32,
212 terminal_modes: &[(Pty, u32)],
213 ) -> Result<(), Error> {
214 self.send_msg(ChannelMsg::RequestPty {
215 want_reply,
216 term: term.to_string(),
217 col_width,
218 row_height,
219 pix_width,
220 pix_height,
221 terminal_modes: terminal_modes.to_vec(),
222 })
223 .await
224 }
225
226 pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> {
228 self.send_msg(ChannelMsg::RequestShell { want_reply }).await
229 }
230
231 pub async fn exec<A: Into<Vec<u8>>>(&self, want_reply: bool, command: A) -> Result<(), Error> {
235 self.send_msg(ChannelMsg::Exec {
236 want_reply,
237 command: command.into(),
238 })
239 .await
240 }
241
242 pub async fn signal(&self, signal: Sig) -> Result<(), Error> {
244 self.send_msg(ChannelMsg::Signal { signal }).await
245 }
246
247 pub async fn request_subsystem<A: Into<String>>(
249 &self,
250 want_reply: bool,
251 name: A,
252 ) -> Result<(), Error> {
253 self.send_msg(ChannelMsg::RequestSubsystem {
254 want_reply,
255 name: name.into(),
256 })
257 .await
258 }
259
260 pub async fn request_x11<A: Into<String>, B: Into<String>>(
265 &self,
266 want_reply: bool,
267 single_connection: bool,
268 x11_authentication_protocol: A,
269 x11_authentication_cookie: B,
270 x11_screen_number: u32,
271 ) -> Result<(), Error> {
272 self.send_msg(ChannelMsg::RequestX11 {
273 want_reply,
274 single_connection,
275 x11_authentication_protocol: x11_authentication_protocol.into(),
276 x11_authentication_cookie: x11_authentication_cookie.into(),
277 x11_screen_number,
278 })
279 .await
280 }
281
282 pub async fn set_env<A: Into<String>, B: Into<String>>(
284 &self,
285 want_reply: bool,
286 variable_name: A,
287 variable_value: B,
288 ) -> Result<(), Error> {
289 self.send_msg(ChannelMsg::SetEnv {
290 want_reply,
291 variable_name: variable_name.into(),
292 variable_value: variable_value.into(),
293 })
294 .await
295 }
296
297 pub async fn window_change(
299 &self,
300 col_width: u32,
301 row_height: u32,
302 pix_width: u32,
303 pix_height: u32,
304 ) -> Result<(), Error> {
305 self.send_msg(ChannelMsg::WindowChange {
306 col_width,
307 row_height,
308 pix_width,
309 pix_height,
310 })
311 .await
312 }
313
314 pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> {
316 self.send_msg(ChannelMsg::AgentForward { want_reply }).await
317 }
318
319 pub async fn data<R: tokio::io::AsyncRead + Unpin>(&self, data: R) -> Result<(), Error> {
321 self.send_data(None, data).await
322 }
323
324 pub async fn extended_data<R: tokio::io::AsyncRead + Unpin>(
328 &self,
329 ext: u32,
330 data: R,
331 ) -> Result<(), Error> {
332 self.send_data(Some(ext), data).await
333 }
334
335 async fn send_data<R: tokio::io::AsyncRead + Unpin>(
336 &self,
337 ext: Option<u32>,
338 mut data: R,
339 ) -> Result<(), Error> {
340 let mut tx = self.make_writer_ext(ext);
341
342 tokio::io::copy(&mut data, &mut tx).await?;
343
344 Ok(())
345 }
346
347 pub async fn eof(&self) -> Result<(), Error> {
348 self.send_msg(ChannelMsg::Eof).await
349 }
350
351 pub async fn exit_status(&self, exit_status: u32) -> Result<(), Error> {
352 self.send_msg(ChannelMsg::ExitStatus { exit_status }).await
353 }
354
355 pub async fn close(&self) -> Result<(), Error> {
357 self.send_msg(ChannelMsg::Close).await
358 }
359
360 async fn send_msg(&self, msg: ChannelMsg) -> Result<(), Error> {
361 self.sender
362 .send((self.id, msg).into())
363 .await
364 .map_err(|_| Error::SendError)
365 }
366
367 pub fn make_writer(&self) -> impl AsyncWrite + 'static {
370 self.make_writer_ext(None)
371 }
372
373 pub fn make_writer_ext(&self, ext: Option<u32>) -> impl AsyncWrite + 'static {
376 io::ChannelTx::new(
377 self.sender.clone(),
378 self.id,
379 self.window_size.value.clone(),
380 self.window_size.subscribe(),
381 self.max_packet_size,
382 ext,
383 )
384 }
385}
386
387pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
391 pub(crate) read_half: ChannelReadHalf,
392 pub(crate) write_half: ChannelWriteHalf<Send>,
393}
394
395impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
396 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
397 f.debug_struct("Channel")
398 .field("id", &self.write_half.id)
399 .finish()
400 }
401}
402
403impl<S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static> Channel<S> {
404 pub(crate) fn new(
405 id: ChannelId,
406 sender: Sender<S>,
407 max_packet_size: u32,
408 window_size: u32,
409 channel_buffer_size: usize,
410 ) -> (Self, ChannelRef) {
411 let (tx, rx) = tokio::sync::mpsc::channel(channel_buffer_size);
412 let window_size = WindowSizeRef::new(window_size);
413 let read_half = ChannelReadHalf { receiver: rx };
414 let write_half = ChannelWriteHalf {
415 id,
416 sender,
417 max_packet_size,
418 window_size: window_size.clone(),
419 };
420
421 (
422 Self {
423 write_half,
424 read_half,
425 },
426 ChannelRef {
427 sender: tx,
428 window_size,
429 },
430 )
431 }
432
433 pub async fn writable_packet_size(&self) -> usize {
436 self.write_half.writable_packet_size().await
437 }
438
439 pub fn id(&self) -> ChannelId {
440 self.write_half.id()
441 }
442
443 pub fn split(self) -> (ChannelReadHalf, ChannelWriteHalf<S>) {
446 (self.read_half, self.write_half)
447 }
448
449 #[allow(clippy::too_many_arguments)] pub async fn request_pty(
452 &self,
453 want_reply: bool,
454 term: &str,
455 col_width: u32,
456 row_height: u32,
457 pix_width: u32,
458 pix_height: u32,
459 terminal_modes: &[(Pty, u32)],
460 ) -> Result<(), Error> {
461 self.write_half
462 .request_pty(
463 want_reply,
464 term,
465 col_width,
466 row_height,
467 pix_width,
468 pix_height,
469 terminal_modes,
470 )
471 .await
472 }
473
474 pub async fn request_shell(&self, want_reply: bool) -> Result<(), Error> {
476 self.write_half.request_shell(want_reply).await
477 }
478
479 pub async fn exec<A: Into<Vec<u8>>>(&self, want_reply: bool, command: A) -> Result<(), Error> {
483 self.write_half.exec(want_reply, command).await
484 }
485
486 pub async fn signal(&self, signal: Sig) -> Result<(), Error> {
488 self.write_half.signal(signal).await
489 }
490
491 pub async fn request_subsystem<A: Into<String>>(
493 &self,
494 want_reply: bool,
495 name: A,
496 ) -> Result<(), Error> {
497 self.write_half.request_subsystem(want_reply, name).await
498 }
499
500 pub async fn request_x11<A: Into<String>, B: Into<String>>(
505 &self,
506 want_reply: bool,
507 single_connection: bool,
508 x11_authentication_protocol: A,
509 x11_authentication_cookie: B,
510 x11_screen_number: u32,
511 ) -> Result<(), Error> {
512 self.write_half
513 .request_x11(
514 want_reply,
515 single_connection,
516 x11_authentication_protocol,
517 x11_authentication_cookie,
518 x11_screen_number,
519 )
520 .await
521 }
522
523 pub async fn set_env<A: Into<String>, B: Into<String>>(
525 &self,
526 want_reply: bool,
527 variable_name: A,
528 variable_value: B,
529 ) -> Result<(), Error> {
530 self.write_half
531 .set_env(want_reply, variable_name, variable_value)
532 .await
533 }
534
535 pub async fn window_change(
537 &self,
538 col_width: u32,
539 row_height: u32,
540 pix_width: u32,
541 pix_height: u32,
542 ) -> Result<(), Error> {
543 self.write_half
544 .window_change(col_width, row_height, pix_width, pix_height)
545 .await
546 }
547
548 pub async fn agent_forward(&self, want_reply: bool) -> Result<(), Error> {
550 self.write_half.agent_forward(want_reply).await
551 }
552
553 pub async fn data<R: tokio::io::AsyncRead + Unpin>(&self, data: R) -> Result<(), Error> {
555 self.write_half.data(data).await
556 }
557
558 pub async fn extended_data<R: tokio::io::AsyncRead + Unpin>(
562 &self,
563 ext: u32,
564 data: R,
565 ) -> Result<(), Error> {
566 self.write_half.extended_data(ext, data).await
567 }
568
569 pub async fn eof(&self) -> Result<(), Error> {
570 self.write_half.eof().await
571 }
572
573 pub async fn exit_status(&self, exit_status: u32) -> Result<(), Error> {
574 self.write_half.exit_status(exit_status).await
575 }
576
577 pub async fn close(&self) -> Result<(), Error> {
579 self.write_half.close().await
580 }
581
582 pub async fn wait(&mut self) -> Option<ChannelMsg> {
584 self.read_half.wait().await
585 }
586
587 pub fn into_stream(self) -> ChannelStream<S> {
590 ChannelStream::new(
591 io::ChannelTx::new(
592 self.write_half.sender.clone(),
593 self.write_half.id,
594 self.write_half.window_size.value.clone(),
595 self.write_half.window_size.subscribe(),
596 self.write_half.max_packet_size,
597 None,
598 ),
599 io::ChannelRx::new(io::ChannelCloseOnDrop(self), None),
600 )
601 }
602
603 pub fn make_reader(&mut self) -> impl AsyncRead + '_ {
606 self.read_half.make_reader()
607 }
608
609 pub fn make_reader_ext(&mut self, ext: Option<u32>) -> impl AsyncRead + '_ {
612 self.read_half.make_reader_ext(ext)
613 }
614
615 pub fn make_writer(&self) -> impl AsyncWrite + 'static {
618 self.write_half.make_writer()
619 }
620
621 pub fn make_writer_ext(&self, ext: Option<u32>) -> impl AsyncWrite + 'static {
624 self.write_half.make_writer_ext(ext)
625 }
626}