Skip to main content

cdk_signatory/proto/
client.rs

1use std::path::Path;
2
3use cdk_common::error::Error;
4use cdk_common::grpc::VERSION_HEADER;
5use cdk_common::{BlindSignature, BlindedMessage, Proof};
6use tonic::metadata::MetadataValue;
7use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
8
9use crate::proto::signatory_client::SignatoryClient;
10use crate::signatory::{RotateKeyArguments, Signatory, SignatoryKeySet, SignatoryKeysets};
11
12/// A client for the Signatory service.
13#[allow(missing_debug_implementations)]
14pub struct SignatoryRpcClient {
15    client: SignatoryClient<tonic::transport::Channel>,
16    url: String,
17}
18
19#[derive(thiserror::Error, Debug)]
20/// Client Signatory Error
21pub enum ClientError {
22    /// Transport error
23    #[error(transparent)]
24    Transport(#[from] tonic::transport::Error),
25
26    /// IO-related errors
27    #[error(transparent)]
28    Io(#[from] std::io::Error),
29
30    /// Signatory Error
31    #[error(transparent)]
32    Signatory(#[from] cdk_common::error::Error),
33
34    /// Invalid URL
35    #[error("Invalid URL")]
36    InvalidUrl,
37}
38
39/// Helper function to add version header to a request
40fn with_version_header<T>(mut request: tonic::Request<T>) -> tonic::Request<T> {
41    request.metadata_mut().insert(
42        VERSION_HEADER,
43        MetadataValue::from_static(cdk_common::SIGNATORY_PROTOCOL_VERSION),
44    );
45    request
46}
47
48impl SignatoryRpcClient {
49    /// Create a new RemoteSigner from a tonic transport channel.
50    pub async fn new<A: AsRef<Path>>(url: String, tls_dir: Option<A>) -> Result<Self, ClientError> {
51        #[cfg(not(target_arch = "wasm32"))]
52        if rustls::crypto::CryptoProvider::get_default().is_none() {
53            let _ = rustls::crypto::ring::default_provider().install_default();
54        }
55
56        let channel = if let Some(tls_dir) = tls_dir {
57            let tls_dir = tls_dir.as_ref();
58            let server_root_ca_cert = std::fs::read_to_string(tls_dir.join("ca.pem"))?;
59            let server_root_ca_cert = Certificate::from_pem(server_root_ca_cert);
60            let client_cert = std::fs::read_to_string(tls_dir.join("client.pem"))?;
61            let client_key = std::fs::read_to_string(tls_dir.join("client.key"))?;
62            let client_identity = Identity::from_pem(client_cert, client_key);
63            let tls = ClientTlsConfig::new()
64                .ca_certificate(server_root_ca_cert)
65                .identity(client_identity);
66
67            Channel::from_shared(url.clone())
68                .map_err(|_| ClientError::InvalidUrl)?
69                .tls_config(tls)?
70                .connect()
71                .await?
72        } else {
73            Channel::from_shared(url.clone())
74                .map_err(|_| ClientError::InvalidUrl)?
75                .connect()
76                .await?
77        };
78
79        Ok(Self {
80            client: SignatoryClient::new(channel),
81            url,
82        })
83    }
84}
85
86macro_rules! handle_error {
87    ($x:expr, $y:ident, scalar) => {{
88        let mut obj = $x.into_inner();
89        if let Some(err) = obj.error.take() {
90            return Err(err.into());
91        }
92
93        obj.$y
94    }};
95    ($x:expr, $y:ident) => {{
96        let mut obj = $x.into_inner();
97        if let Some(err) = obj.error.take() {
98            return Err(err.into());
99        }
100
101        obj.$y
102            .take()
103            .ok_or(Error::Custom("Internal error".to_owned()))?
104    }};
105}
106
107#[async_trait::async_trait]
108impl Signatory for SignatoryRpcClient {
109    fn name(&self) -> String {
110        format!("Rpc Signatory {}", self.url)
111    }
112
113    #[tracing::instrument(skip_all)]
114    async fn blind_sign(&self, request: Vec<BlindedMessage>) -> Result<Vec<BlindSignature>, Error> {
115        let req = super::BlindedMessages {
116            blinded_messages: request
117                .into_iter()
118                .map(|blind_message| blind_message.into())
119                .collect(),
120            operation: super::Operation::Unspecified.into(),
121            correlation_id: "".to_owned(),
122        };
123
124        self.client
125            .clone()
126            .blind_sign(with_version_header(tonic::Request::new(req)))
127            .await
128            .map(|response| {
129                handle_error!(response, sigs)
130                    .blind_signatures
131                    .into_iter()
132                    .map(|blinded_signature| blinded_signature.try_into())
133                    .collect()
134            })
135            .map_err(|e| Error::Custom(e.to_string()))?
136    }
137
138    #[tracing::instrument(skip_all)]
139    async fn verify_proofs(&self, proofs: Vec<Proof>) -> Result<(), Error> {
140        let req: super::Proofs = proofs.into();
141        self.client
142            .clone()
143            .verify_proofs(with_version_header(tonic::Request::new(req)))
144            .await
145            .map(|response| {
146                if handle_error!(response, success, scalar) {
147                    Ok(())
148                } else {
149                    Err(Error::SignatureMissingOrInvalid)
150                }
151            })
152            .map_err(|e| Error::Custom(e.to_string()))?
153    }
154
155    #[tracing::instrument(skip_all)]
156    async fn keysets(&self) -> Result<SignatoryKeysets, Error> {
157        self.client
158            .clone()
159            .keysets(with_version_header(tonic::Request::new(
160                super::EmptyRequest {},
161            )))
162            .await
163            .map(|response| handle_error!(response, keysets).try_into())
164            .map_err(|e| Error::Custom(e.to_string()))?
165    }
166
167    #[tracing::instrument(skip(self))]
168    async fn rotate_keyset(&self, args: RotateKeyArguments) -> Result<SignatoryKeySet, Error> {
169        let req: super::RotationRequest = args.into();
170        self.client
171            .clone()
172            .rotate_keyset(with_version_header(tonic::Request::new(req)))
173            .await
174            .map(|response| handle_error!(response, keyset).try_into())
175            .map_err(|e| Error::Custom(e.to_string()))?
176    }
177}