Skip to main content

heddle_client/grpc_hosted/
mod.rs

1//! Hosted gRPC client for the transport rewrite.
2
3mod content;
4mod helpers;
5mod hydration;
6mod sync;
7mod user;
8
9use crypto::{Ed25519Signer, Signer};
10use grpc::heddle::v1::{
11    KeypairProof, MintBiscuitRequest, auth_service_client::AuthServiceClient,
12    content_service_client::ContentServiceClient,
13    hosted_user_service_client::HostedUserServiceClient, mint_biscuit_request::Proof,
14    repo_sync_service_client::RepoSyncServiceClient,
15};
16use proto::ProtocolError;
17use repo::Repository;
18use tonic::{
19    Request,
20    metadata::MetadataValue,
21    transport::{Certificate, Channel, ClientTlsConfig, Endpoint},
22};
23
24use crate::credentials;
25use cli_shared::ClientConfig;
26
27pub struct HostedGrpcClient {
28    pub(super) inner: RepoSyncServiceClient<Channel>,
29    pub(super) user: HostedUserServiceClient<Channel>,
30    pub(super) auth: AuthServiceClient<Channel>,
31    pub(super) content: ContentServiceClient<Channel>,
32    pub(super) token_header: Option<MetadataValue<tonic::metadata::Ascii>>,
33    transport: helpers::HostedTransportPolicy,
34    pub(super) auth_proof_key_pem: Option<String>,
35    /// The key used to look up this server's credential in the credential
36    /// store.  When set, `auto_rotate_if_needed` will use it to read and
37    /// update `~/.heddle/credentials.toml` transparently.
38    server_key: Option<String>,
39}
40
41impl HostedGrpcClient {
42    pub async fn connect(
43        addr: std::net::SocketAddr,
44        config: &ClientConfig,
45    ) -> Result<Self, ProtocolError> {
46        let scheme = if config.tls_enabled { "https" } else { "http" };
47        let mut endpoint = Endpoint::from_shared(format!("{scheme}://{addr}"))
48            .map_err(|err| ProtocolError::InvalidState(err.to_string()))?;
49        if config.tls_enabled {
50            let mut tls = ClientTlsConfig::new();
51            if let Some(domain_name) = &config.tls_domain_name {
52                tls = tls.domain_name(domain_name.clone());
53            }
54            if let Some(ca_pem) = &config.tls_ca_certificate_pem {
55                tls = tls.ca_certificate(Certificate::from_pem(ca_pem.as_bytes()));
56            }
57            endpoint = endpoint
58                .tls_config(tls)
59                .map_err(|err| ProtocolError::InvalidState(err.to_string()))?;
60        }
61        let channel = endpoint
62            .connect()
63            .await
64            .map_err(|err| ProtocolError::Io(std::io::Error::other(err.to_string())))?;
65        let token_header = config
66            .token
67            .as_ref()
68            .map(|token| MetadataValue::try_from(format!("Bearer {}", token.id)))
69            .transpose()
70            .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
71        let transport = helpers::HostedTransportPolicy::from_client_config(config);
72        Ok(Self {
73            inner: RepoSyncServiceClient::new(channel.clone()),
74            user: HostedUserServiceClient::new(channel.clone()),
75            auth: AuthServiceClient::new(channel.clone()),
76            content: ContentServiceClient::new(channel),
77            token_header,
78            transport,
79            auth_proof_key_pem: config.auth_proof_key_pem.clone(),
80            server_key: config.server_key.clone(),
81        })
82    }
83
84    pub(super) fn apply_auth<T>(&self, request: &mut Request<T>) -> Result<(), ProtocolError> {
85        if let Some(token) = &self.token_header {
86            request
87                .metadata_mut()
88                .insert("authorization", token.clone());
89            if let Some(pem) = &self.auth_proof_key_pem {
90                let signer = Ed25519Signer::from_pem(pem)
91                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
92                let raw = token
93                    .to_str()
94                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
95                let bearer = raw
96                    .strip_prefix("Bearer ")
97                    .or_else(|| raw.strip_prefix("bearer "))
98                    .unwrap_or(raw);
99                let proof_ts = std::time::SystemTime::now()
100                    .duration_since(std::time::UNIX_EPOCH)
101                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?
102                    .as_secs()
103                    .to_string();
104                let signature = signer
105                    .sign(format!("{bearer}|{proof_ts}").as_bytes())
106                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
107                use base64::Engine;
108                let encoded = base64::engine::general_purpose::STANDARD.encode(signature);
109                let proof = MetadataValue::try_from(encoded)
110                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
111                request.metadata_mut().insert("x-heddle-proof", proof);
112                let proof_ts = MetadataValue::try_from(proof_ts)
113                    .map_err(|err| ProtocolError::AuthenticationFailed(err.to_string()))?;
114                request.metadata_mut().insert("x-heddle-proof-ts", proof_ts);
115            }
116        }
117        Ok(())
118    }
119
120    /// Transparently rotate the credential for this client if it is near expiry.
121    ///
122    /// No-ops if `server_key` was not set on `ClientConfig` at construction
123    /// time, or if no credential is stored for the server, or if the token is
124    /// not within 10 minutes of expiry.
125    pub async fn auto_rotate_if_needed(&mut self) {
126        let server_key = match &self.server_key {
127            Some(k) => k.clone(),
128            None => return,
129        };
130        self.rotate_credential_for_server(&server_key).await;
131    }
132
133    async fn rotate_credential_for_server(&mut self, server_key: &str) {
134        // Load the stored credential.
135        let cred = match credentials::resolve_credential_for_server(server_key) {
136            Ok(Some(c)) => c,
137            Ok(None) => return,
138            Err(err) => {
139                tracing::warn!("credential rotation: failed to load credential: {err}");
140                return;
141            }
142        };
143
144        // Check whether the Biscuit's stored expiry is within the
145        // rotation window.
146        if !credentials::token_needs_rotation(&cred) {
147            return;
148        }
149
150        // We need both `credential_id` (the public key id the server
151        // will look up) and `private_key_pem` (to sign the renewal
152        // proof). Older credentials without one or the other can't
153        // self-renew; the user falls back to `heddle auth login`.
154        let public_key_id = match &cred.credential_id {
155            Some(id) => id.clone(),
156            None => {
157                tracing::debug!("credential rotation: no credential_id stored, skipping");
158                return;
159            }
160        };
161        let private_key_pem = match &cred.private_key_pem {
162            Some(pem) => pem.clone(),
163            None => {
164                tracing::debug!("credential rotation: no private_key_pem stored, skipping");
165                return;
166            }
167        };
168
169        // Sign the canonical renewal challenge:
170        //   "{timestamp}\n{public_key_id}\n{requested_scope}"
171        // Empty `requested_scope` == reuse the keypair owner's
172        // original scope. The server clamps anyway, so a permissive
173        // hint is fine.
174        let signer = match Ed25519Signer::from_pem(&private_key_pem) {
175            Ok(s) => s,
176            Err(err) => {
177                tracing::warn!("credential rotation: failed to load signing key: {err}");
178                return;
179            }
180        };
181        let timestamp = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
182            Ok(d) => d.as_secs(),
183            Err(err) => {
184                tracing::warn!("credential rotation: clock skew: {err}");
185                return;
186            }
187        };
188        let canonical = format!("{timestamp}\n{public_key_id}\n");
189        let signature = match signer.sign(canonical.as_bytes()) {
190            Ok(sig) => sig,
191            Err(err) => {
192                tracing::warn!("credential rotation: failed to sign challenge: {err}");
193                return;
194            }
195        };
196
197        let mut request = Request::new(MintBiscuitRequest {
198            subject: cred.subject.clone(),
199            requested_scope: String::new(),
200            user_agent: String::new(),
201            ip: String::new(),
202            proof: Some(Proof::Keypair(KeypairProof {
203                public_key_id,
204                timestamp,
205                signature,
206            })),
207            client_operation_id: String::new(),
208        });
209        // MintBiscuit is unauthenticated — the proof is the auth.
210        // We deliberately skip `apply_auth` here.
211        let _ = &mut request;
212
213        let response = match self.auth.mint_biscuit(request).await {
214            Ok(r) => r.into_inner(),
215            Err(status) => {
216                tracing::warn!(
217                    "credential rotation: MintBiscuit failed: {} — continuing with existing token",
218                    status.message()
219                );
220                return;
221            }
222        };
223
224        // Format the new expiry as RFC 3339.
225        let expires_at_secs = response
226            .expires_at
227            .as_ref()
228            .map(|t| t.seconds.max(0))
229            .unwrap_or(0);
230        let new_expires_at = if expires_at_secs > 0 {
231            chrono::DateTime::from_timestamp(expires_at_secs, 0)
232                .map(|dt| dt.to_rfc3339())
233                .unwrap_or_else(|| expires_at_secs.to_string())
234        } else {
235            String::new()
236        };
237
238        tracing::debug!(
239            "credential rotation: rotated successfully, new expiry: {}",
240            new_expires_at
241        );
242
243        // Persist the updated credential. The keypair stays the
244        // same — that's the whole point of the keypair-based renewal
245        // model. We replace `token` (the Biscuit) and bump
246        // `expires_at` to the fresh window.
247        let updated = credentials::ServerCredential {
248            token: response.token.clone(),
249            subject: if response.subject.is_empty() {
250                cred.subject.clone()
251            } else {
252                response.subject
253            },
254            device_id: cred.device_id.clone(),
255            credential_id: cred.credential_id.clone(),
256            private_key_pem: Some(private_key_pem),
257            expires_at: if new_expires_at.is_empty() {
258                cred.expires_at.clone()
259            } else {
260                Some(new_expires_at)
261            },
262        };
263
264        if let Err(err) = credentials::store_server_credential(server_key, updated) {
265            tracing::warn!("credential rotation: failed to persist updated credential: {err}");
266            // Don't bail — the in-memory update below still improves the session.
267        }
268
269        // Update the in-memory token header so the remaining RPCs on this
270        // client instance use the fresh token.
271        match MetadataValue::try_from(format!("Bearer {}", response.token)) {
272            Ok(header) => self.token_header = Some(header),
273            Err(err) => {
274                tracing::warn!("credential rotation: failed to set new token header: {err}");
275            }
276        }
277    }
278
279    pub(super) async fn sync_remote_markers(
280        &mut self,
281        repo: &Repository,
282        repo_path: &str,
283        pushed_state: objects::object::ChangeId,
284    ) -> Result<(), ProtocolError> {
285        let remote_markers = self
286            .list_refs(repo_path)
287            .await?
288            .into_iter()
289            .filter(|entry| !entry.is_thread)
290            .map(|entry| (entry.name, entry.change_id))
291            .collect::<std::collections::HashMap<_, _>>();
292        for marker in repo.refs().list_markers()? {
293            let Some(change_id) = repo.refs().get_marker(&marker)? else {
294                continue;
295            };
296            if !proto::is_ancestor(repo.store(), change_id, pushed_state)? {
297                continue;
298            }
299
300            let old_value = remote_markers.get(&marker).copied();
301            if old_value == Some(change_id) {
302                continue;
303            }
304
305            let result = self
306                .update_ref(repo_path, &marker, false, old_value, change_id, true, None)
307                .await?;
308            if !result.success {
309                return Err(ProtocolError::InvalidState(
310                    result
311                        .error
312                        .unwrap_or_else(|| format!("failed to sync marker '{marker}'")),
313                ));
314            }
315        }
316        Ok(())
317    }
318
319    pub(super) async fn sync_local_markers(
320        &mut self,
321        repo: &Repository,
322        repo_path: &str,
323    ) -> Result<(), ProtocolError> {
324        let remote_markers = self.list_refs(repo_path).await?;
325        for marker in remote_markers.into_iter().filter(|entry| !entry.is_thread) {
326            if !repo.store().has_state(&marker.change_id)? {
327                continue;
328            }
329            match repo.refs().get_marker(&marker.name)? {
330                Some(existing) if existing == marker.change_id => {}
331                Some(existing) => repo.refs().set_marker_cas(
332                    &marker.name,
333                    refs::RefExpectation::Value(existing),
334                    &marker.change_id,
335                )?,
336                None => repo.refs().create_marker(&marker.name, &marker.change_id)?,
337            }
338        }
339        Ok(())
340    }
341}
342
343pub use hydration::PullMaterialization;