1use argon2::{Algorithm, Argon2, Params, Version};
14use hkdf::Hkdf;
15use rand::RngCore;
16use sha2::Sha256;
17use zeroize::{Zeroize, Zeroizing};
18
19use crate::error::JoyError;
20
21#[derive(Clone)]
23pub struct Salt([u8; 32]);
24
25impl Salt {
26 pub fn as_bytes(&self) -> &[u8; 32] {
27 &self.0
28 }
29
30 pub fn to_hex(&self) -> String {
31 hex::encode(self.0)
32 }
33
34 pub fn from_hex(s: &str) -> Result<Self, JoyError> {
35 let bytes =
36 hex::decode(s).map_err(|e| JoyError::AuthFailed(format!("invalid salt: {e}")))?;
37 let arr: [u8; 32] = bytes
38 .try_into()
39 .map_err(|_| JoyError::AuthFailed("salt must be 32 bytes".into()))?;
40 Ok(Self(arr))
41 }
42}
43
44pub struct DerivedKey(Zeroizing<[u8; 32]>);
46
47impl DerivedKey {
48 pub fn as_bytes(&self) -> &[u8; 32] {
49 &self.0
50 }
51}
52
53pub fn generate_salt() -> Salt {
55 let mut bytes = [0u8; 32];
56 rand::thread_rng().fill_bytes(&mut bytes);
57 Salt(bytes)
58}
59
60pub fn derive_key(passphrase: &str, salt: &Salt) -> Result<DerivedKey, JoyError> {
65 #[cfg(any(feature = "fast-kdf", debug_assertions))]
66 let params = Params::new(256, 1, 1, Some(32))
67 .map_err(|e| JoyError::AuthFailed(format!("argon2 params: {e}")))?;
68 #[cfg(not(any(feature = "fast-kdf", debug_assertions)))]
69 let params = Params::new(65536, 3, 4, Some(32))
70 .map_err(|e| JoyError::AuthFailed(format!("argon2 params: {e}")))?;
71 let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, params);
72
73 let mut output = Zeroizing::new([0u8; 32]);
74 argon2
75 .hash_password_into(passphrase.as_bytes(), salt.as_bytes(), output.as_mut())
76 .map_err(|e| JoyError::AuthFailed(format!("key derivation failed: {e}")))?;
77
78 Ok(DerivedKey(output))
79}
80
81pub fn derive_delegation_seed(
99 identity_key: &DerivedKey,
100 salt: &Salt,
101 project_id: &str,
102 ai_member_id: &str,
103) -> [u8; 32] {
104 let hk = Hkdf::<Sha256>::new(Some(salt.as_bytes()), identity_key.as_bytes());
105 let mut info = Vec::with_capacity(16 + project_id.len() + 1 + ai_member_id.len());
106 info.extend_from_slice(b"joy-delegation:");
107 info.extend_from_slice(project_id.as_bytes());
108 info.push(b':');
109 info.extend_from_slice(ai_member_id.as_bytes());
110 let mut out = [0u8; 32];
111 hk.expand(&info, &mut out)
112 .expect("HKDF-SHA256 expand to 32 bytes never fails");
113 out
114}
115
116pub fn validate_passphrase(passphrase: &str) -> Result<(), JoyError> {
118 let word_count = passphrase.split_whitespace().count();
119 if word_count < 6 {
120 return Err(JoyError::PassphraseTooShort);
121 }
122 Ok(())
123}
124
125impl Drop for Salt {
126 fn drop(&mut self) {
127 self.0.zeroize();
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 const TEST_PASSPHRASE: &str = "correct horse battery staple extra words";
136
137 #[test]
138 fn salt_is_random() {
139 let s1 = generate_salt();
140 let s2 = generate_salt();
141 assert_ne!(s1.as_bytes(), s2.as_bytes());
142 }
143
144 #[test]
145 fn salt_hex_roundtrip() {
146 let salt = generate_salt();
147 let hex = salt.to_hex();
148 let parsed = Salt::from_hex(&hex).unwrap();
149 assert_eq!(salt.as_bytes(), parsed.as_bytes());
150 }
151
152 #[test]
153 fn derive_deterministic() {
154 let salt =
155 Salt::from_hex("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef")
156 .unwrap();
157 let k1 = derive_key(TEST_PASSPHRASE, &salt).unwrap();
158 let k2 = derive_key(TEST_PASSPHRASE, &salt).unwrap();
159 assert_eq!(k1.as_bytes(), k2.as_bytes());
160 }
161
162 #[test]
163 fn derive_different_salt() {
164 let s1 = generate_salt();
165 let s2 = generate_salt();
166 let k1 = derive_key(TEST_PASSPHRASE, &s1).unwrap();
167 let k2 = derive_key(TEST_PASSPHRASE, &s2).unwrap();
168 assert_ne!(k1.as_bytes(), k2.as_bytes());
169 }
170
171 #[test]
172 fn derive_different_passphrase() {
173 let salt = generate_salt();
174 let k1 = derive_key("one two three four five six", &salt).unwrap();
175 let k2 = derive_key("seven eight nine ten eleven twelve", &salt).unwrap();
176 assert_ne!(k1.as_bytes(), k2.as_bytes());
177 }
178
179 #[test]
180 fn passphrase_too_short() {
181 assert!(validate_passphrase("one two three").is_err());
182 assert!(validate_passphrase("one two three four five").is_err());
183 }
184
185 #[test]
186 fn passphrase_valid() {
187 assert!(validate_passphrase("one two three four five six").is_ok());
188 assert!(validate_passphrase("a b c d e f g h").is_ok());
189 }
190
191 fn fixed_identity_key() -> DerivedKey {
194 let bytes = Zeroizing::new([7u8; 32]);
195 DerivedKey(bytes)
196 }
197
198 #[test]
199 fn delegation_seed_is_deterministic() {
200 let salt = generate_salt();
201 let id = fixed_identity_key();
202 let s1 = derive_delegation_seed(&id, &salt, "JOY", "ai:claude@joy");
203 let s2 = derive_delegation_seed(&id, &salt, "JOY", "ai:claude@joy");
204 assert_eq!(s1, s2);
205 }
206
207 #[test]
208 fn delegation_seed_changes_with_salt() {
209 let id = fixed_identity_key();
210 let s1 = derive_delegation_seed(&id, &generate_salt(), "JOY", "ai:claude@joy");
211 let s2 = derive_delegation_seed(&id, &generate_salt(), "JOY", "ai:claude@joy");
212 assert_ne!(s1, s2);
213 }
214
215 #[test]
216 fn delegation_seed_is_domain_separated_by_project() {
217 let salt = generate_salt();
218 let id = fixed_identity_key();
219 let s1 = derive_delegation_seed(&id, &salt, "JOY", "ai:claude@joy");
220 let s2 = derive_delegation_seed(&id, &salt, "OTHER", "ai:claude@joy");
221 assert_ne!(s1, s2);
222 }
223
224 #[test]
225 fn delegation_seed_is_domain_separated_by_member() {
226 let salt = generate_salt();
227 let id = fixed_identity_key();
228 let s1 = derive_delegation_seed(&id, &salt, "JOY", "ai:claude@joy");
229 let s2 = derive_delegation_seed(&id, &salt, "JOY", "ai:qwen@joy");
230 assert_ne!(s1, s2);
231 }
232
233 #[test]
234 fn delegation_seed_changes_with_identity_key() {
235 let salt = generate_salt();
236 let id_a = fixed_identity_key();
237 let id_b = {
238 let bytes = Zeroizing::new([8u8; 32]);
239 DerivedKey(bytes)
240 };
241 let s1 = derive_delegation_seed(&id_a, &salt, "JOY", "ai:claude@joy");
242 let s2 = derive_delegation_seed(&id_b, &salt, "JOY", "ai:claude@joy");
243 assert_ne!(s1, s2);
244 }
245}