Skip to main content

heddle_client/grpc_hosted/
mod.rs

1//! Hosted gRPC client for the transport rewrite.
2
3mod content;
4mod helpers;
5mod hydration;
6mod session;
7mod sync;
8mod user;
9
10use cli_shared::ClientConfig;
11use crypto::{Ed25519Signer, Signer};
12use grpc::heddle::v1::{
13    KeypairProof, MintBiscuitRequest, auth_service_client::AuthServiceClient,
14    content_service_client::ContentServiceClient,
15    hosted_user_service_client::HostedUserServiceClient, mint_biscuit_request::Proof,
16    repo_sync_service_client::RepoSyncServiceClient,
17};
18use objects::{object::MarkerName, store::ObjectStore};
19use wire::ProtocolError;
20use repo::Repository;
21use tonic::{
22    Request,
23    metadata::MetadataValue,
24    transport::{Certificate, Channel, ClientTlsConfig, Endpoint},
25};
26
27use crate::credentials;
28
29pub struct HostedGrpcClient {
30    pub(super) inner: RepoSyncServiceClient<Channel>,
31    pub(super) user: HostedUserServiceClient<Channel>,
32    pub(super) auth: AuthServiceClient<Channel>,
33    pub(super) content: ContentServiceClient<Channel>,
34    pub(super) token_header: Option<MetadataValue<tonic::metadata::Ascii>>,
35    transport: helpers::HostedTransportPolicy,
36    pub(super) auth_proof_key_pem: Option<String>,
37    /// The key used to look up this server's credential in the credential
38    /// store.  When set, `auto_rotate_if_needed` will use it to read and
39    /// update `~/.heddle/credentials.toml` transparently.
40    server_key: Option<String>,
41}
42
43impl HostedGrpcClient {
44    pub async fn connect(
45        addr: std::net::SocketAddr,
46        config: &ClientConfig,
47    ) -> Result<Self, ProtocolError> {
48        let scheme = if config.tls_enabled { "https" } else { "http" };
49        let mut endpoint = Endpoint::from_shared(format!("{scheme}://{addr}"))
50            .map_err(|err| ProtocolError::InvalidState(err.to_string()))?;
51        if config.tls_enabled {
52            let mut tls = ClientTlsConfig::new();
53            if let Some(domain_name) = &config.tls_domain_name {
54                tls = tls.domain_name(domain_name.clone());
55            }
56            if let Some(ca_pem) = &config.tls_ca_certificate_pem {
57                tls = tls.ca_certificate(Certificate::from_pem(ca_pem.as_bytes()));
58            }
59            endpoint = endpoint
60                .tls_config(tls)
61                .map_err(|err| ProtocolError::InvalidState(err.to_string()))?;
62        }
63        let channel = endpoint
64            .connect()
65            .await
66            .map_err(|err| ProtocolError::Io(std::io::Error::other(err.to_string())))?;
67        let token_header = config
68            .token
69            .as_ref()
70            .map(|token| MetadataValue::try_from(format!("Bearer {}", token.id)))
71            .transpose()
72            .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
73        let transport = helpers::HostedTransportPolicy::from_client_config(config);
74        Ok(Self {
75            // Bound the single-shot, server-controlled sidecar allocation at
76            // the gRPC decode boundary: tonic rejects an oversized inbound
77            // `PullMessage` before its `redactions_blob`/`state_visibility_blob`
78            // `Vec<u8>` is ever materialized. The post-decode
79            // `check_received_transfer_blob_size` calls are kept as cheap
80            // defense-in-depth, but this is the load-bearing guard.
81            inner: RepoSyncServiceClient::new(channel.clone())
82                .max_decoding_message_size(wire::MAX_PULL_DECODE_MESSAGE_SIZE),
83            user: HostedUserServiceClient::new(channel.clone()),
84            auth: AuthServiceClient::new(channel.clone()),
85            content: ContentServiceClient::new(channel),
86            token_header,
87            transport,
88            auth_proof_key_pem: config.auth_proof_key_pem.clone(),
89            server_key: config.server_key.clone(),
90        })
91    }
92
93    pub(super) fn apply_auth<T>(&self, request: &mut Request<T>) -> Result<(), ProtocolError> {
94        if let Some(token) = &self.token_header {
95            request
96                .metadata_mut()
97                .insert("authorization", token.clone());
98            if let Some(pem) = &self.auth_proof_key_pem {
99                let signer = Ed25519Signer::from_pem(pem)
100                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
101                let raw = token
102                    .to_str()
103                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
104                let bearer = raw
105                    .strip_prefix("Bearer ")
106                    .or_else(|| raw.strip_prefix("bearer "))
107                    .unwrap_or(raw);
108                let proof_ts = std::time::SystemTime::now()
109                    .duration_since(std::time::UNIX_EPOCH)
110                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?
111                    .as_secs()
112                    .to_string();
113                let signature = signer
114                    .sign(format!("{bearer}|{proof_ts}").as_bytes())
115                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
116                use base64::Engine;
117                let encoded = base64::engine::general_purpose::STANDARD.encode(signature);
118                let proof = MetadataValue::try_from(encoded)
119                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
120                request.metadata_mut().insert("x-heddle-proof", proof);
121                let proof_ts = MetadataValue::try_from(proof_ts)
122                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
123                request.metadata_mut().insert("x-heddle-proof-ts", proof_ts);
124            }
125        }
126        Ok(())
127    }
128
129    /// Transparently rotate the credential for this client if it is near expiry.
130    ///
131    /// No-ops if `server_key` was not set on `ClientConfig` at construction
132    /// time, or if no credential is stored for the server, or if the token is
133    /// not within 10 minutes of expiry.
134    pub async fn auto_rotate_if_needed(&mut self) {
135        let server_key = match &self.server_key {
136            Some(k) => k.clone(),
137            None => return,
138        };
139        self.rotate_credential_for_server(&server_key).await;
140    }
141
142    async fn rotate_credential_for_server(&mut self, server_key: &str) {
143        // Load the stored credential.
144        let cred = match credentials::resolve_credential_for_server(server_key) {
145            Ok(Some(c)) => c,
146            Ok(None) => return,
147            Err(err) => {
148                tracing::warn!("credential rotation: failed to load credential: {err}");
149                return;
150            }
151        };
152
153        // Check whether the Biscuit's stored expiry is within the
154        // rotation window.
155        if !credentials::token_needs_rotation(&cred) {
156            return;
157        }
158
159        // We need both `credential_id` (the public key id the server
160        // will look up) and `private_key_pem` (to sign the renewal
161        // proof). Older credentials without one or the other can't
162        // self-renew; the user falls back to `heddle auth login`.
163        let public_key_id = match &cred.credential_id {
164            Some(id) => id.clone(),
165            None => {
166                tracing::debug!("credential rotation: no credential_id stored, skipping");
167                return;
168            }
169        };
170        let private_key_pem = match &cred.private_key_pem {
171            Some(pem) => pem.clone(),
172            None => {
173                tracing::debug!("credential rotation: no private_key_pem stored, skipping");
174                return;
175            }
176        };
177
178        // Sign the canonical renewal challenge:
179        //   "{timestamp}\n{public_key_id}\n{requested_scope}"
180        // Empty `requested_scope` == reuse the keypair owner's
181        // original scope. The server clamps anyway, so a permissive
182        // hint is fine.
183        let signer = match Ed25519Signer::from_pem(&private_key_pem) {
184            Ok(s) => s,
185            Err(err) => {
186                tracing::warn!("credential rotation: failed to load signing key: {err}");
187                return;
188            }
189        };
190        let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
191            Ok(d) => d.as_secs(),
192            Err(err) => {
193                tracing::warn!("credential rotation: clock skew: {err}");
194                return;
195            }
196        };
197        let canonical = format!("{timestamp}\n{public_key_id}\n");
198        let signature = match signer.sign(canonical.as_bytes()) {
199            Ok(sig) => sig,
200            Err(err) => {
201                tracing::warn!("credential rotation: failed to sign challenge: {err}");
202                return;
203            }
204        };
205
206        let mut request = Request::new(MintBiscuitRequest {
207            subject: cred.subject.clone(),
208            requested_scope: String::new(),
209            user_agent: String::new(),
210            ip: String::new(),
211            proof: Some(Proof::Keypair(KeypairProof {
212                public_key_id,
213                timestamp,
214                signature,
215            })),
216            client_operation_id: String::new(),
217        });
218        // MintBiscuit is unauthenticated — the proof is the auth.
219        // We deliberately skip `apply_auth` here.
220        let _ = &mut request;
221
222        let response = match self.auth.mint_biscuit(request).await {
223            Ok(r) => r.into_inner(),
224            Err(status) => {
225                tracing::warn!(
226                    "credential rotation: MintBiscuit failed: {} — continuing with existing token",
227                    status.message()
228                );
229                return;
230            }
231        };
232
233        // Format the new expiry as RFC 3339.
234        let expires_at_secs = response
235            .expires_at
236            .as_ref()
237            .map(|t| t.seconds.max(0))
238            .unwrap_or(0);
239        let new_expires_at = if expires_at_secs > 0 {
240            chrono::DateTime::from_timestamp(expires_at_secs, 0)
241                .map(|dt| dt.to_rfc3339())
242                .unwrap_or_else(|| expires_at_secs.to_string())
243        } else {
244            String::new()
245        };
246
247        tracing::debug!(
248            "credential rotation: rotated successfully, new expiry: {}",
249            new_expires_at
250        );
251
252        // Persist the updated credential. The keypair stays the
253        // same — that's the whole point of the keypair-based renewal
254        // model. We replace `token` (the Biscuit) and bump
255        // `expires_at` to the fresh window.
256        let updated = credentials::ServerCredential {
257            token: response.token.clone(),
258            subject: if response.subject.is_empty() {
259                cred.subject.clone()
260            } else {
261                response.subject
262            },
263            device_id: cred.device_id.clone(),
264            credential_id: cred.credential_id.clone(),
265            private_key_pem: Some(private_key_pem),
266            expires_at: if new_expires_at.is_empty() {
267                cred.expires_at.clone()
268            } else {
269                Some(new_expires_at)
270            },
271        };
272
273        if let Err(err) = credentials::store_server_credential(server_key, updated) {
274            tracing::warn!("credential rotation: failed to persist updated credential: {err}");
275            // Don't bail — the in-memory update below still improves the session.
276        }
277
278        // Update the in-memory token header so the remaining RPCs on this
279        // client instance use the fresh token.
280        match MetadataValue::try_from(format!("Bearer {}", response.token)) {
281            Ok(header) => self.token_header = Some(header),
282            Err(err) => {
283                tracing::warn!("credential rotation: failed to set new token header: {err}");
284            }
285        }
286    }
287
288    pub(super) async fn sync_remote_markers(
289        &mut self,
290        repo: &Repository,
291        repo_path: &str,
292        pushed_state: objects::object::ChangeId,
293    ) -> Result<(), ProtocolError> {
294        let remote_markers = self
295            .list_refs(repo_path)
296            .await?
297            .into_iter()
298            .filter(|entry| !entry.is_thread)
299            .map(|entry| (entry.name, entry.change_id))
300            .collect::<std::collections::HashMap<_, _>>();
301        for marker in repo.refs().list_markers()? {
302            let Some(change_id) = repo.refs().get_marker(&marker)? else {
303                continue;
304            };
305            if !wire::is_ancestor(repo.store(), change_id, pushed_state)? {
306                continue;
307            }
308
309            let old_value = remote_markers.get(marker.as_str()).copied();
310            if old_value == Some(change_id) {
311                continue;
312            }
313
314            let result = self
315                .update_ref(repo_path, &marker, false, old_value, change_id, true, None)
316                .await?;
317            if !result.success {
318                return Err(ProtocolError::InvalidState(
319                    result
320                        .error
321                        .unwrap_or_else(|| format!("failed to sync marker '{marker}'")),
322                ));
323            }
324        }
325        Ok(())
326    }
327
328    pub(super) async fn sync_local_markers(
329        &mut self,
330        repo: &Repository,
331        repo_path: &str,
332    ) -> Result<(), ProtocolError> {
333        let remote_markers = self.list_refs(repo_path).await?;
334        for marker in remote_markers.into_iter().filter(|entry| !entry.is_thread) {
335            if !repo.store().has_state(&marker.change_id)? {
336                continue;
337            }
338            let marker_name = MarkerName::from(marker.name.as_str());
339            match repo.refs().get_marker(&marker_name)? {
340                Some(existing) if existing == marker.change_id => {}
341                Some(existing) => repo.refs().set_marker_cas(
342                    &marker_name,
343                    refs::RefExpectation::Value(existing),
344                    &marker.change_id,
345                )?,
346                None => repo.refs().create_marker(&marker_name, &marker.change_id)?,
347            }
348        }
349        Ok(())
350    }
351}
352
353pub use hydration::{LazyHostedHydrator, PullMaterialization, register_hosted_factory};
354pub use session::{HostedAuthMode, HostedSession};