#include <ableton/discovery/AsioTypes.hpp>
#include <ableton/link_audio/ChannelAnnouncements.hpp>
#include <ableton/link_audio/Channels.hpp>
#include <ableton/link_audio/Id.hpp>
#include <ableton/link_audio/PeerAnnouncement.hpp>
#include <ableton/platforms/stl/Random.hpp>
#include <ableton/test/CatchWrapper.hpp>
#include <ableton/test/serial_io/Fixture.hpp>
#include <algorithm>
#include <chrono>
#include <functional>
#include <memory>
#include <vector>
namespace ableton
{
namespace link_audio
{
TEST_CASE("Channels")
{
struct Callback
{
void operator()() { ++mNumCalls; }
size_t mNumCalls = 0;
};
struct Input
{
Id ident() const { return announcement.ident(); }
PeerAnnouncement announcement;
double networkQuality;
std::shared_ptr<int> pInterface;
discovery::UdpEndpoint from;
int ttl;
};
using Random = ableton::platforms::stl::Random;
using IoContext = test::serial_io::Context;
using TestChannels = Channels<IoContext, std::reference_wrapper<Callback>, int>;
using Channel = TestChannels::Channel;
auto checkChannel = [](const Input& test, const std::vector<Channel>& channels)
{
for (const auto& channel : test.announcement.channels.channels)
{
const auto expected = Channel{channel.name,
channel.id,
test.announcement.peerInfo.name,
test.announcement.nodeId,
test.announcement.sessionId};
CHECK(std::find_if(channels.begin(),
channels.end(),
[&](const auto& c) { return expected == c; })
!= channels.end());
}
};
const auto sessionId = Id::random<Random>();
const auto foo = Input{{Id::random<Random>(),
sessionId,
{"foo"},
{{ChannelAnnouncement{{"fooChannel"}, Id::random<Random>()}}}},
100.,
{},
{discovery::makeAddress("1.1.1.1"), 1},
2};
const auto bar = Input{{Id::random<Random>(),
sessionId,
{"bar"},
{{ChannelAnnouncement{{"barChannel"}, Id::random<Random>()}}}},
{},
{},
{},
5};
const auto gateway1 = discovery::makeAddress("123.123.123.123");
const auto gateway2 = discovery::makeAddress("210.210.210.210");
auto callback = Callback{};
test::serial_io::Fixture io;
auto channels = TestChannels(util::injectVal(io.makeIoContext()), std::ref(callback));
SECTION("EmptyChannelsAfterInit")
{
CHECK(0 == callback.mNumCalls);
}
SECTION("PeerSendHandlerForPeerWithNoChannels")
{
const auto sourceOnlyPeer =
Input{{Id::random<Random>(), sessionId, {"sourceOnly"}, {}},
100.,
{},
{discovery::makeAddress("2.2.2.2"), 2},
2};
auto observer = makeGatewayObserver(channels, gateway1);
sawAnnouncement(observer, sourceOnlyPeer);
const auto handler = channels.peerSendHandler(sourceOnlyPeer.announcement.nodeId);
CHECK(handler.has_value());
CHECK(sourceOnlyPeer.from == handler->endpoint());
const auto uniqueChannels =
channels.uniqueSessionChannels(sourceOnlyPeer.announcement.sessionId);
CHECK(0 == uniqueChannels.size());
SECTION("PruneDisconnectedSourceOnlyPeer")
{
const auto connectedPeers = std::vector<Id>{};
channels.prunePeerChannels(begin(connectedPeers), end(connectedPeers));
const auto prunedHandler =
channels.peerSendHandler(sourceOnlyPeer.announcement.nodeId);
CHECK(!prunedHandler.has_value());
}
}
SECTION("AddChannel")
{
auto observer = makeGatewayObserver(channels, gateway1);
sawAnnouncement(observer, foo);
CHECK(1 == callback.mNumCalls);
SECTION("UniqueChannels")
{
const auto uniqueChannels =
channels.uniqueSessionChannels(foo.announcement.sessionId);
CHECK(1 == uniqueChannels.size());
checkChannel(foo, uniqueChannels);
SECTION("GetSendHandler")
{
const auto handler =
channels.channelSendHandler(foo.announcement.channels.channels[0].id);
CHECK(handler.has_value());
CHECK(foo.from == handler->endpoint());
SECTION("GetSendHandlerForBestConnection")
{
auto fasterFoo = foo;
fasterFoo.networkQuality = foo.networkQuality * 2;
const auto expectedSendHandler =
TestChannels::SendHandler({discovery::makeAddress("3.3.3.3"), 3333}, {});
fasterFoo.from = expectedSendHandler.endpoint();
auto observer2 = makeGatewayObserver(channels, gateway2);
sawAnnouncement(observer2, fasterFoo);
CHECK(2 == channels.sessionChannels(sessionId).size());
CHECK(1 == channels.uniqueSessionChannels(sessionId).size());
auto sendHandler =
channels.channelSendHandler(foo.announcement.channels.channels[0].id);
CHECK(sendHandler.has_value());
CHECK(expectedSendHandler.endpoint() == sendHandler->endpoint());
}
}
}
SECTION("ReAddChannelWithChangedPeerId")
{
auto changedFoo = foo;
changedFoo.announcement.nodeId = Id::random<Random>();
sawAnnouncement(observer, changedFoo);
SECTION("GetSendHandler")
{
const auto handler = channels.peerSendHandler(changedFoo.announcement.nodeId);
CHECK(handler.has_value());
CHECK(foo.from == handler->endpoint());
}
SECTION("GetSendHandlerWithUnknownChannelId")
{
const auto handler = channels.channelSendHandler(Id::random<Random>());
CHECK(!handler.has_value());
}
}
SECTION("UniqueChannelsWithUnknownSessionId")
{
const auto uniqueChannels = channels.uniqueSessionChannels(Id::random<Random>());
CHECK(0 == uniqueChannels.size());
}
SECTION("RemoveChannel")
{
auto byes = std::vector<Id>{foo.announcement.channels.channels[0].id};
channelsLeft(observer, begin(byes), end(byes));
const auto uniqueChannels =
channels.uniqueSessionChannels(foo.announcement.sessionId);
CHECK(2 == callback.mNumCalls);
CHECK(0 == uniqueChannels.size());
}
SECTION("AddSecondPeer")
{
sawAnnouncement(observer, bar);
auto uniqueChannels = channels.uniqueSessionChannels(sessionId);
CHECK(2 == callback.mNumCalls);
CHECK(2 == uniqueChannels.size());
checkChannel(foo, uniqueChannels);
checkChannel(bar, uniqueChannels);
SECTION("PruneDisconnectedPeerChannels")
{
const auto connectedPeers = std::vector<Id>{foo.announcement.nodeId};
channels.prunePeerChannels(begin(connectedPeers), end(connectedPeers));
uniqueChannels = channels.uniqueSessionChannels(sessionId);
CHECK(3 == callback.mNumCalls);
CHECK(1 == uniqueChannels.size());
checkChannel(foo, uniqueChannels);
const auto handler = channels.peerSendHandler(bar.announcement.nodeId);
CHECK(!handler.has_value());
}
}
SECTION("AddSecondChannel")
{
auto observer2 = makeGatewayObserver(channels, gateway2);
sawAnnouncement(observer2, foo);
const auto uniqueChannels = channels.uniqueSessionChannels(sessionId);
CHECK(1 == uniqueChannels.size());
checkChannel(foo, uniqueChannels);
}
SECTION("Timeout")
{
io.advanceTime(std::chrono::seconds(5));
const auto uniqueChannels = channels.uniqueSessionChannels(sessionId);
CHECK(2 == callback.mNumCalls);
CHECK(0 == uniqueChannels.size());
}
}
}
} }