#define __CASCLIB_SELF__
#include "../CascLib.h"
#include "../CascCommon.h"
#ifdef CASCLIB_PLATFORM_WINDOWS
#include <ws2tcpip.h>
#endif
#define BUFFER_INITIAL_SIZE 0x8000
#ifndef INVALID_SOCKET
#define INVALID_SOCKET (SOCKET)(-1)
#endif
CASC_SOCKET_CACHE SocketCache;
static SOCKET inline HandleToSocket(HANDLE sock)
{
return (SOCKET)(intptr_t)(sock);
}
static HANDLE inline SocketToHandle(SOCKET sock)
{
return (HANDLE)(intptr_t)(sock);
}
char * CASC_SOCKET::ReadResponse(const char * request, size_t request_length, CASC_MIME_RESPONSE & MimeResponse)
{
char * new_server_response = NULL;
char * server_response = NULL;
size_t total_received = 0;
size_t buffer_length = BUFFER_INITIAL_SIZE;
size_t buffer_delta = BUFFER_INITIAL_SIZE;
DWORD dwErrCode = ERROR_SUCCESS;
int bytes_received = 0;
if(request_length == 0)
request_length = strlen(request);
CascLock(Lock);
while(send(HandleToSocket(sock), request, (int)request_length, MSG_NOSIGNAL) == SOCKET_ERROR)
{
if(ReconnectAfterShutdown(sock, remoteItem) == SocketToHandle(INVALID_SOCKET))
{
SetCascError(ERROR_NETWORK_NOT_AVAILABLE);
CascUnlock(Lock);
return NULL;
}
}
if((server_response = CASC_ALLOC_ZERO<char>(buffer_length + 1)) != NULL)
{
for(;;)
{
if(total_received == buffer_length)
{
if((new_server_response = CASC_REALLOC(server_response, buffer_length + buffer_delta + 1)) == NULL)
{
dwErrCode = ERROR_NOT_ENOUGH_MEMORY;
CASC_FREE(server_response);
break;
}
server_response = new_server_response;
buffer_length += buffer_delta;
buffer_delta = BUFFER_INITIAL_SIZE;
}
bytes_received = recv(HandleToSocket(sock), server_response + total_received, (int)(buffer_length - total_received), 0);
if(bytes_received <= 0)
{
MimeResponse.ParseResponse(server_response, total_received, true);
break;
}
if((total_received + bytes_received) < total_received)
{
dwErrCode = ERROR_NOT_ENOUGH_MEMORY;
break;
}
total_received += bytes_received;
server_response[total_received] = 0;
if(MimeResponse.ParseResponse(server_response, total_received, false))
break;
if(MimeResponse.clength_presence == FieldPresencePresent && MimeResponse.content_length != CASC_INVALID_SIZE_T)
{
size_t content_end = MimeResponse.content_offset + MimeResponse.content_length + 2;
if(content_end > CASC_MAX_ONLINE_FILE_SIZE)
{
dwErrCode = ERROR_NOT_ENOUGH_MEMORY;
break;
}
if(content_end > buffer_length)
{
buffer_delta = content_end - buffer_length;
}
}
}
}
CascUnlock(Lock);
if(dwErrCode != ERROR_SUCCESS)
{
CASC_FREE(server_response);
SetCascError(dwErrCode);
total_received = 0;
}
return server_response;
}
DWORD CASC_SOCKET::AddRef()
{
return CascInterlockedIncrement(&dwRefCount);
}
void CASC_SOCKET::Release()
{
if(CascInterlockedDecrement(&dwRefCount) == 0)
{
Delete();
}
}
int CASC_SOCKET::GetSockError()
{
#ifdef CASCLIB_PLATFORM_WINDOWS
return WSAGetLastError();
#else
return errno;
#endif
}
DWORD CASC_SOCKET::GetAddrInfoWrapper(const char * hostName, unsigned portNum, PADDRINFO hints, PADDRINFO * ppResult)
{
char portNumString[16];
CascStrPrintf(portNumString, _countof(portNumString), "%d", portNum);
for(;;)
{
DWORD dwErrCode = getaddrinfo(hostName, portNumString, hints, ppResult);
switch(dwErrCode)
{
#ifdef CASCLIB_PLATFORM_WINDOWS
case WSANOTINITIALISED: {
WSADATA wsd;
WSAStartup(MAKEWORD(2, 2), &wsd);
continue;
}
#endif
case (DWORD)EAI_AGAIN: continue;
default: return dwErrCode;
}
}
}
HANDLE CASC_SOCKET::CreateAndConnect(PADDRINFO remoteItem)
{
SOCKET sock;
if((sock = socket(remoteItem->ai_family, remoteItem->ai_socktype, remoteItem->ai_protocol)) > 0)
{
if(connect(sock, remoteItem->ai_addr, (int)remoteItem->ai_addrlen) == 0)
return SocketToHandle(sock);
closesocket(sock);
sock = INVALID_SOCKET;
}
return SocketToHandle(sock);
}
HANDLE CASC_SOCKET::ReconnectAfterShutdown(HANDLE & sock, PADDRINFO remoteItem)
{
switch(GetSockError())
{
case EPIPE: case WSAECONNRESET: {
if(sock != SocketToHandle(INVALID_SOCKET))
closesocket(HandleToSocket(sock));
sock = CreateAndConnect(remoteItem);
return sock;
}
}
return SocketToHandle(INVALID_SOCKET);
}
PCASC_SOCKET CASC_SOCKET::New(PADDRINFO remoteList, PADDRINFO remoteItem, const char * hostName, unsigned portNum, HANDLE sock)
{
PCASC_SOCKET pSocket;
size_t length = strlen(hostName);
pSocket = (PCASC_SOCKET)CASC_ALLOC<BYTE>(sizeof(CASC_SOCKET) + length);
if(pSocket != NULL)
{
memset(pSocket, 0, sizeof(CASC_SOCKET) + length);
pSocket->remoteList = remoteList;
pSocket->remoteItem = remoteItem;
pSocket->dwRefCount = 1;
pSocket->portNum = portNum;
pSocket->sock = sock;
CascStrCopy((char *)pSocket->hostName, length + 1, hostName);
CascInitLock(pSocket->Lock);
}
return pSocket;
}
PCASC_SOCKET CASC_SOCKET::Connect(const char * hostName, unsigned portNum)
{
PCASC_SOCKET pSocket;
addrinfo * remoteList;
addrinfo * remoteItem;
addrinfo hints = {0};
HANDLE sock;
int nErrCode;
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
nErrCode = GetAddrInfoWrapper(hostName, portNum, &hints, &remoteList);
if(nErrCode == 0)
{
for(remoteItem = remoteList; remoteItem != NULL; remoteItem = remoteItem->ai_next)
{
if((sock = CreateAndConnect(remoteItem)) != 0)
{
if((pSocket = CASC_SOCKET::New(remoteList, remoteItem, hostName, portNum, sock)) != NULL)
{
return pSocket;
}
closesocket(HandleToSocket(sock));
}
}
nErrCode = ERROR_NETWORK_NOT_AVAILABLE;
}
SetCascError(nErrCode);
return NULL;
}
void CASC_SOCKET::Delete()
{
PCASC_SOCKET pThis = this;
if(pCache != NULL)
pCache->UnlinkSocket(this);
pCache = NULL;
if(sock != 0)
closesocket(HandleToSocket(sock));
sock = 0;
CascFreeLock(Lock);
CASC_FREE(pThis);
}
CASC_SOCKET_CACHE::CASC_SOCKET_CACHE()
{
pFirst = pLast = NULL;
dwRefCount = 0;
}
CASC_SOCKET_CACHE::~CASC_SOCKET_CACHE()
{
PurgeAll();
}
PCASC_SOCKET CASC_SOCKET_CACHE::Find(const char * hostName, unsigned portNum)
{
PCASC_SOCKET pSocket;
for(pSocket = pFirst; pSocket != NULL; pSocket = pSocket->pNext)
{
if(!_stricmp(pSocket->hostName, hostName) && (pSocket->portNum == portNum))
break;
}
return pSocket;
}
PCASC_SOCKET CASC_SOCKET_CACHE::InsertSocket(PCASC_SOCKET pSocket)
{
if(pSocket != NULL && pSocket->pCache == NULL)
{
if(dwRefCount > 0)
{
pSocket->AddRef();
if(pFirst == NULL && pLast == NULL)
{
pFirst = pLast = pSocket;
}
else
{
pSocket->pPrev = pLast;
pLast->pNext = pSocket;
pLast = pSocket;
}
pSocket->pCache = this;
}
}
return pSocket;
}
void CASC_SOCKET_CACHE::UnlinkSocket(PCASC_SOCKET pSocket)
{
if(pSocket != NULL)
{
if(pSocket == pFirst)
pFirst = pSocket->pNext;
if(pSocket == pLast)
pLast = pSocket->pPrev;
if(pSocket->pPrev != NULL)
pSocket->pPrev->pNext = pSocket->pNext;
if(pSocket->pNext != NULL)
pSocket->pNext->pPrev = pSocket->pPrev;
}
}
void CASC_SOCKET_CACHE::SetCaching(bool bAddRef)
{
PCASC_SOCKET pSocket;
PCASC_SOCKET pNext;
if(bAddRef)
{
if(dwRefCount == 0)
{
for(pSocket = pFirst; pSocket != NULL; pSocket = pSocket->pNext)
pSocket->AddRef();
}
CascInterlockedIncrement(&dwRefCount);
}
else
{
assert(dwRefCount > 0);
if(CascInterlockedDecrement(&dwRefCount) == 0)
{
for(pSocket = pFirst; pSocket != NULL; pSocket = pNext)
{
pNext = pSocket->pNext;
pSocket->Release();
}
}
}
}
void CASC_SOCKET_CACHE::PurgeAll()
{
PCASC_SOCKET pSocket;
PCASC_SOCKET pNext;
for(pSocket = pFirst; pSocket != NULL; pSocket = pNext)
{
pNext = pSocket->pNext;
pSocket->Delete();
}
}
PCASC_SOCKET sockets_connect(const char * hostName, unsigned portNum)
{
PCASC_SOCKET pSocket;
if((pSocket = SocketCache.Find(hostName, portNum)) != NULL)
{
pSocket->AddRef();
}
else
{
pSocket = CASC_SOCKET::Connect(hostName, portNum);
if(pSocket != NULL && pSocket->portNum == CASC_PORT_HTTP)
pSocket = SocketCache.InsertSocket(pSocket);
}
return pSocket;
}
void sockets_set_caching(bool caching)
{
SocketCache.SetCaching(caching);
}