sia_storage 0.9.1

SDK for interacting with a Sia network indexer
Documentation
use crate::time::{Elapsed, Instant, timeout};

use bytes::Bytes;
use core::fmt::Debug;
use ed25519_dalek::{SignatureError, VerifyingKey};
use log::debug;
use std::collections::HashMap;
use std::num::ParseIntError;
use std::sync::{Arc, RwLock};
use thiserror::{self, Error};
use tokio::net::{TcpStream, lookup_host};
use tokio::sync::OnceCell;

use crate::rhp4::HostEndpoint;
use crate::time::Duration;

use super::{Error as TransportError, Transport};
use sia_core::rhp4::protocol::{RPCReadSector, RPCSettings, RPCWriteSector};
use sia_core::rhp4::{AccountToken, HostPrices};
use sia_core::signing::{PrivateKey, PublicKey};
use sia_core::types::Hash256;
use sia_core::types::v2::Protocol;
use sia_mux::{Mux, Stream};

#[derive(Debug, Error)]
pub enum ConnectError {
    #[error("connect error: {0}")]
    Io(#[from] std::io::Error),

    #[error("mux dial error: {0}")]
    Dial(#[from] sia_mux::DialError),

    #[error("mux error: {0}")]
    Mux(#[from] sia_mux::MuxError),

    #[error("invalid address: {0}")]
    InvalidAddress(String),

    #[error("timeout error: {0}")]
    Elapsed(#[from] Elapsed),

    #[error("invalid port: {0}")]
    InvalidPort(#[from] ParseIntError),

    #[error("invalid public key: {0}")]
    InvalidPublicKey(#[from] SignatureError),

    #[error("no endpoint")]
    NoEndpoint,
}

type ConnCell = Arc<OnceCell<Arc<Mux>>>;

#[derive(Clone)]
pub struct Client {
    open_conns: Arc<RwLock<HashMap<PublicKey, ConnCell>>>,
}

impl Default for Client {
    fn default() -> Self {
        Self::new()
    }
}

impl Client {
    pub fn new() -> Self {
        Self {
            open_conns: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    async fn new_conn(&self, host: &HostEndpoint) -> Result<Mux, ConnectError> {
        let host_bytes: [u8; 32] = host.public_key.into();
        let verifying_key = VerifyingKey::from_bytes(&host_bytes)?;

        for addr in &host.addresses {
            if addr.protocol != Protocol::SiaMux {
                continue;
            }
            let (host_addr, port_str) = addr
                .address
                .rsplit_once(':')
                .ok_or(ConnectError::InvalidAddress(addr.address.clone()))?;
            let port: u16 = port_str.parse()?;
            let resolved_addrs = lookup_host((host_addr, port)).await?;

            for socket in resolved_addrs {
                match TcpStream::connect(socket).await {
                    Ok(tcp) => match sia_mux::dial(tcp, &verifying_key).await {
                        Ok(mux_conn) => {
                            debug!(
                                "established siamux connection to {} via {socket}",
                                host.public_key
                            );
                            return Ok(mux_conn);
                        }
                        Err(e) => {
                            debug!(
                                "mux handshake failed to {} via {socket}: {e}",
                                host.public_key
                            );
                        }
                    },
                    Err(e) => {
                        debug!("TCP connect failed to {host_addr}:{port} ({socket}): {e}");
                    }
                }
            }
        }
        Err(ConnectError::NoEndpoint)
    }

    async fn host_stream(&self, host: &HostEndpoint) -> Result<Stream, ConnectError> {
        let cell = if let Some(cell) = self.open_conns.read().unwrap().get(&host.public_key) {
            cell.clone()
        } else {
            self.open_conns
                .write()
                .unwrap()
                .entry(host.public_key)
                .or_insert_with(|| Arc::new(OnceCell::new()))
                .clone()
        };
        let conn = cell
            .get_or_try_init(|| async {
                let mux = timeout(Duration::from_secs(10), self.new_conn(host))
                    .await
                    .inspect_err(|e| {
                        debug!("siamux connection to {} timed out: {e}", host.public_key);
                    })??;
                debug!("created new siamux connection to {}", host.public_key);
                Ok::<_, ConnectError>(Arc::new(mux))
            })
            .await?
            .clone();

        let stream = conn.dial_stream().inspect_err(|_| {
            self.open_conns.write().unwrap().remove(&host.public_key);
        })?;
        Ok(stream)
    }
}

impl Transport for Client {
    async fn host_prices(
        &self,
        host: &HostEndpoint,
    ) -> Result<(HostPrices, Duration), TransportError> {
        let mut stream = self
            .host_stream(host)
            .await
            .map_err(|e| TransportError::Transport(e.to_string()))?;
        let start = Instant::now();
        let resp = RPCSettings::send_request(&mut stream)
            .await?
            .complete(&mut stream)
            .await?;
        Ok((resp.settings.prices, start.elapsed()))
    }

    async fn write_sector(
        &self,
        host: &HostEndpoint,
        prices: HostPrices,
        account_key: &PrivateKey,
        data: Bytes,
    ) -> Result<(Hash256, Duration), TransportError> {
        let token = AccountToken::new(account_key, host.public_key);
        let mut stream = self
            .host_stream(host)
            .await
            .map_err(|e| TransportError::Transport(e.to_string()))?;
        let start = Instant::now();
        let resp = RPCWriteSector::send_request(&mut stream, prices, token, data)
            .await?
            .complete(&mut stream)
            .await?;
        Ok((resp.root, start.elapsed()))
    }

    async fn read_sector(
        &self,
        host: &HostEndpoint,
        prices: HostPrices,
        account_key: &PrivateKey,
        root: Hash256,
        offset: usize,
        length: usize,
    ) -> Result<(Bytes, Duration), TransportError> {
        let token = AccountToken::new(account_key, host.public_key);
        let mut stream = self
            .host_stream(host)
            .await
            .map_err(|e| TransportError::Transport(e.to_string()))?;
        let start = Instant::now();
        let resp = RPCReadSector::send_request(&mut stream, prices, token, root, offset, length)
            .await?
            .complete(&mut stream)
            .await?;
        Ok((resp.data, start.elapsed()))
    }
}