#include <stdio.h>
#include <string.h>
#include <map>
#include <assert.h>
#include "normSocket.h"
#ifdef WIN32
#include "win32InputHandler.cpp"
#include <Winsock2.h>
#include <Ws2tcpip.h>
#else
#include <arpa/inet.h>
#include <sys/select.h>
#include <fcntl.h>
#include <errno.h>
#endif
void Usage()
{
fprintf(stderr, "Usage: normServer [listen [<groupAddr>/]<port>][debug <level>][trace]\n");
}
class ClientInfo
{
public:
ClientInfo(UINT8 ipVersion = 0, const char* theAddr = NULL, UINT16 thePort = 0);
bool operator < (const ClientInfo& a) const;
int GetAddressFamily() const;
const char* GetAddress() const
{return client_addr;}
UINT16 GetPort() const
{return client_port;}
const char* GetAddressString();
void Print(FILE* filePtr);
private:
UINT8 addr_version; char client_addr[16]; UINT16 client_port;
};
ClientInfo::ClientInfo(UINT8 addrVersion, const char* clientAddr, UINT16 clientPort)
: addr_version(addrVersion), client_port(clientPort)
{
if (NULL == clientAddr) addrVersion = 0; switch (addrVersion)
{
case 4:
memcpy(client_addr, clientAddr, 4);
memset(client_addr+4, 0, 12);
break;
case 6:
memcpy(client_addr, clientAddr, 16);
break;
default:
memset(client_addr, 0, 16);
break;
}
}
bool ClientInfo::operator <(const ClientInfo& a) const
{
if (addr_version != a.addr_version)
return (addr_version < a.addr_version);
else if (client_port != a.client_port)
return (client_port < a.client_port);
else if (4 == addr_version)
return (0 > memcmp(client_addr, a.client_addr, 4));
else
return (0 > memcmp(client_addr, a.client_addr, 16));
}
int ClientInfo::GetAddressFamily() const
{
if (4 == addr_version)
return AF_INET;
else
return AF_INET6;
}
const char* ClientInfo::GetAddressString()
{
static char text[64];
text[63] = '\0';
int addrFamily;
if (4 == addr_version)
addrFamily = AF_INET;
else
addrFamily = AF_INET6;
inet_ntop(addrFamily, client_addr, text, 63);
return text;
}
void ClientInfo::Print(FILE* filePtr)
{
char text[64];
text[63] = '\0';
int addrFamily;
if (4 == addr_version)
addrFamily = AF_INET;
else
addrFamily = AF_INET6;
inet_ntop(addrFamily, client_addr, text, 63);
fprintf(filePtr, "%s/%hu", text, client_port);
}
class Client
{
public:
Client(NormSocketHandle clientSocket);
~Client();
NormSocketHandle GetSocket() const
{return client_socket;}
bool GetWriteReady() const
{return write_ready;}
void SetWriteReady(bool state)
{write_ready = state;}
unsigned int GetBytesWritten() const
{return bytes_written;}
void SetBytesWritten(unsigned long numBytes)
{bytes_written = numBytes;}
private:
NormSocketHandle client_socket;
bool write_ready;
unsigned int bytes_written;
};
Client::Client(NormSocketHandle clientSocket)
: client_socket(clientSocket),
write_ready(true), bytes_written(0)
{
}
Client::~Client()
{
}
typedef std::map<ClientInfo, Client*> ClientMap;
ClientInfo NormGetClientInfo(NormNodeHandle client)
{
char addr[16]; unsigned int addrLen = 16;
UINT16 port;
NormNodeGetAddress(client, addr, &addrLen, &port);
UINT8 version;
if (4 == addrLen)
version = 4;
else
version = 6;
return ClientInfo(version, addr, port);
}
static ClientInfo NormGetSocketInfo(NormSocketHandle socket)
{
char addr[16]; unsigned int addrLen = 16;
UINT16 port;
NormGetPeerName(socket, addr, &addrLen, &port);
UINT8 version;
if (4 == addrLen)
version = 4;
else
version = 6;
return ClientInfo(version, addr, port);
}
Client* FindClient(ClientMap& clientMap, const ClientInfo& clientInfo)
{
ClientMap::iterator it = clientMap.find(clientInfo);
if (clientMap.end() != it)
return it->second;
else
return NULL;
}
NormSocketHandle FindClientSocket(ClientMap& clientMap, const ClientInfo& clientInfo)
{
Client* client = FindClient(clientMap, clientInfo);
if (NULL == client)
return NORM_SOCKET_INVALID;
else
return client->GetSocket();
}
int main(int argc, char* argv[])
{
ClientMap clientMap;
UINT16 serverPort = 5000;
UINT16 serverInstanceId = 1;
char groupAddr[64];
const char* groupAddrPtr = NULL;
const char* mcastInterface = NULL;
bool trace = false;
unsigned int debugLevel = 0;
for (int i = 1; i < argc; i++)
{
const char* cmd = argv[i];
unsigned int len = strlen(cmd);
if (0 == strncmp(cmd, "listen", len))
{
const char* val = argv[++i];
const char* portPtr = strchr(val, '/');
if (NULL != portPtr)
portPtr++;
else
portPtr = val;
unsigned int addrTextLen = portPtr - val;
if (addrTextLen > 0)
{
addrTextLen -= 1;
strncpy(groupAddr, val, addrTextLen);
groupAddr[addrTextLen] = '\0';
groupAddrPtr = groupAddr;
}
if (1 != sscanf(portPtr, "%hu", &serverPort))
{
fprintf(stderr, "normServer error: invalid <port> \"%s\"\n", portPtr);
Usage();
return -1;
}
}
else if (0 == strncmp(cmd, "interface", len))
{
mcastInterface = argv[++i];
}
else if (0 == strncmp(cmd, "trace", len))
{
trace = true;
}
else if (0 == strncmp(cmd, "debug", len))
{
if (1 != sscanf(argv[++i], "%u", &debugLevel))
{
fprintf(stderr, "normServer error: invalid debug level\n");
Usage();
return -1;
}
}
else
{
fprintf(stderr, "normServer error: invalid command \"%s\"\n", cmd);
Usage();
return -1;
}
}
NormInstanceHandle instance = NormCreateInstance();
NormSocketHandle serverSocket = NormOpen(instance);
NormListen(serverSocket, serverPort, groupAddrPtr);
if (trace) NormSetMessageTrace(NormGetSocketSession(serverSocket), true);
if (0 != debugLevel) NormSetDebugLevel(debugLevel);
#ifdef WIN32
HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE);
Win32InputHandler inputHandler;
inputHandler.Open();
HANDLE handleArray[2];
handleArray[0] = NormGetDescriptor(instance);
handleArray[1] = inputHandler.GetEventHandle();
#else
fd_set fdset;
FD_ZERO(&fdset);
FILE* inputFile = stdin;
int inputfd = fileno(inputFile);
if (-1 == fcntl(inputfd, F_SETFL, fcntl(inputfd, F_GETFL, 0) | O_NONBLOCK))
perror("normClient: fcntl(inputfd, O_NONBLOCK) error");
int normfd = NormGetDescriptor(instance);
#endif
bool keepGoing = true;
bool writeReady = false;
int inputLength = 0;
unsigned int bytesWritten = 0;
const unsigned int BUFFER_LENGTH = 2048;
char inputBuffer[BUFFER_LENGTH];
bool inputNeeded = false; bool inputClosed = false;
unsigned int clientCount = 0;
while (keepGoing)
{
bool normEventPending = false;
bool inputEventPending = false;
#ifdef WIN32
DWORD handleCount = inputNeeded ? 2 : 1;
DWORD waitStatus =
MsgWaitForMultipleObjectsEx(handleCount, handleArray, INFINITE, QS_ALLINPUT, 0);
if ((WAIT_OBJECT_0 <= waitStatus) && (waitStatus < (WAIT_OBJECT_0 + handleCount)))
{
if (0 == (waitStatus - WAIT_OBJECT_0))
normEventPending = true;
else
inputEventPending = true;
}
else if (-1 == waitStatus)
{
perror("normServer: MsgWaitForMultipleObjectsEx() error");
break;
}
else
{
continue; }
#else
FD_SET(normfd, &fdset);
int maxfd = normfd;
if (inputNeeded)
{
FD_SET(inputfd, &fdset);
if (inputfd > maxfd) maxfd = inputfd;
}
else
{
FD_CLR(inputfd, &fdset);
}
int result = select(maxfd+1, &fdset, NULL, NULL, NULL);
if (result <= 0)
{
perror("normServer: select() error");
break;
}
if (FD_ISSET(inputfd, &fdset))
inputEventPending = true;
if (FD_ISSET(normfd, &fdset))
normEventPending = true;
#endif
if (inputEventPending)
{
#ifdef WIN32
inputLength = inputHandler.ReadData(inputBuffer, BUFFER_LENGTH);
if (inputLength > 0)
{
bytesWritten = 0;
inputNeeded = false;
}
else if (inputLength < 0)
{
inputHandler.Close();
inputClosed = true;
}
#else
inputLength = fread(inputBuffer, 1, BUFFER_LENGTH, inputFile);
if (inputLength > 0)
{
bytesWritten = 0;
inputNeeded = false;
}
else if (feof(inputFile))
{
if (stdin != inputFile)
{
fclose(inputFile);
inputFile = NULL;
}
inputClosed = true;
}
else if (ferror(inputFile))
{
switch (errno)
{
case EINTR:
break;
case EAGAIN:
break;
default:
perror("normServer: error reading input?!");
break;
}
}
#endif if (inputClosed)
{
inputNeeded = false;
if (clientMap.empty())
{
keepGoing = false;
continue;
}
else
{
ClientMap::iterator it;
for (it = clientMap.begin(); it != clientMap.end(); it++)
{
Client* client = it->second;
NormSocketHandle clientSocket = client->GetSocket();
NormShutdown(clientSocket);
}
}
}
}
if (normEventPending)
{
NormSocketEvent event;
if (NormGetSocketEvent(instance, &event))
{
ClientInfo clientInfo;
if (NORM_NODE_INVALID != event.sender)
clientInfo = NormGetClientInfo(event.sender);
else
clientInfo = NormGetSocketInfo(event.socket);
switch (event.type)
{
case NORM_SOCKET_ACCEPT:
{
if (event.socket == serverSocket)
{
if (NORM_SOCKET_INVALID != FindClientSocket(clientMap, clientInfo))
{
fprintf(stderr, "normServer: duplicative %s from client %s/%hu...\n",
(NORM_REMOTE_SENDER_NEW == event.event.type) ? "new" : "reset",
clientInfo.GetAddressString(), clientInfo.GetPort());
continue;
}
NormSocketHandle clientSocket = NormAccept(serverSocket, event.sender);
Client* client = new Client(clientSocket);
if (NULL == client)
{
perror("normServer: new Client() error");
NormClose(clientSocket);
continue;
}
if (trace) NormSetMessageTrace(NormGetSocketSession(clientSocket), true);
clientMap[clientInfo] = client;
client->SetWriteReady(true);
if (0 == clientCount)
{
inputNeeded = true;
writeReady = true;
}
clientCount++;
fprintf(stderr, "normServer: ACCEPTED connection from %s/%hu\n",
clientInfo.GetAddressString(), clientInfo.GetPort());
}
else
{
}
break;
}
case NORM_SOCKET_CONNECT:
{
fprintf(stderr, "normServer: CONNECTED to %s/%hu ...\n",
clientInfo.GetAddressString(), clientInfo.GetPort());
Client* client = FindClient(clientMap, clientInfo);
assert(NULL != client);
break;
}
case NORM_SOCKET_READ:
{
bool rxReady = true;
while (rxReady)
{
char buffer[1024];
ssize_t bytesRead = NormRead(event.socket, buffer, 1024);
if (bytesRead < 0)
{
fprintf(stderr, "normServer: broken stream ...\n");
continue;
}
if (bytesRead > 0)
{
#ifdef WIN32
DWORD dwWritten;
WriteFile(hStdout, buffer, bytesRead, &dwWritten, NULL);
#else
fwrite(buffer, sizeof(char), bytesRead, stdout);
#endif }
if (bytesRead < 1024) rxReady = false;
}
break;
}
case NORM_SOCKET_WRITE:
{
if (NULL != groupAddrPtr)
{
writeReady = true;
}
else
{
Client* client = FindClient(clientMap, clientInfo);
assert(NULL != client);
client->SetWriteReady(true);
}
break;
}
case NORM_SOCKET_CLOSING:
{
fprintf(stderr, "normServer: client %s/%hu CLOSING connection ...\n",
clientInfo.GetAddressString(), clientInfo.GetPort());
Client* client = FindClient(clientMap, clientInfo);
assert(NULL != client);
client->SetWriteReady(false);
break;
}
case NORM_SOCKET_CLOSE:
{
fprintf(stderr, "normServer: connection to client %s/%hu CLOSED ...\n",
clientInfo.GetAddressString(), clientInfo.GetPort());
clientMap.erase(clientInfo);
NormClose(event.socket);
if (inputClosed && clientMap.empty())
keepGoing = false;
break;
}
case NORM_SOCKET_NONE:
break;
} }
else
{
fprintf(stderr, "normServer: NormGetNextSocketEvent() returned false\n");
}
}
if ((inputLength > 0) && !inputNeeded)
{
if (NULL == groupAddrPtr)
{
bool clientPending = false;
ClientMap::iterator it;
for (it = clientMap.begin(); it != clientMap.end(); it++)
{
Client* client = it->second;
if (!client->GetWriteReady())
{
clientPending = true;
continue;
}
unsigned int numBytes = client->GetBytesWritten();
if (numBytes < inputLength)
{
NormSocketHandle clientSocket = client->GetSocket();
bytesWritten += NormWrite(clientSocket, inputBuffer + numBytes, inputLength - numBytes);
client->SetBytesWritten(numBytes);
if (bytesWritten < inputLength)
{
client->SetWriteReady(false);
clientPending = true;
}
else
{
NormFlush(clientSocket);
}
}
}
if (!clientPending)
{
inputLength = 0;
inputNeeded = true;
for (it = clientMap.begin(); it != clientMap.end(); it++)
it->second->SetBytesWritten(0);
}
}
else
{
NormSocketHandle sendSocket = serverSocket;
if (NORM_SOCKET_INVALID != sendSocket)
{
if (writeReady && (inputLength > 0))
{
bytesWritten += NormWrite(sendSocket, inputBuffer + bytesWritten, inputLength - bytesWritten);
if (bytesWritten < inputLength)
{
writeReady = false;
}
else
{
inputLength = 0;
inputNeeded = true;
NormFlush(sendSocket);
}
}
}
}
}
} #ifdef WIN32
inputHandler.Close();
#else
if ((stdin != inputFile) && (NULL != inputFile))
{
fclose(inputFile);
inputFile = NULL;
}
#endif NormClose(serverSocket);
serverSocket = NORM_SOCKET_INVALID;
}