ockam_api 0.93.0

Ockam's request-response API
use crate::nodes::NodeManager;
use minicbor::{Decode, Decoder, Encode};
use ockam_core::errcode::{Kind, Origin};
use ockam_core::{Any, NeutralMessage, Routed};
use ockam_multiaddr::MultiAddr;
use ockam_node::{Context, MessageSendReceiveOptions};
use ockam_transport_tcp::{TlsCertificate, TlsCertificateProvider};
use std::fmt::{Debug, Display, Formatter};
use std::sync::Weak;
use std::time::Duration;
use tonic::async_trait;

#[derive(Clone)]
pub(crate) struct ProjectCertificateProvider {
    node_manager: Weak<NodeManager>,
    to: MultiAddr,
}

impl Debug for ProjectCertificateProvider {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "ProjectCertificateProvider {{ to: {:?} }}", self.to)
    }
}

impl ProjectCertificateProvider {
    pub fn new(node_manager: Weak<NodeManager>, to: MultiAddr) -> Self {
        Self { node_manager, to }
    }
}

impl Display for ProjectCertificateProvider {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        write!(f, "certificate retrieval from: {}", self.to)
    }
}

#[derive(Encode, Decode)]
struct CertificateRequest {}

#[derive(Debug, Encode, Decode)]
#[rustfmt::skip]
#[cbor(map)]
struct CertificateResponse {
    #[n(1)] kind: ReplyKind,
    #[n(2)] certificate: Option<TlsCertificate>,
}

#[derive(Debug, Encode, Decode, PartialEq)]
#[rustfmt::skip]
#[cbor(index_only)]
enum ReplyKind {
    #[n(0)] Ready,
    #[n(1)] NotReady,
    #[n(2)] Unsupported,
}

#[async_trait]
impl TlsCertificateProvider for ProjectCertificateProvider {
    async fn get_certificate(&self, context: &Context) -> ockam_core::Result<TlsCertificate> {
        debug!("requesting TLS certificate from: {}", self.to);
        let node_manager = self.node_manager.upgrade().ok_or_else(|| {
            ockam_core::Error::new(Origin::Transport, Kind::Invalid, "NodeManager shut down")
        })?;
        let connection = {
            node_manager
                .make_connection(
                    context,
                    &self.to,
                    node_manager.node_identifier.clone(),
                    None,
                    None,
                )
                .await?
        };

        let options = MessageSendReceiveOptions::new().with_timeout(Duration::from_secs(30));

        let payload = {
            let mut buffer = Vec::new();
            minicbor::Encoder::new(&mut buffer).encode(&CertificateRequest {})?;
            buffer
        };

        let reply: Routed<Any> = context
            .send_and_receive_extended(connection.route()?, NeutralMessage::from(payload), options)
            .await?;

        let payload = reply.into_payload();
        let reply: CertificateResponse = Decoder::new(&payload).decode()?;

        match reply.kind {
            ReplyKind::Ready => {
                if let Some(certificate) = reply.certificate {
                    Ok(certificate)
                } else {
                    Err(ockam_core::Error::new(
                        Origin::Transport,
                        Kind::Invalid,
                        "invalid reply from certificate provider",
                    ))
                }
            }
            ReplyKind::Unsupported => Err(ockam_core::Error::new(
                Origin::Transport,
                Kind::NotReady,
                "certificate",
            )),
            ReplyKind::NotReady => Err(ockam_core::Error::new(
                Origin::Transport,
                Kind::NotReady,
                "certificate",
            )),
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn check_orchestrator_encoding_not_ready() {
        let payload = hex::decode("A10101").unwrap();
        let reply: CertificateResponse = Decoder::new(&payload).decode().unwrap();
        assert_eq!(reply.kind, ReplyKind::NotReady);
        assert!(reply.certificate.is_none());
    }

    #[test]
    fn check_orchestrator_encoding_ready() {
        let payload =
            hex::decode("a2010002a2014a66756c6c5f636861696e024b707269766174655f6b6579").unwrap();
        let reply: CertificateResponse = Decoder::new(&payload).decode().unwrap();
        assert_eq!(reply.kind, ReplyKind::Ready);
        assert_eq!(
            reply.certificate,
            Some(TlsCertificate {
                full_chain_pem: "full_chain".as_bytes().to_vec(),
                private_key_pem: "private_key".as_bytes().to_vec()
            })
        );
    }
}