1use std::ops::DerefMut as _;
16
17use thiserror::Error;
18use zeroize::Zeroizing;
19
20pub type Buffer = Zeroizing<Vec<u8>>;
22
23#[derive(Debug, Error)]
24pub enum Error {
25 #[error("Index out of bounds")]
27 IndexOutOfBounds,
28}
29
30pub trait Encodable: Sized {
31 type Error: std::error::Error + Send + Sync + 'static;
32
33 fn read(reader: &mut Cursor) -> Result<Self, Self::Error>;
35 fn write<E: Encoding>(&self, buf: &mut E);
37}
38
39pub trait Encoding {
41 fn extend_ssh_string(&mut self, s: &[u8]);
43 fn extend_ssh_string_blank(&mut self, s: usize) -> &mut [u8];
45 fn extend_ssh_mpint(&mut self, s: &[u8]);
47 fn extend_list<'a, I: Iterator<Item = &'a [u8]>>(&mut self, list: I);
49 fn extend_u32(&mut self, u: u32);
51 fn write_empty_list(&mut self);
53 fn write_len(&mut self);
55 fn extend_usize(&mut self, u: usize) {
63 self.extend_u32(u.try_into().unwrap())
64 }
65}
66
67pub fn mpint_len(s: &[u8]) -> usize {
69 let mut i = 0;
70 while i < s.len() && s[i] == 0 {
71 i += 1
72 }
73 (if s[i] & 0x80 != 0 { 5 } else { 4 }) + s.len() - i
74}
75
76impl Encoding for Vec<u8> {
77 fn extend_ssh_string(&mut self, s: &[u8]) {
78 self.extend_usize(s.len());
79 self.extend(s);
80 }
81
82 fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] {
83 self.extend_usize(len);
84 let current = self.len();
85 self.resize(current + len, 0u8);
86
87 &mut self[current..]
88 }
89
90 fn extend_ssh_mpint(&mut self, s: &[u8]) {
91 let mut i = 0;
93 while i < s.len() && s[i] == 0 {
94 i += 1
95 }
96 if s[i] & 0x80 != 0 {
98 self.extend_usize(s.len() - i + 1);
99 self.push(0)
100 } else {
101 self.extend_usize(s.len() - i);
102 }
103 self.extend(&s[i..]);
104 }
105
106 fn extend_u32(&mut self, s: u32) {
107 self.extend(s.to_be_bytes());
108 }
109
110 fn extend_list<'a, I: Iterator<Item = &'a [u8]>>(&mut self, list: I) {
111 let len0 = self.len();
112
113 let mut first = true;
114 for i in list {
115 if !first {
116 self.push(b',')
117 } else {
118 first = false;
119 }
120 self.extend(i)
121 }
122 let len = (self.len() - len0 - 4) as u32;
123
124 self.splice(len0..len0, len.to_be_bytes());
125 }
126
127 fn write_empty_list(&mut self) {
128 self.extend([0, 0, 0, 0]);
129 }
130
131 fn write_len(&mut self) {
132 let len = self.len() - 4;
133 self[..4].copy_from_slice((len as u32).to_be_bytes().as_slice());
134 }
135}
136
137impl Encoding for Buffer {
138 fn extend_ssh_string(&mut self, s: &[u8]) {
139 self.deref_mut().extend_ssh_string(s)
140 }
141
142 fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] {
143 self.deref_mut().extend_ssh_string_blank(len)
144 }
145
146 fn extend_ssh_mpint(&mut self, s: &[u8]) {
147 self.deref_mut().extend_ssh_mpint(s)
148 }
149
150 fn extend_list<'a, I: Iterator<Item = &'a [u8]>>(&mut self, list: I) {
151 self.deref_mut().extend_list(list)
152 }
153
154 fn write_empty_list(&mut self) {
155 self.deref_mut().write_empty_list()
156 }
157
158 fn extend_u32(&mut self, s: u32) {
159 self.deref_mut().extend_u32(s);
160 }
161
162 fn write_len(&mut self) {
163 self.deref_mut().write_len()
164 }
165}
166
167pub trait Reader {
169 fn reader(&self, starting_at: usize) -> Cursor;
171}
172
173impl Reader for Buffer {
174 fn reader(&self, starting_at: usize) -> Cursor {
175 Cursor {
176 s: self,
177 position: starting_at,
178 }
179 }
180}
181
182impl Reader for [u8] {
183 fn reader(&self, starting_at: usize) -> Cursor {
184 Cursor {
185 s: self,
186 position: starting_at,
187 }
188 }
189}
190
191#[derive(Debug)]
193pub struct Cursor<'a> {
194 s: &'a [u8],
195 #[doc(hidden)]
196 pub position: usize,
197}
198
199impl<'a> Cursor<'a> {
200 pub fn read_string(&mut self) -> Result<&'a [u8], Error> {
202 let len = self.read_u32()? as usize;
203 if self.position + len <= self.s.len() {
204 let result = &self.s[self.position..(self.position + len)];
205 self.position += len;
206 Ok(result)
207 } else {
208 Err(Error::IndexOutOfBounds)
209 }
210 }
211
212 pub fn read_u32(&mut self) -> Result<u32, Error> {
214 if self.position + 4 <= self.s.len() {
215 let u =
216 u32::from_be_bytes(self.s[self.position..self.position + 4].try_into().unwrap());
217 self.position += 4;
218 Ok(u)
219 } else {
220 Err(Error::IndexOutOfBounds)
221 }
222 }
223
224 pub fn read_byte(&mut self) -> Result<u8, Error> {
226 if self.position < self.s.len() {
227 let u = self.s[self.position];
228 self.position += 1;
229 Ok(u)
230 } else {
231 Err(Error::IndexOutOfBounds)
232 }
233 }
234
235 pub fn read_bytes<const S: usize>(&mut self) -> Result<[u8; S], Error> {
236 let mut buf = [0; S];
237 for b in buf.iter_mut() {
238 *b = self.read_byte()?;
239 }
240 Ok(buf)
241 }
242
243 pub fn read_mpint(&mut self) -> Result<&'a [u8], Error> {
245 let len = self.read_u32()? as usize;
246 if self.position + len <= self.s.len() {
247 let result = &self.s[self.position..(self.position + len)];
248 self.position += len;
249 Ok(result)
250 } else {
251 Err(Error::IndexOutOfBounds)
252 }
253 }
254}