use core::fmt;
use std::borrow::Cow;
use std::num::Wrapping;
use super::cipher::SealingKey;
use compression::Compress;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use super::*;
#[derive(Debug)]
pub enum SshId {
Standard(Cow<'static, str>),
Raw(Cow<'static, str>),
}
impl SshId {
pub(crate) fn as_kex_hash_bytes(&self) -> &[u8] {
match self {
Self::Standard(s) => s.as_bytes(),
Self::Raw(s) => s.trim_end_matches(['\n', '\r']).as_bytes(),
}
}
pub(crate) fn write(&self, buffer: &mut Vec<u8>) {
match self {
Self::Standard(s) => buffer.extend_from_slice(format!("{s}\r\n").as_bytes()),
Self::Raw(s) => buffer.extend_from_slice(s.as_bytes()),
}
}
}
#[test]
fn test_ssh_id() {
let mut buffer = Vec::new();
SshId::Standard("SSH-2.0-acme".into()).write(&mut buffer);
assert_eq!(&buffer[..], b"SSH-2.0-acme\r\n");
let mut buffer = Vec::new();
SshId::Raw("SSH-2.0-raw\n".into()).write(&mut buffer);
assert_eq!(&buffer[..], b"SSH-2.0-raw\n");
assert_eq!(
SshId::Standard("SSH-2.0-acme".into()).as_kex_hash_bytes(),
b"SSH-2.0-acme"
);
assert_eq!(
SshId::Raw("SSH-2.0-raw\n".into()).as_kex_hash_bytes(),
b"SSH-2.0-raw"
);
}
#[derive(Debug, Default)]
pub struct SSHBuffer {
pub buffer: Vec<u8>,
pub len: usize, pub bytes: usize, pub seqn: Wrapping<u32>,
}
impl SSHBuffer {
pub fn new() -> Self {
SSHBuffer {
buffer: Vec::new(),
len: 0,
bytes: 0,
seqn: Wrapping(0),
}
}
pub fn send_ssh_id(&mut self, id: &SshId) {
id.write(&mut self.buffer);
}
}
#[derive(Debug)]
pub(crate) struct IncomingSshPacket {
pub buffer: Vec<u8>,
pub seqn: Wrapping<u32>,
}
pub(crate) struct PacketWriter {
cipher: Box<dyn SealingKey + Send>,
compress: Compress,
compress_buffer: Vec<u8>,
write_buffer: SSHBuffer,
}
impl Debug for PacketWriter {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("PacketWriter").finish()
}
}
impl PacketWriter {
pub fn clear() -> Self {
Self::new(Box::new(cipher::clear::Key {}), Compress::None)
}
pub fn new(cipher: Box<dyn SealingKey + Send>, compress: Compress) -> Self {
Self {
cipher,
compress,
compress_buffer: Vec::new(),
write_buffer: SSHBuffer::new(),
}
}
pub fn packet_raw(&mut self, buf: &[u8]) -> Result<(), Error> {
if let Some(message_type) = buf.first() {
debug!("> msg type {message_type:?}, len {}", buf.len());
let packet = self.compress.compress(buf, &mut self.compress_buffer)?;
self.cipher.write(packet, &mut self.write_buffer);
}
Ok(())
}
pub fn packet<F: FnOnce(&mut Vec<u8>) -> Result<(), Error>>(
&mut self,
f: F,
) -> Result<Vec<u8>, Error> {
let mut buf = Vec::new();
f(&mut buf)?;
self.packet_raw(&buf)?;
Ok(buf)
}
pub fn buffer(&mut self) -> &mut SSHBuffer {
&mut self.write_buffer
}
pub fn compress(&mut self) -> &mut Compress {
&mut self.compress
}
pub fn set_cipher(&mut self, cipher: Box<dyn SealingKey + Send>) {
self.cipher = cipher;
}
pub fn reset_seqn(&mut self) {
self.write_buffer.seqn = Wrapping(0);
}
pub async fn flush_into<W: AsyncWrite + Unpin>(&mut self, w: &mut W) -> std::io::Result<()> {
if !self.write_buffer.buffer.is_empty() {
w.write_all(&self.write_buffer.buffer).await?;
w.flush().await?;
self.write_buffer.buffer.clear();
}
Ok(())
}
}