1use std::io::{Read, Write};
2use std::time::Duration;
3use tracing::*;
4
5use crate::{
6 constant::{self, CLIENT_VERSION, SSH_MAGIC},
7 error::{SshError, SshResult},
8 model::Timeout,
9};
10
11#[derive(Debug, Clone)]
12pub(crate) struct SshVersion {
13 pub client_ver: String,
14 pub server_ver: String,
15}
16
17impl Default for SshVersion {
18 fn default() -> Self {
19 Self {
20 client_ver: CLIENT_VERSION.to_owned(),
21 server_ver: String::new(),
22 }
23 }
24}
25
26fn read_version<S>(stream: &mut S, tm: Option<Duration>) -> SshResult<Vec<u8>>
41where
42 S: Read,
43{
44 let mut ch = vec![0; 1];
45 const LF: u8 = 0xa;
46 let crlf = vec![0xd, 0xa];
47 let mut outbuf = vec![];
48 let mut timeout = Timeout::new(tm);
49 loop {
50 match stream.read(&mut ch) {
51 Ok(i) => {
52 if 0 == i {
53 return Ok(outbuf);
55 }
56
57 outbuf.extend_from_slice(&ch);
58
59 if LF == ch[0] && outbuf.len() > 1 && outbuf.ends_with(&crlf) {
60 if outbuf.len() < 4 || &outbuf[0..4] != SSH_MAGIC {
67 outbuf.clear();
70 continue;
71 }
72 return Ok(outbuf);
73 }
74 timeout.renew();
75 }
76 Err(e) => {
77 if let std::io::ErrorKind::WouldBlock = e.kind() {
78 timeout.till_next_tick()?;
79 continue;
80 } else {
81 return Err(e.into());
82 }
83 }
84 };
85 }
86}
87
88impl SshVersion {
89 pub fn read_server_version<S>(
90 &mut self,
91 stream: &mut S,
92 timeout: Option<Duration>,
93 ) -> SshResult<()>
94 where
95 S: Read,
96 {
97 let buf = read_version(stream, timeout)?;
98 if buf.len() < 4 || &buf[0..4] != SSH_MAGIC {
99 error!("SSH version magic doesn't match");
100 error!("Probably not an ssh server");
101 }
102 let from_utf8 = String::from_utf8(buf)?;
103 let version_str = from_utf8.trim();
104 info!("server version: [{}]", version_str);
105
106 self.server_ver = version_str.to_owned();
107 Ok(())
108 }
109
110 pub fn send_our_version<S>(&self, stream: &mut S) -> SshResult<()>
111 where
112 S: Write,
113 {
114 info!("client version: [{}]", self.client_ver);
115 let ver_string = format!("{}\r\n", self.client_ver);
116 let _ = stream.write(ver_string.as_bytes())?;
117 Ok(())
118 }
119
120 pub fn validate(&self) -> SshResult<()> {
121 if self.server_ver.contains("SSH-2.0") {
122 Ok(())
123 } else {
124 error!("error in version negotiation, version mismatch.");
125 Err(SshError::VersionDismatchError {
126 our: constant::CLIENT_VERSION.to_owned(),
127 their: self.server_ver.clone(),
128 })
129 }
130 }
131}