#include "config.h"
#include <errno.h>
#include <stdlib.h>
#include "libssh/priv.h"
#include "libssh/libssh.h"
#include "libssh/poll.h"
#include "libssh/socket.h"
#include "libssh/session.h"
#include "libssh/misc.h"
#ifdef WITH_SERVER
#include "libssh/server.h"
#endif
#ifndef SSH_POLL_CTX_CHUNK
#define SSH_POLL_CTX_CHUNK 5
#endif
struct ssh_poll_handle_struct {
ssh_poll_ctx ctx;
ssh_session session;
union {
socket_t fd;
size_t idx;
} x;
short events;
uint32_t lock_cnt;
ssh_poll_callback cb;
void *cb_data;
};
struct ssh_poll_ctx_struct {
ssh_poll_handle *pollptrs;
ssh_pollfd_t *pollfds;
size_t polls_allocated;
size_t polls_used;
size_t chunk_size;
};
#ifdef HAVE_POLL
#include <poll.h>
void ssh_poll_init(void)
{
return;
}
void ssh_poll_cleanup(void)
{
return;
}
int ssh_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout)
{
return poll((struct pollfd *) fds, nfds, timeout);
}
#else
typedef int (*poll_fn)(ssh_pollfd_t *, nfds_t, int);
static poll_fn ssh_poll_emu;
#include <sys/types.h>
#include <stdbool.h>
#ifdef _WIN32
#ifndef STRICT
#define STRICT
#endif
#include <time.h>
#include <windows.h>
#include <winsock2.h>
#else
#include <sys/select.h>
#include <sys/socket.h>
# ifdef HAVE_SYS_TIME_H
# include <sys/time.h>
# endif
#endif
#ifdef HAVE_UNISTD_H
#include <unistd.h>
#endif
static bool bsd_socket_not_connected(int sock_err)
{
switch (sock_err) {
#ifdef _WIN32
case WSAENOTCONN:
#else
case ENOTCONN:
#endif
return true;
default:
return false;
}
return false;
}
static bool bsd_socket_reset(int sock_err)
{
switch (sock_err) {
#ifdef _WIN32
case WSAECONNABORTED:
case WSAECONNRESET:
case WSAENETRESET:
case WSAESHUTDOWN:
case WSAECONNREFUSED:
case WSAETIMEDOUT:
#else
case ECONNABORTED:
case ECONNRESET:
case ENETRESET:
case ESHUTDOWN:
#endif
return true;
default:
return false;
}
return false;
}
static short bsd_socket_compute_revents(int fd, short events)
{
int save_errno = errno;
int sock_errno = errno;
char data[64] = {0};
short revents = 0;
int flags = MSG_PEEK;
int ret;
#ifdef MSG_NOSIGNAL
flags |= MSG_NOSIGNAL;
#endif
#ifdef _WIN32
WSASetLastError(0);
#endif
ret = recv(fd, data, 64, flags);
errno = save_errno;
#ifdef _WIN32
sock_errno = WSAGetLastError();
WSASetLastError(0);
#endif
if (ret > 0 || bsd_socket_not_connected(sock_errno)) {
revents = (POLLIN | POLLRDNORM) & events;
} else if (ret == 0 || bsd_socket_reset(sock_errno)) {
errno = sock_errno;
revents = POLLHUP;
} else {
revents = POLLERR;
}
return revents;
}
static int bsd_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout)
{
fd_set readfds, writefds, exceptfds;
struct timeval tv, *ptv = NULL;
socket_t max_fd;
int rc;
nfds_t i;
if (fds == NULL) {
errno = EFAULT;
return -1;
}
ZERO_STRUCT(readfds);
FD_ZERO(&readfds);
ZERO_STRUCT(writefds);
FD_ZERO(&writefds);
ZERO_STRUCT(exceptfds);
FD_ZERO(&exceptfds);
for (rc = -1, max_fd = 0, i = 0; i < nfds; i++) {
if (fds[i].fd == SSH_INVALID_SOCKET) {
continue;
}
#ifndef _WIN32
if (fds[i].fd >= FD_SETSIZE) {
rc = -1;
break;
}
#endif
FD_SET (fds[i].fd, &readfds);
if (fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND)) {
FD_SET (fds[i].fd, &writefds);
}
if (fds[i].events & (POLLPRI | POLLRDBAND)) {
FD_SET (fds[i].fd, &exceptfds);
}
if (fds[i].fd > max_fd) {
max_fd = fds[i].fd;
rc = 0;
}
}
if (max_fd == SSH_INVALID_SOCKET || rc == -1) {
errno = EINVAL;
return -1;
}
if (timeout < 0) {
ptv = NULL;
} else {
ptv = &tv;
if (timeout == 0) {
tv.tv_sec = 0;
tv.tv_usec = 0;
} else {
tv.tv_sec = timeout / 1000;
tv.tv_usec = (timeout % 1000) * 1000;
}
}
rc = select(max_fd + 1, &readfds, &writefds, &exceptfds, ptv);
if (rc < 0) {
return -1;
}
if (rc == 0) {
return 0;
}
for (rc = 0, i = 0; i < nfds; i++) {
if (fds[i].fd >= 0) {
fds[i].revents = 0;
if (FD_ISSET(fds[i].fd, &readfds)) {
fds[i].revents = bsd_socket_compute_revents(fds[i].fd,
fds[i].events);
}
if (FD_ISSET(fds[i].fd, &writefds)) {
fds[i].revents |= fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND);
}
if (FD_ISSET(fds[i].fd, &exceptfds)) {
fds[i].revents |= fds[i].events & (POLLPRI | POLLRDBAND);
}
if (fds[i].revents != 0) {
rc++;
}
} else {
fds[i].revents = POLLNVAL;
}
}
return rc;
}
void ssh_poll_init(void) {
ssh_poll_emu = bsd_poll;
}
void ssh_poll_cleanup(void) {
ssh_poll_emu = bsd_poll;
}
int ssh_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) {
return (ssh_poll_emu)(fds, nfds, timeout);
}
#endif
ssh_poll_handle
ssh_poll_new(socket_t fd, short events, ssh_poll_callback cb, void *userdata)
{
ssh_poll_handle p;
p = malloc(sizeof(struct ssh_poll_handle_struct));
if (p == NULL) {
return NULL;
}
ZERO_STRUCTP(p);
p->x.fd = fd;
p->events = events;
p->cb = cb;
p->cb_data = userdata;
return p;
}
void ssh_poll_free(ssh_poll_handle p)
{
if (p->ctx != NULL) {
ssh_poll_ctx_remove(p->ctx, p);
p->ctx = NULL;
}
SAFE_FREE(p);
}
ssh_poll_ctx ssh_poll_get_ctx(ssh_poll_handle p)
{
return p->ctx;
}
short ssh_poll_get_events(ssh_poll_handle p)
{
return p->events;
}
void ssh_poll_set_events(ssh_poll_handle p, short events)
{
p->events = events;
if (p->ctx != NULL) {
if (p->lock_cnt == 0) {
p->ctx->pollfds[p->x.idx].events = events;
} else if (!(p->ctx->pollfds[p->x.idx].events & POLLOUT)) {
p->ctx->pollfds[p->x.idx].events = events & POLLOUT;
}
}
}
void ssh_poll_set_fd(ssh_poll_handle p, socket_t fd)
{
if (p->ctx != NULL) {
p->ctx->pollfds[p->x.idx].fd = fd;
} else {
p->x.fd = fd;
}
}
void ssh_poll_add_events(ssh_poll_handle p, short events)
{
ssh_poll_set_events(p, ssh_poll_get_events(p) | events);
}
void ssh_poll_remove_events(ssh_poll_handle p, short events)
{
ssh_poll_set_events(p, ssh_poll_get_events(p) & ~events);
}
socket_t ssh_poll_get_fd(ssh_poll_handle p)
{
if (p->ctx != NULL) {
return p->ctx->pollfds[p->x.idx].fd;
}
return p->x.fd;
}
void ssh_poll_set_callback(ssh_poll_handle p, ssh_poll_callback cb, void *userdata)
{
if (cb != NULL) {
p->cb = cb;
p->cb_data = userdata;
}
}
ssh_poll_ctx ssh_poll_ctx_new(size_t chunk_size)
{
ssh_poll_ctx ctx;
ctx = malloc(sizeof(struct ssh_poll_ctx_struct));
if (ctx == NULL) {
return NULL;
}
ZERO_STRUCTP(ctx);
if (chunk_size == 0) {
chunk_size = SSH_POLL_CTX_CHUNK;
}
ctx->chunk_size = chunk_size;
return ctx;
}
void ssh_poll_ctx_free(ssh_poll_ctx ctx)
{
if (ctx->polls_allocated > 0) {
while (ctx->polls_used > 0){
ssh_poll_handle p = ctx->pollptrs[0];
ssh_poll_free(p);
}
SAFE_FREE(ctx->pollptrs);
SAFE_FREE(ctx->pollfds);
}
SAFE_FREE(ctx);
}
static int ssh_poll_ctx_resize(ssh_poll_ctx ctx, size_t new_size)
{
ssh_poll_handle *pollptrs;
ssh_pollfd_t *pollfds;
pollptrs = realloc(ctx->pollptrs, sizeof(ssh_poll_handle) * new_size);
if (pollptrs == NULL) {
return -1;
}
ctx->pollptrs = pollptrs;
pollfds = realloc(ctx->pollfds, sizeof(ssh_pollfd_t) * new_size);
if (pollfds == NULL) {
pollptrs = realloc(ctx->pollptrs, sizeof(ssh_poll_handle) * ctx->polls_allocated);
if (pollptrs == NULL) {
return -1;
}
ctx->pollptrs = pollptrs;
return -1;
}
ctx->pollfds = pollfds;
ctx->polls_allocated = new_size;
return 0;
}
int ssh_poll_ctx_add(ssh_poll_ctx ctx, ssh_poll_handle p)
{
socket_t fd;
if (p->ctx != NULL) {
return -1;
}
if (ctx->polls_used == ctx->polls_allocated &&
ssh_poll_ctx_resize(ctx, ctx->polls_allocated + ctx->chunk_size) < 0) {
return -1;
}
fd = p->x.fd;
p->x.idx = ctx->polls_used++;
ctx->pollptrs[p->x.idx] = p;
ctx->pollfds[p->x.idx].fd = fd;
ctx->pollfds[p->x.idx].events = p->events;
ctx->pollfds[p->x.idx].revents = 0;
p->ctx = ctx;
return 0;
}
int ssh_poll_ctx_add_socket (ssh_poll_ctx ctx, ssh_socket s)
{
ssh_poll_handle p = NULL;
int ret;
p = ssh_socket_get_poll_handle(s);
if (p == NULL) {
return -1;
}
ret = ssh_poll_ctx_add(ctx,p);
return ret;
}
void ssh_poll_ctx_remove(ssh_poll_ctx ctx, ssh_poll_handle p)
{
size_t i;
i = p->x.idx;
p->x.fd = ctx->pollfds[i].fd;
p->ctx = NULL;
ctx->polls_used--;
if (ctx->polls_used > 0 && ctx->polls_used != i) {
ctx->pollfds[i] = ctx->pollfds[ctx->polls_used];
ctx->pollptrs[i] = ctx->pollptrs[ctx->polls_used];
ctx->pollptrs[i]->x.idx = i;
}
if (ctx->polls_allocated - ctx->polls_used > ctx->chunk_size) {
ssh_poll_ctx_resize(ctx, ctx->polls_allocated - ctx->chunk_size);
}
}
int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout)
{
int rc;
size_t i, used;
ssh_poll_handle p;
socket_t fd;
int revents;
struct ssh_timestamp ts;
if (ctx->polls_used == 0) {
return SSH_ERROR;
}
for (i = 0; i < ctx->polls_used; i++) {
if (ctx->pollptrs[i]->lock_cnt > 0) {
ctx->pollfds[i].events &= POLLOUT;
}
}
ssh_timestamp_init(&ts);
do {
int tm = ssh_timeout_update(&ts, timeout);
rc = ssh_poll(ctx->pollfds, ctx->polls_used, tm);
} while (rc == -1 && errno == EINTR);
if (rc < 0) {
return SSH_ERROR;
}
if (rc == 0) {
return SSH_AGAIN;
}
used = ctx->polls_used;
for (i = 0; i < used && rc > 0; ) {
revents = ctx->pollfds[i].revents;
if (ctx->pollptrs[i]->lock_cnt > 2) {
revents &= POLLOUT;
}
if (revents == 0) {
i++;
} else {
int ret;
p = ctx->pollptrs[i];
fd = ctx->pollfds[i].fd;
ctx->pollfds[i].events = 0;
p->lock_cnt++;
if (p->cb && (ret = p->cb(p, fd, revents, p->cb_data)) < 0) {
if (ret == -2) {
return -1;
}
used = ctx->polls_used;
i = 0;
} else {
ctx->pollfds[i].revents = 0;
ctx->pollfds[i].events = p->events;
p->lock_cnt--;
i++;
}
rc--;
}
}
return rc;
}
ssh_poll_ctx ssh_poll_get_default_ctx(ssh_session session)
{
if(session->default_poll_ctx != NULL)
return session->default_poll_ctx;
session->default_poll_ctx = ssh_poll_ctx_new(2);
return session->default_poll_ctx;
}
struct ssh_event_fd_wrapper {
ssh_event_callback cb;
void * userdata;
};
struct ssh_event_struct {
ssh_poll_ctx ctx;
#ifdef WITH_SERVER
struct ssh_list *sessions;
#endif
};
ssh_event ssh_event_new(void)
{
ssh_event event;
event = malloc(sizeof(struct ssh_event_struct));
if (event == NULL) {
return NULL;
}
ZERO_STRUCTP(event);
event->ctx = ssh_poll_ctx_new(2);
if(event->ctx == NULL) {
free(event);
return NULL;
}
#ifdef WITH_SERVER
event->sessions = ssh_list_new();
if(event->sessions == NULL) {
ssh_poll_ctx_free(event->ctx);
free(event);
return NULL;
}
#endif
return event;
}
static int
ssh_event_fd_wrapper_callback(ssh_poll_handle p, socket_t fd, int revents,
void *userdata)
{
struct ssh_event_fd_wrapper *pw = (struct ssh_event_fd_wrapper *)userdata;
(void)p;
if (pw->cb != NULL) {
return pw->cb(fd, revents, pw->userdata);
}
return 0;
}
int
ssh_event_add_fd(ssh_event event, socket_t fd, short events,
ssh_event_callback cb, void *userdata)
{
ssh_poll_handle p;
struct ssh_event_fd_wrapper *pw;
if(event == NULL || event->ctx == NULL || cb == NULL
|| fd == SSH_INVALID_SOCKET) {
return SSH_ERROR;
}
pw = malloc(sizeof(struct ssh_event_fd_wrapper));
if(pw == NULL) {
return SSH_ERROR;
}
pw->cb = cb;
pw->userdata = userdata;
p = ssh_poll_new(fd, events, ssh_event_fd_wrapper_callback, pw);
if(p == NULL) {
free(pw);
return SSH_ERROR;
}
if(ssh_poll_ctx_add(event->ctx, p) < 0) {
free(pw);
ssh_poll_free(p);
return SSH_ERROR;
}
return SSH_OK;
}
int ssh_event_add_poll(ssh_event event, ssh_poll_handle p)
{
return ssh_poll_ctx_add(event->ctx, p);
}
void ssh_event_remove_poll(ssh_event event, ssh_poll_handle p)
{
ssh_poll_ctx_remove(event->ctx,p);
}
int ssh_event_add_session(ssh_event event, ssh_session session)
{
ssh_poll_handle p;
#ifdef WITH_SERVER
struct ssh_iterator *iterator;
#endif
if(event == NULL || event->ctx == NULL || session == NULL) {
return SSH_ERROR;
}
if(session->default_poll_ctx == NULL) {
return SSH_ERROR;
}
while (session->default_poll_ctx->polls_used > 0) {
p = session->default_poll_ctx->pollptrs[0];
ssh_poll_ctx_remove(session->default_poll_ctx, p);
ssh_poll_ctx_add(event->ctx, p);
p->session = session;
}
#ifdef WITH_SERVER
iterator = ssh_list_get_iterator(event->sessions);
while(iterator != NULL) {
if((ssh_session)iterator->data == session) {
return SSH_OK;
}
iterator = iterator->next;
}
if(ssh_list_append(event->sessions, session) == SSH_ERROR) {
return SSH_ERROR;
}
#endif
return SSH_OK;
}
int ssh_event_add_connector(ssh_event event, ssh_connector connector)
{
return ssh_connector_set_event(connector, event);
}
int ssh_event_dopoll(ssh_event event, int timeout)
{
int rc;
if (event == NULL || event->ctx == NULL) {
return SSH_ERROR;
}
rc = ssh_poll_ctx_dopoll(event->ctx, timeout);
return rc;
}
int ssh_event_remove_fd(ssh_event event, socket_t fd)
{
register size_t i, used;
int rc = SSH_ERROR;
if(event == NULL || event->ctx == NULL) {
return SSH_ERROR;
}
used = event->ctx->polls_used;
for (i = 0; i < used; i++) {
if(fd == event->ctx->pollfds[i].fd) {
ssh_poll_handle p = event->ctx->pollptrs[i];
if (p->session != NULL){
continue;
}
if (p->cb == ssh_event_fd_wrapper_callback) {
struct ssh_event_fd_wrapper *pw = p->cb_data;
SAFE_FREE(pw);
}
ssh_poll_free(p);
rc = SSH_OK;
used = event->ctx->polls_used;
i = 0;
}
}
return rc;
}
int ssh_event_remove_session(ssh_event event, ssh_session session)
{
ssh_poll_handle p;
register size_t i, used;
int rc = SSH_ERROR;
#ifdef WITH_SERVER
struct ssh_iterator *iterator;
#endif
if (event == NULL || event->ctx == NULL || session == NULL) {
return SSH_ERROR;
}
used = event->ctx->polls_used;
for (i = 0; i < used; i++) {
p = event->ctx->pollptrs[i];
if (p->session == session) {
ssh_poll_ctx_remove(event->ctx, p);
p->session = NULL;
ssh_poll_ctx_add(session->default_poll_ctx, p);
rc = SSH_OK;
used = event->ctx->polls_used;
i = 0;
}
}
#ifdef WITH_SERVER
iterator = ssh_list_get_iterator(event->sessions);
while (iterator != NULL) {
if ((ssh_session)iterator->data == session) {
ssh_list_remove(event->sessions, iterator);
break;
}
iterator = iterator->next;
}
#endif
return rc;
}
int ssh_event_remove_connector(ssh_event event, ssh_connector connector)
{
(void)event;
return ssh_connector_remove_event(connector);
}
void ssh_event_free(ssh_event event)
{
size_t used, i;
ssh_poll_handle p;
if (event == NULL) {
return;
}
if (event->ctx != NULL) {
used = event->ctx->polls_used;
for (i = 0; i < used; i++) {
p = event->ctx->pollptrs[i];
if (p->session != NULL) {
ssh_poll_ctx_remove(event->ctx, p);
ssh_poll_ctx_add(p->session->default_poll_ctx, p);
p->session = NULL;
used = 0;
}
}
ssh_poll_ctx_free(event->ctx);
}
#ifdef WITH_SERVER
if (event->sessions != NULL) {
ssh_list_free(event->sessions);
}
#endif
free(event);
}