1use nla::asn1::{ASN1, Sequence, ExplicitTag, SequenceOf, ASN1Type, OctetString, Integer, to_der};
2use model::error::{RdpError, RdpErrorKind, Error, RdpResult};
3use num_bigint::{BigUint};
4use yasna::Tag;
5use x509_parser::{parse_x509_der, X509Certificate};
6use nla::sspi::AuthenticationProtocol;
7use model::link::Link;
8use std::io::{Read, Write};
9
10pub fn create_ts_request(nego: Vec<u8>) -> Vec<u8> {
22    let ts_request = sequence![
23        "version" => ExplicitTag::new(Tag::context(0), 2 as Integer),
24        "negoTokens" => ExplicitTag::new(Tag::context(1),
25            sequence_of![
26                sequence![
27                    "negoToken" => ExplicitTag::new(Tag::context(0), nego)
28                ]
29            ])
30    ];
31    to_der(&ts_request)
32}
33
34pub fn read_ts_server_challenge(stream: &[u8]) -> RdpResult<Vec<u8>> {
51    let mut ts_request = sequence![
52        "version" => ExplicitTag::new(Tag::context(0), 2 as Integer),
53        "negoTokens" => ExplicitTag::new(Tag::context(1),
54            SequenceOf::reader(|| {
55                Box::new(sequence![
56                    "negoToken" => ExplicitTag::new(Tag::context(0), OctetString::new())
57                ])
58            })
59         )
60    ];
61
62    yasna::parse_der(stream, |reader| {
63        if let Err(Error::ASN1Error(e)) = ts_request.read_asn1(reader) {
64            return Err(e)
65        }
66        Ok(())
67    })?;
68
69    let nego_tokens = cast!(ASN1Type::SequenceOf, ts_request["negoTokens"]).unwrap();
70    let first_nego_tokens = cast!(ASN1Type::Sequence, nego_tokens.inner[0]).unwrap();
71    let nego_token = cast!(ASN1Type::OctetString, first_nego_tokens["negoToken"]).unwrap();
72    Ok(nego_token.to_vec())
73}
74
75pub fn create_ts_authenticate(nego: Vec<u8>, pub_key_auth: Vec<u8>) -> Vec<u8> {
88    let ts_challenge = sequence![
89        "version" => ExplicitTag::new(Tag::context(0), 2 as Integer),
90        "negoTokens" => ExplicitTag::new(Tag::context(1),
91            sequence_of![
92                sequence![
93                    "negoToken" => ExplicitTag::new(Tag::context(0), nego as OctetString)
94                ]
95            ]),
96        "pubKeyAuth" => ExplicitTag::new(Tag::context(3), pub_key_auth as OctetString)
97    ];
98
99    to_der(&ts_challenge)
100}
101
102pub fn read_public_certificate(stream: &[u8]) -> RdpResult<X509Certificate> {
103    let res = parse_x509_der(stream).unwrap();
104    Ok(res.1)
105}
106
107pub fn read_ts_validate(request: &[u8]) -> RdpResult<Vec<u8>> {
120    let mut ts_challenge = sequence![
121        "version" => ExplicitTag::new(Tag::context(0), 2 as Integer),
122        "pubKeyAuth" => ExplicitTag::new(Tag::context(3), OctetString::new())
123    ];
124
125    yasna::parse_der(request, |reader| {
126        if let Err(Error::ASN1Error(e)) = ts_challenge.read_asn1(reader) {
127            return Err(e)
128        }
129        Ok(())
130    })?;
131    let pubkey = cast!(ASN1Type::OctetString, ts_challenge["pubKeyAuth"])?;
132    Ok(pubkey.to_vec())
133}
134
135fn create_ts_credentials(domain: Vec<u8>, user: Vec<u8>, password: Vec<u8>) -> Vec<u8> {
136    let ts_password_creds = sequence![
137        "domainName" => ExplicitTag::new(Tag::context(0), domain as OctetString),
138        "userName" => ExplicitTag::new(Tag::context(1), user as OctetString),
139        "password" => ExplicitTag::new(Tag::context(2), password as OctetString)
140    ];
141
142    let ts_password_cred_encoded = yasna::construct_der(|writer| {
143        ts_password_creds.write_asn1(writer).unwrap();
144    });
145
146    let ts_credentials = sequence![
147        "credType" => ExplicitTag::new(Tag::context(0), 1 as Integer),
148        "credentials" => ExplicitTag::new(Tag::context(1), ts_password_cred_encoded as OctetString)
149    ];
150
151    to_der(&ts_credentials)
152}
153
154fn create_ts_authinfo(auth_info: Vec<u8>) -> Vec<u8> {
155    let ts_authinfo = sequence![
156        "version" => ExplicitTag::new(Tag::context(0), 2 as Integer),
157        "authInfo" => ExplicitTag::new(Tag::context(2), auth_info)
158    ];
159
160    to_der(&ts_authinfo)
161}
162
163pub fn cssp_connect<S: Read + Write>(link: &mut Link<S>, authentication_protocol: &mut dyn AuthenticationProtocol, restricted_admin_mode: bool) -> RdpResult<()> {
167    let negotiate_message = create_ts_request(authentication_protocol.create_negotiate_message()?);
169    link.write(&negotiate_message)?;
170
171    let server_challenge = read_ts_server_challenge(&(link.read(0)?))?;
173
174    let client_challenge = authentication_protocol.read_challenge_message(&server_challenge)?;
176
177    let mut security_interface = authentication_protocol.build_security_interface();
179
180    let certificate_der = try_option!(link.get_peer_certificate()?, "No public certificate available")?.to_der()?;
182    let certificate = read_public_certificate(&certificate_der)?;
183
184    let challenge = create_ts_authenticate(client_challenge, security_interface.gss_wrapex(certificate.tbs_certificate.subject_pki.subject_public_key.data)?);
186    link.write(&challenge)?;
187
188    let inc_pub_key = security_interface.gss_unwrapex(&(read_ts_validate(&(link.read(0)?))?))?;
190
191    if BigUint::from_bytes_le(&inc_pub_key) != BigUint::from_bytes_le(certificate.tbs_certificate.subject_pki.subject_public_key.data) + BigUint::new(vec![1]) {
193        return Err(Error::RdpError(RdpError::new(RdpErrorKind::PossibleMITM, "Man in the middle detected")))
194    }
195
196    let domain = if restricted_admin_mode { vec![] } else { authentication_protocol.get_domain_name()};
199    let user = if restricted_admin_mode { vec![] } else { authentication_protocol.get_user_name() };
200    let password = if restricted_admin_mode { vec![] } else { authentication_protocol.get_password() };
201
202    let credentials = create_ts_authinfo(security_interface.gss_wrapex(&create_ts_credentials(domain, user, password))?);
203    link.write(&credentials)?;
204
205    Ok(())
206}
207
208#[cfg(test)]
209mod test {
210    use super::*;
211
212    #[test]
213    fn test_create_ts_credentials() {
214        let credentials = create_ts_credentials(b"domain".to_vec(), b"user".to_vec(), b"password".to_vec());
215        let result =  [48, 41, 160, 3, 2, 1, 1, 161, 34, 4, 32, 48, 30, 160, 8, 4, 6, 100, 111, 109, 97, 105, 110, 161, 6, 4, 4, 117, 115, 101, 114, 162, 10, 4, 8, 112, 97, 115, 115, 119, 111, 114, 100];
216        assert_eq!(credentials[0..32], result[0..32]);
217        assert_eq!(credentials[33..43], result[33..43]);
218    }
219
220    #[test]
221    fn test_create_ts_authinfo() {
222        assert_eq!(create_ts_authinfo(b"foo".to_vec()), [48, 12, 160, 3, 2, 1, 2, 162, 5, 4, 3, 102, 111, 111])
223    }
224}