#include "config.h"
#include <limits.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#ifndef _WIN32
#include <netinet/in.h>
#include <arpa/inet.h>
#endif
#include "libssh/priv.h"
#include "libssh/buffer.h"
#include "libssh/misc.h"
#include "libssh/bignum.h"
struct ssh_buffer_struct {
bool secure;
uint32_t used;
uint32_t allocated;
uint32_t pos;
uint8_t *data;
};
#define BUFFER_SIZE_MAX 0x10000000
#ifdef DEBUG_BUFFER
static void buffer_verify(ssh_buffer buf)
{
bool do_abort = false;
if (buf->data == NULL) {
return;
}
if (buf->used > buf->allocated) {
fprintf(stderr,
"BUFFER ERROR: allocated %u, used %u\n",
buf->allocated,
buf->used);
do_abort = true;
}
if (buf->pos > buf->used) {
fprintf(stderr,
"BUFFER ERROR: position %u, used %u\n",
buf->pos,
buf->used);
do_abort = true;
}
if (buf->pos > buf->allocated) {
fprintf(stderr,
"BUFFER ERROR: position %u, allocated %u\n",
buf->pos,
buf->allocated);
do_abort = true;
}
if (do_abort) {
abort();
}
}
#else
#define buffer_verify(x)
#endif
struct ssh_buffer_struct *ssh_buffer_new(void)
{
struct ssh_buffer_struct *buf = NULL;
int rc;
buf = calloc(1, sizeof(struct ssh_buffer_struct));
if (buf == NULL) {
return NULL;
}
rc = ssh_buffer_allocate_size(buf, 64 - 1);
if (rc != 0) {
SAFE_FREE(buf);
return NULL;
}
buffer_verify(buf);
return buf;
}
void ssh_buffer_free(struct ssh_buffer_struct *buffer)
{
if (buffer == NULL) {
return;
}
buffer_verify(buffer);
if (buffer->secure && buffer->allocated > 0) {
explicit_bzero(buffer->data, buffer->allocated);
SAFE_FREE(buffer->data);
explicit_bzero(buffer, sizeof(struct ssh_buffer_struct));
} else {
SAFE_FREE(buffer->data);
}
SAFE_FREE(buffer);
}
void ssh_buffer_set_secure(ssh_buffer buffer)
{
buffer->secure = true;
}
static int realloc_buffer(struct ssh_buffer_struct *buffer, uint32_t needed)
{
uint32_t smallest = 1;
uint8_t *new = NULL;
buffer_verify(buffer);
while(smallest <= needed) {
if (smallest == 0) {
return -1;
}
smallest <<= 1;
}
needed = smallest;
if (needed > BUFFER_SIZE_MAX) {
return -1;
}
if (buffer->secure) {
new = malloc(needed);
if (new == NULL) {
return -1;
}
memcpy(new, buffer->data, buffer->used);
explicit_bzero(buffer->data, buffer->used);
SAFE_FREE(buffer->data);
} else {
new = realloc(buffer->data, needed);
if (new == NULL) {
return -1;
}
}
buffer->data = new;
buffer->allocated = needed;
buffer_verify(buffer);
return 0;
}
static void buffer_shift(ssh_buffer buffer)
{
size_t burn_pos = buffer->pos;
buffer_verify(buffer);
if (buffer->pos == 0) {
return;
}
memmove(buffer->data,
buffer->data + buffer->pos,
buffer->used - buffer->pos);
buffer->used -= buffer->pos;
buffer->pos = 0;
if (buffer->secure) {
void *ptr = buffer->data + buffer->used;
explicit_bzero(ptr, burn_pos);
}
buffer_verify(buffer);
}
int ssh_buffer_reinit(struct ssh_buffer_struct *buffer)
{
if (buffer == NULL) {
return -1;
}
buffer_verify(buffer);
if (buffer->secure && buffer->allocated > 0) {
explicit_bzero(buffer->data, buffer->allocated);
}
buffer->used = 0;
buffer->pos = 0;
if (buffer->allocated > 65536) {
int rc;
rc = realloc_buffer(buffer, 65536 - 1);
if (rc != 0) {
return -1;
}
}
buffer_verify(buffer);
return 0;
}
int ssh_buffer_add_data(struct ssh_buffer_struct *buffer, const void *data, uint32_t len)
{
if (buffer == NULL) {
return -1;
}
buffer_verify(buffer);
if (data == NULL) {
return -1;
}
if (buffer->used + len < len) {
return -1;
}
if (buffer->allocated < (buffer->used + len)) {
if (buffer->pos > 0) {
buffer_shift(buffer);
}
if (realloc_buffer(buffer, buffer->used + len) < 0) {
return -1;
}
}
memcpy(buffer->data + buffer->used, data, len);
buffer->used += len;
buffer_verify(buffer);
return 0;
}
int ssh_buffer_allocate_size(struct ssh_buffer_struct *buffer,
uint32_t len)
{
buffer_verify(buffer);
if (buffer->allocated < len) {
if (buffer->pos > 0) {
buffer_shift(buffer);
}
if (realloc_buffer(buffer, len) < 0) {
return -1;
}
}
buffer_verify(buffer);
return 0;
}
void *ssh_buffer_allocate(struct ssh_buffer_struct *buffer, uint32_t len)
{
void *ptr;
buffer_verify(buffer);
if (buffer->used + len < len) {
return NULL;
}
if (buffer->allocated < (buffer->used + len)) {
if (buffer->pos > 0) {
buffer_shift(buffer);
}
if (realloc_buffer(buffer, buffer->used + len) < 0) {
return NULL;
}
}
ptr = buffer->data + buffer->used;
buffer->used+=len;
buffer_verify(buffer);
return ptr;
}
int ssh_buffer_add_ssh_string(struct ssh_buffer_struct *buffer,
struct ssh_string_struct *string) {
uint32_t len = 0;
if (string == NULL) {
return -1;
}
len = ssh_string_len(string);
if (ssh_buffer_add_data(buffer, string, len + sizeof(uint32_t)) < 0) {
return -1;
}
return 0;
}
int ssh_buffer_add_u32(struct ssh_buffer_struct *buffer,uint32_t data)
{
int rc;
rc = ssh_buffer_add_data(buffer, &data, sizeof(data));
if (rc < 0) {
return -1;
}
return 0;
}
int ssh_buffer_add_u16(struct ssh_buffer_struct *buffer,uint16_t data)
{
int rc;
rc = ssh_buffer_add_data(buffer, &data, sizeof(data));
if (rc < 0) {
return -1;
}
return 0;
}
int ssh_buffer_add_u64(struct ssh_buffer_struct *buffer, uint64_t data)
{
int rc;
rc = ssh_buffer_add_data(buffer, &data, sizeof(data));
if (rc < 0) {
return -1;
}
return 0;
}
int ssh_buffer_add_u8(struct ssh_buffer_struct *buffer,uint8_t data)
{
int rc;
rc = ssh_buffer_add_data(buffer, &data, sizeof(uint8_t));
if (rc < 0) {
return -1;
}
return 0;
}
int ssh_buffer_prepend_data(struct ssh_buffer_struct *buffer, const void *data,
uint32_t len) {
buffer_verify(buffer);
if(len <= buffer->pos){
memcpy(buffer->data + (buffer->pos - len), data, len);
buffer->pos -= len;
buffer_verify(buffer);
return 0;
}
if (buffer->used - buffer->pos + len < len) {
return -1;
}
if (buffer->allocated < (buffer->used - buffer->pos + len)) {
if (realloc_buffer(buffer, buffer->used - buffer->pos + len) < 0) {
return -1;
}
}
memmove(buffer->data + len, buffer->data + buffer->pos, buffer->used - buffer->pos);
memcpy(buffer->data, data, len);
buffer->used += len - buffer->pos;
buffer->pos = 0;
buffer_verify(buffer);
return 0;
}
int ssh_buffer_add_buffer(struct ssh_buffer_struct *buffer,
struct ssh_buffer_struct *source)
{
int rc;
rc = ssh_buffer_add_data(buffer,
ssh_buffer_get(source),
ssh_buffer_get_len(source));
if (rc < 0) {
return -1;
}
return 0;
}
void *ssh_buffer_get(struct ssh_buffer_struct *buffer){
return buffer->data + buffer->pos;
}
uint32_t ssh_buffer_get_len(struct ssh_buffer_struct *buffer){
buffer_verify(buffer);
return buffer->used - buffer->pos;
}
uint32_t ssh_buffer_pass_bytes(struct ssh_buffer_struct *buffer, uint32_t len){
buffer_verify(buffer);
if (buffer->pos + len < len || buffer->used < buffer->pos + len) {
return 0;
}
buffer->pos+=len;
if(buffer->pos==buffer->used){
buffer->pos=0;
buffer->used=0;
}
buffer_verify(buffer);
return len;
}
uint32_t ssh_buffer_pass_bytes_end(struct ssh_buffer_struct *buffer, uint32_t len){
buffer_verify(buffer);
if (buffer->used < len) {
return 0;
}
buffer->used-=len;
buffer_verify(buffer);
return len;
}
uint32_t ssh_buffer_get_data(struct ssh_buffer_struct *buffer, void *data, uint32_t len)
{
int rc;
rc = ssh_buffer_validate_length(buffer, len);
if (rc != SSH_OK) {
return 0;
}
memcpy(data,buffer->data+buffer->pos,len);
buffer->pos+=len;
return len;
}
uint32_t ssh_buffer_get_u8(struct ssh_buffer_struct *buffer, uint8_t *data){
return ssh_buffer_get_data(buffer,data,sizeof(uint8_t));
}
uint32_t ssh_buffer_get_u32(struct ssh_buffer_struct *buffer, uint32_t *data){
return ssh_buffer_get_data(buffer,data,sizeof(uint32_t));
}
uint32_t ssh_buffer_get_u64(struct ssh_buffer_struct *buffer, uint64_t *data){
return ssh_buffer_get_data(buffer,data,sizeof(uint64_t));
}
int ssh_buffer_validate_length(struct ssh_buffer_struct *buffer, size_t len)
{
if (buffer == NULL || buffer->pos + len < len ||
buffer->pos + len > buffer->used) {
return SSH_ERROR;
}
return SSH_OK;
}
struct ssh_string_struct *
ssh_buffer_get_ssh_string(struct ssh_buffer_struct *buffer)
{
uint32_t stringlen;
uint32_t hostlen;
struct ssh_string_struct *str = NULL;
int rc;
rc = ssh_buffer_get_u32(buffer, &stringlen);
if (rc == 0) {
return NULL;
}
hostlen = ntohl(stringlen);
rc = ssh_buffer_validate_length(buffer, hostlen);
if (rc != SSH_OK) {
return NULL;
}
str = ssh_string_new(hostlen);
if (str == NULL) {
return NULL;
}
stringlen = ssh_buffer_get_data(buffer, ssh_string_data(str), hostlen);
if (stringlen != hostlen) {
SAFE_FREE(str);
return NULL;
}
return str;
}
static int ssh_buffer_pack_allocate_va(struct ssh_buffer_struct *buffer,
const char *format,
size_t argc,
va_list ap)
{
const char *p = NULL;
ssh_string string = NULL;
char *cstring = NULL;
size_t needed_size = 0;
size_t len;
size_t count;
int rc = SSH_OK;
for (p = format, count = 0; *p != '\0'; p++, count++) {
if (count > argc) {
return SSH_ERROR;
}
switch(*p) {
case 'b':
va_arg(ap, unsigned int);
needed_size += sizeof(uint8_t);
break;
case 'w':
va_arg(ap, unsigned int);
needed_size += sizeof(uint16_t);
break;
case 'd':
va_arg(ap, uint32_t);
needed_size += sizeof(uint32_t);
break;
case 'q':
va_arg(ap, uint64_t);
needed_size += sizeof(uint64_t);
break;
case 'S':
string = va_arg(ap, ssh_string);
needed_size += 4 + ssh_string_len(string);
string = NULL;
break;
case 's':
cstring = va_arg(ap, char *);
needed_size += sizeof(uint32_t) + strlen(cstring);
cstring = NULL;
break;
case 'P':
len = va_arg(ap, size_t);
needed_size += len;
va_arg(ap, void *);
count++;
break;
case 'B':
va_arg(ap, bignum);
needed_size += 64;
break;
case 't':
cstring = va_arg(ap, char *);
needed_size += strlen(cstring);
cstring = NULL;
break;
default:
SSH_LOG(SSH_LOG_TRACE, "Invalid buffer format %c", *p);
rc = SSH_ERROR;
}
if (rc != SSH_OK){
break;
}
}
if (argc != count) {
return SSH_ERROR;
}
if (rc != SSH_ERROR){
uint32_t canary = va_arg(ap, uint32_t);
if (canary != SSH_BUFFER_PACK_END) {
abort();
}
}
rc = ssh_buffer_allocate_size(buffer, (uint32_t)needed_size);
if (rc != 0) {
return SSH_ERROR;
}
return SSH_OK;
}
static int
ssh_buffer_pack_va(struct ssh_buffer_struct *buffer,
const char *format,
size_t argc,
va_list ap)
{
int rc = SSH_ERROR;
const char *p;
union {
uint8_t byte;
uint16_t word;
uint32_t dword;
uint64_t qword;
ssh_string string;
void *data;
} o;
char *cstring;
bignum b;
size_t len;
size_t count;
if (argc > 256) {
return SSH_ERROR;
}
for (p = format, count = 0; *p != '\0'; p++, count++) {
if (count > argc) {
return SSH_ERROR;
}
switch(*p) {
case 'b':
o.byte = (uint8_t)va_arg(ap, unsigned int);
rc = ssh_buffer_add_u8(buffer, o.byte);
break;
case 'w':
o.word = (uint16_t)va_arg(ap, unsigned int);
o.word = htons(o.word);
rc = ssh_buffer_add_u16(buffer, o.word);
break;
case 'd':
o.dword = va_arg(ap, uint32_t);
o.dword = htonl(o.dword);
rc = ssh_buffer_add_u32(buffer, o.dword);
break;
case 'q':
o.qword = va_arg(ap, uint64_t);
o.qword = htonll(o.qword);
rc = ssh_buffer_add_u64(buffer, o.qword);
break;
case 'S':
o.string = va_arg(ap, ssh_string);
rc = ssh_buffer_add_ssh_string(buffer, o.string);
o.string = NULL;
break;
case 's':
cstring = va_arg(ap, char *);
len = strlen(cstring);
rc = ssh_buffer_add_u32(buffer, htonl(len));
if (rc == SSH_OK){
rc = ssh_buffer_add_data(buffer, cstring, len);
}
cstring = NULL;
break;
case 'P':
len = va_arg(ap, size_t);
o.data = va_arg(ap, void *);
count++;
rc = ssh_buffer_add_data(buffer, o.data, len);
o.data = NULL;
break;
case 'B':
b = va_arg(ap, bignum);
o.string = ssh_make_bignum_string(b);
if(o.string == NULL){
rc = SSH_ERROR;
break;
}
rc = ssh_buffer_add_ssh_string(buffer, o.string);
SAFE_FREE(o.string);
break;
case 't':
cstring = va_arg(ap, char *);
len = strlen(cstring);
rc = ssh_buffer_add_data(buffer, cstring, len);
cstring = NULL;
break;
default:
SSH_LOG(SSH_LOG_TRACE, "Invalid buffer format %c", *p);
rc = SSH_ERROR;
}
if (rc != SSH_OK){
break;
}
}
if (argc != count) {
return SSH_ERROR;
}
if (rc != SSH_ERROR){
uint32_t canary = va_arg(ap, uint32_t);
if (canary != SSH_BUFFER_PACK_END) {
abort();
}
}
return rc;
}
int _ssh_buffer_pack(struct ssh_buffer_struct *buffer,
const char *format,
size_t argc,
...)
{
va_list ap;
int rc;
if (argc > 256) {
return SSH_ERROR;
}
va_start(ap, argc);
rc = ssh_buffer_pack_allocate_va(buffer, format, argc, ap);
va_end(ap);
if (rc != SSH_OK) {
return rc;
}
va_start(ap, argc);
rc = ssh_buffer_pack_va(buffer, format, argc, ap);
va_end(ap);
return rc;
}
int ssh_buffer_unpack_va(struct ssh_buffer_struct *buffer,
const char *format,
size_t argc,
va_list ap)
{
int rc = SSH_ERROR;
const char *p = format, *last;
union {
uint8_t *byte;
uint16_t *word;
uint32_t *dword;
uint64_t *qword;
ssh_string *string;
char **cstring;
bignum *bignum;
void **data;
} o;
size_t len;
uint32_t rlen, max_len;
ssh_string tmp_string = NULL;
va_list ap_copy;
size_t count;
max_len = ssh_buffer_get_len(buffer);
va_copy(ap_copy, ap);
if (argc > 256) {
rc = SSH_ERROR;
goto cleanup;
}
for (count = 0; *p != '\0'; p++, count++) {
if (count > argc) {
rc = SSH_ERROR;
goto cleanup;
}
rc = SSH_ERROR;
switch (*p) {
case 'b':
o.byte = va_arg(ap, uint8_t *);
rlen = ssh_buffer_get_u8(buffer, o.byte);
rc = rlen==1 ? SSH_OK : SSH_ERROR;
break;
case 'w':
o.word = va_arg(ap, uint16_t *);
rlen = ssh_buffer_get_data(buffer, o.word, sizeof(uint16_t));
if (rlen == 2) {
*o.word = ntohs(*o.word);
rc = SSH_OK;
}
break;
case 'd':
o.dword = va_arg(ap, uint32_t *);
rlen = ssh_buffer_get_u32(buffer, o.dword);
if (rlen == 4) {
*o.dword = ntohl(*o.dword);
rc = SSH_OK;
}
break;
case 'q':
o.qword = va_arg(ap, uint64_t*);
rlen = ssh_buffer_get_u64(buffer, o.qword);
if (rlen == 8) {
*o.qword = ntohll(*o.qword);
rc = SSH_OK;
}
break;
case 'B':
o.bignum = va_arg(ap, bignum *);
*o.bignum = NULL;
tmp_string = ssh_buffer_get_ssh_string(buffer);
if (tmp_string == NULL) {
break;
}
*o.bignum = ssh_make_string_bn(tmp_string);
ssh_string_burn(tmp_string);
SSH_STRING_FREE(tmp_string);
rc = (*o.bignum != NULL) ? SSH_OK : SSH_ERROR;
break;
case 'S':
o.string = va_arg(ap, ssh_string *);
*o.string = ssh_buffer_get_ssh_string(buffer);
rc = *o.string != NULL ? SSH_OK : SSH_ERROR;
o.string = NULL;
break;
case 's': {
uint32_t u32len = 0;
o.cstring = va_arg(ap, char **);
*o.cstring = NULL;
rlen = ssh_buffer_get_u32(buffer, &u32len);
if (rlen != 4){
break;
}
len = ntohl(u32len);
if (len > max_len - 1) {
break;
}
rc = ssh_buffer_validate_length(buffer, len);
if (rc != SSH_OK) {
break;
}
*o.cstring = malloc(len + 1);
if (*o.cstring == NULL){
rc = SSH_ERROR;
break;
}
rlen = ssh_buffer_get_data(buffer, *o.cstring, len);
if (rlen != len){
SAFE_FREE(*o.cstring);
rc = SSH_ERROR;
break;
}
(*o.cstring)[len] = '\0';
o.cstring = NULL;
rc = SSH_OK;
break;
}
case 'P':
len = va_arg(ap, size_t);
if (len > max_len - 1) {
rc = SSH_ERROR;
break;
}
rc = ssh_buffer_validate_length(buffer, len);
if (rc != SSH_OK) {
break;
}
o.data = va_arg(ap, void **);
count++;
*o.data = malloc(len);
if(*o.data == NULL){
rc = SSH_ERROR;
break;
}
rlen = ssh_buffer_get_data(buffer, *o.data, len);
if (rlen != len){
SAFE_FREE(*o.data);
rc = SSH_ERROR;
break;
}
o.data = NULL;
rc = SSH_OK;
break;
default:
SSH_LOG(SSH_LOG_TRACE, "Invalid buffer format %c", *p);
}
if (rc != SSH_OK) {
break;
}
}
if (argc != count) {
rc = SSH_ERROR;
}
cleanup:
if (rc != SSH_ERROR){
uint32_t canary = va_arg(ap, uint32_t);
if (canary != SSH_BUFFER_PACK_END){
abort();
}
}
if (rc != SSH_OK){
last = p;
for(p=format;p<last;++p){
switch(*p){
case 'b':
o.byte = va_arg(ap_copy, uint8_t *);
if (buffer->secure) {
explicit_bzero(o.byte, sizeof(uint8_t));
break;
}
break;
case 'w':
o.word = va_arg(ap_copy, uint16_t *);
if (buffer->secure) {
explicit_bzero(o.word, sizeof(uint16_t));
break;
}
break;
case 'd':
o.dword = va_arg(ap_copy, uint32_t *);
if (buffer->secure) {
explicit_bzero(o.dword, sizeof(uint32_t));
break;
}
break;
case 'q':
o.qword = va_arg(ap_copy, uint64_t *);
if (buffer->secure) {
explicit_bzero(o.qword, sizeof(uint64_t));
break;
}
break;
case 'B':
o.bignum = va_arg(ap_copy, bignum *);
bignum_safe_free(*o.bignum);
break;
case 'S':
o.string = va_arg(ap_copy, ssh_string *);
if (buffer->secure) {
ssh_string_burn(*o.string);
}
SAFE_FREE(*o.string);
break;
case 's':
o.cstring = va_arg(ap_copy, char **);
if (buffer->secure) {
explicit_bzero(*o.cstring, strlen(*o.cstring));
}
SAFE_FREE(*o.cstring);
break;
case 'P':
len = va_arg(ap_copy, size_t);
o.data = va_arg(ap_copy, void **);
if (buffer->secure) {
explicit_bzero(*o.data, len);
}
SAFE_FREE(*o.data);
break;
default:
(void)va_arg(ap_copy, void *);
break;
}
}
}
va_end(ap_copy);
return rc;
}
int _ssh_buffer_unpack(struct ssh_buffer_struct *buffer,
const char *format,
size_t argc,
...)
{
va_list ap;
int rc;
va_start(ap, argc);
rc = ssh_buffer_unpack_va(buffer, format, argc, ap);
va_end(ap);
return rc;
}