Skip to main content

kagi_sync/infrastructure/
remote_client.rs

1use crate::domain::envelope::{RequestPlaintext, ResponseEnvelope, verify_response_mac};
2use crate::domain::remote_config::ServerKeyResponse;
3use crate::infrastructure::remote_envelope::{decrypt_response, encrypt_request, parse_recipient};
4use age::x25519;
5use base64::Engine;
6use kagi_domain::error::DomainError;
7use reqwest::Client;
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10use url::Url;
11
12pub struct RemoteClient {
13    client: Client,
14    remote_url: String,
15    server_recipient: x25519::Recipient,
16    server_key_id: String,
17    fingerprint: String,
18}
19
20fn is_localhost_url(url: &str) -> bool {
21    let parsed = match Url::parse(url) {
22        Ok(u) => u,
23        Err(_) => return false,
24    };
25    if parsed.scheme() != "http" {
26        return false;
27    }
28    match parsed.host() {
29        Some(url::Host::Domain("localhost")) => true,
30        Some(url::Host::Ipv4(ip)) if ip.is_loopback() => true,
31        Some(url::Host::Ipv6(ip)) if ip.is_loopback() => true,
32        _ => false,
33    }
34}
35
36#[derive(Error, Debug)]
37pub enum ClientError {
38    #[error("invalid token")]
39    InvalidToken,
40    #[error("project not found")]
41    ProjectNotFound,
42    #[error("project state conflict")]
43    ProjectStateConflict,
44    #[error("request failed: {0}")]
45    RequestFailed(String),
46}
47
48#[derive(Serialize, Deserialize, Debug, Clone)]
49pub struct MemberJoinRequest {
50    pub member_id: String,
51    pub name: String,
52    pub recipient: String,
53    pub signing_public_key: String,
54}
55
56#[derive(Deserialize, Debug, Clone)]
57pub struct JoinResponse {}
58
59#[derive(Deserialize, Debug, Clone)]
60pub struct TokenIssueResponse {
61    pub token_id: String,
62    pub project_token: String,
63    #[allow(dead_code)]
64    pub status: String,
65}
66
67pub fn validate_http_transport(remote_url: &str, allow_insecure: bool) -> Result<(), DomainError> {
68    if remote_url.starts_with("http://") && !is_localhost_url(remote_url) && !allow_insecure {
69        let env_override = std::env::var("KAGI_ALLOW_INSECURE_HTTP")
70            .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
71            .unwrap_or(false);
72        if !env_override {
73            return Err(DomainError::StoreCorrupted(
74                "HTTP remotes are only allowed for localhost. Use HTTPS or pass --allow-insecure-http for local testing.".into(),
75            ));
76        }
77    }
78    Ok(())
79}
80
81impl RemoteClient {
82    pub async fn new(remote_url: String, allow_insecure: bool) -> Result<Self, DomainError> {
83        validate_http_transport(&remote_url, allow_insecure)?;
84        let client = if is_localhost_url(&remote_url) {
85            Client::builder().no_proxy().build().map_err(|e| {
86                DomainError::StoreCorrupted(format!("failed to build HTTP client: {}", e))
87            })?
88        } else {
89            Client::new()
90        };
91        let url = format!("{}/v1/server-key", remote_url.trim_end_matches('/'));
92        let server_key: ServerKeyResponse = client
93            .get(&url)
94            .send()
95            .await
96            .map_err(|e| DomainError::StoreCorrupted(format!("failed to fetch server key: {}", e)))?
97            .json()
98            .await
99            .map_err(|e| {
100                DomainError::StoreCorrupted(format!("invalid server key response: {}", e))
101            })?;
102
103        let server_recipient = parse_recipient(&server_key.recipient)?;
104        Ok(Self {
105            client,
106            remote_url,
107            server_recipient,
108            server_key_id: server_key.server_key_id,
109            fingerprint: server_key.fingerprint,
110        })
111    }
112
113    pub async fn new_pinned(
114        remote_url: String,
115        expected_fingerprint: &str,
116        allow_insecure: bool,
117    ) -> Result<Self, DomainError> {
118        let remote = Self::new(remote_url, allow_insecure).await?;
119        if remote.fingerprint != expected_fingerprint {
120            return Err(DomainError::StoreCorrupted(format!(
121                "server fingerprint mismatch: expected {}, got {}",
122                expected_fingerprint, remote.fingerprint
123            )));
124        }
125        Ok(remote)
126    }
127
128    pub fn fingerprint(&self) -> &str {
129        &self.fingerprint
130    }
131
132    pub fn server_key_id(&self) -> &str {
133        &self.server_key_id
134    }
135
136    pub async fn send_request(
137        &self,
138        plaintext: &RequestPlaintext,
139        local_identity: &x25519::Identity,
140    ) -> Result<serde_json::Value, DomainError> {
141        let local_recipient = local_identity.to_public();
142        let mut envelope = encrypt_request(plaintext, &self.server_recipient, &local_recipient)?;
143        envelope.server_key_id = self.server_key_id.clone();
144
145        let url = format!(
146            "{}{}",
147            self.remote_url.trim_end_matches('/'),
148            plaintext.path
149        );
150        let response = self
151            .client
152            .post(&url)
153            .json(&envelope)
154            .send()
155            .await
156            .map_err(|e| DomainError::StoreCorrupted(format!("request failed: {}", e)))?;
157
158        let response_text = response
159            .text()
160            .await
161            .map_err(|e| DomainError::StoreCorrupted(format!("invalid response body: {}", e)))?;
162        let response_envelope: ResponseEnvelope =
163            serde_json::from_str(&response_text).map_err(|e| {
164                DomainError::StoreCorrupted(format!(
165                    "invalid response: {} | raw: {}",
166                    e, response_text
167                ))
168            })?;
169
170        if response_envelope.request_id != plaintext.request_id {
171            return Err(DomainError::StoreCorrupted(
172                "response request_id mismatch".into(),
173            ));
174        }
175        let mac_key = plaintext
176            .token
177            .as_deref()
178            .or(plaintext.claim_secret.as_deref());
179        if let Some(key) = mac_key {
180            let mac = response_envelope.mac.as_deref().ok_or_else(|| {
181                DomainError::StoreCorrupted("missing response authentication mac".into())
182            })?;
183            if !verify_response_mac(
184                key,
185                &plaintext.request_id,
186                &response_envelope.ciphertext,
187                mac,
188            ) {
189                return Err(DomainError::StoreCorrupted(
190                    "invalid response authentication mac".into(),
191                ));
192            }
193        }
194
195        let ciphertext = base64::engine::general_purpose::URL_SAFE_NO_PAD
196            .decode(&response_envelope.ciphertext)
197            .map_err(|e| DomainError::DecryptFailed(e.to_string()))?;
198        let decrypted = decrypt_response(&ciphertext, local_identity)?;
199        if decrypted.get("request_id").and_then(|v| v.as_str())
200            != Some(plaintext.request_id.as_str())
201        {
202            return Err(DomainError::StoreCorrupted(
203                "decrypted response request_id mismatch".into(),
204            ));
205        }
206
207        if !decrypted
208            .get("ok")
209            .and_then(|v| v.as_bool())
210            .unwrap_or(false)
211        {
212            let error = decrypted.get("error").cloned().unwrap_or_default();
213            let code = error
214                .get("code")
215                .and_then(|v| v.as_str())
216                .unwrap_or("unknown");
217            let message = error
218                .get("message")
219                .and_then(|v| v.as_str())
220                .unwrap_or("unknown error");
221            return Err(DomainError::RemoteRejected {
222                code: code.to_string(),
223                message: message.to_string(),
224            });
225        }
226
227        Ok(decrypted.get("data").cloned().unwrap_or_default())
228    }
229
230    fn map_request_error(e: DomainError) -> ClientError {
231        if let DomainError::RemoteRejected { ref code, .. } = e {
232            if code == "auth_failed" {
233                return ClientError::InvalidToken;
234            }
235            if code == "not_found" {
236                return ClientError::ProjectNotFound;
237            }
238            if code == "conflict" {
239                return ClientError::ProjectStateConflict;
240            }
241        }
242        ClientError::RequestFailed(e.to_string())
243    }
244
245    pub async fn get_token_from_claim_secret(
246        &self,
247        project_id: &str,
248        member_id: &str,
249        claim_secret: &str,
250        identity: &x25519::Identity,
251    ) -> Result<String, DomainError> {
252        let request_id = format!("kgr_{}", nanoid::nanoid!(12));
253        let plaintext = RequestPlaintext {
254            version: 1,
255            request_id: request_id.clone(),
256            issued_at: time::OffsetDateTime::now_utc()
257                .format(&time::format_description::well_known::Rfc3339)
258                .unwrap(),
259            operation: "pull".into(),
260            method: "POST".into(),
261            path: format!("/v1/projects/{}/pull", project_id),
262            project_id: Some(project_id.to_string()),
263            token: None,
264            claim_secret: Some(claim_secret.to_string()),
265            payload: serde_json::json!({
266                "member_id": member_id,
267            }),
268        };
269        let data = self.send_request(&plaintext, identity).await?;
270        if let Some(wrapped_b64) = data.get("wrapped_project_token").and_then(|v| v.as_str()) {
271            let wrapped = base64::engine::general_purpose::URL_SAFE_NO_PAD
272                .decode(wrapped_b64)
273                .map_err(|e| {
274                    DomainError::StoreCorrupted(format!("invalid wrapped token: {}", e))
275                })?;
276            let decrypted = crate::infrastructure::remote_envelope::decrypt_bytes(
277                &wrapped, identity,
278            )
279            .map_err(|e| {
280                DomainError::StoreCorrupted(format!("failed to decrypt wrapped token: {}", e))
281            })?;
282            String::from_utf8(decrypted)
283                .map_err(|e| DomainError::StoreCorrupted(format!("invalid token: {}", e)))
284        } else {
285            Err(DomainError::ProjectTokenUnavailable(
286                "no project token available; ask an active member/admin to approve this member, then run `kagi pull`"
287                    .into(),
288            ))
289        }
290    }
291
292    pub async fn send_member_join_request(
293        &self,
294        project_id: &str,
295        token: &str,
296        join_request: &MemberJoinRequest,
297        identity: &x25519::Identity,
298    ) -> Result<JoinResponse, ClientError> {
299        let request_id = format!("kgr_{}", nanoid::nanoid!(12));
300        let plaintext = RequestPlaintext {
301            version: 1,
302            request_id: request_id.clone(),
303            issued_at: time::OffsetDateTime::now_utc()
304                .format(&time::format_description::well_known::Rfc3339)
305                .unwrap(),
306            operation: "join".into(),
307            method: "POST".into(),
308            path: format!("/v1/projects/{}/join", project_id),
309            project_id: Some(project_id.to_string()),
310            token: Some(token.to_string()),
311            claim_secret: None,
312            payload: serde_json::json!({
313                "join_request": {
314                    "member_id": join_request.member_id,
315                    "name": join_request.name,
316                    "recipient": join_request.recipient,
317                    "signing_public_key": join_request.signing_public_key,
318                }
319            }),
320        };
321        let data = self
322            .send_request(&plaintext, identity)
323            .await
324            .map_err(Self::map_request_error)?;
325        serde_json::from_value(data).map_err(|e| ClientError::RequestFailed(e.to_string()))
326    }
327
328    pub async fn send_member_token_issue(
329        &self,
330        project_id: &str,
331        token: &str,
332        member_id: &str,
333        identity: &x25519::Identity,
334    ) -> Result<TokenIssueResponse, ClientError> {
335        let request_id = format!("kgr_{}", nanoid::nanoid!(12));
336        let plaintext = RequestPlaintext {
337            version: 1,
338            request_id: request_id.clone(),
339            issued_at: time::OffsetDateTime::now_utc()
340                .format(&time::format_description::well_known::Rfc3339)
341                .unwrap(),
342            operation: "token_issue".into(),
343            method: "POST".into(),
344            path: format!("/v1/projects/{}/tokens/issue", project_id),
345            project_id: Some(project_id.to_string()),
346            token: Some(token.to_string()),
347            claim_secret: None,
348            payload: serde_json::json!({
349                "member_id": member_id,
350                "capabilities": ["pull", "push"],
351            }),
352        };
353        let data = self
354            .send_request(&plaintext, identity)
355            .await
356            .map_err(Self::map_request_error)?;
357        serde_json::from_value(data).map_err(|e| ClientError::RequestFailed(e.to_string()))
358    }
359
360    pub async fn send_list_tokens(
361        &self,
362        project_id: &str,
363        token: &str,
364        identity: &x25519::Identity,
365    ) -> Result<serde_json::Value, ClientError> {
366        let request_id = format!("kgr_{}", nanoid::nanoid!(12));
367        let plaintext = RequestPlaintext {
368            version: 1,
369            request_id: request_id.clone(),
370            issued_at: time::OffsetDateTime::now_utc()
371                .format(&time::format_description::well_known::Rfc3339)
372                .unwrap(),
373            operation: "token_list".into(),
374            method: "POST".into(),
375            path: format!("/v1/projects/{}/tokens/list", project_id),
376            project_id: Some(project_id.to_string()),
377            token: Some(token.to_string()),
378            claim_secret: None,
379            payload: serde_json::json!({}),
380        };
381        self.send_request(&plaintext, identity)
382            .await
383            .map_err(Self::map_request_error)
384    }
385
386    pub async fn send_revoke_tokens(
387        &self,
388        project_id: &str,
389        token: &str,
390        token_ids: &[String],
391        identity: &x25519::Identity,
392    ) -> Result<serde_json::Value, ClientError> {
393        let request_id = format!("kgr_{}", nanoid::nanoid!(12));
394        let plaintext = RequestPlaintext {
395            version: 1,
396            request_id: request_id.clone(),
397            issued_at: time::OffsetDateTime::now_utc()
398                .format(&time::format_description::well_known::Rfc3339)
399                .unwrap(),
400            operation: "token_revoke".into(),
401            method: "POST".into(),
402            path: format!("/v1/projects/{}/tokens/revoke", project_id),
403            project_id: Some(project_id.to_string()),
404            token: Some(token.to_string()),
405            claim_secret: None,
406            payload: serde_json::json!({
407                "token_ids": token_ids,
408            }),
409        };
410        self.send_request(&plaintext, identity)
411            .await
412            .map_err(Self::map_request_error)
413    }
414
415    pub async fn send_audit_query(
416        &self,
417        token: &str,
418        project_id: Option<&str>,
419        limit: i64,
420        identity: &x25519::Identity,
421    ) -> Result<serde_json::Value, ClientError> {
422        let request_id = format!("kgr_{}", nanoid::nanoid!(12));
423        let mut payload = serde_json::json!({
424            "limit": limit.clamp(1, 500),
425        });
426        if let Some(pid) = project_id {
427            payload["project_id"] = serde_json::json!(pid);
428        }
429        let plaintext = RequestPlaintext {
430            version: 1,
431            request_id: request_id.clone(),
432            issued_at: time::OffsetDateTime::now_utc()
433                .format(&time::format_description::well_known::Rfc3339)
434                .unwrap(),
435            operation: "audit".into(),
436            method: "POST".into(),
437            path: "/v1/audit".to_string(),
438            project_id: None,
439            token: Some(token.to_string()),
440            claim_secret: None,
441            payload,
442        };
443        self.send_request(&plaintext, identity)
444            .await
445            .map_err(Self::map_request_error)
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use std::ffi::OsString;
453    use std::sync::Mutex;
454
455    static ENV_LOCK: Mutex<()> = Mutex::new(());
456
457    struct EnvVarGuard {
458        key: &'static str,
459        previous: Option<OsString>,
460    }
461
462    impl EnvVarGuard {
463        fn set(key: &'static str, value: &str) -> Self {
464            let previous = std::env::var_os(key);
465            unsafe {
466                std::env::set_var(key, value);
467            }
468            Self { key, previous }
469        }
470
471        fn unset(key: &'static str) -> Self {
472            let previous = std::env::var_os(key);
473            unsafe {
474                std::env::remove_var(key);
475            }
476            Self { key, previous }
477        }
478    }
479
480    impl Drop for EnvVarGuard {
481        fn drop(&mut self) {
482            unsafe {
483                if let Some(previous) = &self.previous {
484                    std::env::set_var(self.key, previous);
485                } else {
486                    std::env::remove_var(self.key);
487                }
488            }
489        }
490    }
491
492    #[test]
493    fn test_is_localhost_url_localhost() {
494        assert!(is_localhost_url("http://localhost:13816"));
495        assert!(is_localhost_url("http://localhost:8787"));
496    }
497
498    #[test]
499    fn test_is_localhost_url_127_0_0_1() {
500        assert!(is_localhost_url("http://127.0.0.1:13816"));
501        assert!(is_localhost_url("http://127.0.0.1:8787"));
502    }
503
504    #[test]
505    fn test_is_localhost_url_ipv6_loopback() {
506        assert!(is_localhost_url("http://[::1]:13816"));
507    }
508
509    #[test]
510    fn test_is_localhost_url_rejects_non_loopback() {
511        assert!(!is_localhost_url("http://example.com"));
512        assert!(!is_localhost_url("http://192.168.1.1:13816"));
513        assert!(!is_localhost_url("http://10.0.0.1:13816"));
514    }
515
516    #[test]
517    fn test_is_localhost_url_rejects_https() {
518        assert!(!is_localhost_url("https://localhost:13816"));
519        assert!(!is_localhost_url("https://127.0.0.1:13816"));
520    }
521
522    #[test]
523    fn test_validate_http_transport_blocks_non_localhost_http() {
524        let _lock = ENV_LOCK.lock().unwrap();
525        let _env = EnvVarGuard::unset("KAGI_ALLOW_INSECURE_HTTP");
526        let result = validate_http_transport("http://example.com", false);
527        assert!(result.is_err());
528        assert!(
529            result
530                .unwrap_err()
531                .to_string()
532                .contains("HTTP remotes are only allowed for localhost")
533        );
534    }
535
536    #[test]
537    fn test_validate_http_transport_allows_localhost_http() {
538        assert!(validate_http_transport("http://127.0.0.1:13816", false).is_ok());
539        assert!(validate_http_transport("http://localhost:13816", false).is_ok());
540    }
541
542    #[test]
543    fn test_validate_http_transport_allows_https_anywhere() {
544        assert!(validate_http_transport("https://example.com", false).is_ok());
545        assert!(validate_http_transport("https://kagi.example.com", false).is_ok());
546    }
547
548    #[test]
549    fn test_validate_http_transport_allows_insecure_with_flag() {
550        assert!(validate_http_transport("http://example.com", true).is_ok());
551    }
552
553    #[test]
554    fn test_validate_http_transport_allows_insecure_with_env() {
555        let _lock = ENV_LOCK.lock().unwrap();
556        let _env = EnvVarGuard::set("KAGI_ALLOW_INSECURE_HTTP", "1");
557        assert!(validate_http_transport("http://example.com", false).is_ok());
558    }
559}