#include <filesystem>
#include <fstream>
#include <functional>
#include <random>
#include <thread>
#include <gtest/gtest.h>
#include <openssl/crypto.h>
#if MGCLIENT_ON_WINDOWS
#define WIN32_LEAN_AND_MEAN
#include <openssl/x509.h>
#endif
extern "C" {
#include "mgclient.h"
#include "mgcommon.h"
#include "mgsocket.h"
#include "mgtransport.h"
}
#include "gmock_wrapper.h"
#include "test-common.hpp"
std::pair<X509 *, EVP_PKEY *> MakeCertAndKey(const char *name) {
#if OPENSSL_VERSION_NUMBER < 0x30000000L
RSA *rsa = RSA_new();
BIGNUM *bne = BN_new();
BN_set_word(bne, RSA_F4);
RSA_generate_key_ex(rsa, 2048, bne, nullptr);
BN_free(bne);
EVP_PKEY *pkey = EVP_PKEY_new();
EVP_PKEY_assign_RSA(pkey, rsa);
#else
EVP_PKEY *pkey = EVP_RSA_gen(2048);
#endif
X509 *x509 = X509_new();
ASN1_INTEGER_set(X509_get_serialNumber(x509), 1);
X509_gmtime_adj(X509_get_notBefore(x509), 0);
X509_gmtime_adj(X509_get_notAfter(x509), 86400L);
X509_set_pubkey(x509, pkey);
X509_NAME *subject_name;
subject_name = X509_get_subject_name(x509);
X509_NAME_add_entry_by_txt(subject_name, "C", MBSTRING_ASC,
(unsigned char *)"CA", -1, -1, 0);
X509_NAME_add_entry_by_txt(subject_name, "O", MBSTRING_ASC,
(unsigned char *)name, -1, -1, 0);
X509_NAME_add_entry_by_txt(subject_name, "CN", MBSTRING_ASC,
(unsigned char *)"localhost", -1, -1, 0);
return std::make_pair(x509, pkey);
}
extern int mg_init_ssl;
class SecureTransportTest : public ::testing::Test {
protected:
static void SetUpTestCase() {
#if OPENSSL_VERSION_NUMBER < 0x10100000L
mg_init_ssl = 0;
SSL_library_init();
SSL_load_error_strings();
ERR_load_crypto_strings();
#endif
}
virtual void SetUp() override {
OPENSSL_init_crypto(OPENSSL_INIT_ADD_ALL_CIPHERS, nullptr);
int sv[2];
ASSERT_EQ(mg_socket_pair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
sc = sv[0];
ss = sv[1];
std::tie(server_cert, server_key) = MakeCertAndKey("server");
X509_set_issuer_name(server_cert, X509_get_subject_name(server_cert));
X509_sign(server_cert, server_key, EVP_sha512());
X509 *client_cert;
EVP_PKEY *client_key;
std::tie(client_cert, client_key) = MakeCertAndKey("client");
EVP_PKEY *ca_key;
std::tie(ca_cert, ca_key) = MakeCertAndKey("ca");
X509_set_issuer_name(ca_cert, X509_get_subject_name(ca_cert));
X509_sign(ca_cert, ca_key, EVP_sha512());
X509_set_issuer_name(client_cert, X509_get_subject_name(ca_cert));
X509_sign(client_cert, ca_key, EVP_sha512());
client_cert_path = std::filesystem::temp_directory_path() / "client.crt";
BIO *cert_file = BIO_new_file(client_cert_path.string().c_str(), "w");
PEM_write_bio_X509(cert_file, client_cert);
BIO_free(cert_file);
client_key_path = std::filesystem::temp_directory_path() / "client.key";
BIO *key_file = BIO_new_file(client_key_path.string().c_str(), "w");
PEM_write_bio_PrivateKey(key_file, client_key, nullptr, nullptr, 0, nullptr,
nullptr);
BIO_free(key_file);
X509_free(client_cert);
EVP_PKEY_free(client_key);
EVP_PKEY_free(ca_key);
}
virtual void TearDown() override {
X509_free(server_cert);
X509_free(ca_cert);
EVP_PKEY_free(server_key);
}
void RunServer(const std::function<void(void)> &server_function) {
server_thread = std::thread(server_function);
}
void StopServer() {
if (server_thread.joinable()) {
server_thread.join();
}
}
X509 *server_cert;
X509 *ca_cert;
EVP_PKEY *server_key;
std::filesystem::path client_cert_path;
std::filesystem::path client_key_path;
int sc;
int ss;
std::thread server_thread;
tracking_allocator allocator;
};
TEST_F(SecureTransportTest, NoCertificate) {
RunServer([this] {
SSL_CTX *ctx;
#if OPENSSL_VERSION_NUMBER < 0x10100000L
ctx = SSL_CTX_new(SSLv23_server_method());
#else
ctx = SSL_CTX_new(TLS_server_method());
#endif
SSL_CTX_use_certificate(ctx, server_cert);
SSL_CTX_use_PrivateKey(ctx, server_key);
ASSERT_TRUE(ctx);
SSL *ssl = SSL_new(ctx);
ASSERT_TRUE(ssl);
SSL_set_fd(ssl, ss);
ASSERT_EQ(SSL_accept(ssl), 1);
char request[5];
ASSERT_GT(SSL_read(ssl, request, 5), 0);
ASSERT_EQ(strncmp(request, "hello", 5), 0);
ASSERT_GT(SSL_write(ssl, "hello", 5), 0);
SSL_free(ssl);
SSL_CTX_free(ctx);
});
mg_transport *transport;
ASSERT_EQ(mg_secure_transport_init(sc, nullptr, nullptr,
(mg_secure_transport **)&transport,
(mg_allocator *)&allocator),
0);
ASSERT_EQ(mg_transport_send((mg_transport *)transport, "hello", 5), 0);
char response[5];
ASSERT_EQ(mg_transport_recv(transport, response, 5), 0);
ASSERT_EQ(strncmp(response, "hello", 5), 0);
mg_transport_destroy(transport);
StopServer();
}
TEST_F(SecureTransportTest, WithCertificate) {
RunServer([this] {
SSL_CTX *ctx;
#if OPENSSL_VERSION_NUMBER < 0x10100000L
ctx = SSL_CTX_new(SSLv23_server_method());
#else
ctx = SSL_CTX_new(TLS_server_method());
#endif
SSL_CTX_use_certificate(ctx, server_cert);
SSL_CTX_use_PrivateKey(ctx, server_key);
{
X509_STORE *store = SSL_CTX_get_cert_store(ctx);
X509_STORE_add_cert(store, ca_cert);
}
SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
nullptr);
ASSERT_TRUE(ctx);
SSL *ssl = SSL_new(ctx);
ASSERT_TRUE(ssl);
SSL_set_fd(ssl, ss);
if (SSL_accept(ssl) != 1) {
ERR_print_errors_fp(stderr);
FAIL();
}
char request[5];
ASSERT_GT(SSL_read(ssl, request, 5), 0);
ASSERT_EQ(strncmp(request, "hello", 5), 0);
ASSERT_GT(SSL_write(ssl, "hello", 5), 0);
SSL_free(ssl);
SSL_CTX_free(ctx);
});
mg_transport *transport;
ASSERT_EQ(mg_secure_transport_init(sc, client_cert_path.string().c_str(),
client_key_path.string().c_str(),
(mg_secure_transport **)&transport,
(mg_allocator *)&allocator),
0);
ASSERT_EQ(mg_transport_send((mg_transport *)transport, "hello", 5), 0);
char response[5];
ASSERT_EQ(mg_transport_recv(transport, response, 5), 0);
ASSERT_EQ(strncmp(response, "hello", 5), 0);
mg_transport_destroy(transport);
StopServer();
}