radicle_ssh/
encoding.rs

1// Copyright 2016 Pierre-Étienne Meunier
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15use std::ops::DerefMut as _;
16
17use thiserror::Error;
18use zeroize::Zeroizing;
19
20/// General purpose writable byte buffer we use everywhere.
21pub type Buffer = Zeroizing<Vec<u8>>;
22
23#[derive(Debug, Error)]
24pub enum Error {
25    /// Index out of bounds
26    #[error("Index out of bounds")]
27    IndexOutOfBounds,
28}
29
30pub trait Encodable: Sized {
31    type Error: std::error::Error + Send + Sync + 'static;
32
33    /// Read from the SSH format.
34    fn read(reader: &mut Cursor) -> Result<Self, Self::Error>;
35    /// Write to the SSH format.
36    fn write<E: Encoding>(&self, buf: &mut E);
37}
38
39/// Encode in the SSH format.
40pub trait Encoding {
41    /// Push an SSH-encoded string to `self`.
42    fn extend_ssh_string(&mut self, s: &[u8]);
43    /// Push an SSH-encoded blank string of length `s` to `self`.
44    fn extend_ssh_string_blank(&mut self, s: usize) -> &mut [u8];
45    /// Push an SSH-encoded multiple-precision integer.
46    fn extend_ssh_mpint(&mut self, s: &[u8]);
47    /// Push an SSH-encoded list.
48    fn extend_list<'a, I: Iterator<Item = &'a [u8]>>(&mut self, list: I);
49    /// Push an SSH-encoded unsigned 32-bit integer.
50    fn extend_u32(&mut self, u: u32);
51    /// Push an SSH-encoded empty list.
52    fn write_empty_list(&mut self);
53    /// Write the buffer length at the beginning of the buffer.
54    fn write_len(&mut self);
55    /// Push a [`usize`] as an SSH-encoded unsiged 32-bit integer.
56    /// May panic if the argument is greater than [`u32::MAX`].
57    /// This is a convience method, to spare callers casting or converting
58    /// [`usize`] to [`u32`]. If callers end up in a situation where they
59    /// need to push a 32-bit unisgned integer, but the value they would
60    /// like to push does not fit 32 bits, then the implementation will not
61    /// comply with the SSH format anyway.
62    fn extend_usize(&mut self, u: usize) {
63        self.extend_u32(u.try_into().unwrap())
64    }
65}
66
67/// Encoding length of the given mpint.
68pub fn mpint_len(s: &[u8]) -> usize {
69    let mut i = 0;
70    while i < s.len() && s[i] == 0 {
71        i += 1
72    }
73    (if s[i] & 0x80 != 0 { 5 } else { 4 }) + s.len() - i
74}
75
76impl Encoding for Vec<u8> {
77    fn extend_ssh_string(&mut self, s: &[u8]) {
78        self.extend_usize(s.len());
79        self.extend(s);
80    }
81
82    fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] {
83        self.extend_usize(len);
84        let current = self.len();
85        self.resize(current + len, 0u8);
86
87        &mut self[current..]
88    }
89
90    fn extend_ssh_mpint(&mut self, s: &[u8]) {
91        // Skip initial 0s.
92        let mut i = 0;
93        while i < s.len() && s[i] == 0 {
94            i += 1
95        }
96        // If the first non-zero is >= 128, write its length (u32, BE), followed by 0.
97        if s[i] & 0x80 != 0 {
98            self.extend_usize(s.len() - i + 1);
99            self.push(0)
100        } else {
101            self.extend_usize(s.len() - i);
102        }
103        self.extend(&s[i..]);
104    }
105
106    fn extend_u32(&mut self, s: u32) {
107        self.extend(s.to_be_bytes());
108    }
109
110    fn extend_list<'a, I: Iterator<Item = &'a [u8]>>(&mut self, list: I) {
111        let len0 = self.len();
112
113        let mut first = true;
114        for i in list {
115            if !first {
116                self.push(b',')
117            } else {
118                first = false;
119            }
120            self.extend(i)
121        }
122        let len = (self.len() - len0 - 4) as u32;
123
124        self.splice(len0..len0, len.to_be_bytes());
125    }
126
127    fn write_empty_list(&mut self) {
128        self.extend([0, 0, 0, 0]);
129    }
130
131    fn write_len(&mut self) {
132        let len = self.len() - 4;
133        self[..4].copy_from_slice((len as u32).to_be_bytes().as_slice());
134    }
135}
136
137impl Encoding for Buffer {
138    fn extend_ssh_string(&mut self, s: &[u8]) {
139        self.deref_mut().extend_ssh_string(s)
140    }
141
142    fn extend_ssh_string_blank(&mut self, len: usize) -> &mut [u8] {
143        self.deref_mut().extend_ssh_string_blank(len)
144    }
145
146    fn extend_ssh_mpint(&mut self, s: &[u8]) {
147        self.deref_mut().extend_ssh_mpint(s)
148    }
149
150    fn extend_list<'a, I: Iterator<Item = &'a [u8]>>(&mut self, list: I) {
151        self.deref_mut().extend_list(list)
152    }
153
154    fn write_empty_list(&mut self) {
155        self.deref_mut().write_empty_list()
156    }
157
158    fn extend_u32(&mut self, s: u32) {
159        self.deref_mut().extend_u32(s);
160    }
161
162    fn write_len(&mut self) {
163        self.deref_mut().write_len()
164    }
165}
166
167/// A cursor-like trait to read SSH-encoded things.
168pub trait Reader {
169    /// Create an SSH reader for `self`.
170    fn reader(&self, starting_at: usize) -> Cursor;
171}
172
173impl Reader for Buffer {
174    fn reader(&self, starting_at: usize) -> Cursor {
175        Cursor {
176            s: self,
177            position: starting_at,
178        }
179    }
180}
181
182impl Reader for [u8] {
183    fn reader(&self, starting_at: usize) -> Cursor {
184        Cursor {
185            s: self,
186            position: starting_at,
187        }
188    }
189}
190
191/// A cursor-like type to read SSH-encoded values.
192#[derive(Debug)]
193pub struct Cursor<'a> {
194    s: &'a [u8],
195    #[doc(hidden)]
196    pub position: usize,
197}
198
199impl<'a> Cursor<'a> {
200    /// Read one string from this reader.
201    pub fn read_string(&mut self) -> Result<&'a [u8], Error> {
202        let len = self.read_u32()? as usize;
203        if self.position + len <= self.s.len() {
204            let result = &self.s[self.position..(self.position + len)];
205            self.position += len;
206            Ok(result)
207        } else {
208            Err(Error::IndexOutOfBounds)
209        }
210    }
211
212    /// Read a `u32` from this reader.
213    pub fn read_u32(&mut self) -> Result<u32, Error> {
214        if self.position + 4 <= self.s.len() {
215            let u =
216                u32::from_be_bytes(self.s[self.position..self.position + 4].try_into().unwrap());
217            self.position += 4;
218            Ok(u)
219        } else {
220            Err(Error::IndexOutOfBounds)
221        }
222    }
223
224    /// Read one byte from this reader.
225    pub fn read_byte(&mut self) -> Result<u8, Error> {
226        if self.position < self.s.len() {
227            let u = self.s[self.position];
228            self.position += 1;
229            Ok(u)
230        } else {
231            Err(Error::IndexOutOfBounds)
232        }
233    }
234
235    pub fn read_bytes<const S: usize>(&mut self) -> Result<[u8; S], Error> {
236        let mut buf = [0; S];
237        for b in buf.iter_mut() {
238            *b = self.read_byte()?;
239        }
240        Ok(buf)
241    }
242
243    /// Read one byte from this reader.
244    pub fn read_mpint(&mut self) -> Result<&'a [u8], Error> {
245        let len = self.read_u32()? as usize;
246        if self.position + len <= self.s.len() {
247            let result = &self.s[self.position..(self.position + len)];
248            self.position += len;
249            Ok(result)
250        } else {
251            Err(Error::IndexOutOfBounds)
252        }
253    }
254}