use std::io::Cursor;
use binrw::{BinRead, BinReaderExt, BinWrite};
use ntlmclient::Flags;
use rasn::{AsnType, Decode, Encode, Decoder, Encoder, oid, der};
use rasn::Codec::Der;
use rasn::types::{Enumerated, ObjectIdentifier, Oid};
use rasn::types::OctetString;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::{command, Header, SessionSetup1Req, SmbOptions};
const OID1: &[u32] = &[1, 3, 6, 1, 5, 5, 2];
const OID2: &[u32] = &[1, 3, 6, 1, 4, 1, 311, 2, 2, 10];
impl SmbOptions<'_> {
pub async fn SessionSetupRequset1(&mut self, mut stream: &mut TcpStream, n: Header) -> anyhow::Result<()> {
let mut header = Header::new();
header.credit_charge = 1;
header.command = command::CommandSessionSetup as u16;
header.message_id = n.session_id + 1;
header.session_id = n.session_id;
let f
= Flags::NEGOTIATE_56BIT
| Flags::NEGOTIATE_128BIT
| Flags::NEGOTIATE_TARGET_INFO
| Flags::NEGOTIATE_NTLM2_KEY
| Flags::NEGOTIATE_DOMAIN_SUPPLIED
| Flags::NEGOTIATE_NTLM
| Flags::REQUEST_TARGET
| Flags::NEGOTIATE_UNICODE
;
let mut req = SessionSetup1Req {
Header: header,
StructureSize: 25,
Flags: 0x00,
SecurityMode: 1,
Capabilities: 0,
Channel: 0,
SecurityBufferOffset: 88,
SecurityBufferLength: 0,
PreviousSessionID: 0,
};
let (a, b) = self.GeneraeSecBlob()?;
req.SecurityBufferLength = b;
let mut cur = Cursor::new(Vec::new());
req.write_le(&mut cur)?;
let mut data = cur.into_inner();
data.extend_from_slice(&a);
let mut metadata = (data.len() as u32).to_be_bytes().to_vec();
metadata.extend_from_slice(&data);
stream.write_all(&metadata).await?;
stream.flush().await?;
Ok(())
}
pub async fn SessionSetUpResponse1(&mut self, mut stream: &mut TcpStream,) -> anyhow::Result<()> {
let mut length_header:[u8;4] = [0;4];
let _ = stream.read_exact(&mut length_header).await?;
let mut bb:[u8;64] = [0;64];
let _ = stream.read_exact(&mut bb).await?;
let mut cur = Cursor::new(bb);
let mut sessionRespHeader:Header = cur.read_le()?;
let ssesion_id = sessionRespHeader.session_id;
let resp_length = u32::from_be_bytes(length_header) as usize;
let mut session_resp_header:[u8;2] = [0;2];
let mut blob_offset:usize = 0;
let mut blob_length:usize = 0;
for i in 1..5 {
let _ = stream.read_exact(&mut session_resp_header).await?;
if i == 3{
blob_offset = u16::from_le_bytes(session_resp_header) as usize;
}else if i ==4 {
blob_length = u16::from_le_bytes(session_resp_header) as usize;
}
}
let mut start_position:usize = 0;
if blob_offset > 72 {
start_position = blob_offset - 72
}
let mut secBlobdataTemp: Vec<u8> = vec![0; resp_length - 72];
let _ = stream.read_exact(&mut secBlobdataTemp).await?;
let secBlobData = &secBlobdataTemp[start_position..];
let secBlob:SecBlob2 = der::decode(secBlobData)?;
let secProvider = secBlob.secPro.data.unwrap().to_vec();
self.sesionSetup1RespHeader = sessionRespHeader;
self.sessionSetup1RespSecProvider = secProvider;
Ok(())
}
}
#[derive(AsnType, Decode, Encode)]
#[rasn(tag(application, 0))]
struct SecBlob1 {
Oid: Option<ObjectIdentifier>,
#[rasn(tag(explicit(0)))]
negoInit: NegoInit
}
impl SecBlob1 {
pub fn new() -> anyhow::Result<Self> {
let oidString1 = Oid::new(OID1).unwrap();
let oidString2 = Oid::new(OID2).unwrap();
let s = SecBlob1 {
Oid: Some(ObjectIdentifier::from(oidString1)),
negoInit: NegoInit {
mechTypes: MechTypes {
mechType: Some(ObjectIdentifier::from(oidString2)),
},
mechTokens: MechTokens {
data: Some(OctetString::new()),
}
}
};
Ok(s)
}
}
#[derive(AsnType, Decode, Encode)]
struct NegoInit {
#[rasn(tag(explicit(0)))]
mechTypes: MechTypes,
#[rasn(tag(context, 2))]
mechTokens: MechTokens,
}
#[derive(AsnType, Decode, Encode)]
struct MechTypes {
mechType: Option<ObjectIdentifier>,
}
#[derive(AsnType, Decode, Encode)]
struct MechTokens {
data: Option<OctetString>
}
#[derive(BinRead, BinWrite, Debug)]
struct NtmlSecProvider {
identifier: [u8; 8],
messageType: [u8; 4],
negoFlags: [u8; 4],
domainLen: [u8; 2],
domainMaxLen: [u8; 2],
domainOffset: [u8; 4],
workstationLen: [u8; 2],
workstationMaxLen: [u8; 2],
workstationOffset: [u8; 4],
#[br(count = u16::from_le_bytes(domainLen))]
domainName: Vec<u8>,
#[br(count = u16::from_le_bytes(workstationLen))]
workstationName: Vec<u8>,
}
impl SmbOptions<'_> {
fn GeneraeSecBlob(&mut self) -> anyhow::Result<(Vec<u8>, u16)> {
let mut blob = SecBlob1::new()?;
let f
= Flags::NEGOTIATE_56BIT
| Flags::NEGOTIATE_128BIT
| Flags::NEGOTIATE_TARGET_INFO
| Flags::NEGOTIATE_NTLM2_KEY
| Flags::NEGOTIATE_DOMAIN_SUPPLIED
| Flags::NEGOTIATE_NTLM
| Flags::REQUEST_TARGET
| Flags::NEGOTIATE_UNICODE
;
let signature= *b"NTLMSSP\x00";
let message_type= 1u32.to_le_bytes();
let NegotiateFlags = f.bits().to_le_bytes();
let DomainName = self.Domain.as_bytes().to_vec();
let DomainName_len = DomainName.len() as u16;
let Workstation = self.Workstation.to_string().into_bytes();
let Workstation_len = Workstation.len() as u16;
let mut a = NtmlSecProvider {
identifier: *b"NTLMSSP\x00",
messageType: message_type,
negoFlags: NegotiateFlags,
domainLen: DomainName_len.to_le_bytes(),
domainMaxLen: DomainName_len.to_le_bytes(),
domainOffset: [0,0,0,0],
workstationLen: Workstation_len.to_le_bytes(),
workstationMaxLen: Workstation_len.to_le_bytes(),
workstationOffset: [0,0,0,0],
domainName: DomainName,
workstationName: Workstation,
};
let providerLen = 8 + 4 + 4 + 2 + 2 + 4 + 2 + 2 + 4 + DomainName_len + Workstation_len;
a.workstationOffset = ((providerLen - Workstation_len) as u32).to_le_bytes();
a.domainOffset = ((providerLen - DomainName_len - Workstation_len) as u32).to_le_bytes();
let mut c = Cursor::new(Vec::<u8>::new());
a.write_le(&mut c)?;
let NTLMSSP_DATA = c.into_inner();
blob.negoInit.mechTokens.data = Some(OctetString::from(NTLMSSP_DATA));
let data = der::encode(&blob)?;
let len = data.len();
Ok((data, len as u16))
}
}
#[derive(AsnType, Debug, PartialEq, Encode, Decode, Copy, Clone)]
#[rasn(enumerated)]
enum ErrCode {
InitialRequest = 0,
NtlmChallenge = 1,
KerberosAuth = 2,
}
#[derive(AsnType, Encode, Decode, Debug, PartialEq)]
#[rasn(tag(explicit(1)))]
struct SecBlob2 {
#[rasn(tag(context, 0))]
negReuslt: NegResult,
#[rasn(tag(context, 1))]
supportMech: SuportMech,
#[rasn(tag(context, 2))]
pub secPro: SecProvider,
}
#[derive(AsnType, Encode, Decode, Debug, PartialEq)]
struct NegResult {
result: ErrCode,
}
#[derive(AsnType, Encode, Decode, Debug, PartialEq)]
struct SuportMech {
pub oid: Option<ObjectIdentifier>,
}
#[derive(AsnType, Encode, Decode, Debug, PartialEq)]
struct SecProvider {
pub data: Option<OctetString>,
}