cdk_signatory/proto/
client.rs

1use std::path::Path;
2
3use cdk_common::error::Error;
4use cdk_common::{BlindSignature, BlindedMessage, Proof};
5use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
6
7use crate::proto::signatory_client::SignatoryClient;
8use crate::signatory::{RotateKeyArguments, Signatory, SignatoryKeySet, SignatoryKeysets};
9
10/// A client for the Signatory service.
11pub struct SignatoryRpcClient {
12    client: SignatoryClient<tonic::transport::Channel>,
13    url: String,
14}
15
16#[derive(thiserror::Error, Debug)]
17/// Client Signatory Error
18pub enum ClientError {
19    /// Transport error
20    #[error(transparent)]
21    Transport(#[from] tonic::transport::Error),
22
23    /// IO-related errors
24    #[error(transparent)]
25    Io(#[from] std::io::Error),
26
27    /// Signatory Error
28    #[error(transparent)]
29    Signatory(#[from] cdk_common::error::Error),
30
31    /// Invalid URL
32    #[error("Invalid URL")]
33    InvalidUrl,
34}
35
36impl SignatoryRpcClient {
37    /// Create a new RemoteSigner from a tonic transport channel.
38    pub async fn new<A: AsRef<Path>>(url: String, tls_dir: Option<A>) -> Result<Self, ClientError> {
39        #[cfg(not(target_arch = "wasm32"))]
40        if rustls::crypto::CryptoProvider::get_default().is_none() {
41            let _ = rustls::crypto::ring::default_provider().install_default();
42        }
43
44        let channel = if let Some(tls_dir) = tls_dir {
45            let tls_dir = tls_dir.as_ref();
46            let server_root_ca_cert = std::fs::read_to_string(tls_dir.join("ca.pem"))?;
47            let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
48            let client_cert = std::fs::read_to_string(tls_dir.join("client.pem"))?;
49            let client_key = std::fs::read_to_string(tls_dir.join("client.key"))?;
50            let client_identity = Identity::from_pem(client_cert, client_key);
51            let tls = ClientTlsConfig::new()
52                .ca_certificate(server_root_ca_cert)
53                .identity(client_identity);
54
55            Channel::from_shared(url.clone())
56                .map_err(|_| ClientError::InvalidUrl)?
57                .tls_config(tls)?
58                .connect()
59                .await?
60        } else {
61            Channel::from_shared(url.clone())
62                .map_err(|_| ClientError::InvalidUrl)?
63                .connect()
64                .await?
65        };
66
67        Ok(Self {
68            client: SignatoryClient::new(channel),
69            url,
70        })
71    }
72}
73
74macro_rules! handle_error {
75    ($x:expr, $y:ident, scalar) => {{
76        let mut obj = $x.into_inner();
77        if let Some(err) = obj.error.take() {
78            return Err(err.into());
79        }
80
81        obj.$y
82    }};
83    ($x:expr, $y:ident) => {{
84        let mut obj = $x.into_inner();
85        if let Some(err) = obj.error.take() {
86            return Err(err.into());
87        }
88
89        obj.$y
90            .take()
91            .ok_or(Error::Custom("Internal error".to_owned()))?
92    }};
93}
94
95#[async_trait::async_trait]
96impl Signatory for SignatoryRpcClient {
97    fn name(&self) -> String {
98        format!("Rpc Signatory {}", self.url)
99    }
100
101    #[tracing::instrument(skip_all)]
102    async fn blind_sign(&self, request: Vec<BlindedMessage>) -> Result<Vec<BlindSignature>, Error> {
103        let req = super::BlindedMessages {
104            blinded_messages: request
105                .into_iter()
106                .map(|blind_message| blind_message.into())
107                .collect(),
108            operation: super::Operation::Unspecified.into(),
109            correlation_id: "".to_owned(),
110        };
111
112        self.client
113            .clone()
114            .blind_sign(req)
115            .await
116            .map(|response| {
117                handle_error!(response, sigs)
118                    .blind_signatures
119                    .into_iter()
120                    .map(|blinded_signature| blinded_signature.try_into())
121                    .collect()
122            })
123            .map_err(|e| Error::Custom(e.to_string()))?
124    }
125
126    #[tracing::instrument(skip_all)]
127    async fn verify_proofs(&self, proofs: Vec<Proof>) -> Result<(), Error> {
128        let req: super::Proofs = proofs.into();
129        self.client
130            .clone()
131            .verify_proofs(req)
132            .await
133            .map(|response| {
134                if handle_error!(response, success, scalar) {
135                    Ok(())
136                } else {
137                    Err(Error::SignatureMissingOrInvalid)
138                }
139            })
140            .map_err(|e| Error::Custom(e.to_string()))?
141    }
142
143    #[tracing::instrument(skip_all)]
144    async fn keysets(&self) -> Result<SignatoryKeysets, Error> {
145        self.client
146            .clone()
147            .keysets(super::EmptyRequest {})
148            .await
149            .map(|response| handle_error!(response, keysets).try_into())
150            .map_err(|e| Error::Custom(e.to_string()))?
151    }
152
153    #[tracing::instrument(skip(self))]
154    async fn rotate_keyset(&self, args: RotateKeyArguments) -> Result<SignatoryKeySet, Error> {
155        let req: super::RotationRequest = args.into();
156        self.client
157            .clone()
158            .rotate_keyset(req)
159            .await
160            .map(|response| handle_error!(response, keyset).try_into())
161            .map_err(|e| Error::Custom(e.to_string()))?
162    }
163}