1use core::ops::{self, Range};
2
3use num_traits::ToPrimitive;
4
5use crate::str::StrKind;
6use crate::wtf8::{CodePoint, Wtf8, Wtf8Buf};
7
8pub trait StrBuffer: AsRef<Wtf8> {
9 fn is_compatible_with(&self, kind: StrKind) -> bool {
10 let s = self.as_ref();
11 match kind {
12 StrKind::Ascii => s.is_ascii(),
13 StrKind::Utf8 => s.is_utf8(),
14 StrKind::Wtf8 => true,
15 }
16 }
17}
18
19pub trait CodecContext: Sized {
20 type Error;
21 type StrBuf: StrBuffer;
22 type BytesBuf: AsRef<[u8]>;
23
24 fn string(&self, s: Wtf8Buf) -> Self::StrBuf;
25 fn bytes(&self, b: Vec<u8>) -> Self::BytesBuf;
26}
27
28pub trait EncodeContext: CodecContext {
29 fn full_data(&self) -> &Wtf8;
30 fn data_len(&self) -> StrSize;
31
32 fn remaining_data(&self) -> &Wtf8;
33 fn position(&self) -> StrSize;
34
35 fn restart_from(&mut self, pos: StrSize) -> Result<(), Self::Error>;
36
37 fn error_encoding(&self, range: Range<StrSize>, reason: Option<&str>) -> Self::Error;
38
39 fn handle_error<E>(
40 &mut self,
41 errors: &E,
42 range: Range<StrSize>,
43 reason: Option<&str>,
44 ) -> Result<EncodeReplace<Self>, Self::Error>
45 where
46 E: EncodeErrorHandler<Self>,
47 {
48 let (replace, restart) = errors.handle_encode_error(self, range, reason)?;
49 self.restart_from(restart)?;
50 Ok(replace)
51 }
52}
53
54pub trait DecodeContext: CodecContext {
55 fn full_data(&self) -> &[u8];
56
57 fn remaining_data(&self) -> &[u8];
58 fn position(&self) -> usize;
59
60 fn advance(&mut self, by: usize);
61
62 fn restart_from(&mut self, pos: usize) -> Result<(), Self::Error>;
63
64 fn error_decoding(&self, byte_range: Range<usize>, reason: Option<&str>) -> Self::Error;
65
66 fn handle_error<E>(
67 &mut self,
68 errors: &E,
69 byte_range: Range<usize>,
70 reason: Option<&str>,
71 ) -> Result<Self::StrBuf, Self::Error>
72 where
73 E: DecodeErrorHandler<Self>,
74 {
75 let (replace, restart) = errors.handle_decode_error(self, byte_range, reason)?;
76 self.restart_from(restart)?;
77 Ok(replace)
78 }
79}
80
81pub trait EncodeErrorHandler<Ctx: EncodeContext> {
82 fn handle_encode_error(
83 &self,
84 ctx: &mut Ctx,
85 range: Range<StrSize>,
86 reason: Option<&str>,
87 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error>;
88}
89pub trait DecodeErrorHandler<Ctx: DecodeContext> {
90 fn handle_decode_error(
91 &self,
92 ctx: &mut Ctx,
93 byte_range: Range<usize>,
94 reason: Option<&str>,
95 ) -> Result<(Ctx::StrBuf, usize), Ctx::Error>;
96}
97
98pub enum EncodeReplace<Ctx: CodecContext> {
99 Str(Ctx::StrBuf),
100 Bytes(Ctx::BytesBuf),
101}
102
103#[derive(Copy, Clone, Default, Debug)]
104pub struct StrSize {
105 pub bytes: usize,
106 pub chars: usize,
107}
108
109fn iter_code_points(w: &Wtf8) -> impl Iterator<Item = (StrSize, CodePoint)> {
110 w.code_point_indices()
111 .enumerate()
112 .map(|(chars, (bytes, c))| (StrSize { bytes, chars }, c))
113}
114
115impl ops::Add for StrSize {
116 type Output = Self;
117 fn add(self, rhs: Self) -> Self::Output {
118 Self {
119 bytes: self.bytes + rhs.bytes,
120 chars: self.chars + rhs.chars,
121 }
122 }
123}
124
125impl ops::AddAssign for StrSize {
126 fn add_assign(&mut self, rhs: Self) {
127 self.bytes += rhs.bytes;
128 self.chars += rhs.chars;
129 }
130}
131
132struct DecodeError<'a> {
133 valid_prefix: &'a str,
134 rest: &'a [u8],
135 err_len: Option<usize>,
136}
137
138const unsafe fn make_decode_err(
141 v: &[u8],
142 valid_up_to: usize,
143 err_len: Option<usize>,
144) -> DecodeError<'_> {
145 let (valid_prefix, rest) = unsafe { v.split_at_unchecked(valid_up_to) };
146 let valid_prefix = unsafe { core::str::from_utf8_unchecked(valid_prefix) };
147 DecodeError {
148 valid_prefix,
149 rest,
150 err_len,
151 }
152}
153
154enum HandleResult<'a> {
155 Done,
156 Error {
157 err_len: Option<usize>,
158 reason: &'a str,
159 },
160}
161
162fn decode_utf8_compatible<Ctx, E, DecodeF, ErrF>(
163 mut ctx: Ctx,
164 errors: &E,
165 decode: DecodeF,
166 handle_error: ErrF,
167) -> Result<(Wtf8Buf, usize), Ctx::Error>
168where
169 Ctx: DecodeContext,
170 E: DecodeErrorHandler<Ctx>,
171 DecodeF: Fn(&[u8]) -> Result<&str, DecodeError<'_>>,
172 ErrF: Fn(&[u8], Option<usize>) -> HandleResult<'static>,
173{
174 if ctx.remaining_data().is_empty() {
175 return Ok((Wtf8Buf::new(), 0));
176 }
177 let mut out = Wtf8Buf::with_capacity(ctx.remaining_data().len());
178 loop {
179 match decode(ctx.remaining_data()) {
180 Ok(decoded) => {
181 out.push_str(decoded);
182 ctx.advance(decoded.len());
183 break;
184 }
185 Err(e) => {
186 out.push_str(e.valid_prefix);
187 match handle_error(e.rest, e.err_len) {
188 HandleResult::Done => {
189 ctx.advance(e.valid_prefix.len());
190 break;
191 }
192 HandleResult::Error { err_len, reason } => {
193 let err_start = ctx.position() + e.valid_prefix.len();
194 let err_end = match err_len {
195 Some(len) => err_start + len,
196 None => ctx.full_data().len(),
197 };
198 let err_range = err_start..err_end;
199 let replace = ctx.handle_error(errors, err_range, Some(reason))?;
200 out.push_wtf8(replace.as_ref());
201 continue;
202 }
203 }
204 }
205 }
206 }
207 Ok((out, ctx.position()))
208}
209
210#[inline]
211fn encode_utf8_compatible<Ctx, E>(
212 mut ctx: Ctx,
213 errors: &E,
214 err_reason: &str,
215 target_kind: StrKind,
216) -> Result<Vec<u8>, Ctx::Error>
217where
218 Ctx: EncodeContext,
219 E: EncodeErrorHandler<Ctx>,
220{
221 let mut out = Vec::<u8>::with_capacity(ctx.remaining_data().len());
224 loop {
225 let data = ctx.remaining_data();
226 let mut iter = iter_code_points(data);
227 let Some((i, _)) = iter.find(|(_, c)| !target_kind.can_encode(*c)) else {
228 break;
229 };
230
231 out.extend_from_slice(&ctx.remaining_data().as_bytes()[..i.bytes]);
232
233 let err_start = ctx.position() + i;
234 let err_end = match { iter }.find(|(_, c)| target_kind.can_encode(*c)) {
236 Some((i, _)) => ctx.position() + i,
237 None => ctx.data_len(),
238 };
239
240 let range = err_start..err_end;
241 let replace = ctx.handle_error(errors, range.clone(), Some(err_reason))?;
242 match replace {
243 EncodeReplace::Str(s) => {
244 if s.is_compatible_with(target_kind) {
245 out.extend_from_slice(s.as_ref().as_bytes());
246 } else {
247 return Err(ctx.error_encoding(range, Some(err_reason)));
248 }
249 }
250 EncodeReplace::Bytes(b) => {
251 out.extend_from_slice(b.as_ref());
252 }
253 }
254 }
255 out.extend_from_slice(ctx.remaining_data().as_bytes());
256 Ok(out)
257}
258
259pub mod errors {
260 use crate::str::UnicodeEscapeCodepoint;
261
262 use super::*;
263 use core::fmt::Write;
264
265 #[derive(Clone, Copy)]
266 pub struct Strict;
267
268 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for Strict {
269 fn handle_encode_error(
270 &self,
271 ctx: &mut Ctx,
272 range: Range<StrSize>,
273 reason: Option<&str>,
274 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
275 Err(ctx.error_encoding(range, reason))
276 }
277 }
278
279 impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for Strict {
280 fn handle_decode_error(
281 &self,
282 ctx: &mut Ctx,
283 byte_range: Range<usize>,
284 reason: Option<&str>,
285 ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
286 Err(ctx.error_decoding(byte_range, reason))
287 }
288 }
289
290 #[derive(Clone, Copy)]
291 pub struct Ignore;
292
293 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for Ignore {
294 fn handle_encode_error(
295 &self,
296 ctx: &mut Ctx,
297 range: Range<StrSize>,
298 _reason: Option<&str>,
299 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
300 Ok((EncodeReplace::Bytes(ctx.bytes(b"".into())), range.end))
301 }
302 }
303
304 impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for Ignore {
305 fn handle_decode_error(
306 &self,
307 ctx: &mut Ctx,
308 byte_range: Range<usize>,
309 _reason: Option<&str>,
310 ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
311 Ok((ctx.string("".into()), byte_range.end))
312 }
313 }
314
315 #[derive(Clone, Copy)]
316 pub struct Replace;
317
318 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for Replace {
319 fn handle_encode_error(
320 &self,
321 ctx: &mut Ctx,
322 range: Range<StrSize>,
323 _reason: Option<&str>,
324 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
325 let replace = "?".repeat(range.end.chars - range.start.chars);
326 Ok((EncodeReplace::Str(ctx.string(replace.into())), range.end))
327 }
328 }
329
330 impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for Replace {
331 fn handle_decode_error(
332 &self,
333 ctx: &mut Ctx,
334 byte_range: Range<usize>,
335 _reason: Option<&str>,
336 ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
337 Ok((
338 ctx.string(char::REPLACEMENT_CHARACTER.to_string().into()),
339 byte_range.end,
340 ))
341 }
342 }
343
344 #[derive(Clone, Copy)]
345 pub struct XmlCharRefReplace;
346
347 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for XmlCharRefReplace {
348 fn handle_encode_error(
349 &self,
350 ctx: &mut Ctx,
351 range: Range<StrSize>,
352 _reason: Option<&str>,
353 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
354 let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
355 let num_chars = range.end.chars - range.start.chars;
356 let mut out = String::with_capacity(num_chars * 6);
358 for c in err_str.code_points() {
359 write!(out, "&#{};", c.to_u32()).unwrap()
360 }
361 Ok((EncodeReplace::Str(ctx.string(out.into())), range.end))
362 }
363 }
364
365 #[derive(Clone, Copy)]
366 pub struct BackslashReplace;
367
368 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for BackslashReplace {
369 fn handle_encode_error(
370 &self,
371 ctx: &mut Ctx,
372 range: Range<StrSize>,
373 _reason: Option<&str>,
374 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
375 let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
376 let num_chars = range.end.chars - range.start.chars;
377 let mut out = String::with_capacity(num_chars * 4);
379 for c in err_str.code_points() {
380 write!(out, "{}", UnicodeEscapeCodepoint(c)).unwrap();
381 }
382 Ok((EncodeReplace::Str(ctx.string(out.into())), range.end))
383 }
384 }
385
386 impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for BackslashReplace {
387 fn handle_decode_error(
388 &self,
389 ctx: &mut Ctx,
390 byte_range: Range<usize>,
391 _reason: Option<&str>,
392 ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
393 let err_bytes = &ctx.full_data()[byte_range.clone()];
394 let mut replace = String::with_capacity(4 * err_bytes.len());
395 for &c in err_bytes {
396 write!(replace, "\\x{c:02x}").unwrap();
397 }
398 Ok((ctx.string(replace.into()), byte_range.end))
399 }
400 }
401
402 #[derive(Clone, Copy)]
403 pub struct NameReplace;
404
405 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for NameReplace {
406 fn handle_encode_error(
407 &self,
408 ctx: &mut Ctx,
409 range: Range<StrSize>,
410 _reason: Option<&str>,
411 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
412 let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
413 let num_chars = range.end.chars - range.start.chars;
414 let mut out = String::with_capacity(num_chars * 4);
415 for c in err_str.code_points() {
416 let c_u32 = c.to_u32();
417 if let Some(c_name) = c.to_char().and_then(unicode_names2::name) {
418 write!(out, "\\N{{{c_name}}}").unwrap();
419 } else if c_u32 >= 0x10000 {
420 write!(out, "\\U{c_u32:08x}").unwrap();
421 } else if c_u32 >= 0x100 {
422 write!(out, "\\u{c_u32:04x}").unwrap();
423 } else {
424 write!(out, "\\x{c_u32:02x}").unwrap();
425 }
426 }
427 Ok((EncodeReplace::Str(ctx.string(out.into())), range.end))
428 }
429 }
430
431 #[derive(Clone, Copy)]
432 pub struct SurrogateEscape;
433
434 impl<Ctx: EncodeContext> EncodeErrorHandler<Ctx> for SurrogateEscape {
435 fn handle_encode_error(
436 &self,
437 ctx: &mut Ctx,
438 range: Range<StrSize>,
439 reason: Option<&str>,
440 ) -> Result<(EncodeReplace<Ctx>, StrSize), Ctx::Error> {
441 let err_str = &ctx.full_data()[range.start.bytes..range.end.bytes];
442 let num_chars = range.end.chars - range.start.chars;
443 let mut out = Vec::with_capacity(num_chars);
444 let mut pos = range.start;
445 for ch in err_str.code_points() {
446 let ch_u32 = ch.to_u32();
447 if !(0xdc80..=0xdcff).contains(&ch_u32) {
448 if out.is_empty() {
449 return Err(ctx.error_encoding(range, reason));
451 }
452 return Ok((EncodeReplace::Bytes(ctx.bytes(out)), pos));
454 }
455 out.push((ch_u32 - 0xdc00) as u8);
456 pos += StrSize {
457 bytes: ch.len_wtf8(),
458 chars: 1,
459 };
460 }
461 Ok((EncodeReplace::Bytes(ctx.bytes(out)), range.end))
462 }
463 }
464
465 impl<Ctx: DecodeContext> DecodeErrorHandler<Ctx> for SurrogateEscape {
466 fn handle_decode_error(
467 &self,
468 ctx: &mut Ctx,
469 byte_range: Range<usize>,
470 reason: Option<&str>,
471 ) -> Result<(Ctx::StrBuf, usize), Ctx::Error> {
472 let err_bytes = &ctx.full_data()[byte_range.clone()];
473 let mut consumed = 0;
474 let mut replace = Wtf8Buf::with_capacity(4 * byte_range.len());
475 while consumed < 4 && consumed < byte_range.len() {
476 let c = err_bytes[consumed] as u16;
477 if c < 128 {
479 break;
480 }
481 replace.push(CodePoint::from(0xdc00 + c));
482 consumed += 1;
483 }
484 if consumed == 0 {
485 return Err(ctx.error_decoding(byte_range, reason));
486 }
487 Ok((ctx.string(replace), byte_range.start + consumed))
488 }
489 }
490}
491
492pub mod utf8 {
493 use super::*;
494
495 pub const ENCODING_NAME: &str = "utf-8";
496
497 #[inline]
498 pub fn encode<Ctx, E>(ctx: Ctx, errors: &E) -> Result<Vec<u8>, Ctx::Error>
499 where
500 Ctx: EncodeContext,
501 E: EncodeErrorHandler<Ctx>,
502 {
503 encode_utf8_compatible(ctx, errors, "surrogates not allowed", StrKind::Utf8)
504 }
505
506 pub fn decode<Ctx: DecodeContext, E: DecodeErrorHandler<Ctx>>(
507 ctx: Ctx,
508 errors: &E,
509 final_decode: bool,
510 ) -> Result<(Wtf8Buf, usize), Ctx::Error> {
511 decode_utf8_compatible(
512 ctx,
513 errors,
514 |v| {
515 core::str::from_utf8(v).map_err(|e| {
516 unsafe { make_decode_err(v, e.valid_up_to(), e.error_len()) }
519 })
520 },
521 |rest, err_len| {
522 let first_err = rest[0];
523 if matches!(first_err, 0x80..=0xc1 | 0xf5..=0xff) {
524 HandleResult::Error {
525 err_len: Some(1),
526 reason: "invalid start byte",
527 }
528 } else if err_len.is_none() {
529 if final_decode {
531 HandleResult::Error {
532 err_len,
533 reason: "unexpected end of data",
534 }
535 } else {
536 HandleResult::Done
537 }
538 } else if !final_decode && matches!(rest, [0xed, 0xa0..=0xbf]) {
539 HandleResult::Done
541 } else {
542 HandleResult::Error {
543 err_len,
544 reason: "invalid continuation byte",
545 }
546 }
547 },
548 )
549 }
550}
551
552pub mod latin_1 {
553 use super::*;
554
555 pub const ENCODING_NAME: &str = "latin-1";
556
557 const ERR_REASON: &str = "ordinal not in range(256)";
558
559 #[inline]
560 pub fn encode<Ctx, E>(mut ctx: Ctx, errors: &E) -> Result<Vec<u8>, Ctx::Error>
561 where
562 Ctx: EncodeContext,
563 E: EncodeErrorHandler<Ctx>,
564 {
565 let mut out = Vec::<u8>::new();
566 loop {
567 let data = ctx.remaining_data();
568 let mut iter = iter_code_points(ctx.remaining_data());
569 let Some((i, ch)) = iter.find(|(_, c)| !c.is_ascii()) else {
570 break;
571 };
572 out.extend_from_slice(&data.as_bytes()[..i.bytes]);
573 let err_start = ctx.position() + i;
574 if let Some(byte) = ch.to_u32().to_u8() {
575 drop(iter);
576 out.push(byte);
577 ctx.restart_from(err_start + StrSize { bytes: 2, chars: 1 })?;
579 } else {
580 let err_end = match { iter }.find(|(_, c)| c.to_u32() <= 255) {
582 Some((i, _)) => ctx.position() + i,
583 None => ctx.data_len(),
584 };
585 let err_range = err_start..err_end;
586 let replace = ctx.handle_error(errors, err_range.clone(), Some(ERR_REASON))?;
587 match replace {
588 EncodeReplace::Str(s) => {
589 if s.as_ref().code_points().any(|c| c.to_u32() > 255) {
590 return Err(ctx.error_encoding(err_range, Some(ERR_REASON)));
591 }
592 out.extend(s.as_ref().code_points().map(|c| c.to_u32() as u8));
593 }
594 EncodeReplace::Bytes(b) => {
595 out.extend_from_slice(b.as_ref());
596 }
597 }
598 }
599 }
600 out.extend_from_slice(ctx.remaining_data().as_bytes());
601 Ok(out)
602 }
603
604 pub fn decode<Ctx: DecodeContext, E: DecodeErrorHandler<Ctx>>(
605 ctx: Ctx,
606 _errors: &E,
607 ) -> Result<(Wtf8Buf, usize), Ctx::Error> {
608 let out: String = ctx.remaining_data().iter().map(|c| *c as char).collect();
609 let out_len = out.len();
610 Ok((out.into(), out_len))
611 }
612}
613
614pub mod ascii {
615 use super::*;
616 use ::ascii::AsciiStr;
617
618 pub const ENCODING_NAME: &str = "ascii";
619
620 const ERR_REASON: &str = "ordinal not in range(128)";
621
622 #[inline]
623 pub fn encode<Ctx, E>(ctx: Ctx, errors: &E) -> Result<Vec<u8>, Ctx::Error>
624 where
625 Ctx: EncodeContext,
626 E: EncodeErrorHandler<Ctx>,
627 {
628 encode_utf8_compatible(ctx, errors, ERR_REASON, StrKind::Ascii)
629 }
630
631 pub fn decode<Ctx: DecodeContext, E: DecodeErrorHandler<Ctx>>(
632 ctx: Ctx,
633 errors: &E,
634 ) -> Result<(Wtf8Buf, usize), Ctx::Error> {
635 decode_utf8_compatible(
636 ctx,
637 errors,
638 |v| {
639 AsciiStr::from_ascii(v).map(|s| s.as_str()).map_err(|e| {
640 unsafe { make_decode_err(v, e.valid_up_to(), Some(1)) }
643 })
644 },
645 |_rest, err_len| HandleResult::Error {
646 err_len,
647 reason: ERR_REASON,
648 },
649 )
650 }
651}