#include "libssh2_priv.h"
#include <errno.h>
#include <fcntl.h>
#include <ctype.h>
#ifdef LIBSSH2DEBUG
#include <stdio.h>
#endif
#include <assert.h>
#include "transport.h"
#include "mac.h"
#define MAX_BLOCKSIZE 32    
#define MAX_MACSIZE 20      
#ifdef LIBSSH2DEBUG
#define UNPRINTABLE_CHAR '.'
static void
debugdump(LIBSSH2_SESSION * session,
          const char *desc, const unsigned char *ptr, size_t size)
{
    size_t i;
    size_t c;
    unsigned int width = 0x10;
    char buffer[256];  
    size_t used;
    static const char* hex_chars = "0123456789ABCDEF";
    if (!(session->showmask & LIBSSH2_TRACE_TRANS)) {
        
        return;
    }
    used = snprintf(buffer, sizeof(buffer), "=> %s (%d bytes)\n",
                    desc, (int) size);
    if (session->tracehandler)
        (session->tracehandler)(session, session->tracehandler_context,
                                buffer, used);
    else
        fprintf(stderr, "%s", buffer);
    for(i = 0; i < size; i += width) {
        used = snprintf(buffer, sizeof(buffer), "%04lx: ", (long)i);
        
        for(c = 0; c < width; c++) {
            if (i + c < size) {
                buffer[used++] = hex_chars[(ptr[i+c] >> 4) & 0xF];
                buffer[used++] = hex_chars[ptr[i+c] & 0xF];
            }
            else {
                buffer[used++] = ' ';
                buffer[used++] = ' ';
            }
            buffer[used++] = ' ';
            if ((width/2) - 1 == c)
                buffer[used++] = ' ';
        }
        buffer[used++] = ':';
        buffer[used++] = ' ';
        for(c = 0; (c < width) && (i + c < size); c++) {
            buffer[used++] = isprint(ptr[i + c]) ?
                ptr[i + c] : UNPRINTABLE_CHAR;
        }
        buffer[used++] = '\n';
        buffer[used] = 0;
        if (session->tracehandler)
            (session->tracehandler)(session, session->tracehandler_context,
                                    buffer, used);
        else
            fprintf(stderr, "%s", buffer);
    }
}
#else
#define debugdump(a,x,y,z)
#endif
static int
decrypt(LIBSSH2_SESSION * session, unsigned char *source,
        unsigned char *dest, int len)
{
    struct transportpacket *p = &session->packet;
    int blocksize = session->remote.crypt->blocksize;
    
    assert((len % blocksize) == 0);
    while (len >= blocksize) {
        if (session->remote.crypt->crypt(session, source, blocksize,
                                         &session->remote.crypt_abstract)) {
            LIBSSH2_FREE(session, p->payload);
            return LIBSSH2_ERROR_DECRYPT;
        }
        
        memcpy(dest, source, blocksize);
        len -= blocksize;       
        dest += blocksize;      
        source += blocksize;    
    }
    return LIBSSH2_ERROR_NONE;         
}
static int
fullpacket(LIBSSH2_SESSION * session, int encrypted  )
{
    unsigned char macbuf[MAX_MACSIZE];
    struct transportpacket *p = &session->packet;
    int rc;
    int compressed;
    if (session->fullpacket_state == libssh2_NB_state_idle) {
        session->fullpacket_macstate = LIBSSH2_MAC_CONFIRMED;
        session->fullpacket_payload_len = p->packet_length - 1;
        if (encrypted) {
            
            session->remote.mac->hash(session, macbuf,  
                                      session->remote.seqno,
                                      p->init, 5,
                                      p->payload,
                                      session->fullpacket_payload_len,
                                      &session->remote.mac_abstract);
            
            if (memcmp(macbuf, p->payload + session->fullpacket_payload_len,
                       session->remote.mac->mac_len)) {
                session->fullpacket_macstate = LIBSSH2_MAC_INVALID;
            }
        }
        session->remote.seqno++;
        
        session->fullpacket_payload_len -= p->padding_length;
        
        compressed =
            session->local.comp != NULL &&
            session->local.comp->compress &&
            ((session->state & LIBSSH2_STATE_AUTHENTICATED) ||
             session->local.comp->use_in_auth);
        if (compressed && session->remote.comp_abstract) {
            
            unsigned char *data;
            size_t data_len;
            rc = session->remote.comp->decomp(session,
                                              &data, &data_len,
                                              LIBSSH2_PACKET_MAXDECOMP,
                                              p->payload,
                                              session->fullpacket_payload_len,
                                              &session->remote.comp_abstract);
            LIBSSH2_FREE(session, p->payload);
            if(rc)
                return rc;
            p->payload = data;
            session->fullpacket_payload_len = data_len;
        }
        session->fullpacket_packet_type = p->payload[0];
        debugdump(session, "libssh2_transport_read() plain",
                  p->payload, session->fullpacket_payload_len);
        session->fullpacket_state = libssh2_NB_state_created;
    }
    if (session->fullpacket_state == libssh2_NB_state_created) {
        rc = _libssh2_packet_add(session, p->payload,
                                 session->fullpacket_payload_len,
                                 session->fullpacket_macstate);
        if (rc == LIBSSH2_ERROR_EAGAIN)
            return rc;
        if (rc) {
            session->fullpacket_state = libssh2_NB_state_idle;
            return rc;
        }
    }
    session->fullpacket_state = libssh2_NB_state_idle;
    return session->fullpacket_packet_type;
}
int _libssh2_transport_read(LIBSSH2_SESSION * session)
{
    int rc;
    struct transportpacket *p = &session->packet;
    int remainbuf;
    int remainpack;
    int numbytes;
    int numdecrypt;
    unsigned char block[MAX_BLOCKSIZE];
    int blocksize;
    int encrypted = 1;
    size_t total_num;
    
    session->socket_block_directions &= ~LIBSSH2_SESSION_BLOCK_INBOUND;
    
    if (session->state & LIBSSH2_STATE_EXCHANGING_KEYS &&
        !(session->state & LIBSSH2_STATE_KEX_ACTIVE)) {
        
        _libssh2_debug(session, LIBSSH2_TRACE_TRANS, "Redirecting into the"
                       " key re-exchange from _libssh2_transport_read");
        rc = _libssh2_kex_exchange(session, 1, &session->startup_key_state);
        if (rc)
            return rc;
    }
    
    if (session->readPack_state == libssh2_NB_state_jump1) {
        session->readPack_state = libssh2_NB_state_idle;
        encrypted = session->readPack_encrypted;
        goto libssh2_transport_read_point1;
    }
    do {
        if (session->socket_state == LIBSSH2_SOCKET_DISCONNECTED) {
            return LIBSSH2_ERROR_NONE;
        }
        if (session->state & LIBSSH2_STATE_NEWKEYS) {
            blocksize = session->remote.crypt->blocksize;
        } else {
            encrypted = 0;      
            blocksize = 5;      
        }
        
        
        remainbuf = p->writeidx - p->readidx;
        
        assert(remainbuf >= 0);
        if (remainbuf < blocksize) {
            
            ssize_t nread;
            
            if (remainbuf) {
                memmove(p->buf, &p->buf[p->readidx], remainbuf);
                p->readidx = 0;
                p->writeidx = remainbuf;
            } else {
                
                p->readidx = p->writeidx = 0;
            }
            
            nread =
                LIBSSH2_RECV(session, &p->buf[remainbuf],
                              PACKETBUFSIZE - remainbuf,
                              LIBSSH2_SOCKET_RECV_FLAGS(session));
            if (nread <= 0) {
                
                if ((nread < 0) && (nread == -EAGAIN)) {
                    session->socket_block_directions |=
                        LIBSSH2_SESSION_BLOCK_INBOUND;
                    return LIBSSH2_ERROR_EAGAIN;
                }
                _libssh2_debug(session, LIBSSH2_TRACE_SOCKET,
                               "Error recving %d bytes (got %d)",
                               PACKETBUFSIZE - remainbuf, -nread);
                return LIBSSH2_ERROR_SOCKET_RECV;
            }
            _libssh2_debug(session, LIBSSH2_TRACE_SOCKET,
                           "Recved %d/%d bytes to %p+%d", nread,
                           PACKETBUFSIZE - remainbuf, p->buf, remainbuf);
            debugdump(session, "libssh2_transport_read() raw",
                      &p->buf[remainbuf], nread);
            
            p->writeidx += nread;
            
            remainbuf = p->writeidx - p->readidx;
        }
        
        numbytes = remainbuf;
        if (!p->total_num) {
            
            if (numbytes < blocksize) {
                
                session->socket_block_directions |=
                    LIBSSH2_SESSION_BLOCK_INBOUND;
                return LIBSSH2_ERROR_EAGAIN;
            }
            if (encrypted) {
                rc = decrypt(session, &p->buf[p->readidx], block, blocksize);
                if (rc != LIBSSH2_ERROR_NONE) {
                    return rc;
                }
                
                memcpy(p->init, &p->buf[p->readidx], 5);
            } else {
                
                memcpy(block, &p->buf[p->readidx], blocksize);
            }
            
            p->readidx += blocksize;
            
            p->packet_length = _libssh2_ntohu32(block);
            if (p->packet_length < 1)
                return LIBSSH2_ERROR_DECRYPT;
            p->padding_length = block[4];
            
            total_num =
                p->packet_length - 1 +
                (encrypted ? session->remote.mac->mac_len : 0);
            
            if (total_num > LIBSSH2_PACKET_MAXPAYLOAD) {
                return LIBSSH2_ERROR_OUT_OF_BOUNDARY;
            }
            
            p->payload = LIBSSH2_ALLOC(session, total_num);
            if (!p->payload) {
                return LIBSSH2_ERROR_ALLOC;
            }
            p->total_num = total_num;
            
            p->wptr = p->payload;
            if (blocksize > 5) {
                
                memcpy(p->wptr, &block[5], blocksize - 5);
                p->wptr += blocksize - 5;       
            }
            
            p->data_num = p->wptr - p->payload;
            
            numbytes -= blocksize;
        }
        
        remainpack = p->total_num - p->data_num;
        if (numbytes > remainpack) {
            
            numbytes = remainpack;
        }
        if (encrypted) {
            
            int skip = session->remote.mac->mac_len;
            
            if ((p->data_num + numbytes) > (p->total_num - skip)) {
                numdecrypt = (p->total_num - skip) - p->data_num;
            } else {
                int frac;
                numdecrypt = numbytes;
                frac = numdecrypt % blocksize;
                if (frac) {
                    
                    numdecrypt -= frac;
                    
                    numbytes = 0;
                }
            }
        } else {
            
            numdecrypt = 0;
        }
        
        if (numdecrypt > 0) {
            
            rc = decrypt(session, &p->buf[p->readidx], p->wptr, numdecrypt);
            if (rc != LIBSSH2_ERROR_NONE) {
                p->total_num = 0;   
                return rc;
            }
            
            p->readidx += numdecrypt;
            
            p->wptr += numdecrypt;
            
            p->data_num += numdecrypt;
            
            numbytes -= numdecrypt;
        }
        
        if (numbytes > 0) {
            memcpy(p->wptr, &p->buf[p->readidx], numbytes);
            
            p->readidx += numbytes;
            
            p->wptr += numbytes;
            
            p->data_num += numbytes;
        }
        
        remainpack = p->total_num - p->data_num;
        if (!remainpack) {
            
          libssh2_transport_read_point1:
            rc = fullpacket(session, encrypted);
            if (rc == LIBSSH2_ERROR_EAGAIN) {
                if (session->packAdd_state != libssh2_NB_state_idle)
                {
                    
                    session->readPack_encrypted = encrypted;
                    session->readPack_state = libssh2_NB_state_jump1;
                }
                return rc;
            }
            p->total_num = 0;   
            return rc;
        }
    } while (1);                
    return LIBSSH2_ERROR_SOCKET_RECV; 
}
static int
send_existing(LIBSSH2_SESSION *session, const unsigned char *data,
              size_t data_len, ssize_t *ret)
{
    ssize_t rc;
    ssize_t length;
    struct transportpacket *p = &session->packet;
    if (!p->olen) {
        *ret = 0;
        return LIBSSH2_ERROR_NONE;
    }
    
    if ((data != p->odata) || (data_len != p->olen)) {
        
        return LIBSSH2_ERROR_BAD_USE;
    }
    *ret = 1;                   
    
    length = p->ototal_num - p->osent;
    rc = LIBSSH2_SEND(session, &p->outbuf[p->osent], length,
                       LIBSSH2_SOCKET_SEND_FLAGS(session));
    if (rc < 0)
        _libssh2_debug(session, LIBSSH2_TRACE_SOCKET,
                       "Error sending %d bytes: %d", length, -rc);
    else {
        _libssh2_debug(session, LIBSSH2_TRACE_SOCKET,
                       "Sent %d/%d bytes at %p+%d", rc, length, p->outbuf,
                       p->osent);
        debugdump(session, "libssh2_transport_write send()",
                  &p->outbuf[p->osent], rc);
    }
    if (rc == length) {
        
        p->ototal_num = 0;
        p->olen = 0;
        
        return LIBSSH2_ERROR_NONE;
    }
    else if (rc < 0) {
        
        if (rc != -EAGAIN)
            
            return LIBSSH2_ERROR_SOCKET_SEND;
        session->socket_block_directions |= LIBSSH2_SESSION_BLOCK_OUTBOUND;
        return LIBSSH2_ERROR_EAGAIN;
    }
    p->osent += rc;         
    return rc < length ? LIBSSH2_ERROR_EAGAIN : LIBSSH2_ERROR_NONE;
}
int _libssh2_transport_send(LIBSSH2_SESSION *session,
                            const unsigned char *data, size_t data_len,
                            const unsigned char *data2, size_t data2_len)
{
    int blocksize =
        (session->state & LIBSSH2_STATE_NEWKEYS) ?
        session->local.crypt->blocksize : 8;
    int padding_length;
    size_t packet_length;
    int total_length;
#ifdef RANDOM_PADDING
    int rand_max;
    int seed = data[0];         
#endif
    struct transportpacket *p = &session->packet;
    int encrypted;
    int compressed;
    ssize_t ret;
    int rc;
    const unsigned char *orgdata = data;
    size_t orgdata_len = data_len;
    
    if (session->state & LIBSSH2_STATE_EXCHANGING_KEYS &&
        !(session->state & LIBSSH2_STATE_KEX_ACTIVE)) {
        
        _libssh2_debug(session, LIBSSH2_TRACE_TRANS, "Redirecting into the"
                       " key re-exchange from _libssh2_transport_send");
        rc = _libssh2_kex_exchange(session, 1, &session->startup_key_state);
        if (rc)
            return rc;
    }
    debugdump(session, "libssh2_transport_write plain", data, data_len);
    if(data2)
        debugdump(session, "libssh2_transport_write plain2", data2, data2_len);
    
    rc = send_existing(session, data, data_len, &ret);
    if (rc)
        return rc;
    session->socket_block_directions &= ~LIBSSH2_SESSION_BLOCK_OUTBOUND;
    if (ret)
        
        return rc;
    encrypted = (session->state & LIBSSH2_STATE_NEWKEYS) ? 1 : 0;
    compressed =
        session->local.comp != NULL &&
        session->local.comp->compress &&
        ((session->state & LIBSSH2_STATE_AUTHENTICATED) ||
         session->local.comp->use_in_auth);
    if (encrypted && compressed) {
        
        size_t dest_len = MAX_SSH_PACKET_LEN-5-256;
        size_t dest2_len = dest_len;
        
        rc = session->local.comp->comp(session,
                                       &p->outbuf[5], &dest_len,
                                       data, data_len,
                                       &session->local.comp_abstract);
        if(rc)
            return rc;     
        if(data2 && data2_len) {
            
            dest2_len -= dest_len;
            rc = session->local.comp->comp(session,
                                           &p->outbuf[5+dest_len], &dest2_len,
                                           data2, data2_len,
                                           &session->local.comp_abstract);
        }
        else
            dest2_len = 0;
        if(rc)
            return rc;     
        data_len = dest_len + dest2_len; 
    }
    else {
        if((data_len + data2_len) >= (MAX_SSH_PACKET_LEN-0x100))
            
            return LIBSSH2_ERROR_INVAL;
        
        memcpy(&p->outbuf[5], data, data_len);
        if(data2 && data2_len)
            memcpy(&p->outbuf[5+data_len], data2, data2_len);
        data_len += data2_len; 
    }
    
    
    packet_length = data_len + 1 + 4;   
    
    
    padding_length = blocksize - (packet_length % blocksize);
    
    if (padding_length < 4) {
        padding_length += blocksize;
    }
#ifdef RANDOM_PADDING
    
    
    rand_max = (255 - padding_length) / blocksize + 1;
    padding_length += blocksize * (seed % rand_max);
#endif
    packet_length += padding_length;
    
    total_length =
        packet_length + (encrypted ? session->local.mac->mac_len : 0);
    
    _libssh2_htonu32(p->outbuf, packet_length - 4);
    
    p->outbuf[4] = (unsigned char)padding_length;
    
    _libssh2_random(p->outbuf + 5 + data_len, padding_length);
    if (encrypted) {
        size_t i;
        
        session->local.mac->hash(session, p->outbuf + packet_length,
                                 session->local.seqno, p->outbuf,
                                 packet_length, NULL, 0,
                                 &session->local.mac_abstract);
        
        for(i = 0; i < packet_length; i += session->local.crypt->blocksize) {
            unsigned char *ptr = &p->outbuf[i];
            if (session->local.crypt->crypt(session, ptr,
                                            session->local.crypt->blocksize,
                                            &session->local.crypt_abstract))
                return LIBSSH2_ERROR_ENCRYPT;     
        }
    }
    session->local.seqno++;
    ret = LIBSSH2_SEND(session, p->outbuf, total_length,
                        LIBSSH2_SOCKET_SEND_FLAGS(session));
    if (ret < 0)
        _libssh2_debug(session, LIBSSH2_TRACE_SOCKET,
                       "Error sending %d bytes: %d", total_length, -ret);
    else {
        _libssh2_debug(session, LIBSSH2_TRACE_SOCKET, "Sent %d/%d bytes at %p",
                       ret, total_length, p->outbuf);
        debugdump(session, "libssh2_transport_write send()", p->outbuf, ret);
    }
    if (ret != total_length) {
        if (ret >= 0 || ret == -EAGAIN) {
            
            session->socket_block_directions |= LIBSSH2_SESSION_BLOCK_OUTBOUND;
            p->odata = orgdata;
            p->olen = orgdata_len;
            p->osent = ret <= 0 ? 0 : ret;
            p->ototal_num = total_length;
            return LIBSSH2_ERROR_EAGAIN;
        }
        return LIBSSH2_ERROR_SOCKET_SEND;
    }
    
    p->odata = NULL;
    p->olen = 0;
    return LIBSSH2_ERROR_NONE;         
}