#include <interfaces/init.h>
#include <ipc/capnp/context.h>
#include <ipc/capnp/init.capnp.h>
#include <ipc/capnp/init.capnp.proxy.h>
#include <ipc/capnp/protocol.h>
#include <ipc/exception.h>
#include <ipc/protocol.h>
#include <kj/async.h>
#include <logging.h>
#include <mp/proxy-io.h>
#include <mp/proxy-types.h>
#include <mp/util.h>
#include <util/threadnames.h>
#include <cassert>
#include <cerrno>
#include <future>
#include <memory>
#include <mutex>
#include <optional>
#include <string>
#include <sys/socket.h>
#include <system_error>
#include <thread>
namespace ipc {
namespace capnp {
namespace {
mp::Log GetRequestedIPCLogLevel()
{
if (LogAcceptCategory(BCLog::IPC, BCLog::Level::Trace)) return mp::Log::Trace;
if (LogAcceptCategory(BCLog::IPC, BCLog::Level::Debug)) return mp::Log::Debug;
return mp::Log::Info;
}
void IpcLogFn(mp::LogMessage message)
{
switch (message.level) {
case mp::Log::Trace:
LogTrace(BCLog::IPC, "%s", message.message);
return;
case mp::Log::Debug:
LogDebug(BCLog::IPC, "%s", message.message);
return;
case mp::Log::Info:
LogInfo("ipc: %s", message.message);
return;
case mp::Log::Warning:
LogWarning("ipc: %s", message.message);
return;
case mp::Log::Error:
LogError("ipc: %s", message.message);
return;
case mp::Log::Raise:
LogError("ipc: %s", message.message);
throw Exception(message.message);
}
LogTrace(BCLog::IPC, "%s", message.message);
}
class CapnpProtocol : public Protocol
{
public:
~CapnpProtocol() noexcept(true)
{
m_loop_ref.reset();
if (m_loop_thread.joinable()) m_loop_thread.join();
assert(!m_loop);
};
std::unique_ptr<interfaces::Init> connect(int fd, const char* exe_name) override
{
startLoop(exe_name);
return mp::ConnectStream<messages::Init>(*m_loop, fd);
}
void listen(int listen_fd, const char* exe_name, interfaces::Init& init) override
{
startLoop(exe_name);
if (::listen(listen_fd, 5) != 0) {
throw std::system_error(errno, std::system_category());
}
mp::ListenConnections<messages::Init>(*m_loop, listen_fd, init);
}
void serve(int fd, const char* exe_name, interfaces::Init& init, const std::function<void()>& ready_fn = {}) override
{
assert(!m_loop);
mp::g_thread_context.thread_name = mp::ThreadName(exe_name);
mp::LogOptions opts = {
.log_fn = IpcLogFn,
.log_level = GetRequestedIPCLogLevel()
};
m_loop.emplace(exe_name, std::move(opts), &m_context);
if (ready_fn) ready_fn();
mp::ServeStream<messages::Init>(*m_loop, fd, init);
m_parent_connection = &m_loop->m_incoming_connections.back();
m_loop->loop();
m_loop.reset();
}
void disconnectIncoming() override
{
if (!m_loop) return;
m_loop->sync([&] {
m_loop->m_incoming_connections.remove_if([this](mp::Connection& c) { return &c != m_parent_connection; });
});
}
void addCleanup(std::type_index type, void* iface, std::function<void()> cleanup) override
{
mp::ProxyTypeRegister::types().at(type)(iface).cleanup_fns.emplace_back(std::move(cleanup));
}
Context& context() override { return m_context; }
void startLoop(const char* exe_name)
{
if (m_loop) return;
std::promise<void> promise;
m_loop_thread = std::thread([&] {
util::ThreadRename("capnp-loop");
mp::LogOptions opts = {
.log_fn = IpcLogFn,
.log_level = GetRequestedIPCLogLevel()
};
m_loop.emplace(exe_name, std::move(opts), &m_context);
m_loop_ref.emplace(*m_loop);
promise.set_value();
m_loop->loop();
m_loop.reset();
});
promise.get_future().wait();
}
Context m_context;
std::thread m_loop_thread;
std::optional<mp::EventLoop> m_loop;
std::optional<mp::EventLoopRef> m_loop_ref;
mp::Connection* m_parent_connection{nullptr};
};
}
std::unique_ptr<Protocol> MakeCapnpProtocol() { return std::make_unique<CapnpProtocol>(); }
} }