bssh_russh/
compression.rs

1use std::convert::TryFrom;
2
3use delegate::delegate;
4use ssh_encoding::Encode;
5
6#[derive(Debug, Clone)]
7pub enum Compression {
8    None,
9    #[cfg(feature = "flate2")]
10    Zlib,
11}
12
13#[derive(Debug)]
14pub enum Compress {
15    None,
16    #[cfg(feature = "flate2")]
17    Zlib(flate2::Compress),
18}
19
20#[derive(Debug)]
21pub enum Decompress {
22    None,
23    #[cfg(feature = "flate2")]
24    Zlib(flate2::Decompress),
25}
26
27#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash)]
28pub struct Name(&'static str);
29impl AsRef<str> for Name {
30    fn as_ref(&self) -> &str {
31        self.0
32    }
33}
34
35impl Encode for Name {
36    delegate! { to self.as_ref() {
37        fn encoded_len(&self) -> Result<usize, ssh_encoding::Error>;
38        fn encode(&self, writer: &mut impl ssh_encoding::Writer) -> Result<(), ssh_encoding::Error>;
39    }}
40}
41
42impl TryFrom<&str> for Name {
43    type Error = ();
44    fn try_from(s: &str) -> Result<Name, ()> {
45        ALL_COMPRESSION_ALGORITHMS
46            .iter()
47            .find(|x| x.0 == s)
48            .map(|x| **x)
49            .ok_or(())
50    }
51}
52
53pub const NONE: Name = Name("none");
54#[cfg(feature = "flate2")]
55pub const ZLIB: Name = Name("zlib");
56#[cfg(feature = "flate2")]
57pub const ZLIB_LEGACY: Name = Name("zlib@openssh.com");
58
59pub const ALL_COMPRESSION_ALGORITHMS: &[&Name] = &[
60    &NONE,
61    #[cfg(feature = "flate2")]
62    &ZLIB,
63    #[cfg(feature = "flate2")]
64    &ZLIB_LEGACY,
65];
66
67#[cfg(feature = "flate2")]
68impl Compression {
69    pub fn new(name: &Name) -> Self {
70        if name == &ZLIB || name == &ZLIB_LEGACY {
71            Compression::Zlib
72        } else {
73            Compression::None
74        }
75    }
76
77    pub fn init_compress(&self, comp: &mut Compress) {
78        if let Compression::Zlib = *self {
79            if let Compress::Zlib(ref mut c) = *comp {
80                c.reset()
81            } else {
82                *comp = Compress::Zlib(flate2::Compress::new(flate2::Compression::fast(), true))
83            }
84        } else {
85            *comp = Compress::None
86        }
87    }
88
89    pub fn init_decompress(&self, comp: &mut Decompress) {
90        if let Compression::Zlib = *self {
91            if let Decompress::Zlib(ref mut c) = *comp {
92                c.reset(true)
93            } else {
94                *comp = Decompress::Zlib(flate2::Decompress::new(true))
95            }
96        } else {
97            *comp = Decompress::None
98        }
99    }
100}
101
102#[cfg(not(feature = "flate2"))]
103impl Compression {
104    pub fn new(_name: &Name) -> Self {
105        Compression::None
106    }
107
108    pub fn init_compress(&self, _: &mut Compress) {}
109
110    pub fn init_decompress(&self, _: &mut Decompress) {}
111}
112
113#[cfg(not(feature = "flate2"))]
114impl Compress {
115    pub fn compress<'a>(
116        &mut self,
117        input: &'a [u8],
118        _: &'a mut russh_cryptovec::CryptoVec,
119    ) -> Result<&'a [u8], crate::Error> {
120        Ok(input)
121    }
122}
123
124#[cfg(not(feature = "flate2"))]
125impl Decompress {
126    pub fn decompress<'a>(
127        &mut self,
128        input: &'a [u8],
129        _: &'a mut russh_cryptovec::CryptoVec,
130    ) -> Result<&'a [u8], crate::Error> {
131        Ok(input)
132    }
133}
134
135#[cfg(feature = "flate2")]
136impl Compress {
137    pub fn compress<'a>(
138        &mut self,
139        input: &'a [u8],
140        output: &'a mut russh_cryptovec::CryptoVec,
141    ) -> Result<&'a [u8], crate::Error> {
142        match *self {
143            Compress::None => Ok(input),
144            Compress::Zlib(ref mut z) => {
145                output.clear();
146                let n_in = z.total_in() as usize;
147                let n_out = z.total_out() as usize;
148                output.resize(input.len() + 10);
149                let flush = flate2::FlushCompress::Partial;
150                loop {
151                    let n_in_ = z.total_in() as usize - n_in;
152                    let n_out_ = z.total_out() as usize - n_out;
153                    #[allow(clippy::indexing_slicing)] // length checked
154                    let c = z.compress(&input[n_in_..], &mut output[n_out_..], flush)?;
155                    match c {
156                        flate2::Status::BufError => {
157                            output.resize(output.len() * 2);
158                        }
159                        _ => break,
160                    }
161                }
162                let n_out_ = z.total_out() as usize - n_out;
163                #[allow(clippy::indexing_slicing)] // length checked
164                Ok(&output[..n_out_])
165            }
166        }
167    }
168}
169
170#[cfg(feature = "flate2")]
171impl Decompress {
172    pub fn decompress<'a>(
173        &mut self,
174        input: &'a [u8],
175        output: &'a mut russh_cryptovec::CryptoVec,
176    ) -> Result<&'a [u8], crate::Error> {
177        match *self {
178            Decompress::None => Ok(input),
179            Decompress::Zlib(ref mut z) => {
180                output.clear();
181                let n_in = z.total_in() as usize;
182                let n_out = z.total_out() as usize;
183                output.resize(input.len());
184                let flush = flate2::FlushDecompress::None;
185                loop {
186                    let n_in_ = z.total_in() as usize - n_in;
187                    let n_out_ = z.total_out() as usize - n_out;
188                    #[allow(clippy::indexing_slicing)] // length checked
189                    let d = z.decompress(&input[n_in_..], &mut output[n_out_..], flush);
190                    match d? {
191                        flate2::Status::Ok => {
192                            output.resize(output.len() * 2);
193                        }
194                        _ => break,
195                    }
196                }
197                let n_out_ = z.total_out() as usize - n_out;
198                #[allow(clippy::indexing_slicing)] // length checked
199                Ok(&output[..n_out_])
200            }
201        }
202    }
203}