#include "protoPktTCP.h"
ProtoPktTCP::ProtoPktTCP(void* bufferPtr,
unsigned int numBytes,
bool initFromBuffer,
bool freeOnDestruct)
: ProtoPkt(bufferPtr, numBytes, freeOnDestruct)
{
if (NULL != bufferPtr)
{
if (initFromBuffer)
InitFromBuffer();
else
InitIntoBuffer();
}
}
ProtoPktTCP::~ProtoPktTCP()
{
}
bool ProtoPktTCP::InitFromPacket(ProtoPktIP& ipPkt)
{
switch (ipPkt.GetVersion())
{
case 4:
{
ProtoPktIPv4 ip4Pkt(ipPkt);
if (ProtoPktIP::TCP == ip4Pkt.GetProtocol())
{
return InitFromBuffer(ip4Pkt.AccessPayload(), ip4Pkt.GetPayloadLength(), false);
}
else
{
return false; }
break;
}
case 6:
{
ProtoPktIPv6 ip6Pkt(ipPkt);
if (ip6Pkt.HasExtendedHeader())
{
unsigned int extHeaderLength = 0;
ProtoPktIPv6::Extension::Iterator extIterator(ip6Pkt);
ProtoPktIPv6::Extension ext;
while (extIterator.GetNextExtension(ext))
{
extHeaderLength += ext.GetLength();
if (ProtoPktIP::TCP == ext.GetNextHeader())
{
void* tcpBuffer = (char*)ip6Pkt.AccessPayload() + extHeaderLength;
unsigned int tcpLength = ip6Pkt.GetPayloadLength() - extHeaderLength;
return InitFromBuffer(tcpBuffer, tcpLength, false);
}
}
return false; }
else if (ProtoPktIP::TCP == ip6Pkt.GetNextHeader())
{
return InitFromBuffer(ip6Pkt.AccessPayload(), ip6Pkt.GetPayloadLength(), false);
}
else
{
return false; }
break;
}
default:
PLOG(PL_ERROR, "ProtoPktTCP::InitFromPacket() error: bad IP packet version: %d\n", ipPkt.GetVersion());
return false;
}
return true;
}
bool ProtoPktTCP::InitFromBuffer(void* bufferPtr,
unsigned int numBytes,
bool freeOnDestruct)
{
if (NULL != bufferPtr)
AttachBuffer(bufferPtr, numBytes, freeOnDestruct);
UINT16 totalLen = GetPayloadLength() + (OffsetPayload() << 2);
if (totalLen > GetBufferLength())
{
ProtoPkt::SetLength(0);
if (NULL != bufferPtr) DetachBuffer();
return false;
}
else
{
ProtoPkt::SetLength(totalLen);
return true;
}
}
bool ProtoPktTCP::InitIntoBuffer(void* bufferPtr,
unsigned int numBytes,
bool freeOnDestruct)
{
if (NULL != bufferPtr)
{
if (numBytes < 20) return false;
else
AttachBuffer(bufferPtr, numBytes, freeOnDestruct);
}
if (GetBufferLength() < 20) return false;
SetDataOffset(5);
ClearFlags();
SetChecksum(0);
return true;
}
UINT16 ProtoPktTCP::ComputeChecksum(ProtoPktIP& ipPkt) const
{
UINT32 sum = 0;
switch(ipPkt.GetVersion())
{
case 4:
{
ProtoPktIPv4 ipv4Pkt(ipPkt);
const UINT16* ptr = (const UINT16*)ipv4Pkt.GetSrcAddrPtr();
int addrEndex = ProtoPktIPv4::ADDR_LEN; for (int i = 0; i < addrEndex; i++)
sum += GetUINT16(ptr++);
sum += (UINT16)ipv4Pkt.GetProtocol();
sum += (UINT16)GetLength(); break;
}
case 6:
{
ProtoPktIPv6 ipv6Pkt(ipPkt);
const UINT16* ptr = (const UINT16*)ipv6Pkt.GetSrcAddrPtr();
int addrEndex = ProtoPktIPv6::ADDR_LEN; for (int i = 0; i < addrEndex; i++)
sum += GetUINT16(ptr++);
sum += (UINT16)GetLength(); sum += (UINT16)ipv6Pkt.GetNextHeader();
break;
}
default:
return 0;
}
unsigned int i;
for (i = 0; i < OFFSET_CHECKSUM; i++)
sum += GetWord16(i);
unsigned int dataEndex = GetLength();
if (0 != (dataEndex & 0x01))
sum += ((UINT16)GetUINT8(dataEndex-1) << 8);
dataEndex >>= 1; for (i = (OFFSET_CHECKSUM+1); i < dataEndex; i++)
sum += GetWord16(i);
while (0 != (sum >> 16))
sum = (sum & 0x0000ffff) + (sum >> 16);
sum = ~sum;
if (0 == sum) sum = 0x0000ffff;
return (UINT16)sum;
}