1use crate::error::{EncryptionError, KittyError};
2use aes_gcm::{
3 aead::{Aead, AeadCore, KeyInit},
4 Aes256Gcm,
5};
6use rand_core::OsRng;
7use sha2::{Digest, Sha256};
8use std::fs;
9use std::path::Path;
10use x25519_dalek::{PublicKey, StaticSecret};
11
12pub struct Encryptor {
18 kitty_public_key: PublicKey,
19}
20
21impl Encryptor {
22 pub fn new() -> Result<Self, EncryptionError> {
23 let kitty_public_key = Self::load_kitty_public_key()?;
24 Ok(Self { kitty_public_key })
25 }
26
27 pub fn new_with_public_key(public_key: Option<&str>) -> Result<Self, EncryptionError> {
28 let kitty_public_key = if let Some(pk) = public_key {
29 Self::parse_public_key(pk)?
30 } else {
31 let key_bytes = Self::read_kitty_public_key()?;
32 Self::bytes_to_public_key(&key_bytes)?
33 };
34
35 Ok(Self { kitty_public_key })
36 }
37
38 fn load_kitty_public_key() -> Result<PublicKey, EncryptionError> {
39 let key_bytes = Self::read_kitty_public_key()?;
40 Self::bytes_to_public_key(&key_bytes)
41 }
42
43 fn parse_public_key(key_str: &str) -> Result<PublicKey, EncryptionError> {
44 let key_data = key_str.strip_prefix("1:").ok_or_else(|| {
45 EncryptionError::InvalidPublicKey("Missing version prefix".to_string())
46 })?;
47 let key_bytes = base85::decode(key_data)
48 .map_err(|e| EncryptionError::InvalidPublicKey(e.to_string()))?;
49 Self::bytes_to_public_key(&key_bytes)
50 }
51
52 fn bytes_to_public_key(key_bytes: &[u8]) -> Result<PublicKey, EncryptionError> {
53 if key_bytes.len() < 32 {
54 return Err(EncryptionError::PublicKeyTooShort {
55 expected: 32,
56 actual: key_bytes.len(),
57 });
58 }
59
60 let mut key_array = [0u8; 32];
61 key_array.copy_from_slice(&key_bytes[..32]);
62 Ok(PublicKey::from(key_array))
63 }
64
65 fn read_kitty_public_key() -> Result<Vec<u8>, EncryptionError> {
74 if let Ok(key_str) = std::env::var("KITTY_PUBLIC_KEY") {
75 let key_data = key_str.strip_prefix("1:").ok_or_else(|| {
76 EncryptionError::InvalidPublicKey("Missing version prefix".to_string())
77 })?;
78 return base85::decode(key_data)
79 .map_err(|e| EncryptionError::InvalidPublicKey(e.to_string()));
80 }
81
82 let default_path = format!(
83 "{}/.config/kitty/key.pub",
84 std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
85 );
86
87 let key_path = Path::new(&default_path);
88 if !key_path.exists() {
89 return Err(EncryptionError::MissingPublicKey);
90 }
91
92 let key_bytes =
93 fs::read(&key_path).map_err(|e| EncryptionError::InvalidPublicKey(e.to_string()))?;
94
95 Ok(key_bytes)
96 }
97
98 pub fn encrypt_command(
99 &self,
100 payload: serde_json::Value,
101 ) -> Result<serde_json::Value, KittyError> {
102 let payload_str = serde_json::to_string(&payload)
103 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
104
105 let payload_bytes = payload_str.as_bytes();
106
107 let secret = StaticSecret::random_from_rng(&mut OsRng);
108 let public_key = PublicKey::from(&secret);
109 let shared_secret = secret.diffie_hellman(&self.kitty_public_key);
110
111 let mut hasher = Sha256::new();
112 hasher.update(shared_secret.as_bytes());
113 let encryption_key = hasher.finalize();
114
115 let cipher = Aes256Gcm::new_from_slice(&encryption_key)
116 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
117 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
118
119 let ciphertext = cipher
120 .encrypt(&nonce, payload_bytes)
121 .map_err(|e| EncryptionError::EncryptionFailed(e.to_string()))?;
122
123 let tag = &ciphertext[ciphertext.len() - 16..];
124 let encrypted_data = &ciphertext[..ciphertext.len() - 16];
125
126 let result = serde_json::json!({
127 "version": "0.43.1",
128 "iv": base85::encode(&nonce),
129 "tag": base85::encode(tag),
130 "pubkey": base85::encode(public_key.as_bytes()),
131 "encrypted": base85::encode(encrypted_data),
132 });
133
134 Ok(result)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_load_kitty_public_key_missing() {
144 unsafe {
146 std::env::remove_var("KITTY_PUBLIC_KEY");
147 }
148 let result = Encryptor::new();
149 assert!(matches!(result, Err(EncryptionError::MissingPublicKey)));
150 }
151
152 #[test]
153 fn test_load_kitty_public_key_invalid() {
154 unsafe {
156 std::env::set_var("KITTY_PUBLIC_KEY", "invalid base85");
157 }
158 let result = Encryptor::new();
159 assert!(matches!(result, Err(EncryptionError::InvalidPublicKey(_))));
160 }
161
162 #[test]
163 fn test_load_kitty_public_key_too_short() {
164 let short_key = format!("1:{}", base85::encode(&[1u8, 2, 3]));
165 unsafe {
167 std::env::set_var("KITTY_PUBLIC_KEY", short_key);
168 }
169 let result = Encryptor::new();
170 assert!(matches!(
171 result,
172 Err(EncryptionError::PublicKeyTooShort { .. })
173 ));
174 }
175
176 #[test]
177 fn test_new_with_public_key() {
178 let secret = StaticSecret::random_from_rng(&mut OsRng);
179 let public_key = PublicKey::from(&secret);
180 let public_key_str = format!("1:{}", base85::encode(public_key.as_bytes()));
181
182 let encryptor = Encryptor::new_with_public_key(Some(&public_key_str));
183 assert!(encryptor.is_ok());
184 }
185
186 #[test]
187 fn test_new_with_public_key_invalid() {
188 let encryptor = Encryptor::new_with_public_key(Some("invalid base85"));
189 assert!(matches!(
190 encryptor,
191 Err(EncryptionError::InvalidPublicKey(_))
192 ));
193 }
194
195 #[test]
196 fn test_new_with_public_key_none() {
197 let secret = StaticSecret::random_from_rng(&mut OsRng);
198 let public_key = PublicKey::from(&secret);
199 unsafe {
201 std::env::set_var(
202 "KITTY_PUBLIC_KEY",
203 format!("1:{}", base85::encode(public_key.as_bytes())),
204 );
205 }
206
207 let encryptor = Encryptor::new_with_public_key(None);
208 assert!(encryptor.is_ok());
209 }
210
211 #[test]
212 fn test_encrypt_command() {
213 let secret = StaticSecret::random_from_rng(&mut OsRng);
214 let public_key = PublicKey::from(&secret);
215 unsafe {
217 std::env::set_var(
218 "KITTY_PUBLIC_KEY",
219 format!("1:{}", base85::encode(public_key.as_bytes())),
220 );
221 }
222
223 let encryptor = Encryptor::new().unwrap();
224 let payload = serde_json::json!({"cmd": "ls", "password": "test", "timestamp": 1234567890});
225
226 let result = encryptor.encrypt_command(payload);
227 assert!(result.is_ok());
228
229 let encrypted = result.unwrap();
230 assert!(encrypted.is_object());
231 let obj = encrypted.as_object().unwrap();
232 assert!(obj.contains_key("version"));
233 assert!(obj.contains_key("iv"));
234 assert!(obj.contains_key("tag"));
235 assert!(obj.contains_key("pubkey"));
236 assert!(obj.contains_key("encrypted"));
237 }
238}