1use crate::error::{EncryptionError, KittyError};
2use aes_gcm::{
3 aead::{Aead, AeadCore, KeyInit},
4 Aes256Gcm,
5};
6use rand_core::OsRng;
7use sha2::{Digest, Sha256};
8use std::fs;
9use std::path::Path;
10use x25519_dalek::{PublicKey, StaticSecret};
11
12pub struct Encryptor {
13 kitty_public_key: PublicKey,
14}
15
16impl Encryptor {
17 pub fn new() -> Result<Self, EncryptionError> {
18 let kitty_public_key = Self::load_kitty_public_key()?;
19 Ok(Self { kitty_public_key })
20 }
21
22 pub fn new_with_public_key(public_key: Option<&str>) -> Result<Self, EncryptionError> {
23 let kitty_public_key = if let Some(key_str) = public_key {
24 Self::parse_public_key(key_str)?
25 } else {
26 Self::load_kitty_public_key()?
27 };
28 Ok(Self { kitty_public_key })
29 }
30
31 fn load_kitty_public_key() -> Result<PublicKey, EncryptionError> {
32 let key_bytes = Self::read_kitty_public_key()?;
33 Self::bytes_to_public_key(key_bytes)
34 }
35
36 fn parse_public_key(key_str: &str) -> Result<PublicKey, EncryptionError> {
37 let key_bytes = base85::decode(key_str)
38 .map_err(|e| EncryptionError::InvalidPublicKey(e.to_string()))?;
39 Self::bytes_to_public_key(key_bytes)
40 }
41
42 fn bytes_to_public_key(key_bytes: Vec<u8>) -> Result<PublicKey, EncryptionError> {
43 if key_bytes.len() < 32 {
44 return Err(EncryptionError::PublicKeyTooShort {
45 expected: 32,
46 actual: key_bytes.len(),
47 });
48 }
49
50 let mut key_array = [0u8; 32];
51 key_array.copy_from_slice(&key_bytes[..32]);
52 Ok(PublicKey::from(key_array))
53 }
54
55 fn read_kitty_public_key() -> Result<Vec<u8>, EncryptionError> {
56 if let Ok(key_str) = std::env::var("KITTY_PUBLIC_KEY") {
57 return base85::decode(&key_str)
58 .map_err(|e| EncryptionError::InvalidPublicKey(e.to_string()));
59 }
60
61 let default_path = format!(
62 "{}/.config/kitty/key.pub",
63 std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
64 );
65
66 let key_path = Path::new(&default_path);
67 if !key_path.exists() {
68 return Err(EncryptionError::MissingPublicKey);
69 }
70
71 let key_bytes =
72 fs::read(&key_path).map_err(|e| EncryptionError::InvalidPublicKey(e.to_string()))?;
73
74 Ok(key_bytes)
75 }
76
77 pub fn encrypt_command(
78 &self,
79 payload: serde_json::Value,
80 ) -> Result<serde_json::Value, KittyError> {
81 let payload_str = serde_json::to_string(&payload)
82 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
83
84 let payload_bytes = payload_str.as_bytes();
85
86 let secret = StaticSecret::random_from_rng(&mut OsRng);
87 let public_key = PublicKey::from(&secret);
88 let shared_secret = secret.diffie_hellman(&self.kitty_public_key);
89
90 let mut hasher = Sha256::new();
91 hasher.update(shared_secret.as_bytes());
92 let encryption_key = hasher.finalize();
93
94 let cipher = Aes256Gcm::new_from_slice(&encryption_key)
95 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
96 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
97
98 let ciphertext = cipher
99 .encrypt(&nonce, payload_bytes)
100 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
101
102 let tag = &ciphertext[ciphertext.len() - 16..];
103 let encrypted_data = &ciphertext[..ciphertext.len() - 16];
104
105 let result = serde_json::json!({
106 "version": "0.43.1",
107 "iv": base85::encode(&nonce),
108 "tag": base85::encode(tag),
109 "pubkey": base85::encode(public_key.as_bytes()),
110 "encrypted": base85::encode(encrypted_data),
111 });
112
113 Ok(result)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_load_kitty_public_key_missing() {
123 unsafe {
124 std::env::remove_var("KITTY_PUBLIC_KEY");
125 }
126 let result = Encryptor::new();
127 assert!(matches!(result, Err(EncryptionError::MissingPublicKey)));
128 }
129
130 #[test]
131 fn test_load_kitty_public_key_invalid() {
132 unsafe {
133 std::env::set_var("KITTY_PUBLIC_KEY", "invalid base85");
134 }
135 let result = Encryptor::new();
136 assert!(matches!(result, Err(EncryptionError::InvalidPublicKey(_))));
137 }
138
139 #[test]
140 fn test_load_kitty_public_key_too_short() {
141 let short_key = base85::encode(&[1u8, 2, 3]);
142 unsafe {
143 std::env::set_var("KITTY_PUBLIC_KEY", short_key);
144 }
145 let result = Encryptor::new();
146 assert!(matches!(
147 result,
148 Err(EncryptionError::PublicKeyTooShort { .. })
149 ));
150 }
151
152 #[test]
153 fn test_new_with_public_key() {
154 let secret = StaticSecret::random_from_rng(&mut OsRng);
155 let public_key = PublicKey::from(&secret);
156 let public_key_str = base85::encode(public_key.as_bytes());
157
158 let encryptor = Encryptor::new_with_public_key(Some(&public_key_str));
159 assert!(encryptor.is_ok());
160 }
161
162 #[test]
163 fn test_new_with_public_key_invalid() {
164 let encryptor = Encryptor::new_with_public_key(Some("invalid base85"));
165 assert!(matches!(
166 encryptor,
167 Err(EncryptionError::InvalidPublicKey(_))
168 ));
169 }
170
171 #[test]
172 fn test_new_with_public_key_none() {
173 let secret = StaticSecret::random_from_rng(&mut OsRng);
174 let public_key = PublicKey::from(&secret);
175 unsafe {
176 std::env::set_var("KITTY_PUBLIC_KEY", base85::encode(public_key.as_bytes()));
177 }
178
179 let encryptor = Encryptor::new_with_public_key(None);
180 assert!(encryptor.is_ok());
181 }
182
183 #[test]
184 fn test_encrypt_command() {
185 let secret = StaticSecret::random_from_rng(&mut OsRng);
186 let public_key = PublicKey::from(&secret);
187 unsafe {
188 std::env::set_var("KITTY_PUBLIC_KEY", base85::encode(public_key.as_bytes()));
189 }
190
191 let encryptor = Encryptor::new().unwrap();
192 let payload = serde_json::json!({"cmd": "ls", "password": "test", "timestamp": 1234567890});
193
194 let result = encryptor.encrypt_command(payload);
195 assert!(result.is_ok());
196
197 let encrypted = result.unwrap();
198 assert!(encrypted.is_object());
199 let obj = encrypted.as_object().unwrap();
200 assert!(obj.contains_key("version"));
201 assert!(obj.contains_key("iv"));
202 assert!(obj.contains_key("tag"));
203 assert!(obj.contains_key("pubkey"));
204 assert!(obj.contains_key("encrypted"));
205 }
206}