1use std::{error::Error, fmt, str};
4
5const TABLE: &str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
6
7#[derive(Debug, PartialEq, Eq, Clone)]
8pub enum Base64Error {
9 InvalidDataLenght,
11 InvalidBase64Data,
13 EncodingError,
15}
16
17impl Error for Base64Error {}
18
19impl fmt::Display for Base64Error {
20 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 f.write_str("Base64 error : ")?;
22 f.write_str(match self {
23 Base64Error::InvalidDataLenght => "Invalid input data length",
24 Base64Error::InvalidBase64Data => "Invalid base64 data",
25 Base64Error::EncodingError => "Cannot encode input data",
26 })
27 }
28}
29
30impl From<std::string::FromUtf8Error> for Base64Error {
31 fn from(_e: std::string::FromUtf8Error) -> Base64Error {
32 Base64Error::InvalidBase64Data
33 }
34}
35
36impl From<std::str::Utf8Error> for Base64Error {
37 fn from(_e: std::str::Utf8Error) -> Base64Error {
38 Base64Error::EncodingError
39 }
40}
41
42impl From<std::num::ParseIntError> for Base64Error {
43 fn from(_e: std::num::ParseIntError) -> Base64Error {
44 Base64Error::EncodingError
45 }
46}
47
48impl From<Box<dyn Error>> for Base64Error {
49 fn from(_e: Box<dyn Error>) -> Base64Error {
50 Base64Error::EncodingError
51 }
52}
53
54pub trait Base64 {
55 fn encode(&self) -> Result<String, Base64Error>;
56 fn decode(&self) -> Result<String, Base64Error>;
57}
58
59impl Base64 for String {
60 fn encode(&self) -> Result<String, Base64Error> {
69 let a = self.as_bytes();
70
71 let mut octal = String::new();
72 let mut i = 0;
73
74 let inputlenmod = a.len() % 3;
76 let blockstoprocess = if inputlenmod == 0 {
77 a.len()
78 } else {
79 a.len() - inputlenmod
80 };
81 let padding = if inputlenmod != 0 {
82 3 - (a.len() - blockstoprocess)
83 } else {
84 0
85 };
86
87 while i < blockstoprocess {
89 octal.push_str(
90 format!("{:o}", u32::from_be_bytes([0, a[i], a[i + 1], a[i + 2]])).as_str(),
91 );
92 i += 3;
93 }
94
95 match padding {
96 1 => {
97 octal
98 .push_str(format!("{:o}", u32::from_be_bytes([0, a[i], a[i + 1], 0])).as_str());
99 }
100 2 => {
101 octal.push_str(format!("{:o}", u32::from_be_bytes([0, a[i], 0, 0])).as_str());
102 }
103 _ => {}
104 };
105
106 let sextets = octal
108 .as_bytes()
109 .chunks(2)
110 .map(str::from_utf8)
111 .map(|u| {
112 u.map_err::<Box<dyn Error>, _>(|e| e.into())
113 .and_then(|u| usize::from_str_radix(u, 8).map_err(|e| e.into()))
114 })
115 .collect::<Result<Vec<_>, _>>()?;
116
117 let mut result = String::new();
118
119 for i in 0..(sextets.len() - padding) {
120 result.push_str(&TABLE[sextets[i]..(sextets[i] + 1)]);
121 }
122 match padding {
123 1 => result.push('='),
124 2 => result.push_str("=="),
125 _ => {}
126 };
127 Ok(result)
128 }
129
130 fn decode(&self) -> Result<String, Base64Error> {
139 let mut encoded_data = self.to_owned();
140 let padding = encoded_data.matches('=').count();
141
142 if encoded_data.len() % 4 != 0 {
143 return Err(Base64Error::InvalidDataLenght);
144 };
145
146 for _ in 0..padding {
148 encoded_data.pop();
149 }
150
151 for _ in 0..padding {
152 encoded_data.push('A');
153 }
154
155 let octal = encoded_data
157 .chars()
158 .map(|c| format!("{:02o}", TABLE.find(c).unwrap_or(65))) .collect::<Vec<String>>();
160
161 let mut octalsextets = Vec::new();
163 let mut n = 0;
164 while n < encoded_data.len() {
165 let mut s = String::new();
166 for i in 0..4 {
167 if octal[n + i] == "101" {
168 return Err(Base64Error::InvalidBase64Data);
169 } s.push_str(octal[n + i].as_str());
171 }
172 n += 4;
173 octalsextets.push(s);
174 }
175
176 let decimal = octalsextets
178 .iter()
179 .map(|s| usize::from_str_radix(s, 8))
180 .collect::<Result<Vec<_>, _>>()?;
181
182 let mut bytes: Vec<u8> = Vec::new();
184 for i in 0..decimal.len() {
185 let a = decimal[i].to_be_bytes();
186 bytes.push(a[5]);
187 bytes.push(a[6]);
188 bytes.push(a[7]);
189 }
190
191 for _ in 0..padding {
193 bytes.pop();
194 }
195
196 let result = String::from_utf8(bytes)?;
197 Ok(result)
198 }
199}
200
201impl Base64 for Vec<u8> {
202 fn encode(&self) -> Result<String, Base64Error> {
211 let table = TABLE.as_bytes();
212
213 let mut input_buffer = Vec::new();
214 let l = self.len();
215 let mut i = 0;
216
217 let inputlenmod = l % 3;
219 let blockstoprocess = if inputlenmod == 0 { l } else { l - inputlenmod };
220
221 let padding = if inputlenmod != 0 {
222 3 - (l - blockstoprocess)
223 } else {
224 0
225 };
226
227 let mut base64_buffer: Vec<u8> = Vec::new();
228
229 while i < blockstoprocess {
232 input_buffer.push(u32::from_be_bytes([0, self[i], self[i + 1], self[i + 2]]));
233 i += 3;
234 }
235
236 match padding {
237 1 => {
238 input_buffer.push(u32::from_be_bytes([0, self[i], self[i + 1], 0]));
239 }
240 2 => {
241 input_buffer.push(u32::from_be_bytes([0, self[i], 0, 0]));
242 }
243 _ => {}
244 };
245
246 i = 0;
248 while i < input_buffer.len() {
249 let t0 = ((input_buffer[i] & 0xFC0000) >> 18) as u8;
250 let t1 = ((input_buffer[i] & 0x3F000) >> 12) as u8;
251 let t2 = ((input_buffer[i] & 0xFC0) >> 6) as u8;
252 let t3 = (input_buffer[i] & 0x3F) as u8;
253 base64_buffer.push(table[t0 as usize]);
254 base64_buffer.push(table[t1 as usize]);
255 base64_buffer.push(table[t2 as usize]);
256 base64_buffer.push(table[t3 as usize]);
257 i = i + 1;
258 }
259
260 let mut result = String::from_utf8(base64_buffer)?;
261 match padding {
262 1 => {
263 result.pop();
264 result.push('=');
265 }
266 2 => {
267 result.pop();
268 result.pop();
269 result.push_str("==");
270 }
271 _ => {}
272 };
273
274 Ok(result)
275 }
276
277 fn decode(&self) -> Result<String, Base64Error> {
286 let temp_string = self.to_owned();
287 let mut encoded_data = String::from_utf8(temp_string)?;
288 let padding = encoded_data.matches('=').count();
289
290 if encoded_data.len() % 4 != 0 {
291 return Err(Base64Error::InvalidDataLenght);
292 };
293
294 for _ in 0..padding {
296 encoded_data.pop();
297 }
298
299 for _ in 0..padding {
300 encoded_data.push('A');
301 }
302
303 let octal = encoded_data
305 .chars()
306 .map(|c| format!("{:02o}", TABLE.find(c).unwrap_or(65))) .collect::<Vec<String>>();
308
309 let mut octalsextets = Vec::new();
311 let mut n = 0;
312 while n < encoded_data.len() {
313 let mut s = String::new();
314 for i in 0..4 {
315 if octal[n + i] == "101" {
316 return Err(Base64Error::InvalidBase64Data);
317 } s.push_str(octal[n + i].as_str());
319 }
320 n += 4;
321 octalsextets.push(s);
322 }
323
324 let decimal = octalsextets
326 .iter()
327 .map(|s| usize::from_str_radix(s, 8))
328 .collect::<Result<Vec<_>, _>>()?;
329
330 let mut bytes: Vec<u8> = Vec::new();
332 for i in 0..decimal.len() {
333 let a = decimal[i].to_be_bytes();
334 bytes.push(a[5]);
335 bytes.push(a[6]);
336 bytes.push(a[7]);
337 }
338
339 for _ in 0..padding {
341 bytes.pop();
342 }
343
344 let result = String::from_utf8(bytes)?;
345 Ok(result)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use crate::Base64;
352 #[test]
353 fn encode_works() {
354 assert_eq!(
355 Ok(String::from("SmUgdCdhaW1lIG1hIGNow6lyaWU=")),
356 String::from("Je t'aime ma chérie").encode()
357 );
358 }
359
360 #[test]
361 fn encode_no_padding() {
362 assert_eq!(Ok(String::from("TWFu")), String::from("Man").encode());
363 }
364
365 #[test]
366 fn encode_one_padding() {
367 assert_eq!(Ok(String::from("TWE=")), String::from("Ma").encode());
368 }
369
370 #[test]
371 fn encode_two_padding() {
372 assert_eq!(Ok(String::from("TQ==")), String::from("M").encode());
373 }
374
375 #[test]
376 fn decode_works() {
377 assert_eq!(
378 Ok(String::from("Joyeux anniversaire !")),
379 String::from("Sm95ZXV4IGFubml2ZXJzYWlyZSAh").decode()
380 );
381 }
382
383 #[test]
384 fn datalength_check() {
385 assert_eq!(
386 Err(crate::Base64Error::InvalidDataLenght),
387 String::from("TWF").decode()
388 );
389 }
390
391 #[test]
392 fn validb64data_check() {
393 assert_eq!(
394 Err(crate::Base64Error::InvalidBase64Data),
395 String::from("TWF$").decode()
396 );
397 }
398
399 #[test]
400 fn encode_u8_no_padding() {
401 let input: Vec<u8> = vec![0x4d, 0x61, 0x6e];
402 assert_eq!(Ok(String::from("TWFu")), input.encode());
403 }
404
405 #[test]
406 fn encode_u8_one_padding() {
407 let input: Vec<u8> = vec![0x4d, 0x61];
408 assert_eq!(Ok(String::from("TWE=")), input.encode());
409 }
410
411 #[test]
412 fn encode_u8_two_padding() {
413 let input: Vec<u8> = vec![0x4d];
414 assert_eq!(Ok(String::from("TQ==")), input.encode());
415 }
416
417 #[test]
418 fn encode_u8() {
419 let input: Vec<u8> = String::from("light work.").as_bytes().to_vec();
420 assert_eq!(Ok(String::from("bGlnaHQgd29yay4=")), input.encode());
421 }
422
423 #[test]
424 fn decode_u8() {
425 let input: Vec<u8> = vec![
426 0x62, 0x47, 0x6C, 0x6E, 0x61, 0x48, 0x51, 0x67, 0x64, 0x32, 0x39, 0x79,
427 ];
428 assert_eq!(Ok(String::from("light wor")), input.decode());
429 }
430
431 #[test]
432 fn decode_u8_one_padding() {
433 let input: Vec<u8> = vec![
434 0x62, 0x47, 0x6C, 0x6E, 0x61, 0x48, 0x51, 0x67, 0x64, 0x32, 0x39, 0x79, 0x61, 0x79,
435 0x34, 0x3D,
436 ];
437 assert_eq!(Ok(String::from("light work.")), input.decode());
438 }
439
440 #[test]
441 fn decode_u8_two_padding() {
442 let input: Vec<u8> = vec![
443 0x62, 0x47, 0x6C, 0x6E, 0x61, 0x48, 0x51, 0x67, 0x64, 0x32, 0x39, 0x79, 0x61, 0x77,
444 0x3D, 0x3D,
445 ];
446 assert_eq!(Ok(String::from("light work")), input.decode());
447 }
448}