1use crate::protocol::validate;
2use crate::{HttpGatewayResponseBody, ResponseBodyStream};
3use bytes::Bytes;
4use candid::Principal;
5use futures::{stream, Stream, StreamExt, TryStreamExt};
6use http_body::Frame;
7use http_body_util::{BodyExt, Full};
8use ic_agent::{Agent, AgentError};
9use ic_http_certification::{HttpRequest, HttpResponse, StatusCode};
10use ic_response_verification::MAX_VERIFICATION_VERSION;
11use ic_utils::interfaces::http_request::HeaderField;
12use ic_utils::{
13 call::SyncCall,
14 interfaces::http_request::{
15 HttpRequestCanister, HttpRequestStreamingCallbackAny, HttpResponse as AgentResponse,
16 StreamingCallbackHttpResponse, StreamingStrategy, Token,
17 },
18};
19
20static MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 1000;
22
23static MAX_VERIFIED_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT: usize = 4;
25
26static STREAM_CALLBACK_BUFFER: usize = 2;
28
29pub type AgentResponseAny = AgentResponse<Token, HttpRequestStreamingCallbackAny>;
30
31pub async fn get_body_and_streaming_body(
32 agent: &Agent,
33 response: &AgentResponseAny,
34) -> Result<HttpGatewayResponseBody, AgentError> {
35 let Some(StreamingStrategy::Callback(callback_strategy)) = response.streaming_strategy.clone()
37 else {
38 return Ok(HttpGatewayResponseBody::Right(Full::from(
39 response.body.clone(),
40 )));
41 };
42
43 let (streamed_body, token) = create_stream(
44 agent.clone(),
45 callback_strategy.callback.clone(),
46 Some(callback_strategy.token),
47 )
48 .take(MAX_VERIFIED_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
49 .map(|x| async move { x })
50 .buffered(STREAM_CALLBACK_BUFFER)
51 .try_fold(
52 (vec![], None::<Token>),
53 |mut accum, (mut body, token)| async move {
54 accum.0.append(&mut body);
55 accum.1 = token;
56
57 Ok(accum)
58 },
59 )
60 .await?;
61
62 let streamed_body = [response.body.clone(), streamed_body].concat();
63
64 if token.is_some() {
68 let body_stream = create_body_stream(
69 agent.clone(),
70 callback_strategy.callback,
71 token,
72 streamed_body,
73 );
74
75 return Ok(HttpGatewayResponseBody::Left(body_stream));
76 };
77
78 Ok(HttpGatewayResponseBody::Right(Full::from(streamed_body)))
82}
83
84fn create_body_stream(
85 agent: Agent,
86 callback: HttpRequestStreamingCallbackAny,
87 token: Option<Token>,
88 initial_body: Vec<u8>,
89) -> ResponseBodyStream {
90 let chunks_stream = create_stream(agent, callback, token)
91 .map(|chunk| chunk.map(|(body, _)| Frame::data(Bytes::from(body))));
92
93 let body_stream = stream::once(async move { Ok(Frame::data(Bytes::from(initial_body))) })
94 .chain(chunks_stream)
95 .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
96 .map(|x| async move { x })
97 .buffered(STREAM_CALLBACK_BUFFER);
98
99 ResponseBodyStream::new(Box::pin(body_stream))
100}
101
102fn create_stream(
103 agent: Agent,
104 callback: HttpRequestStreamingCallbackAny,
105 token: Option<Token>,
106) -> impl Stream<Item = Result<(Vec<u8>, Option<Token>), AgentError>> {
107 futures::stream::try_unfold(
108 (agent, callback, token),
109 |(agent, callback, token)| async move {
110 let Some(token) = token else {
111 return Ok(None);
112 };
113
114 let canister = HttpRequestCanister::create(&agent, callback.0.principal);
115 match canister
116 .http_request_stream_callback(&callback.0.method, token)
117 .call()
118 .await
119 {
120 Ok((StreamingCallbackHttpResponse { body, token },)) => {
121 Ok(Some(((body, token.clone()), (agent, callback, token))))
122 }
123 Err(e) => Err(e),
124 }
125 },
126 )
127}
128
129#[derive(Clone, Debug)]
130struct StreamState<'a> {
131 pub http_request: HttpRequest<'a>,
132 pub canister_id: Principal,
133 pub total_length: usize,
134 pub fetched_length: usize,
135 pub skip_verification: bool,
136}
137
138pub async fn get_206_stream_response_body_and_total_length(
139 agent: &Agent,
140 http_request: HttpRequest<'static>,
141 canister_id: Principal,
142 response_headers: &Vec<HeaderField<'static>>,
143 response_206_body: HttpGatewayResponseBody,
144 skip_verification: bool,
145) -> Result<(HttpGatewayResponseBody, usize), AgentError> {
146 let HttpGatewayResponseBody::Right(body) = response_206_body else {
147 return Err(AgentError::InvalidHttpResponse(
148 "Expected full 206 response".to_string(),
149 ));
150 };
151 let streamed_body = body
153 .collect()
154 .await
155 .expect("missing streamed chunk body")
156 .to_bytes()
157 .to_vec();
158 let stream_state = get_initial_stream_state(
159 http_request,
160 canister_id,
161 response_headers,
162 skip_verification,
163 )?;
164 let content_length = stream_state.total_length;
165
166 let body_stream = create_206_body_stream(agent.clone(), stream_state, streamed_body);
167 Ok((HttpGatewayResponseBody::Left(body_stream), content_length))
168}
169
170#[derive(Debug)]
171struct ContentRangeValues {
172 pub range_begin: usize,
173 pub range_end: usize,
174 pub total_length: usize,
175}
176
177fn parse_content_range_header_str(
178 content_range_str: &str,
179) -> Result<ContentRangeValues, AgentError> {
180 let str_value = content_range_str.trim();
182 if !str_value.starts_with("bytes ") {
183 return Err(AgentError::InvalidHttpResponse(format!(
184 "Invalid Content-Range header '{}'",
185 content_range_str
186 )));
187 }
188 let str_value = str_value.trim_start_matches("bytes ");
189
190 let str_value_parts = str_value.split('-').collect::<Vec<_>>();
191 if str_value_parts.len() != 2 {
192 return Err(AgentError::InvalidHttpResponse(format!(
193 "Invalid bytes spec in Content-Range header '{}'",
194 content_range_str
195 )));
196 }
197 let range_begin = str_value_parts[0].parse::<usize>().map_err(|e| {
198 AgentError::InvalidHttpResponse(format!(
199 "Invalid range_begin in '{}': {}",
200 content_range_str, e
201 ))
202 })?;
203
204 let other_value_parts = str_value_parts[1].split('/').collect::<Vec<_>>();
205 if other_value_parts.len() != 2 {
206 return Err(AgentError::InvalidHttpResponse(format!(
207 "Invalid bytes spec in Content-Range header '{}'",
208 content_range_str
209 )));
210 }
211 let range_end = other_value_parts[0].parse::<usize>().map_err(|e| {
212 AgentError::InvalidHttpResponse(format!(
213 "Invalid range_end in '{}': {}",
214 content_range_str, e
215 ))
216 })?;
217 let total_length = other_value_parts[1].parse::<usize>().map_err(|e| {
218 AgentError::InvalidHttpResponse(format!(
219 "Invalid total_length in '{}': {}",
220 content_range_str, e
221 ))
222 })?;
223
224 let rv = ContentRangeValues {
225 range_begin,
226 range_end,
227 total_length,
228 };
229 if rv.range_begin > rv.range_end
230 || rv.range_begin >= rv.total_length
231 || rv.range_end >= rv.total_length
232 {
233 Err(AgentError::InvalidHttpResponse(format!(
234 "inconsistent Content-Range header {}: {:?}",
235 content_range_str, rv
236 )))
237 } else {
238 Ok(rv)
239 }
240}
241
242fn get_content_range_header_str(
243 response_headers: &Vec<HeaderField<'static>>,
244) -> Result<String, AgentError> {
245 for HeaderField(name, value) in response_headers {
246 if name.eq_ignore_ascii_case(http::header::CONTENT_RANGE.as_ref()) {
247 return Ok(value.to_string());
248 }
249 }
250 Err(AgentError::InvalidHttpResponse(
251 "missing Content-Range header in 206 response".to_string(),
252 ))
253}
254
255fn get_content_range_values(
256 response_headers: &Vec<HeaderField<'static>>,
257 fetched_length: usize,
258) -> Result<ContentRangeValues, AgentError> {
259 let str_value = get_content_range_header_str(response_headers)?;
260 let range_values = parse_content_range_header_str(&str_value)?;
261
262 if range_values.range_begin > fetched_length {
263 return Err(AgentError::InvalidHttpResponse(format!(
264 "chunk out-of-order: range_begin={} is larger than expected begin={} ",
265 range_values.range_begin, fetched_length
266 )));
267 }
268 if range_values.range_end < fetched_length {
269 return Err(AgentError::InvalidHttpResponse(format!(
270 "chunk out-of-order: range_end={} is smaller than length fetched so far={} ",
271 range_values.range_begin, fetched_length
272 )));
273 }
274 Ok(range_values)
275}
276
277fn get_initial_stream_state<'a>(
278 http_request: HttpRequest<'a>,
279 canister_id: Principal,
280 response_headers: &Vec<HeaderField<'static>>,
281 skip_verification: bool,
282) -> Result<StreamState<'a>, AgentError> {
283 let range_values = get_content_range_values(response_headers, 0)?;
284
285 Ok(StreamState {
286 http_request,
287 canister_id,
288 total_length: range_values.total_length,
289 fetched_length: range_values
290 .range_end
291 .saturating_sub(range_values.range_begin)
292 + 1,
293 skip_verification,
294 })
295}
296
297fn create_206_body_stream(
298 agent: Agent,
299 stream_state: StreamState<'static>,
300 initial_body: Vec<u8>,
301) -> ResponseBodyStream {
302 let chunks_stream = create_206_stream(agent, Some(stream_state))
303 .map(|chunk| chunk.map(|(body, _)| Frame::data(Bytes::from(body))));
304
305 let body_stream = stream::once(async move { Ok(Frame::data(Bytes::from(initial_body))) })
306 .chain(chunks_stream)
307 .take(MAX_HTTP_REQUEST_STREAM_CALLBACK_CALL_COUNT)
308 .map(|x| async move { x })
309 .buffered(STREAM_CALLBACK_BUFFER);
310
311 ResponseBodyStream::new(Box::pin(body_stream))
312}
313
314fn create_206_stream(
315 agent: Agent,
316 maybe_stream_state: Option<StreamState>,
317) -> impl Stream<Item = Result<(Vec<u8>, Option<StreamState>), AgentError>> {
318 futures::stream::try_unfold(
319 (agent, maybe_stream_state),
320 |(agent, maybe_stream_state)| async move {
321 let Some(stream_state) = maybe_stream_state else {
322 return Ok(None);
323 };
324 let canister = HttpRequestCanister::create(&agent, stream_state.canister_id);
325 let next_chunk_begin = stream_state.fetched_length;
326
327 let range_header = ("Range".to_string(), format!("bytes={}-", next_chunk_begin));
328 let mut updated_headers = stream_state.http_request.headers().to_vec();
329 updated_headers.push(range_header.clone());
330 let headers = updated_headers
331 .iter()
332 .map(|(name, value)| HeaderField(name.into(), value.into()))
333 .collect::<Vec<HeaderField>>()
334 .into_iter();
335 let query_result = canister
336 .http_request(
337 &stream_state.http_request.method(),
338 &stream_state.http_request.url(),
339 headers,
340 &stream_state.http_request.body(),
341 Some(&u16::from(MAX_VERIFICATION_VERSION)),
342 )
343 .call()
344 .await;
345 let agent_response = match query_result {
346 Ok((response,)) => response,
347 Err(e) => return Err(e),
348 };
349 let range_values =
350 get_content_range_values(&agent_response.headers, stream_state.fetched_length)?;
351 let new_bytes_begin = stream_state
352 .fetched_length
353 .saturating_sub(range_values.range_begin);
354 let chunk_length = range_values
355 .range_end
356 .saturating_sub(stream_state.fetched_length)
357 + 1;
358 let current_fetched_length = stream_state.fetched_length + chunk_length;
359 if agent_response.streaming_strategy.is_some() {
361 return Err(AgentError::InvalidHttpResponse(
362 "unexpected StreamingStrategy".to_string(),
363 ));
364 }
365
366 let Ok(status_code) = StatusCode::from_u16(agent_response.status_code) else {
367 return Err(AgentError::InvalidHttpResponse(format!(
368 "Invalid canister response status code: {}",
369 agent_response.status_code
370 )));
371 };
372 let response = HttpResponse::builder()
373 .with_status_code(status_code)
374 .with_headers(
375 agent_response
376 .headers
377 .iter()
378 .map(|HeaderField(k, v)| (k.to_string(), v.to_string()))
379 .collect(),
380 )
381 .with_body(agent_response.body.clone())
382 .build();
383 let mut http_request = stream_state.http_request.clone();
384 http_request.headers_mut().push(range_header);
385 let validation_result = validate(
386 &agent,
387 &stream_state.canister_id,
388 http_request,
389 response,
390 stream_state.skip_verification,
391 );
392
393 if let Err(e) = validation_result {
394 return Err(AgentError::InvalidHttpResponse(format!(
395 "CertificateVerificationFailed for a chunk starting at {}, error: {}",
396 stream_state.fetched_length, e
397 )));
398 }
399 let maybe_new_state = if current_fetched_length < stream_state.total_length {
400 Some(StreamState {
401 fetched_length: current_fetched_length,
402 ..stream_state
403 })
404 } else {
405 None
406 };
407 Ok(Some((
408 (
409 agent_response.body[new_bytes_begin..].to_vec(),
410 maybe_new_state.clone(),
411 ),
412 (agent, maybe_new_state),
413 )))
414 },
415 )
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use assert_matches::assert_matches;
422 use std::borrow::Cow;
423
424 #[test]
425 fn should_parse_content_range_header_str() {
426 let header_values = [
427 ContentRangeValues {
428 range_begin: 0,
429 range_end: 0,
430 total_length: 1,
431 },
432 ContentRangeValues {
433 range_begin: 100,
434 range_end: 2000,
435 total_length: 3000,
436 },
437 ContentRangeValues {
438 range_begin: 10_000,
439 range_end: 300_000,
440 total_length: 500_000,
441 },
442 ];
443 for v in header_values {
444 let input = format!("bytes {}-{}/{}", v.range_begin, v.range_end, v.total_length);
445 let result = parse_content_range_header_str(&input);
446 let output = result.unwrap_or_else(|_| panic!("failed parsing '{}'", input));
447 assert_eq!(v.range_begin, output.range_begin);
448 assert_eq!(v.range_end, output.range_end);
449 assert_eq!(v.total_length, output.total_length);
450 }
451 }
452
453 #[test]
454 fn should_fail_parse_content_range_header_str_on_malformed_input() {
455 let malformed_inputs = [
456 "byte 1-2/3",
457 "bites 2-4/8",
458 "bytes 100-200/asdf",
459 "bytes 12345",
460 "something else",
461 "bytes dead-beef/123456",
462 ];
463 for input in malformed_inputs {
464 let result = parse_content_range_header_str(input);
465 assert_matches!(result, Err(e) if format!("{}", e).contains("Invalid "));
466 }
467 }
468
469 #[test]
470 fn should_fail_parse_content_range_header_str_on_inconsistent_input() {
471 let inconsistent_inputs = ["bytes 100-200/190", "bytes 200-150/400", "bytes 100-110/40"];
472 for input in inconsistent_inputs {
473 let result = parse_content_range_header_str(input);
474 assert_matches!(result, Err(e) if format!("{}", e).contains("inconsistent Content-Range header"));
475 }
476 }
477
478 #[test]
479 fn should_get_initial_stream_state() {
480 let http_request = HttpRequest::get("http://example.com/some_file")
481 .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
482 .with_body(vec![42])
483 .build();
484 let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
485 let response_headers = vec![HeaderField(
486 Cow::from("Content-Range"),
487 Cow::from("bytes 0-2/10"), )];
489 let skip_verification = false;
490 let state = get_initial_stream_state(
491 http_request.clone(),
492 canister_id,
493 &response_headers,
494 skip_verification,
495 )
496 .expect("failed constructing StreamState");
497 assert_eq!(state.http_request, http_request);
498 assert_eq!(state.canister_id, canister_id);
499 assert_eq!(state.fetched_length, 3);
500 assert_eq!(state.total_length, 10);
501 assert_eq!(state.skip_verification, skip_verification);
502 }
503
504 #[test]
505 fn should_fail_get_initial_stream_state_without_content_range_header() {
506 let http_request = HttpRequest::get("http://example.com/some_file")
507 .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
508 .with_body(vec![42])
509 .build();
510 let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
511 let response_headers = vec![HeaderField(
512 Cow::from("other header"),
513 Cow::from("other value"),
514 )];
515 let result = get_initial_stream_state(http_request, canister_id, &response_headers, false);
516 assert_matches!(result, Err(e) if format!("{}", e).contains("missing Content-Range header"));
517 }
518
519 #[test]
520 fn should_fail_get_initial_stream_state_with_malformed_content_range_header() {
521 let http_request = HttpRequest::get("http://example.com/some_file")
522 .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
523 .with_body(vec![42])
524 .build();
525 let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
526 let response_headers = vec![HeaderField(
527 Cow::from("Content-Range"),
528 Cow::from("bytes 42/10"),
529 )];
530 let result = get_initial_stream_state(http_request, canister_id, &response_headers, false);
531 assert_matches!(result, Err(e) if format!("{}", e).contains("Invalid bytes spec in Content-Range header"));
532 }
533
534 #[test]
535 fn should_fail_get_initial_stream_state_with_inconsistent_content_range_header() {
536 let http_request = HttpRequest::get("http://example.com/some_file")
537 .with_headers(vec![("Xyz".to_string(), "some value".to_string())])
538 .with_body(vec![42])
539 .build();
540 let canister_id = Principal::from_slice(&[1, 2, 3, 4]);
541 let response_headers = vec![HeaderField(
542 Cow::from("Content-Range"),
543 Cow::from("bytes 40-100/90"),
544 )];
545 let result = get_initial_stream_state(http_request, canister_id, &response_headers, false);
546 assert_matches!(result, Err(e) if format!("{}", e).contains("inconsistent Content-Range header"));
547 }
548}