use std::io;
use std::fs::File;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic;
use std::sync::atomic::{AtomicI64, AtomicU32, AtomicU64};
use std::task::{Context, Poll};
use std::time::Duration;
use chrono::{TimeZone, Utc};
use futures_util::pin_mut;
use futures_util::future::{select, Either};
use log::info;
use log::{debug, error, warn};
use pin_project_lite::pin_project;
use rpki::rtr::client::{Client, PayloadError, PayloadTarget, PayloadUpdate};
use rpki::rtr::payload::{Action, Payload, Timing};
use rpki::rtr::state::State;
use serde::Deserialize;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio::time::{timeout_at, Instant};
use crate::manager::WaitPoint;
use crate::metrics;
use crate::comms::{Gate, GateMetrics, GateStatus, Terminated};
use crate::manager::Component;
use crate::metrics::{Metric, MetricType, MetricUnit};
use crate::payload;
use crate::payload::Update;
#[derive(Clone, Debug, Deserialize)]
pub struct Tcp {
remote: String,
#[serde(default = "Tcp::default_retry")]
retry: u64,
}
impl Tcp {
fn default_retry() -> u64 {
60
}
pub async fn run(
self, component: Component, gate: Gate, mut waitpoint: WaitPoint,
) -> Result<(), Terminated> {
let metrics = Arc::new(RtrMetrics::new(&gate));
gate.process_until(waitpoint.ready()).await?;
RtrClient::run(
component, waitpoint, gate, self.retry, metrics.clone(),
|| async {
Ok(RtrTcpStream {
sock: TcpStream::connect(&self.remote).await?,
metrics: metrics.clone()
})
}
).await
}
}
#[derive(Debug)]
struct RtrClient<Connect> {
connect: Connect,
retry: u64,
status: GateStatus,
metrics: Arc<RtrMetrics>,
cache: Vec<Payload>,
}
impl<Connect> RtrClient<Connect> {
fn new(connect: Connect, retry: u64, metrics: Arc<RtrMetrics>) -> Self {
RtrClient {
connect,
retry,
status: Default::default(),
metrics,
cache: Vec::new(),
}
}
}
impl<Connect, ConnectFut, Socket> RtrClient<Connect>
where
Connect: FnMut() -> ConnectFut,
ConnectFut: Future<Output = Result<Socket, io::Error>>,
Socket: AsyncRead + AsyncWrite + Unpin,
{
async fn run(
mut component: Component,
waitpoint: WaitPoint,
mut gate: Gate,
retry: u64,
metrics: Arc<RtrMetrics>,
connect: Connect,
) -> Result<(), Terminated> {
let mut rtr_target = RtrTarget::new(component.name().clone());
component.register_metrics(metrics.clone());
let mut this = Self::new(connect, retry, metrics);
tokio::spawn(async {
debug!("waiting 5 secs to give rtr-in a headstart");
tokio::time::sleep(Duration::from_secs(5)).await;
debug!("waiting done, calling waitpoint.running()");
waitpoint.running().await;
});
loop {
debug!("Unit {}: Connecting ...", component.name());
let mut client = match this.connect(rtr_target, &mut gate).await {
Ok(client) => client,
Err(res) => {
info!(
"Unit {}: Connection failed, retrying in {retry}s",
res.name
);
this.retry_wait(&mut gate).await?;
rtr_target = res;
continue;
}
};
loop {
let update = match this.update(&mut client, &mut gate).await {
Ok(Ok(update)) => {
update
}
Ok(Err(err)) => {
warn!(
"Unit {}: RTR client disconnected: {}",
client.target().name, err,
);
debug!(
"Unit {}: awaiting reconnect.",
client.target().name,
);
break;
}
Err(_) => {
debug!(
"Unit {}: RTR client terminated.",
client.target().name
);
return Err(Terminated);
}
};
if let Some(update) = update {
gate.update_data(update).await;
}
}
rtr_target = client.into_target();
this.retry_wait(&mut gate).await?;
}
}
async fn connect(
&mut self, target: RtrTarget, gate: &mut Gate,
) -> Result<Client<Socket, RtrTarget>, RtrTarget> {
let sock = {
let connect = (self.connect)();
pin_mut!(connect);
loop {
let process = gate.process();
pin_mut!(process);
match select(process, connect).await {
Either::Left((Err(_), _)) => {
return Err(target)
}
Either::Left((Ok(status), next_fut)) => {
self.status = status;
connect = next_fut;
}
Either::Right((res, _)) => break res
}
}
};
let sock = match sock {
Ok(sock) => sock,
Err(err) => {
warn!(
"Unit {}: failed to connect to server: {}",
target.name, err
);
return Err(target)
}
};
Ok(Client::new(sock, target, None))
}
#[allow(clippy::needless_pass_by_ref_mut)] async fn update(
&mut self, client: &mut Client<Socket, RtrTarget>, gate: &mut Gate
) -> Result<Result<Option<payload::Update>, io::Error>, Terminated> {
let update_fut = async {
let update = client.update().await?;
let state = client.state();
Ok((state, Some(update)))
};
pin_mut!(update_fut);
loop {
let process = gate.process();
pin_mut!(process);
match select(process, update_fut).await {
Either::Left((Err(_), _)) => {
return Err(Terminated)
}
Either::Left((Ok(status), next_fut)) => {
self.status = status;
update_fut = next_fut;
}
Either::Right((res, _)) => {
let res = match res {
Ok((state, res)) => {
if let Some(state) = state {
self.metrics.session.store(
state.session().into(),
atomic::Ordering::Relaxed
);
self.metrics.serial.store(
state.serial().into(),
atomic::Ordering::Relaxed
);
self.metrics.updated.store(
Utc::now().timestamp(),
atomic::Ordering::Relaxed
);
}
Ok(res.map(payload::Update::Rtr))
}
Err(err) => Err(err)
};
return Ok(res)
}
}
}
}
async fn retry_wait(
&mut self, gate: &mut Gate
) -> Result<(), Terminated> {
debug!("in retry_wait");
let end = Instant::now() + Duration::from_secs(self.retry);
while end > Instant::now() {
match timeout_at(end, gate.process()).await {
Ok(Ok(status)) => {
self.status = status
}
Ok(Err(_)) => return Err(Terminated),
Err(_) => return Ok(()),
}
}
Ok(())
}
}
struct RtrTarget {
cache: RtrVerbs,
pub name: Arc<str>,
}
impl RtrTarget {
}
#[derive(Clone, Debug, Default)]
pub struct RtrVerbs {
verbs: Vec<(Action, Payload)>,
}
impl IntoIterator for RtrVerbs {
type Item = (Action, Payload);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.verbs.into_iter()
}
}
#[derive(Clone, Debug)]
pub enum RtrUpdate {
Full(RtrVerbs),
Delta(RtrVerbs),
}
impl RtrTarget {
pub fn new(name: Arc<str>) -> Self {
Self {
cache: RtrVerbs::default(),
name,
}
}
}
impl PayloadTarget for RtrTarget {
type Update = RtrUpdate;
fn start(&mut self, reset: bool) -> Self::Update {
if reset {
debug!("RTR reset/ full dump");
RtrUpdate::Full(Default::default())
} else {
debug!("RTR delta");
RtrUpdate::Delta(self.cache.clone())
}
}
fn apply(
&mut self, update: Self::Update, timing: Timing
) -> Result<(), PayloadError> {
todo!()
}
}
impl PayloadUpdate for RtrUpdate {
fn push_update(
&mut self, action: Action, payload: Payload
) -> Result<(), PayloadError> {
match self {
RtrUpdate::Full(rtr_cache) => {
if action == Action::Withdraw {
return Err(PayloadError::Corrupt);
}
rtr_cache.verbs.push((Action::Announce, payload));
}
RtrUpdate::Delta(rtr_cache) => {
rtr_cache.verbs.push((action, payload));
}
}
Ok(())
}
}
#[derive(Debug, Default)]
struct RtrMetrics {
gate: Arc<GateMetrics>,
session: AtomicU32,
serial: AtomicU32,
updated: AtomicI64,
bytes_read: AtomicU64,
bytes_written: AtomicU64,
}
impl RtrMetrics {
fn new(gate: &Gate) -> Self {
RtrMetrics {
gate: gate.metrics(),
session: u32::MAX.into(),
serial: u32::MAX.into(),
updated: i64::MIN.into(),
bytes_read: 0.into(),
bytes_written: 0.into(),
}
}
fn inc_bytes_read(&self, count: u64) {
self.bytes_read.fetch_add(count, atomic::Ordering::Relaxed);
}
fn inc_bytes_written(&self, count: u64) {
self.bytes_written.fetch_add(count, atomic::Ordering::Relaxed);
}
}
impl RtrMetrics {
const SESSION_METRIC: Metric = Metric::new(
"session_id", "the session ID of the last successful update",
MetricType::Text, MetricUnit::Info
);
const SERIAL_METRIC: Metric = Metric::new(
"serial", "the serial number of the last successful update",
MetricType::Counter, MetricUnit::Total
);
const UPDATED_AGO_METRIC: Metric = Metric::new(
"since_last_rtr_update",
"the number of seconds since last successful update",
MetricType::Counter, MetricUnit::Total
);
const UPDATED_METRIC: Metric = Metric::new(
"rtr_updated", "the time of the last successful update",
MetricType::Text, MetricUnit::Info
);
const BYTES_READ_METRIC: Metric = Metric::new(
"bytes_read", "the number of bytes read",
MetricType::Counter, MetricUnit::Total,
);
const BYTES_WRITTEN_METRIC: Metric = Metric::new(
"bytes_written", "the number of bytes written",
MetricType::Counter, MetricUnit::Total,
);
const ISO_DATE: &'static [chrono::format::Item<'static>] = &[
chrono::format::Item::Numeric(
chrono::format::Numeric::Year, chrono::format::Pad::Zero
),
chrono::format::Item::Literal("-"),
chrono::format::Item::Numeric(
chrono::format::Numeric::Month, chrono::format::Pad::Zero
),
chrono::format::Item::Literal("-"),
chrono::format::Item::Numeric(
chrono::format::Numeric::Day, chrono::format::Pad::Zero
),
chrono::format::Item::Literal("T"),
chrono::format::Item::Numeric(
chrono::format::Numeric::Hour, chrono::format::Pad::Zero
),
chrono::format::Item::Literal(":"),
chrono::format::Item::Numeric(
chrono::format::Numeric::Minute, chrono::format::Pad::Zero
),
chrono::format::Item::Literal(":"),
chrono::format::Item::Numeric(
chrono::format::Numeric::Second, chrono::format::Pad::Zero
),
chrono::format::Item::Literal("Z"),
];
}
impl metrics::Source for RtrMetrics {
fn append(&self, unit_name: &str, target: &mut metrics::Target) {
self.gate.append(unit_name, target);
let session = self.session.load(atomic::Ordering::Relaxed);
if session != u32::MAX {
target.append_simple(
&Self::SESSION_METRIC, Some(unit_name), session
);
}
let serial = self.serial.load(atomic::Ordering::Relaxed);
if serial != u32::MAX {
target.append_simple(
&Self::SERIAL_METRIC, Some(unit_name), serial
)
}
let updated = self.updated.load(atomic::Ordering::Relaxed);
if updated != i64::MIN {
if let Some(updated) = Utc.timestamp_opt(updated, 0).single() {
let ago = Utc::now().signed_duration_since(updated);
target.append_simple(
&Self::UPDATED_AGO_METRIC, Some(unit_name),
ago.num_seconds()
);
target.append_simple(
&Self::UPDATED_METRIC, Some(unit_name),
updated.format_with_items(Self::ISO_DATE.iter())
);
}
}
target.append_simple(
&Self::BYTES_READ_METRIC, Some(unit_name),
self.bytes_read.load(atomic::Ordering::Relaxed)
);
target.append_simple(
&Self::BYTES_WRITTEN_METRIC, Some(unit_name),
self.bytes_written.load(atomic::Ordering::Relaxed)
);
}
}
pin_project! {
struct RtrTcpStream {
#[pin] sock: TcpStream,
metrics: Arc<RtrMetrics>,
}
}
impl AsyncRead for RtrTcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>
) -> Poll<Result<(), io::Error>> {
let len = buf.filled().len();
let res = self.as_mut().project().sock.poll_read(cx, buf);
if let Poll::Ready(Ok(())) = res {
self.metrics.inc_bytes_read(
(buf.filled().len().saturating_sub(len)) as u64
)
}
res
}
}
impl AsyncWrite for RtrTcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8]
) -> Poll<Result<usize, io::Error>> {
let res = self.as_mut().project().sock.poll_write(cx, buf);
if let Poll::Ready(Ok(n)) = res {
self.metrics.inc_bytes_written(n as u64)
}
res
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), io::Error>> {
self.as_mut().project().sock.poll_flush(cx)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), io::Error>> {
self.as_mut().project().sock.poll_shutdown(cx)
}
}