1use alloc::string::String;
8use alloc::vec::Vec;
9use core::cmp::min;
10use core::iter;
11use core::mem;
12use core::ops::Bound;
13use core::ops::{Deref, RangeBounds};
14use core::time::Duration;
15
16use lru_time_cache::LruCache;
17
18mod block_value;
19
20use crate::error::HandlingError;
21use crate::{CoapOption, CoapRequest, MessageClass, Packet, ResponseType};
22pub use block_value::BlockValue;
23
24const BLOCK_OPTIONS_MAX_LENGTH: usize = 12;
27
28const MAXIMUM_UNCOMMITTED_BUFFER_RESERVE_LENGTH: usize = 16 * 1024;
38
39const DEFAULT_MAX_TOTAL_MESSAGE_SIZE: usize = 1152;
41
42pub struct BlockHandler<Endpoint: Ord + Clone> {
45 config: BlockHandlerConfig,
46
47 states: LruCache<RequestCacheKey<Endpoint>, BlockState>,
52}
53
54pub struct BlockHandlerConfig {
56 pub max_total_message_size: usize,
67
68 pub cache_expiry_duration: Duration,
71}
72
73impl Default for BlockHandlerConfig {
74 fn default() -> Self {
75 Self {
76 max_total_message_size: DEFAULT_MAX_TOTAL_MESSAGE_SIZE,
77 cache_expiry_duration: Duration::from_secs(120),
78 }
79 }
80}
81
82impl<Endpoint: Ord + Clone> BlockHandler<Endpoint> {
83 pub fn new(config: BlockHandlerConfig) -> Self {
86 Self {
87 states: LruCache::with_expiry_duration(
88 config.cache_expiry_duration,
89 ),
90 config,
91 }
92 }
93
94 pub fn intercept_request(
101 &mut self,
102 request: &mut CoapRequest<Endpoint>,
103 ) -> Result<bool, HandlingError> {
104 let state = self
105 .states
106 .entry(request.deref().into())
107 .or_insert(BlockState::default());
108 let block1_handled = Self::maybe_handle_request_block1(
109 request,
110 self.config.max_total_message_size,
111 state,
112 )?;
113 if block1_handled {
114 return Ok(true);
115 }
116
117 let block2_handled =
118 Self::maybe_handle_request_block2(request, state)?;
119 if block2_handled {
120 return Ok(true);
121 }
122
123 Ok(false)
124 }
125
126 fn maybe_handle_request_block1(
127 request: &mut CoapRequest<Endpoint>,
128 max_total_message_size: usize,
129 state: &mut BlockState,
130 ) -> Result<bool, HandlingError> {
131 let request_block1 = request
132 .message
133 .get_first_option_as::<BlockValue>(CoapOption::Block1)
134 .and_then(|x| x.ok());
135 let maybe_response_block1 = Self::negotiate_block_size_if_necessary(
136 request_block1.as_ref(),
137 Self::compute_message_size_hack(&mut request.message),
138 request.message.payload.len(),
139 max_total_message_size,
140 )?;
141
142 match (request_block1, maybe_response_block1) {
143 (Some(request_block1), Some(response_block1)) => {
144 if state.cached_request_payload.is_none() {
145 state.cached_request_payload = Some(Vec::new());
146 }
147 let cached_payload =
148 state.cached_request_payload.as_mut().unwrap();
149
150 let payload_offset =
151 usize::from(request_block1.num) * request_block1.size();
152 extending_splice(
153 cached_payload,
154 payload_offset..payload_offset + request_block1.size(),
155 request.message.payload.iter().copied(),
156 MAXIMUM_UNCOMMITTED_BUFFER_RESERVE_LENGTH,
157 )
158 .map_err(HandlingError::internal)?;
159
160 if request_block1.more {
161 let response = request
162 .response
163 .as_mut()
164 .ok_or_else(HandlingError::not_handled)?;
165 response
166 .message
167 .add_option_as(CoapOption::Block1, response_block1);
168 response.message.header.code =
169 MessageClass::Response(ResponseType::Continue);
170 Ok(true)
171 } else {
172 let cached_payload =
173 mem::take(&mut state.cached_request_payload).unwrap();
174 request.message.payload = cached_payload;
175
176 let response = request
180 .response
181 .as_mut()
182 .ok_or_else(HandlingError::not_handled)?;
183 response
184 .message
185 .add_option_as(CoapOption::Block1, response_block1);
186
187 Ok(false)
188 }
189 }
190 (None, Some(response_block1)) => {
191 let response = request
192 .response
193 .as_mut()
194 .ok_or_else(HandlingError::not_handled)?;
195 response
196 .message
197 .add_option_as(CoapOption::Block1, response_block1);
198 response.message.header.code = MessageClass::Response(
199 ResponseType::RequestEntityTooLarge,
200 );
201 Ok(true)
202 }
203 _ => Ok(false),
204 }
205 }
206
207 fn maybe_handle_request_block2(
208 request: &mut CoapRequest<Endpoint>,
209 state: &mut BlockState,
210 ) -> Result<bool, HandlingError> {
211 let maybe_block2 = request
212 .message
213 .get_first_option_as::<BlockValue>(CoapOption::Block2)
214 .and_then(|x| x.ok());
215 state.last_request_block2.clone_from(&maybe_block2);
216
217 if let Some(block2) = maybe_block2 {
218 if let Some(ref response) = state.cached_response {
219 let has_more_chunks = Self::maybe_serve_cached_response(
220 request, block2, response,
221 )?;
222 if !has_more_chunks {
223 state.cached_response = None
224 }
225 return Ok(true);
226 }
227 }
228
229 Ok(false)
230 }
231
232 fn maybe_serve_cached_response(
233 request: &mut CoapRequest<Endpoint>,
234 request_block2: BlockValue,
235 cached_response: &Packet,
236 ) -> Result<bool, HandlingError> {
237 let response = request
238 .response
239 .as_mut()
240 .ok_or_else(HandlingError::not_handled)?;
241
242 Self::packet_clone_limited(&mut response.message, cached_response);
243
244 let cached_payload = &cached_response.payload;
245
246 let request_block_size = request_block2.size();
247 let mut chunks = cached_payload
248 .chunks(request_block_size)
249 .skip(usize::from(request_block2.num));
250
251 let cached_payload_chunk = chunks.next().ok_or_else(|| {
252 HandlingError::bad_request(format!(
253 "num={}, block_size={}",
254 request_block2.num,
255 request_block2.size()
256 ))
257 })?;
258
259 let response_payload = &mut response.message.payload;
260 response_payload.clear();
261 response_payload.extend(cached_payload_chunk);
262
263 let has_more_chunks = chunks.next().is_some();
264 let response_block2 = BlockValue {
265 more: has_more_chunks,
266 ..request_block2
267 };
268
269 response.message.set_options_as::<BlockValue>(
270 CoapOption::Block2,
271 [response_block2].into(),
272 );
273
274 Ok(has_more_chunks)
275 }
276
277 fn packet_clone_limited(dst: &mut Packet, src: &Packet) {
280 dst.header.set_version(src.header.get_version());
281 dst.header.set_type(src.header.get_type());
282 dst.header.code = src.header.code;
283 for (&option, value) in src.options() {
284 dst.set_option(CoapOption::from(option), value.clone());
285 }
286 }
287
288 pub fn intercept_response(
297 &mut self,
298 request: &mut CoapRequest<Endpoint>,
299 ) -> Result<bool, HandlingError> {
300 let state = self
301 .states
302 .entry(request.deref().into())
303 .or_insert(BlockState::default());
304 if let Some(ref mut response) = request.response {
305 if response.message.get_option(CoapOption::Block2).is_none() {
308 if let Some(request_block2) =
309 Self::negotiate_block_size_if_necessary(
310 state.last_request_block2.as_ref(),
311 Self::compute_message_size_hack(&mut response.message),
312 response.message.payload.len(),
313 self.config.max_total_message_size,
314 )?
315 {
316 let cached_response = response.message.clone();
317 let has_more_chunks = Self::maybe_serve_cached_response(
318 request,
319 request_block2,
320 &cached_response,
321 )?;
322 if has_more_chunks {
323 state.cached_response = Some(cached_response);
324 return Ok(true);
325 }
326 }
327 }
328 }
329
330 Ok(false)
331 }
332
333 fn compute_message_size_hack(packet: &mut Packet) -> usize {
336 let moved_payload = mem::take(&mut packet.payload);
337 let size_sans_payload = packet
338 .to_bytes()
339 .expect("Internal error encoding packet")
340 .len();
341 packet.payload = moved_payload;
342
343 size_sans_payload + packet.payload.len()
344 }
345
346 fn negotiate_block_size_if_necessary(
347 request_block: Option<&BlockValue>,
348 message_size: usize,
349 total_payload_size: usize,
350 max_total_message_size: usize,
351 ) -> Result<Option<BlockValue>, HandlingError> {
352 let max_non_payload_size =
353 (message_size + BLOCK_OPTIONS_MAX_LENGTH) - total_payload_size;
354 let max_block_size = max_total_message_size
355 .checked_sub(max_non_payload_size)
356 .ok_or_else(|| {
357 HandlingError::internal(format!(
358 "Message too large to encode at any block size: {} exceeds {}",
359 max_total_message_size,
360 max_non_payload_size))
361 })?;
362
363 let maybe_response_block = match request_block {
364 Some(request_block) => {
365 let negotiated_block_size =
368 min(request_block.size(), max_block_size);
369
370 let reply_start_offset =
371 usize::from(request_block.num) * request_block.size();
372 let reply_end_offset =
373 reply_start_offset + negotiated_block_size;
374
375 let num = reply_start_offset / negotiated_block_size;
376 let more = reply_end_offset < total_payload_size;
377
378 Some(BlockValue::new(num, more, negotiated_block_size))
379 }
380 None => {
381 if total_payload_size < max_block_size {
382 None
386 } else {
387 Some(BlockValue::new(
390 0,
391 true, max_block_size,
393 ))
394 }
395 }
396 };
397
398 match maybe_response_block {
399 Some(block) => block.map(Some).map_err(HandlingError::internal),
400 None => Ok(None),
401 }
402 }
403}
404
405pub fn extending_splice<R, I, T>(
410 dst: &mut Vec<T>,
411 range: R,
412 replace_with: I,
413 maximum_reserve_len: usize,
414) -> Result<alloc::vec::Splice<I::IntoIter>, String>
415where
416 R: RangeBounds<usize>,
417 I: IntoIterator<Item = T>,
418 T: Default + Copy,
419{
420 let end_index_plus_1 = match range.end_bound() {
421 Bound::Included(&included) => included + 1,
422 Bound::Excluded(&excluded) => excluded,
423 Bound::Unbounded => panic!(),
424 };
425
426 if let Some(extend_len) = end_index_plus_1.checked_sub(dst.len()) {
427 if extend_len > maximum_reserve_len {
428 return Err(format!(
429 "extend_len={}, maximum_extend_len={}",
430 extend_len, maximum_reserve_len
431 ));
432 }
433 dst.extend(iter::repeat(T::default()).take(extend_len));
435 }
436
437 Ok(dst.splice(range, replace_with))
438}
439
440#[derive(Ord, PartialOrd, Eq, PartialEq, Clone)]
442pub struct RequestCacheKey<Endpoint: Ord + Clone> {
443 request_type_ord: u8,
445 path: Vec<String>,
446 requester: Option<Endpoint>,
447}
448
449impl<Endpoint: Ord + Clone> From<&CoapRequest<Endpoint>>
450 for RequestCacheKey<Endpoint>
451{
452 fn from(request: &CoapRequest<Endpoint>) -> Self {
453 Self {
454 request_type_ord: u8::from(MessageClass::Request(
455 *request.get_method(),
456 )),
457 path: request.get_path_as_vec().unwrap_or_default(),
458 requester: request.source.clone(),
459 }
460 }
461}
462
463#[derive(Debug, Clone, Default)]
465pub struct BlockState {
466 last_request_block2: Option<BlockValue>,
470
471 cached_response: Option<Packet>,
474
475 cached_request_payload: Option<Vec<u8>>,
483}
484
485#[cfg(test)]
486mod tests {
487 use alloc::{borrow::ToOwned, collections::LinkedList};
488
489 use crate::option_value::OptionValueString;
490 use crate::{CoapResponse, RequestType, ResponseType};
491
492 use super::*;
493
494 #[derive(Ord, PartialOrd, Eq, PartialEq, Clone)]
495 enum TestEndpoint {
496 TestClient,
497 }
498
499 #[test]
500 fn test_cached_response_with_blocks() {
501 let block = "0123456789\n";
502
503 let mut harness = TestServerHarness::new(32);
504
505 let expected_payload = block.repeat(8).into_bytes();
506 let delivered_payload = expected_payload.clone();
507
508 let mut sent_req = create_get_request("test", 1, None);
509 let mut received_response = harness
510 .exchange_messages(&mut sent_req, move |received_request| {
511 let sent_response =
512 received_request.response.as_mut().unwrap();
513 sent_response.message.header.code =
514 MessageClass::Response(ResponseType::Content);
515 sent_response.message.payload = delivered_payload;
516 InterceptPolicy::Expected
517 })
518 .unwrap();
519
520 let mut received_payload = Vec::<u8>::new();
521
522 let total_blocks = loop {
523 received_payload.extend(received_response.message.payload.clone());
524
525 let received_block = received_response
526 .message
527 .get_first_option_as::<BlockValue>(CoapOption::Block2)
528 .unwrap()
529 .unwrap();
530 let block_size = received_block.size();
531 let block_num = received_block.num;
532
533 if !received_block.more {
534 break block_num;
535 }
536
537 let sent_block = BlockValue::new(
538 usize::from(block_num + 1),
539 false, block_size,
541 )
542 .unwrap();
543 let mut next_sent_req = create_get_request(
544 "test",
545 received_response.message.header.message_id + 1,
546 Some(sent_block),
547 );
548
549 received_response = harness
550 .exchange_messages_using_cache(&mut next_sent_req)
551 .unwrap();
552
553 assert_eq!(
556 received_response.message.header.message_id,
557 next_sent_req.message.header.message_id
558 );
559 };
560
561 assert!(total_blocks > 1);
563
564 assert_eq!(
565 String::from_utf8(received_payload).unwrap(),
566 String::from_utf8(expected_payload).unwrap()
567 );
568
569 let mut followup_req = create_get_request("test", u16::MAX, None);
571 let followup_block2 = BlockValue::new(0, false, 16).unwrap();
572 followup_req
573 .message
574 .add_option_as::<BlockValue>(CoapOption::Block2, followup_block2);
575 let followup_response = harness
576 .exchange_messages(&mut followup_req, move |received_request| {
577 let sent_response =
578 received_request.response.as_mut().unwrap();
579 sent_response.message.header.code =
580 MessageClass::Response(ResponseType::Content);
581 sent_response.message.payload = "small".as_bytes().to_vec();
582 InterceptPolicy::NotExpected
583 })
584 .unwrap();
585
586 assert_eq!(
587 String::from_utf8(followup_response.message.payload).unwrap(),
588 "small".to_owned()
589 );
590 }
591
592 #[test]
593 fn test_server_asserts_block1_encoding_required() {
594 let block = "0123456789\n";
595
596 let mut harness = TestServerHarness::new(32);
597
598 let full_payload = block.repeat(8).into_bytes();
599
600 let mut sent_request =
601 create_put_request("test", 1, &full_payload, None);
602 let received_response = harness
603 .exchange_messages_using_cache(&mut sent_request)
604 .unwrap();
605
606 assert_eq!(
607 received_response.message.header.code,
608 MessageClass::Response(ResponseType::RequestEntityTooLarge)
609 );
610 let received_block = received_response
611 .message
612 .get_first_option_as::<BlockValue>(CoapOption::Block1)
613 .expect("Must respond with Block1 option")
614 .expect("Must provide valid Block1 option");
615 assert!(received_block.more);
616 }
617
618 #[test]
619 fn test_cached_request_happy_path() {
620 let block = "0123456789\n";
621
622 let mut harness = TestServerHarness::new(32);
623
624 let sent_payload = block.repeat(8).into_bytes();
625 let expected_payload = sent_payload.clone();
626
627 let block_size = 16;
628
629 let chunks = sent_payload.chunks(block_size);
630 let total_chunks = chunks.len();
631
632 for (num, chunk) in chunks.enumerate() {
633 let has_more_chunks = num + 1 < total_chunks;
634
635 let block =
636 BlockValue::new(num, has_more_chunks, block_size).unwrap();
637 let mut sent_request =
638 create_put_request("test", 1, chunk, Some(block));
639
640 let received_response = if has_more_chunks {
641 let received_response = harness
642 .exchange_messages_using_cache(&mut sent_request)
643 .unwrap();
644 assert_eq!(
645 received_response.message.header.code,
646 MessageClass::Response(ResponseType::Continue)
647 );
648 received_response
649 } else {
650 let received_response = harness
651 .exchange_messages(&mut sent_request, |received_request| {
652 assert_eq!(
653 String::from_utf8(
654 received_request.message.payload.clone()
655 )
656 .unwrap(),
657 String::from_utf8(expected_payload.clone())
658 .unwrap()
659 );
660 let sent_response =
661 received_request.response.as_mut().unwrap();
662 sent_response.message.header.code =
663 MessageClass::Response(ResponseType::Changed);
664 InterceptPolicy::NotExpected
665 })
666 .unwrap();
667 assert_eq!(
668 received_response.message.header.code,
669 MessageClass::Response(ResponseType::Changed)
670 );
671 received_response
672 };
673
674 let received_block = received_response
675 .message
676 .get_first_option_as::<BlockValue>(CoapOption::Block1)
677 .unwrap()
678 .unwrap();
679
680 assert_eq!(received_block.size(), block_size);
682 }
683 }
684
685 struct TestServerHarness {
686 handler: BlockHandler<TestEndpoint>,
687 }
688
689 impl TestServerHarness {
690 pub fn new(max_message_size: usize) -> Self {
691 TestServerHarness {
692 handler: BlockHandler::new(BlockHandlerConfig {
693 max_total_message_size: max_message_size,
694 cache_expiry_duration: Duration::from_millis(
695 u32::MAX.into(),
696 ),
697 }),
698 }
699 }
700
701 pub fn exchange_messages_using_cache(
702 &mut self,
703 sent_request: &mut CoapRequest<TestEndpoint>,
704 ) -> Option<CoapResponse> {
705 self.exchange_messages_internal(sent_request, true, |_| {
706 InterceptPolicy::DoNotInvoke
707 })
708 }
709
710 pub fn exchange_messages<F>(
711 &mut self,
712 sent_request: &mut CoapRequest<TestEndpoint>,
713 response_generator: F,
714 ) -> Option<CoapResponse>
715 where
716 F: FnOnce(&mut CoapRequest<TestEndpoint>) -> InterceptPolicy,
717 {
718 self.exchange_messages_internal(
719 sent_request,
720 false,
721 response_generator,
722 )
723 }
724
725 fn exchange_messages_internal<F>(
726 &mut self,
727 sent_request: &mut CoapRequest<TestEndpoint>,
728 expect_intercept_request: bool,
729 response_generator: F,
730 ) -> Option<CoapResponse>
731 where
732 F: FnOnce(&mut CoapRequest<TestEndpoint>) -> InterceptPolicy,
733 {
734 assert_eq!(
735 self.handler.intercept_request(sent_request).unwrap(),
736 expect_intercept_request
737 );
738
739 let mut received_request = sent_request.clone();
740 match response_generator(&mut received_request) {
741 InterceptPolicy::DoNotInvoke => sent_request.response.clone(),
742 policy => {
743 assert_eq!(
744 self.handler
745 .intercept_response(&mut received_request)
746 .unwrap(),
747 match policy {
748 InterceptPolicy::Expected => true,
749 InterceptPolicy::NotExpected => false,
750 _ => panic!(),
751 }
752 );
753
754 received_request.response
755 }
756 }
757 }
758 }
759
760 #[derive(Debug, Copy, Clone)]
761 enum InterceptPolicy {
762 Expected,
763 NotExpected,
764 DoNotInvoke,
765 }
766
767 fn create_get_request(
768 path: &str,
769 mid: u16,
770 block2: Option<BlockValue>,
771 ) -> CoapRequest<TestEndpoint> {
772 create_request(RequestType::Get, path, mid, None, block2)
773 }
774
775 fn create_put_request(
776 path: &str,
777 mid: u16,
778 payload: &[u8],
779 block1: Option<BlockValue>,
780 ) -> CoapRequest<TestEndpoint> {
781 let mut request =
782 create_request(RequestType::Put, path, mid, block1, None);
783 request.message.payload.extend(payload);
784 request
785 }
786
787 fn create_request(
788 method: RequestType,
789 path: &str,
790 mid: u16,
791 block1: Option<BlockValue>,
792 block2: Option<BlockValue>,
793 ) -> CoapRequest<TestEndpoint> {
794 let mut packet = Packet::new();
795 packet.header.code = MessageClass::Request(method);
796
797 let uri_path: LinkedList<_> = path
798 .split('/')
799 .map(|x| OptionValueString(x.to_owned()))
800 .collect();
801 packet.set_options_as(CoapOption::UriPath, uri_path);
802
803 let options =
804 vec![(CoapOption::Block1, block1), (CoapOption::Block2, block2)];
805 for (key, value) in options {
806 if let Some(value) = value {
807 packet.add_option_as(key, value);
808 }
809 }
810
811 packet.header.message_id = mid;
812 packet.payload = Vec::new();
813 CoapRequest::<TestEndpoint>::from_packet(
814 packet,
815 TestEndpoint::TestClient,
816 )
817 }
818}