Skip to main content

kagi_sync/infrastructure/
remote_client.rs

1use crate::domain::envelope::{RequestPlaintext, ResponseEnvelope, verify_response_mac};
2use crate::domain::remote_config::ServerKeyResponse;
3use crate::infrastructure::remote_envelope::{decrypt_response, encrypt_request, parse_recipient};
4use age::x25519;
5use base64::Engine;
6use kagi_domain::error::DomainError;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use url::Url;
11
12pub struct RemoteClient {
13    client: Client,
14    remote_url: String,
15    server_recipient: x25519::Recipient,
16    server_key_id: String,
17    fingerprint: String,
18}
19
20fn is_localhost_url(url: &str) -> bool {
21    let parsed = match Url::parse(url) {
22        Ok(u) => u,
23        Err(_) => return false,
24    };
25    if parsed.scheme() != "http" {
26        return false;
27    }
28    match parsed.host() {
29        Some(url::Host::Domain("localhost")) => true,
30        Some(url::Host::Ipv4(ip)) if ip.is_loopback() => true,
31        Some(url::Host::Ipv6(ip)) if ip.is_loopback() => true,
32        _ => false,
33    }
34}
35
36const HTTP_ERROR_BODY_PREVIEW_LEN: usize = 512;
37
38fn summarize_http_error_body(body: &str) -> String {
39    let preview: String = body.chars().take(HTTP_ERROR_BODY_PREVIEW_LEN).collect();
40    if body.chars().count() > HTTP_ERROR_BODY_PREVIEW_LEN {
41        format!("{preview}...")
42    } else {
43        preview
44    }
45}
46
47#[derive(Error, Debug)]
48pub enum ClientError {
49    #[error("invalid token")]
50    InvalidToken,
51    #[error("project not found")]
52    ProjectNotFound,
53    #[error("project state conflict")]
54    ProjectStateConflict,
55    #[error("request failed: {0}")]
56    RequestFailed(String),
57}
58
59#[derive(Serialize, Deserialize, Debug, Clone)]
60pub struct MemberJoinRequest {
61    pub member_id: String,
62    pub name: String,
63    pub recipient: String,
64    pub signing_public_key: String,
65}
66
67#[derive(Deserialize, Debug, Clone)]
68pub struct JoinResponse {}
69
70#[derive(Deserialize, Debug, Clone)]
71pub struct TokenIssueResponse {
72    pub token_id: String,
73    pub project_token: String,
74    #[allow(dead_code)]
75    pub status: String,
76}
77
78pub fn validate_http_transport(remote_url: &str, allow_insecure: bool) -> Result<(), DomainError> {
79    if remote_url.starts_with("http://") && !is_localhost_url(remote_url) && !allow_insecure {
80        let env_override = std::env::var("KAGI_ALLOW_INSECURE_HTTP")
81            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
82            .unwrap_or(false);
83        if !env_override {
84            return Err(DomainError::RemoteProtocolError(
85                "HTTP remotes are only allowed for localhost. Use HTTPS or pass --allow-insecure-http for local testing.".into(),
86            ));
87        }
88    }
89    Ok(())
90}
91
92impl RemoteClient {
93    pub async fn new(remote_url: String, allow_insecure: bool) -> Result<Self, DomainError> {
94        validate_http_transport(&remote_url, allow_insecure)?;
95        let client = if is_localhost_url(&remote_url) {
96            Client::builder().no_proxy().build().map_err(|e| {
97                DomainError::RemoteProtocolError(format!("failed to build HTTP client: {e}"))
98            })?
99        } else {
100            Client::new()
101        };
102        let url = format!("{}/v1/server-key", remote_url.trim_end_matches('/'));
103        let server_key: ServerKeyResponse = client
104            .get(&url)
105            .send()
106            .await
107            .map_err(|e| {
108                DomainError::ServerUnavailable(format!("failed to fetch server key: {e}"))
109            })?
110            .json()
111            .await
112            .map_err(|e| {
113                DomainError::ServerUnavailable(format!("invalid server key response: {e}"))
114            })?;
115
116        let server_recipient = parse_recipient(&server_key.recipient)?;
117        Ok(Self {
118            client,
119            remote_url,
120            server_recipient,
121            server_key_id: server_key.server_key_id,
122            fingerprint: server_key.fingerprint,
123        })
124    }
125
126    pub async fn new_pinned(
127        remote_url: String,
128        expected_fingerprint: &str,
129        allow_insecure: bool,
130    ) -> Result<Self, DomainError> {
131        let remote = Self::new(remote_url, allow_insecure).await?;
132        if remote.fingerprint != expected_fingerprint {
133            return Err(DomainError::RemoteProtocolError(format!(
134                "server fingerprint mismatch: expected {}, got {}",
135                expected_fingerprint, remote.fingerprint
136            )));
137        }
138        Ok(remote)
139    }
140
141    pub fn fingerprint(&self) -> &str {
142        &self.fingerprint
143    }
144
145    pub fn server_key_id(&self) -> &str {
146        &self.server_key_id
147    }
148
149    pub async fn send_request(
150        &self,
151        plaintext: &RequestPlaintext,
152        local_identity: &x25519::Identity,
153    ) -> Result<serde_json::Value, DomainError> {
154        let local_recipient = local_identity.to_public();
155        let mut envelope = encrypt_request(plaintext, &self.server_recipient, &local_recipient)?;
156        envelope.server_key_id = self.server_key_id.clone();
157
158        let url = format!(
159            "{}{}",
160            self.remote_url.trim_end_matches('/'),
161            plaintext.path
162        );
163        let response = self
164            .client
165            .post(&url)
166            .json(&envelope)
167            .send()
168            .await
169            .map_err(|e| DomainError::ServerUnavailable(format!("request failed: {e}")))?;
170        let status = response.status();
171
172        let response_text = response
173            .text()
174            .await
175            .map_err(|e| DomainError::ServerUnavailable(format!("invalid response body: {e}")))?;
176        if !status.is_success() {
177            return Err(DomainError::RemoteProtocolError(format!(
178                "request failed with status {status}: {}",
179                summarize_http_error_body(&response_text),
180            )));
181        }
182
183        let response_envelope: ResponseEnvelope =
184            serde_json::from_str(&response_text).map_err(|e| {
185                DomainError::ServerUnavailable(format!(
186                    "invalid response: {e} | raw: {response_text}"
187                ))
188            })?;
189
190        if response_envelope.request_id != plaintext.request_id {
191            return Err(DomainError::RemoteProtocolError(
192                "response request_id mismatch".into(),
193            ));
194        }
195        let mac_key = plaintext
196            .token
197            .as_deref()
198            .or(plaintext.claim_secret.as_deref());
199        if let Some(key) = mac_key {
200            let mac = response_envelope.mac.as_deref().ok_or_else(|| {
201                DomainError::RemoteProtocolError("missing response authentication mac".into())
202            })?;
203            if !verify_response_mac(
204                key,
205                &plaintext.request_id,
206                &response_envelope.ciphertext,
207                mac,
208            ) {
209                return Err(DomainError::RemoteProtocolError(
210                    "invalid response authentication mac".into(),
211                ));
212            }
213        }
214
215        let ciphertext = base64::engine::general_purpose::URL_SAFE_NO_PAD
216            .decode(&response_envelope.ciphertext)
217            .map_err(|e| DomainError::DecryptFailed(e.to_string()))?;
218        let decrypted = decrypt_response(&ciphertext, local_identity)?;
219        if decrypted.get("request_id").and_then(|v| v.as_str())
220            != Some(plaintext.request_id.as_str())
221        {
222            return Err(DomainError::RemoteProtocolError(
223                "decrypted response request_id mismatch".into(),
224            ));
225        }
226
227        if !decrypted
228            .get("ok")
229            .and_then(serde_json::Value::as_bool)
230            .unwrap_or(false)
231        {
232            let error = decrypted.get("error").cloned().unwrap_or_default();
233            let code = error
234                .get("code")
235                .and_then(|v| v.as_str())
236                .unwrap_or("unknown");
237            let message = error
238                .get("message")
239                .and_then(|v| v.as_str())
240                .unwrap_or("unknown error");
241            return Err(DomainError::RemoteRejected {
242                code: code.to_string(),
243                message: message.to_string(),
244            });
245        }
246
247        Ok(decrypted.get("data").cloned().unwrap_or_default())
248    }
249
250    fn map_request_error(e: DomainError) -> ClientError {
251        if let DomainError::RemoteRejected { ref code, .. } = e {
252            if code == "auth_failed" {
253                return ClientError::InvalidToken;
254            }
255            if code == "not_found" {
256                return ClientError::ProjectNotFound;
257            }
258            if code == "conflict" {
259                return ClientError::ProjectStateConflict;
260            }
261        }
262        ClientError::RequestFailed(e.to_string())
263    }
264
265    pub async fn get_token_from_claim_secret(
266        &self,
267        project_id: &str,
268        member_id: &str,
269        claim_secret: &str,
270        identity: &x25519::Identity,
271    ) -> Result<String, DomainError> {
272        let request_id = format!(r"kgr_{}", nanoid::nanoid!(12));
273        let plaintext = RequestPlaintext {
274            version: 1,
275            request_id: request_id.clone(),
276            issued_at: time::OffsetDateTime::now_utc()
277                .format(&time::format_description::well_known::Rfc3339)
278                .unwrap(),
279            operation: "pull".into(),
280            method: "POST".into(),
281            path: format!("/v1/projects/{project_id}/pull"),
282            project_id: Some(project_id.to_string()),
283            token: None,
284            claim_secret: Some(claim_secret.to_string()),
285            payload: serde_json::json!({
286                "member_id": member_id,
287            }),
288        };
289        let data = self.send_request(&plaintext, identity).await?;
290        if let Some(wrapped_b64) = data.get("wrapped_project_token").and_then(|v| v.as_str()) {
291            let wrapped = base64::engine::general_purpose::URL_SAFE_NO_PAD
292                .decode(wrapped_b64)
293                .map_err(|e| {
294                    DomainError::RemoteProtocolError(format!("invalid wrapped token: {e}"))
295                })?;
296            let decrypted = crate::infrastructure::remote_envelope::decrypt_bytes(
297                &wrapped, identity,
298            )
299            .map_err(|e| {
300                DomainError::RemoteProtocolError(format!("failed to decrypt wrapped token: {e}"))
301            })?;
302            String::from_utf8(decrypted)
303                .map_err(|e| DomainError::RemoteProtocolError(format!("invalid token: {e}")))
304        } else {
305            Err(DomainError::ProjectTokenUnavailable(
306                "no project token available; ask an active member/admin to approve this member, then run `kagi remote pull`"
307                    .into(),
308            ))
309        }
310    }
311
312    pub async fn send_member_join_request(
313        &self,
314        project_id: &str,
315        token: &str,
316        join_request: &MemberJoinRequest,
317        identity: &x25519::Identity,
318    ) -> Result<JoinResponse, ClientError> {
319        let request_id = format!(r"kgr_{}", nanoid::nanoid!(12));
320        let plaintext = RequestPlaintext {
321            version: 1,
322            request_id: request_id.clone(),
323            issued_at: time::OffsetDateTime::now_utc()
324                .format(&time::format_description::well_known::Rfc3339)
325                .unwrap(),
326            operation: "join".into(),
327            method: "POST".into(),
328            path: format!("/v1/projects/{project_id}/join"),
329            project_id: Some(project_id.to_string()),
330            token: Some(token.to_string()),
331            claim_secret: None,
332            payload: serde_json::json!({
333                "join_request": {
334                    "member_id": join_request.member_id,
335                    "name": join_request.name,
336                    "recipient": join_request.recipient,
337                    "signing_public_key": join_request.signing_public_key,
338                }
339            }),
340        };
341        let data = self
342            .send_request(&plaintext, identity)
343            .await
344            .map_err(Self::map_request_error)?;
345        serde_json::from_value(data).map_err(|e| ClientError::RequestFailed(e.to_string()))
346    }
347
348    pub async fn send_member_token_issue(
349        &self,
350        project_id: &str,
351        token: &str,
352        member_id: &str,
353        identity: &x25519::Identity,
354    ) -> Result<TokenIssueResponse, ClientError> {
355        let request_id = format!(r"kgr_{}", nanoid::nanoid!(12));
356        let plaintext = RequestPlaintext {
357            version: 1,
358            request_id: request_id.clone(),
359            issued_at: time::OffsetDateTime::now_utc()
360                .format(&time::format_description::well_known::Rfc3339)
361                .unwrap(),
362            operation: "token_issue".into(),
363            method: "POST".into(),
364            path: format!("/v1/projects/{project_id}/tokens/issue"),
365            project_id: Some(project_id.to_string()),
366            token: Some(token.to_string()),
367            claim_secret: None,
368            payload: serde_json::json!({
369                "member_id": member_id,
370                "capabilities": ["pull", "push"],
371            }),
372        };
373        let data = self
374            .send_request(&plaintext, identity)
375            .await
376            .map_err(Self::map_request_error)?;
377        serde_json::from_value(data).map_err(|e| ClientError::RequestFailed(e.to_string()))
378    }
379
380    pub async fn send_list_tokens(
381        &self,
382        project_id: &str,
383        token: &str,
384        identity: &x25519::Identity,
385    ) -> Result<serde_json::Value, ClientError> {
386        let request_id = format!(r"kgr_{}", nanoid::nanoid!(12));
387        let plaintext = RequestPlaintext {
388            version: 1,
389            request_id: request_id.clone(),
390            issued_at: time::OffsetDateTime::now_utc()
391                .format(&time::format_description::well_known::Rfc3339)
392                .unwrap(),
393            operation: "token_list".into(),
394            method: "POST".into(),
395            path: format!("/v1/projects/{project_id}/tokens/list"),
396            project_id: Some(project_id.to_string()),
397            token: Some(token.to_string()),
398            claim_secret: None,
399            payload: serde_json::json!({}),
400        };
401        self.send_request(&plaintext, identity)
402            .await
403            .map_err(Self::map_request_error)
404    }
405
406    pub async fn send_revoke_tokens(
407        &self,
408        project_id: &str,
409        token: &str,
410        token_ids: &[String],
411        identity: &x25519::Identity,
412    ) -> Result<serde_json::Value, ClientError> {
413        let request_id = format!(r"kgr_{}", nanoid::nanoid!(12));
414        let plaintext = RequestPlaintext {
415            version: 1,
416            request_id: request_id.clone(),
417            issued_at: time::OffsetDateTime::now_utc()
418                .format(&time::format_description::well_known::Rfc3339)
419                .unwrap(),
420            operation: "token_revoke".into(),
421            method: "POST".into(),
422            path: format!("/v1/projects/{project_id}/tokens/revoke"),
423            project_id: Some(project_id.to_string()),
424            token: Some(token.to_string()),
425            claim_secret: None,
426            payload: serde_json::json!({
427                "token_ids": token_ids,
428            }),
429        };
430        self.send_request(&plaintext, identity)
431            .await
432            .map_err(Self::map_request_error)
433    }
434
435    pub async fn send_audit_query(
436        &self,
437        token: &str,
438        project_id: Option<&str>,
439        limit: i64,
440        identity: &x25519::Identity,
441    ) -> Result<serde_json::Value, ClientError> {
442        let request_id = format!(r"kgr_{}", nanoid::nanoid!(12));
443        let mut payload = serde_json::json!({
444            "limit": limit.clamp(1, 500),
445        });
446        if let Some(pid) = project_id {
447            payload["project_id"] = serde_json::json!(pid);
448        }
449        let plaintext = RequestPlaintext {
450            version: 1,
451            request_id: request_id.clone(),
452            issued_at: time::OffsetDateTime::now_utc()
453                .format(&time::format_description::well_known::Rfc3339)
454                .unwrap(),
455            operation: "audit".into(),
456            method: "POST".into(),
457            path: "/v1/audit".to_string(),
458            project_id: None,
459            token: Some(token.to_string()),
460            claim_secret: None,
461            payload,
462        };
463        self.send_request(&plaintext, identity)
464            .await
465            .map_err(Self::map_request_error)
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use std::ffi::OsString;
473    use std::sync::Mutex;
474
475    static ENV_LOCK: Mutex<()> = Mutex::new(());
476
477    struct EnvVarGuard {
478        key: &'static str,
479        previous: Option<OsString>,
480    }
481
482    impl EnvVarGuard {
483        fn set(key: &'static str, value: &str) -> Self {
484            let previous = std::env::var_os(key);
485            unsafe {
486                std::env::set_var(key, value);
487            }
488            Self { key, previous }
489        }
490
491        fn unset(key: &'static str) -> Self {
492            let previous = std::env::var_os(key);
493            unsafe {
494                std::env::remove_var(key);
495            }
496            Self { key, previous }
497        }
498    }
499
500    impl Drop for EnvVarGuard {
501        fn drop(&mut self) {
502            unsafe {
503                if let Some(previous) = &self.previous {
504                    std::env::set_var(self.key, previous);
505                } else {
506                    std::env::remove_var(self.key);
507                }
508            }
509        }
510    }
511
512    #[test]
513    fn test_is_localhost_url_localhost() {
514        assert!(is_localhost_url("http://localhost:13816"));
515        assert!(is_localhost_url("http://localhost:8787"));
516    }
517
518    #[test]
519    fn test_is_localhost_url_127_0_0_1() {
520        assert!(is_localhost_url("http://127.0.0.1:13816"));
521        assert!(is_localhost_url("http://127.0.0.1:8787"));
522    }
523
524    #[test]
525    fn test_is_localhost_url_ipv6_loopback() {
526        assert!(is_localhost_url("http://[::1]:13816"));
527    }
528
529    #[test]
530    fn test_is_localhost_url_rejects_non_loopback() {
531        assert!(!is_localhost_url("http://example.com"));
532        assert!(!is_localhost_url("http://192.168.1.1:13816"));
533        assert!(!is_localhost_url("http://10.0.0.1:13816"));
534    }
535
536    #[test]
537    fn test_is_localhost_url_rejects_https() {
538        assert!(!is_localhost_url("https://localhost:13816"));
539        assert!(!is_localhost_url("https://127.0.0.1:13816"));
540    }
541
542    #[test]
543    fn test_validate_http_transport_blocks_non_localhost_http() {
544        let _lock = ENV_LOCK.lock().unwrap();
545        let _env = EnvVarGuard::unset("KAGI_ALLOW_INSECURE_HTTP");
546        let result = validate_http_transport("http://example.com", false);
547        assert!(result.is_err());
548        assert!(
549            result
550                .unwrap_err()
551                .to_string()
552                .contains("HTTP remotes are only allowed for localhost")
553        );
554    }
555
556    #[test]
557    fn test_validate_http_transport_allows_localhost_http() {
558        assert!(validate_http_transport("http://127.0.0.1:13816", false).is_ok());
559        assert!(validate_http_transport("http://localhost:13816", false).is_ok());
560    }
561
562    #[test]
563    fn test_validate_http_transport_allows_https_anywhere() {
564        assert!(validate_http_transport("https://example.com", false).is_ok());
565        assert!(validate_http_transport("https://kagi.example.com", false).is_ok());
566    }
567
568    #[test]
569    fn test_validate_http_transport_allows_insecure_with_flag() {
570        assert!(validate_http_transport("http://example.com", true).is_ok());
571    }
572
573    #[test]
574    fn test_validate_http_transport_allows_insecure_with_env() {
575        let _lock = ENV_LOCK.lock().unwrap();
576        let _env = EnvVarGuard::set("KAGI_ALLOW_INSECURE_HTTP", "1");
577        assert!(validate_http_transport("http://example.com", false).is_ok());
578    }
579
580    #[test]
581    fn test_summarize_http_error_body_short() {
582        assert_eq!(summarize_http_error_body("error body"), "error body");
583    }
584
585    #[test]
586    fn test_summarize_http_error_body_long() {
587        let body = "x".repeat(1024);
588        let summary = summarize_http_error_body(&body);
589        assert_eq!(summary.len(), 515);
590        assert!(summary.ends_with("..."));
591    }
592}