use std::sync::Arc;
use bytes::Bytes;
use tl_proto::{TlRead, TlWrite};
use tycho_util::tl;
use crate::types::{PeerId, PeerInfo};
use crate::util::check_peer_signature;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, scheme = "proto.tl")]
pub enum PeerValueKeyName {
#[tl(id = "dht.peerValueKeyName.nodeInfo")]
NodeInfo,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, scheme = "proto.tl")]
pub enum MergedValueKeyName {
#[tl(id = "dht.mergedValueKeyName.publicOverlayEntries")]
PublicOverlayEntries,
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.peerValueKey", scheme = "proto.tl")]
pub struct PeerValueKey {
pub name: PeerValueKeyName,
pub peer_id: PeerId,
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.peerValueKey", scheme = "proto.tl")]
pub struct PeerValueKeyRef<'tl> {
pub name: PeerValueKeyName,
pub peer_id: &'tl PeerId,
}
impl PeerValueKeyRef<'_> {
pub fn as_owned(&self) -> PeerValueKey {
PeerValueKey {
name: self.name,
peer_id: *self.peer_id,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.mergedValueKey", scheme = "proto.tl")]
pub struct MergedValueKey {
pub name: MergedValueKeyName,
pub group_id: [u8; 32],
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.mergedValueKey", scheme = "proto.tl")]
pub struct MergedValueKeyRef<'tl> {
pub name: MergedValueKeyName,
pub group_id: &'tl [u8; 32],
}
impl MergedValueKeyRef<'_> {
pub fn as_owned(&self) -> MergedValueKey {
MergedValueKey {
name: self.name,
group_id: *self.group_id,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.peerValue", scheme = "proto.tl")]
pub struct PeerValue {
pub key: PeerValueKey,
pub data: Box<[u8]>,
pub expires_at: u32,
#[tl(signature, with = "tl::signature_owned")]
pub signature: Box<[u8; 64]>,
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.peerValue", scheme = "proto.tl")]
pub struct PeerValueRef<'tl> {
pub key: PeerValueKeyRef<'tl>,
pub data: &'tl [u8],
pub expires_at: u32,
#[tl(signature, with = "tl::signature_ref")]
pub signature: &'tl [u8; 64],
}
impl PeerValueRef<'_> {
pub fn as_owned(&self) -> PeerValue {
PeerValue {
key: self.key.as_owned(),
data: Box::from(self.data),
expires_at: self.expires_at,
signature: Box::new(*self.signature),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.mergedValue", scheme = "proto.tl")]
pub struct MergedValue {
pub key: MergedValueKey,
pub data: Box<[u8]>,
pub expires_at: u32,
}
#[derive(Debug, Clone, PartialEq, Eq, TlRead, TlWrite)]
#[tl(boxed, id = "dht.mergedValue", scheme = "proto.tl")]
pub struct MergedValueRef<'tl> {
pub key: MergedValueKeyRef<'tl>,
pub data: &'tl [u8],
pub expires_at: u32,
}
impl MergedValueRef<'_> {
pub fn as_owned(&self) -> MergedValue {
MergedValue {
key: self.key.as_owned(),
data: Box::from(self.data),
expires_at: self.expires_at,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Value {
Peer(PeerValue),
Merged(MergedValue),
}
impl Value {
pub fn verify(&self, at: u32, key_hash: &[u8; 32]) -> bool {
self.verify_ext(at, key_hash, &mut false)
}
pub fn verify_ext(&self, at: u32, key_hash: &[u8; 32], signature_checked: &mut bool) -> bool {
match self {
Self::Peer(value) => {
let timings_ok = value.expires_at >= at && key_hash == &tl_proto::hash(&value.key);
if !timings_ok {
return false;
}
*signature_checked = true;
check_peer_signature(&value.key.peer_id, &value.signature, value)
}
Self::Merged(value) => {
value.expires_at >= at && key_hash == &tl_proto::hash(&value.key)
}
}
}
pub const fn expires_at(&self) -> u32 {
match self {
Self::Peer(value) => value.expires_at,
Self::Merged(value) => value.expires_at,
}
}
}
impl TlWrite for Value {
type Repr = tl_proto::Boxed;
fn max_size_hint(&self) -> usize {
match self {
Self::Peer(value) => value.max_size_hint(),
Self::Merged(value) => value.max_size_hint(),
}
}
fn write_to<P>(&self, packet: &mut P)
where
P: tl_proto::TlPacket,
{
match self {
Self::Peer(value) => value.write_to(packet),
Self::Merged(value) => value.write_to(packet),
}
}
}
impl<'a> TlRead<'a> for Value {
type Repr = tl_proto::Boxed;
fn read_from(packet: &mut &'a [u8]) -> tl_proto::TlResult<Self> {
let id = u32::read_from(&mut std::convert::identity(packet))?;
match id {
PeerValue::TL_ID => PeerValue::read_from(packet).map(Self::Peer),
MergedValue::TL_ID => MergedValue::read_from(packet).map(Self::Merged),
_ => Err(tl_proto::TlError::UnknownConstructor),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ValueRef<'tl> {
Peer(PeerValueRef<'tl>),
Merged(MergedValueRef<'tl>),
}
impl ValueRef<'_> {
pub const fn expires_at(&self) -> u32 {
match self {
Self::Peer(value) => value.expires_at,
Self::Merged(value) => value.expires_at,
}
}
}
impl TlWrite for ValueRef<'_> {
type Repr = tl_proto::Boxed;
fn max_size_hint(&self) -> usize {
match self {
Self::Peer(value) => value.max_size_hint(),
Self::Merged(value) => value.max_size_hint(),
}
}
fn write_to<P>(&self, packet: &mut P)
where
P: tl_proto::TlPacket,
{
match self {
Self::Peer(value) => value.write_to(packet),
Self::Merged(value) => value.write_to(packet),
}
}
}
impl<'a> TlRead<'a> for ValueRef<'a> {
type Repr = tl_proto::Boxed;
fn read_from(packet: &mut &'a [u8]) -> tl_proto::TlResult<Self> {
let id = u32::read_from(&mut std::convert::identity(packet))?;
match id {
PeerValue::TL_ID => PeerValueRef::read_from(packet).map(Self::Peer),
MergedValue::TL_ID => MergedValueRef::read_from(packet).map(Self::Merged),
_ => Err(tl_proto::TlError::UnknownConstructor),
}
}
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.nodesFound", scheme = "proto.tl")]
pub struct NodeResponse {
#[tl(with = "tl::VecWithMaxLen::<20>")]
pub nodes: Vec<Arc<PeerInfo>>,
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, scheme = "proto.tl")]
pub enum ValueResponse {
#[tl(id = "dht.valueFound")]
Found(Box<Value>),
#[tl(id = "dht.valueNotFound")]
NotFound(#[tl(with = "tl::VecWithMaxLen::<20>")] Vec<Arc<PeerInfo>>),
}
#[derive(Debug, Clone)]
pub enum ValueResponseRaw {
Found(Bytes),
NotFound(Vec<Arc<PeerInfo>>),
}
impl TlWrite for ValueResponseRaw {
type Repr = tl_proto::Boxed;
fn max_size_hint(&self) -> usize {
4 + match self {
Self::Found(value) => value.max_size_hint(),
Self::NotFound(nodes) => nodes.max_size_hint(),
}
}
fn write_to<P>(&self, packet: &mut P)
where
P: tl_proto::TlPacket,
{
const FOUND_TL_ID: u32 = tl_proto::id!("dht.valueFound", scheme = "proto.tl");
const NOT_FOUND_TL_ID: u32 = tl_proto::id!("dht.valueNotFound", scheme = "proto.tl");
match self {
Self::Found(value) => {
packet.write_u32(FOUND_TL_ID);
packet.write_raw_slice(value);
}
Self::NotFound(nodes) => {
packet.write_u32(NOT_FOUND_TL_ID);
nodes.write_to(packet);
}
}
}
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.nodeInfoFound", scheme = "proto.tl")]
pub struct NodeInfoResponse {
pub info: PeerInfo,
}
pub mod rpc {
use super::*;
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.withPeerInfo", scheme = "proto.tl")]
#[repr(transparent)]
pub struct WithPeerInfo {
pub peer_info: PeerInfo,
}
impl WithPeerInfo {
pub fn wrap(value: &'_ PeerInfo) -> &'_ Self {
unsafe { &*(value as *const PeerInfo).cast() }
}
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.store", scheme = "proto.tl")]
#[repr(transparent)]
pub struct Store {
pub value: Value,
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.store", scheme = "proto.tl")]
#[repr(transparent)]
pub struct StoreRef<'tl> {
pub value: ValueRef<'tl>,
}
impl<'tl> StoreRef<'tl> {
pub fn wrap<'a>(value: &'a ValueRef<'tl>) -> &'a Self {
unsafe { &*(value as *const ValueRef<'tl>).cast() }
}
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.findNode", scheme = "proto.tl")]
pub struct FindNode {
pub key: [u8; 32],
pub k: u32,
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.findValue", scheme = "proto.tl")]
pub struct FindValue {
pub key: [u8; 32],
pub k: u32,
}
#[derive(Debug, Clone, TlRead, TlWrite)]
#[tl(boxed, id = "dht.getNodeInfo", scheme = "proto.tl")]
pub struct GetNodeInfo;
}