Skip to main content

heddle_client/grpc_hosted/
mod.rs

1//! Hosted gRPC client for the transport rewrite.
2
3mod content;
4mod helpers;
5mod hydration;
6mod session;
7mod sync;
8mod tree_edit;
9mod user;
10
11use cli_shared::ClientConfig;
12use crypto::{Ed25519Signer, Signer};
13use grpc::heddle::v1::{
14    KeypairProof, MintBiscuitRequest, auth_service_client::AuthServiceClient,
15    content_service_client::ContentServiceClient,
16    hosted_user_service_client::HostedUserServiceClient, mint_biscuit_request::Proof,
17    repo_sync_service_client::RepoSyncServiceClient,
18    tree_edit_service_client::TreeEditServiceClient,
19};
20use objects::{object::MarkerName, store::ObjectStore};
21use repo::Repository;
22use tonic::{
23    Request,
24    metadata::MetadataValue,
25    transport::{Certificate, Channel, ClientTlsConfig, Endpoint},
26};
27use wire::ProtocolError;
28
29use crate::credentials;
30
31pub struct HostedGrpcClient {
32    pub(super) inner: RepoSyncServiceClient<Channel>,
33    pub(super) user: HostedUserServiceClient<Channel>,
34    pub(super) auth: AuthServiceClient<Channel>,
35    pub(super) content: ContentServiceClient<Channel>,
36    pub(super) tree_edit: TreeEditServiceClient<Channel>,
37    pub(super) token_header: Option<MetadataValue<tonic::metadata::Ascii>>,
38    transport: helpers::HostedTransportPolicy,
39    pub(super) auth_proof_key_pem: Option<String>,
40    /// The key used to look up this server's credential in the credential
41    /// store.  When set, `auto_rotate_if_needed` will use it to read and
42    /// update `~/.heddle/credentials.toml` transparently.
43    server_key: Option<String>,
44}
45
46impl HostedGrpcClient {
47    pub async fn connect(
48        addr: std::net::SocketAddr,
49        config: &ClientConfig,
50    ) -> Result<Self, ProtocolError> {
51        let scheme = if config.tls_enabled { "https" } else { "http" };
52        let mut endpoint = Endpoint::from_shared(format!("{scheme}://{addr}"))
53            .map_err(|err| ProtocolError::InvalidState(err.to_string()))?;
54        if config.tls_enabled {
55            let mut tls = ClientTlsConfig::new();
56            if let Some(domain_name) = &config.tls_domain_name {
57                tls = tls.domain_name(domain_name.clone());
58            }
59            if let Some(ca_pem) = &config.tls_ca_certificate_pem {
60                tls = tls.ca_certificate(Certificate::from_pem(ca_pem.as_bytes()));
61            }
62            endpoint = endpoint
63                .tls_config(tls)
64                .map_err(|err| ProtocolError::InvalidState(err.to_string()))?;
65        }
66        let channel = endpoint
67            .connect()
68            .await
69            .map_err(|err| ProtocolError::Io(std::io::Error::other(err.to_string())))?;
70        let token_header = config
71            .token
72            .as_ref()
73            .map(|token| MetadataValue::try_from(format!("Bearer {}", token.id)))
74            .transpose()
75            .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
76        let transport = helpers::HostedTransportPolicy::from_client_config(config);
77        Ok(Self {
78            // Bound the single-shot, server-controlled sidecar allocation at
79            // the gRPC decode boundary: tonic rejects an oversized inbound
80            // `PullMessage` before its `redactions_blob`/`state_visibility_blob`
81            // `Vec<u8>` is ever materialized. The post-decode
82            // `check_received_transfer_blob_size` calls are kept as cheap
83            // defense-in-depth, but this is the load-bearing guard.
84            inner: RepoSyncServiceClient::new(channel.clone())
85                .max_decoding_message_size(wire::MAX_PULL_DECODE_MESSAGE_SIZE),
86            user: HostedUserServiceClient::new(channel.clone()),
87            auth: AuthServiceClient::new(channel.clone()),
88            content: ContentServiceClient::new(channel.clone()),
89            tree_edit: TreeEditServiceClient::new(channel),
90            token_header,
91            transport,
92            auth_proof_key_pem: config.auth_proof_key_pem.clone(),
93            server_key: config.server_key.clone(),
94        })
95    }
96
97    pub(super) fn apply_auth<T>(&self, request: &mut Request<T>) -> Result<(), ProtocolError> {
98        if let Some(token) = &self.token_header {
99            request
100                .metadata_mut()
101                .insert("authorization", token.clone());
102            if let Some(pem) = &self.auth_proof_key_pem {
103                let signer = Ed25519Signer::from_pem(pem)
104                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
105                let raw = token
106                    .to_str()
107                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
108                let bearer = raw
109                    .strip_prefix("Bearer ")
110                    .or_else(|| raw.strip_prefix("bearer "))
111                    .unwrap_or(raw);
112                let proof_ts = std::time::SystemTime::now()
113                    .duration_since(std::time::UNIX_EPOCH)
114                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?
115                    .as_secs()
116                    .to_string();
117                let signature = signer
118                    .sign(format!("{bearer}|{proof_ts}").as_bytes())
119                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
120                use base64::Engine;
121                let encoded = base64::engine::general_purpose::STANDARD.encode(signature);
122                let proof = MetadataValue::try_from(encoded)
123                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
124                request.metadata_mut().insert("x-heddle-proof", proof);
125                let proof_ts = MetadataValue::try_from(proof_ts)
126                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
127                request.metadata_mut().insert("x-heddle-proof-ts", proof_ts);
128            }
129        }
130        Ok(())
131    }
132
133    /// Transparently rotate the credential for this client if it is near expiry.
134    ///
135    /// No-ops if `server_key` was not set on `ClientConfig` at construction
136    /// time, or if no credential is stored for the server, or if the token is
137    /// not within 10 minutes of expiry.
138    pub async fn auto_rotate_if_needed(&mut self) {
139        let server_key = match &self.server_key {
140            Some(k) => k.clone(),
141            None => return,
142        };
143        self.rotate_credential_for_server(&server_key).await;
144    }
145
146    async fn rotate_credential_for_server(&mut self, server_key: &str) {
147        // Load the stored credential.
148        let cred = match credentials::resolve_credential_for_server(server_key) {
149            Ok(Some(c)) => c,
150            Ok(None) => return,
151            Err(err) => {
152                tracing::warn!("credential rotation: failed to load credential: {err}");
153                return;
154            }
155        };
156
157        // Check whether the Biscuit's stored expiry is within the
158        // rotation window.
159        if !credentials::token_needs_rotation(&cred) {
160            return;
161        }
162
163        // We need both `credential_id` (the public key id the server
164        // will look up) and `private_key_pem` (to sign the renewal
165        // proof). Older credentials without one or the other can't
166        // self-renew; the user falls back to `heddle auth login`.
167        let public_key_id = match &cred.credential_id {
168            Some(id) => id.clone(),
169            None => {
170                tracing::debug!("credential rotation: no credential_id stored, skipping");
171                return;
172            }
173        };
174        let private_key_pem = match &cred.private_key_pem {
175            Some(pem) => pem.clone(),
176            None => {
177                tracing::debug!("credential rotation: no private_key_pem stored, skipping");
178                return;
179            }
180        };
181
182        // Sign the canonical renewal challenge:
183        //   "{timestamp}\n{public_key_id}\n{requested_scope}"
184        // Empty `requested_scope` == reuse the keypair owner's
185        // original scope. The server clamps anyway, so a permissive
186        // hint is fine.
187        let signer = match Ed25519Signer::from_pem(&private_key_pem) {
188            Ok(s) => s,
189            Err(err) => {
190                tracing::warn!("credential rotation: failed to load signing key: {err}");
191                return;
192            }
193        };
194        let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
195            Ok(d) => d.as_secs(),
196            Err(err) => {
197                tracing::warn!("credential rotation: clock skew: {err}");
198                return;
199            }
200        };
201        let canonical = format!("{timestamp}\n{public_key_id}\n");
202        let signature = match signer.sign(canonical.as_bytes()) {
203            Ok(sig) => sig,
204            Err(err) => {
205                tracing::warn!("credential rotation: failed to sign challenge: {err}");
206                return;
207            }
208        };
209
210        let mut request = Request::new(MintBiscuitRequest {
211            subject: cred.subject.clone(),
212            requested_scope: String::new(),
213            user_agent: String::new(),
214            ip: String::new(),
215            proof: Some(Proof::Keypair(KeypairProof {
216                public_key_id,
217                timestamp,
218                signature,
219            })),
220            client_operation_id: String::new(),
221        });
222        // MintBiscuit is unauthenticated — the proof is the auth.
223        // We deliberately skip `apply_auth` here.
224        let _ = &mut request;
225
226        let response = match self.auth.mint_biscuit(request).await {
227            Ok(r) => r.into_inner(),
228            Err(status) => {
229                tracing::warn!(
230                    "credential rotation: MintBiscuit failed: {} — continuing with existing token",
231                    status.message()
232                );
233                return;
234            }
235        };
236
237        // Format the new expiry as RFC 3339.
238        let expires_at_secs = response
239            .expires_at
240            .as_ref()
241            .map(|t| t.seconds.max(0))
242            .unwrap_or(0);
243        let new_expires_at = if expires_at_secs > 0 {
244            chrono::DateTime::from_timestamp(expires_at_secs, 0)
245                .map(|dt| dt.to_rfc3339())
246                .unwrap_or_else(|| expires_at_secs.to_string())
247        } else {
248            String::new()
249        };
250
251        tracing::debug!(
252            "credential rotation: rotated successfully, new expiry: {}",
253            new_expires_at
254        );
255
256        // Persist the updated credential. The keypair stays the
257        // same — that's the whole point of the keypair-based renewal
258        // model. We replace `token` (the Biscuit) and bump
259        // `expires_at` to the fresh window.
260        let updated = credentials::ServerCredential {
261            token: response.token.clone(),
262            subject: if response.subject.is_empty() {
263                cred.subject.clone()
264            } else {
265                response.subject
266            },
267            device_id: cred.device_id.clone(),
268            credential_id: cred.credential_id.clone(),
269            private_key_pem: Some(private_key_pem),
270            expires_at: if new_expires_at.is_empty() {
271                cred.expires_at.clone()
272            } else {
273                Some(new_expires_at)
274            },
275        };
276
277        if let Err(err) = credentials::store_server_credential(server_key, updated) {
278            tracing::warn!("credential rotation: failed to persist updated credential: {err}");
279            // Don't bail — the in-memory update below still improves the session.
280        }
281
282        // Update the in-memory token header so the remaining RPCs on this
283        // client instance use the fresh token.
284        match MetadataValue::try_from(format!("Bearer {}", response.token)) {
285            Ok(header) => self.token_header = Some(header),
286            Err(err) => {
287                tracing::warn!("credential rotation: failed to set new token header: {err}");
288            }
289        }
290    }
291
292    pub(super) async fn sync_remote_markers(
293        &mut self,
294        repo: &Repository,
295        repo_path: &str,
296        pushed_state: objects::object::ChangeId,
297    ) -> Result<(), ProtocolError> {
298        let remote_markers = self
299            .list_refs(repo_path)
300            .await?
301            .into_iter()
302            .filter(|entry| !entry.is_thread)
303            .map(|entry| (entry.name, entry.change_id))
304            .collect::<std::collections::HashMap<_, _>>();
305        for marker in repo.refs().list_markers()? {
306            let Some(change_id) = repo.refs().get_marker(&marker)? else {
307                continue;
308            };
309            if !wire::is_ancestor(repo.store(), change_id, pushed_state)? {
310                continue;
311            }
312
313            let old_value = remote_markers.get(marker.as_str()).copied();
314            if old_value == Some(change_id) {
315                continue;
316            }
317
318            let result = self
319                .update_ref(repo_path, &marker, false, old_value, change_id, true, None)
320                .await?;
321            if !result.success {
322                return Err(ProtocolError::InvalidState(
323                    result
324                        .error
325                        .unwrap_or_else(|| format!("failed to sync marker '{marker}'")),
326                ));
327            }
328        }
329        Ok(())
330    }
331
332    pub(super) async fn sync_local_markers(
333        &mut self,
334        repo: &Repository,
335        repo_path: &str,
336    ) -> Result<(), ProtocolError> {
337        let remote_markers = self.list_refs(repo_path).await?;
338        for marker in remote_markers.into_iter().filter(|entry| !entry.is_thread) {
339            if !repo.store().has_state(&marker.change_id)? {
340                continue;
341            }
342            let marker_name = MarkerName::from(marker.name.as_str());
343            match repo.refs().get_marker(&marker_name)? {
344                Some(existing) if existing == marker.change_id => {}
345                Some(existing) => repo.refs().set_marker_cas(
346                    &marker_name,
347                    refs::RefExpectation::Value(existing),
348                    &marker.change_id,
349                )?,
350                None => repo.refs().create_marker(&marker_name, &marker.change_id)?,
351            }
352        }
353        Ok(())
354    }
355}
356
357pub use hydration::{LazyHostedHydrator, PullMaterialization, register_hosted_factory};
358pub use session::{HostedAuthMode, HostedSession};
359pub use sync::HostedRefEntry;