codegame 0.7.0

CodeGame framework
Documentation
#include "TcpStream.hpp"
#include <cstring>
#include <iostream>
#include <stdexcept>
#include <string>

#ifdef _WIN32
typedef int RECV_SEND_T;
#else
typedef ssize_t RECV_SEND_T;
#endif

TcpStream::TcpStream(const std::string& host, int port)
{
#ifdef _WIN32
    WSADATA wsa_data;
    if (WSAStartup(MAKEWORD(1, 1), &wsa_data) != 0) {
        throw std::runtime_error("Failed to initialize sockets");
    }
#endif
    sock = socket(AF_INET, SOCK_STREAM, 0);
    if (sock == -1) {
        throw std::runtime_error("Failed to create socket");
    }
    int yes = 1;
    if (setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (char*)&yes, sizeof(int)) < 0) {
        throw std::runtime_error("Failed to set TCP_NODELAY");
    }
    addrinfo hints, *servinfo;
    std::memset(&hints, 0, sizeof(hints));
    hints.ai_family = AF_INET;
    hints.ai_socktype = SOCK_STREAM;
    if (getaddrinfo(host.c_str(), std::to_string(port).c_str(), &hints,
            &servinfo)
        != 0) {
        throw std::runtime_error("Failed to get addr info");
    }
    if (connect(sock, servinfo->ai_addr, servinfo->ai_addrlen) == -1) {
        throw std::runtime_error("Failed to connect");
    }
    freeaddrinfo(servinfo);
}

class TcpInputStream : public InputStream {
public:
    TcpInputStream(std::shared_ptr<TcpStream> tcpStream)
        : tcpStream(tcpStream)
        , bufferPos(0)
        , bufferSize(0)
    {
    }
    void readBytes(char* buffer, size_t byteCount)
    {
        while (byteCount > 0) {
            if (bufferSize > 0) {
                if (bufferSize >= byteCount) {
                    memcpy(buffer, this->buffer + bufferPos, byteCount);
                    bufferPos += byteCount;
                    bufferSize -= byteCount;
                    return;
                }
                memcpy(buffer, this->buffer + bufferPos, bufferSize);
                buffer += bufferSize;
                byteCount -= bufferSize;
                bufferPos += bufferSize;
                bufferSize = 0;
            }
            if (bufferPos == BUFFER_CAPACITY) {
                bufferPos = 0;
            }
            RECV_SEND_T received = recv(tcpStream->sock, this->buffer + bufferPos + bufferSize,
                BUFFER_CAPACITY - bufferPos - bufferSize, 0);
            if (received < 0) {
                throw std::runtime_error("Failed to read from socket");
            }
            bufferSize += received;
        }
    }

private:
    static const size_t BUFFER_CAPACITY = 8 * 1024;
    char buffer[BUFFER_CAPACITY];
    size_t bufferPos;
    size_t bufferSize;
    std::shared_ptr<TcpStream> tcpStream;
};

class TcpOutputStream : public OutputStream {
public:
    TcpOutputStream(std::shared_ptr<TcpStream> tcpStream)
        : tcpStream(tcpStream)
        , bufferPos(0)
        , bufferSize(0)
    {
    }
    void writeBytes(const char* buffer, size_t byteCount)
    {
        while (byteCount > 0) {
            size_t capacity = BUFFER_CAPACITY - bufferPos - bufferSize;
            if (capacity >= byteCount) {
                memcpy(this->buffer + bufferPos + bufferSize, buffer, byteCount);
                bufferSize += byteCount;
                return;
            }
            memcpy(this->buffer + bufferPos + bufferSize, buffer, capacity);
            bufferSize += capacity;
            byteCount -= capacity;
            buffer += capacity;
            flush();
        }
    }
    void flush()
    {
        while (bufferSize > 0) {
            RECV_SEND_T sent = send(tcpStream->sock, buffer + bufferPos, bufferSize, 0);
            if (sent < 0) {
                throw std::runtime_error("Failed to write to socket");
            }
            bufferPos += sent;
            bufferSize -= sent;
        }
        bufferPos = 0;
    }

private:
    static const size_t BUFFER_CAPACITY = 8 * 1024;
    char buffer[BUFFER_CAPACITY];
    size_t bufferPos;
    size_t bufferSize;
    std::shared_ptr<TcpStream> tcpStream;
};

std::shared_ptr<InputStream>
getInputStream(std::shared_ptr<TcpStream> tcpStream)
{
    return std::shared_ptr<TcpInputStream>(new TcpInputStream(tcpStream));
}

std::shared_ptr<OutputStream>
getOutputStream(std::shared_ptr<TcpStream> tcpStream)
{
    return std::shared_ptr<TcpOutputStream>(new TcpOutputStream(tcpStream));
}