import argparse
import asyncio
import logging
import struct
import socket
import ipaddress
from collections import namedtuple
from dataclasses import dataclass
import numpy as np
from . import DAC_VOLTS_PER_LSB
logger = logging.getLogger(__name__)
Trace = namedtuple("Trace", "values scale label")
def wrap(wide):
return wide & 0xffffffff
def get_local_ip(remote):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
sock.connect((remote, 1883))
return sock.getsockname()[0]
finally:
sock.close()
class AdcDac:
format_id = 1
def __init__(self, header, body):
self.header = header
self.body = body
def size(self):
return len(self.body)
def to_mu(self):
data = np.frombuffer(self.body, "<i2")
data = data.reshape(self.header.batches, 4, -1)
data = data.swapaxes(0, 1).reshape(4, -1)
data[2:] ^= np.int16(0x8000)
return data
def to_si(self):
data = self.to_mu() * DAC_VOLTS_PER_LSB
return {
"adc": data[:2],
"dac": data[2:],
}
def to_traces(self):
data = self.to_mu()
return [
Trace(data[0], scale=DAC_VOLTS_PER_LSB, label='ADC0'),
Trace(data[1], scale=DAC_VOLTS_PER_LSB, label='ADC1'),
Trace(data[2], scale=DAC_VOLTS_PER_LSB, label='DAC0'),
Trace(data[3], scale=DAC_VOLTS_PER_LSB, label='DAC1')
]
class StabilizerStream(asyncio.DatagramProtocol):
magic = 0x057B
header_fmt = struct.Struct("<HBBI")
header = namedtuple("Header", "magic format_id batches sequence")
parsers = {
AdcDac.format_id: AdcDac,
}
@classmethod
async def open(cls, addr, port, broker, maxsize=1):
loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4 << 20)
if ipaddress.ip_address(addr).is_multicast:
print('Subscribing to multicast')
group = socket.inet_aton(addr)
iface = socket.inet_aton('.'.join([str(x) for x in get_local_ip(broker)]))
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, group + iface)
sock.bind(('', port))
else:
sock.bind((addr, port))
transport, protocol = await loop.create_datagram_endpoint(lambda: cls(maxsize), sock=sock)
return transport, protocol
def __init__(self, maxsize):
self.queue = asyncio.Queue(maxsize)
def connection_made(self, _transport):
logger.info("Connection made (listening)")
def connection_lost(self, _exc):
logger.info("Connection lost")
def datagram_received(self, data, _addr):
header = self.header._make(self.header_fmt.unpack_from(data))
if header.magic != self.magic:
logger.warning("Bad frame magic: %#04x, ignoring", header.magic)
return
try:
parser = self.parsers[header.format_id]
except KeyError:
logger.warning("No parser for format %s, ignoring", header.format_id)
return
frame = parser(header, data[self.header_fmt.size:])
if self.queue.full():
old = self.queue.get_nowait()
logger.debug("Dropping frame: %#08x", old.header.sequence)
self.queue.put_nowait(frame)
async def measure(stream, duration):
@dataclass
class _Statistics:
expect = None
received = 0
lost = 0
bytes = 0
stat = _Statistics()
async def _record():
while True:
frame = await stream.queue.get()
if stat.expect is not None:
stat.lost += wrap(frame.header.sequence - stat.expect)
stat.received += frame.header.batches
stat.expect = wrap(frame.header.sequence + frame.header.batches)
stat.bytes += frame.size()
try:
await asyncio.wait_for(_record(), timeout=duration)
except asyncio.TimeoutError:
pass
logger.info("Received %g MB, %g MB/s", stat.bytes/1e6,
stat.bytes/1e6/duration)
sent = stat.received + stat.lost
if sent:
loss = stat.lost/sent
else:
loss = 1
logger.info("Loss: %s/%s batches (%g %%)", stat.lost, sent, loss*1e2)
return loss
async def main():
parser = argparse.ArgumentParser(description="Stabilizer streaming demo")
parser.add_argument("--port", type=int, default=9293,
help="Local port to listen on")
parser.add_argument("--host", default="0.0.0.0",
help="Local address to listen on")
parser.add_argument("--broker", default="mqtt",
help="The MQTT broker address")
parser.add_argument("--maxsize", type=int, default=1,
help="Frame queue size")
parser.add_argument("--duration", type=float, default=1.,
help="Test duration")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
_transport, stream = await StabilizerStream.open(
args.host, args.port, args.broker, args.maxsize)
await measure(stream, args.duration)
if __name__ == "__main__":
asyncio.run(main())