1use std::{collections::HashMap, io};
2
3use age::{Identity, Recipient};
4use age_core::format::{FileKey, Stanza};
5use age_plugin::{
6 identity::{self, IdentityPluginV1},
7 recipient::{self, RecipientPluginV1},
8 Callbacks,
9};
10use bincode::{config, Decode, Encode};
11
12use tlock_age::{internal::STANZA_TAG, Header};
13
14pub const ROUND_ENV: &str = "ROUND";
16
17#[derive(Debug, Encode, Decode, PartialEq, Clone)]
18pub struct RecipientInfo {
24 hash: Vec<u8>,
25 public_key_bytes: Vec<u8>,
26 genesis_time: u64,
27 period: u64,
28}
29
30impl RecipientInfo {
31 pub fn new(hash: &[u8], public_key_bytes: &[u8], genesis_time: u64, period: u64) -> Self {
32 Self {
33 hash: hash.to_vec(),
34 public_key_bytes: public_key_bytes.to_vec(),
35 genesis_time,
36 period,
37 }
38 }
39
40 fn serialize(&self) -> Vec<u8> {
41 bincode::encode_to_vec(self, config::standard()).unwrap()
42 }
43
44 fn deserialize(data: &[u8]) -> Self {
45 let (result, _) = bincode::decode_from_slice(data, config::standard()).unwrap();
46 result
47 }
48
49 pub fn hash(&self) -> Vec<u8> {
50 self.hash.clone()
51 }
52 pub fn public_key_bytes(&self) -> Vec<u8> {
53 self.public_key_bytes.clone()
54 }
55 pub fn genesis_time(&self) -> u64 {
56 self.genesis_time
57 }
58 pub fn period(&self) -> u64 {
59 self.period
60 }
61}
62
63struct RecipientPlugin {
64 plugin_name: String,
65 info: Option<RecipientInfo>,
66 parse_round: fn(&RecipientInfo, &str) -> u64,
67}
68
69impl RecipientPlugin {
70 pub fn new(plugin_name: &str, parse_round: fn(&RecipientInfo, &str) -> u64) -> Self {
71 Self {
72 plugin_name: plugin_name.to_owned(),
73 info: None,
74 parse_round,
75 }
76 }
77
78 pub fn plugin_name(&self) -> String {
79 self.plugin_name.clone()
80 }
81
82 pub fn info(&self) -> Option<RecipientInfo> {
83 self.info.clone()
84 }
85
86 pub fn parse_round(&self, round: &str) -> u64 {
87 (self.parse_round)(&self.info().unwrap(), round)
88 }
89}
90
91impl RecipientPluginV1 for RecipientPlugin {
92 fn add_recipient(
93 &mut self,
94 index: usize,
95 plugin_name: &str,
96 bytes: &[u8],
97 ) -> Result<(), recipient::Error> {
98 if plugin_name == self.plugin_name() {
99 let chain = RecipientInfo::deserialize(bytes);
100 self.info = Some(chain);
101 Ok(())
102 } else {
103 Err(recipient::Error::Recipient {
104 index,
105 message: "unsupported plugin".to_owned(),
106 })
107 }
108 }
109
110 fn add_identity(
111 &mut self,
112 _index: usize,
113 _plugin_name: &str,
114 _bytes: &[u8],
115 ) -> Result<(), recipient::Error> {
116 todo!()
117 }
118
119 fn wrap_file_keys(
120 &mut self,
121 file_keys: Vec<FileKey>,
122 mut callbacks: impl Callbacks<recipient::Error>,
123 ) -> io::Result<Result<Vec<Vec<Stanza>>, Vec<recipient::Error>>> {
124 let round = if let Ok(round) = std::env::var(ROUND_ENV) {
125 round
126 } else {
127 let prompt_message = "Enter decryption round: ";
128 match callbacks.request_public(prompt_message) {
129 Ok(round) => round.unwrap_or("".to_owned()),
130 Err(err) => return Err(err),
131 }
132 };
133 let round = self.parse_round(&round);
134
135 let info = self.info().unwrap();
136
137 let recipient =
138 tlock_age::internal::Recipient::new(&info.hash, &info.public_key_bytes, round);
139 Ok(Ok(file_keys
140 .into_iter()
141 .map(|file_key| recipient.wrap_file_key(&file_key).unwrap())
142 .collect()))
143 }
144}
145
146pub enum IdentityFormat {
150 RAW,
151 HTTP,
152}
153
154#[derive(Debug, Encode, Decode, PartialEq, Clone)]
155pub enum IdentityInfo {
157 RawIdentityInfo(RawIdentityInfo),
158 HTTPIdentityInfo(HTTPIdentityInfo),
159}
160
161impl IdentityInfo {
162 fn serialize(&self) -> Vec<u8> {
163 bincode::encode_to_vec(self, config::standard()).unwrap()
164 }
165
166 fn deserialize(data: &[u8]) -> Self {
167 let (result, _) = bincode::decode_from_slice(data, config::standard()).unwrap();
168 result
169 }
170
171 pub fn format(&self) -> IdentityFormat {
172 match self {
173 Self::RawIdentityInfo(_) => IdentityFormat::RAW,
174 Self::HTTPIdentityInfo(_) => IdentityFormat::HTTP,
175 }
176 }
177}
178
179impl From<RawIdentityInfo> for IdentityInfo {
180 fn from(value: RawIdentityInfo) -> Self {
181 IdentityInfo::RawIdentityInfo(value)
182 }
183}
184
185impl From<HTTPIdentityInfo> for IdentityInfo {
186 fn from(value: HTTPIdentityInfo) -> Self {
187 IdentityInfo::HTTPIdentityInfo(value)
188 }
189}
190
191#[derive(Debug, Encode, Decode, PartialEq, Clone)]
192pub struct RawIdentityInfo {
193 signature: Vec<u8>,
194}
195
196impl RawIdentityInfo {
197 pub fn new(signature: &[u8]) -> Self {
198 Self {
199 signature: signature.to_vec(),
200 }
201 }
202}
203
204#[derive(Debug, Encode, Decode, PartialEq, Clone)]
205pub struct HTTPIdentityInfo {
206 url: String,
207}
208
209impl HTTPIdentityInfo {
210 pub fn new(url: &str) -> Self {
211 Self {
212 url: url.to_owned(),
213 }
214 }
215}
216
217struct IdentityPlugin {
218 plugin_name: String,
219 info: Option<IdentityInfo>,
220 get_signature: fn(url: &str, header: &Header) -> Vec<u8>,
221}
222
223impl IdentityPlugin {
224 pub fn new(
225 plugin_name: &str,
226 get_signature: fn(url: &str, header: &Header) -> Vec<u8>,
227 ) -> Self {
228 Self {
229 plugin_name: plugin_name.to_owned(),
230 info: None,
231 get_signature,
232 }
233 }
234}
235
236impl IdentityPluginV1 for IdentityPlugin {
237 fn add_identity(
238 &mut self,
239 index: usize,
240 plugin_name: &str,
241 bytes: &[u8],
242 ) -> Result<(), identity::Error> {
243 if plugin_name == self.plugin_name.as_str() {
244 let info = IdentityInfo::deserialize(bytes);
245 self.info = Some(info);
246 Ok(())
247 } else {
248 Err(identity::Error::Identity {
249 index,
250 message: "unsupported plugin".to_owned(),
251 })
252 }
253 }
254
255 fn unwrap_file_keys(
256 &mut self,
257 files: Vec<Vec<Stanza>>,
258 _callbacks: impl Callbacks<identity::Error>,
259 ) -> io::Result<HashMap<usize, Result<FileKey, Vec<identity::Error>>>> {
260 let mut file_keys = HashMap::with_capacity(files.len());
261
262 for (file, stanzas) in files.iter().enumerate() {
263 for (_stanza_index, stanza) in stanzas.iter().enumerate() {
264 if stanza.tag != STANZA_TAG {
265 continue;
266 }
267 if stanza.args.len() != 2 {
268 continue; }
270 let [round, hash] = [stanza.args[0].clone(), stanza.args[1].clone()];
271 let round = round.parse().unwrap();
272 let hash = hex::decode(hash).unwrap();
273 let header = Header::new(round, &hash);
274
275 let signature = match self.info.as_ref().unwrap() {
276 IdentityInfo::HTTPIdentityInfo(info) => {
277 (self.get_signature)(info.url.as_str(), &header)
278 }
279 IdentityInfo::RawIdentityInfo(info) => info.signature.clone(),
280 };
281 let identity = tlock_age::internal::Identity::new(&hash, &signature);
282
283 let file_key = identity.unwrap_stanza(stanza).unwrap();
284 let r = file_key.map_err(|e| {
285 vec![identity::Error::Identity {
286 index: file,
287 message: format!("{e}"),
288 }]
289 });
290
291 file_keys.entry(file).or_insert_with(|| r);
292 }
293 }
294 Ok(file_keys)
295 }
296}
297
298pub fn run_state_machine(
301 state_machine: String,
302 plugin_name: &str,
303 parse_round: fn(&RecipientInfo, &str) -> u64,
304 get_signature: fn(&str, &Header) -> Vec<u8>,
305) -> io::Result<()> {
306 age_plugin::run_state_machine(
308 &state_machine,
309 || RecipientPlugin::new(plugin_name, parse_round),
310 || IdentityPlugin::new(plugin_name, get_signature),
311 )
312}
313
314pub fn print_new_identity(plugin_name: &str, identity: &IdentityInfo, recipient: &RecipientInfo) {
316 age_plugin::print_new_identity(plugin_name, &identity.serialize(), &recipient.serialize())
317}